Commit 899c0c47 authored by Lukas Burgey's avatar Lukas Burgey
Browse files

Add the backchannel to the backend

parent 89b4a300
......@@ -31,7 +31,7 @@ type (
FetchIntervalString string `json:"fetch_interval"` // string parsed by time.ParseDuration
ReconnectTimeoutString string `json:"reconnect_timeout"` // string parsed by time.ParseDuration
NewTasks chan task
DoneTasks chan task
DoneTasks chan taskExecution
FetchInterval time.Duration
ReconnectTimeout time.Duration
......@@ -170,7 +170,7 @@ func getConfig(configFile string) (c config, err error) {
// initialize the task queues
c.NewTasks = make(chan task)
c.DoneTasks = make(chan task)
c.DoneTasks = make(chan taskExecution)
return
}
......@@ -199,14 +199,12 @@ func main() {
return
}
// start task handler and acker
// start task handler and responder
go c.taskHandler()
go c.taskAcker()
go c.taskResponder()
consumer := c.consumer()
defer func() {
consumer.close()
}()
defer consumer.close()
consumer.startConsuming()
......
package main
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
......@@ -26,80 +27,33 @@ type (
Key scripts.SSHKey `json:"key"`
}
response struct {
OK bool `json:"ok"`
taskExecution struct {
// ID the according task.ID
ID int `json:"id"`
Output scripts.Output `json:"output"`
}
)
func (c *config) handleTask(ti task) (err error) {
var output scripts.Output
// encode input as json
input := scripts.Input{
Action: ti.Action,
User: ti.User,
Key: ti.Key,
}
iBytes, err := input.Marshal()
if err != nil {
return
}
commandName := c.Services[ti.Service.Name].Command
log.Println("[Task] Executing: ", commandName)
log.Printf("[Task] Input: %s", input)
cmd := exec.Command(commandName)
stdin, err := cmd.StdinPipe()
if err != nil {
return
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return
}
stderr, err := cmd.StderrPipe()
if err != nil {
return
}
cmd.Start()
if err != nil {
return
}
stdin.Write(iBytes)
stdin.Close()
// decode json output
outputBytes, err := ioutil.ReadAll(stdout)
if err != nil {
return
}
logOutputBytes, err := ioutil.ReadAll(stderr)
if err != nil {
return
}
log.Printf("[Task] Logs:\n%s", logOutputBytes)
err = cmd.Wait()
if err != nil {
return
}
// fetches tasks (which rabbitmq missed) manually
func (c *config) taskFetcher() {
// start ticker
ticker := time.NewTicker(c.FetchInterval)
err = json.Unmarshal(outputBytes, &output)
if err != nil {
return
}
log.Printf("[Task] Output: %s", output)
go func() {
for {
go func() {
if err := c.fetchTasks(); err != nil {
log.Printf("[Fetch] error fetching: %s", err)
}
}()
if output.Status {
c.DoneTasks <- ti
}
return
// wait for a tick
<-ticker.C
}
}()
}
// handles either the tasks from rest and from rabbitmq
// handles tasks in c.NewTasks
func (c *config) taskHandler() {
for newTask := range c.NewTasks {
// handle tasks asynchronously
......@@ -111,60 +65,28 @@ func (c *config) taskHandler() {
}
}
func (c *config) ackTask(ti task) (ok bool, err error) {
url := fmt.Sprintf(
"https://%s/backend/clientapi/ack/%d/",
c.Host,
ti.ID)
req, err := http.NewRequest("DELETE", url, nil)
if err != nil {
// TODO retransmit ACK
return
}
req.SetBasicAuth(c.Username, c.Password)
resp, err := client.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
respBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
var ackResponse response
err = json.Unmarshal(respBytes, &ackResponse)
if err != nil {
log.Printf("[Task] Invalid ACK response: %s", respBytes)
}
ok = ackResponse.OK
return
}
// acks all done tasks
func (c *config) taskAcker() {
// acks tasks in c.DoneTasks
func (c *config) taskResponder() {
for doneTask := range c.DoneTasks {
// ack tasks asynchronously
go func(t task) {
ok, err := c.ackTask(t)
go func(te taskExecution) {
err := c.respondToTask(te)
if err != nil {
log.Printf("[Task] %v: ACK-ERROR: %s", t, err)
log.Printf("[Task] %v: ACK-ERROR: %s", te.ID, err)
// reschedule failed responses
// TODO does this work?
go func(te taskExecution) {
time.Sleep(time.Minute)
c.DoneTasks <- te
}(te)
return
}
if !ok {
log.Printf("[Task] %v: ACK-FAILED", t)
return
}
log.Printf("[Task] %v: ACK", t)
}(doneTask)
}
}
// responds to tasks (either ack/"reject"/fail)
func (c *config) taskResponder() {
}
// IMPLEMENTATIONS
func (c *config) fetchTasks() (err error) {
if len(c.Services) == 0 {
......@@ -210,22 +132,103 @@ func (c *config) fetchTasks() (err error) {
}
return
}
func (c *config) respondToTask(te taskExecution) (err error) {
responseBytes, err := json.Marshal(te)
if err != nil {
return
}
log.Printf("[Task] Sending ACK %v:\n%s", te.ID, responseBytes)
// fetches tasks (which rabbitmq missed) manually
func (c *config) taskFetcher() {
// start ticker
ticker := time.NewTicker(c.FetchInterval)
url := fmt.Sprintf("https://%s/backend/clientapi/response", c.Host)
req, err := http.NewRequest("POST", url, bytes.NewReader(responseBytes))
if err != nil {
// TODO retransmit ACK
return
}
req.SetBasicAuth(c.Username, c.Password)
req.Header.Set("Content-Type", "application/json")
go func() {
for {
go func() {
if err := c.fetchTasks(); err != nil {
log.Printf("[Fetch] error fetching: %s", err)
}
}()
// execute request
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
// wait for a tick
<-ticker.C
}
}()
if resp.StatusCode == 200 {
log.Printf("[Task] Backend received ACK %v", te.ID)
}
respBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
log.Printf("[Task] Backend response to ACK %v:\n%s", te.ID, respBytes)
return
}
func (c *config) handleTask(ti task) (err error) {
var output scripts.Output
// encode input as json
input := scripts.Input{
Action: ti.Action,
User: ti.User,
Key: ti.Key,
}
iBytes, err := input.Marshal()
if err != nil {
return
}
// execute the script
commandName := c.Services[ti.Service.Name].Command
log.Println("[Task] Executing: ", commandName)
log.Printf("[Task] Input: %s", input)
cmd := exec.Command(commandName)
stdin, err := cmd.StdinPipe()
if err != nil {
return
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return
}
stderr, err := cmd.StderrPipe()
if err != nil {
return
}
cmd.Start()
if err != nil {
return
}
stdin.Write(iBytes)
stdin.Close()
// decode json output
outputBytes, err := ioutil.ReadAll(stdout)
if err != nil {
return
}
logOutputBytes, err := ioutil.ReadAll(stderr)
if err != nil {
return
}
log.Printf("[Task] Logs:\n%s", logOutputBytes)
err = cmd.Wait()
if err != nil {
return
}
err = json.Unmarshal(outputBytes, &output)
if err != nil {
return
}
log.Printf("[Task] Output: %s", output)
c.DoneTasks <- taskExecution{ti.ID, output}
return
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment