From 323947e432394aeb2441a07b95e6386878feb7d5 Mon Sep 17 00:00:00 2001 From: zcubbs Date: Sun, 12 May 2024 20:55:03 +0200 Subject: [PATCH] :bug: fix semaphore concurrent jobs negative wait group bug --- cmd/server/main.go | 2 +- cmd/server/web/handler.go | 32 +++++++++---- cmd/server/web/templates/tasks.html | 8 ++-- pkg/k8s/jobs/job.go | 60 ++++++++++++------------- pkg/k8s/jobs/queries.go | 20 +++++++++ pkg/k8s/jobs/task.go | 70 ++++++++++++----------------- 6 files changed, 106 insertions(+), 86 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index fd7dea6..2dc75f6 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -55,7 +55,7 @@ func init() { } utils.CheckTimeZone() - log.Info("loaded configuration") + log.Info("loaded configuration", "path", *configPath) } func main() { diff --git a/cmd/server/web/handler.go b/cmd/server/web/handler.go index f27d826..c87ec9c 100644 --- a/cmd/server/web/handler.go +++ b/cmd/server/web/handler.go @@ -31,13 +31,14 @@ func NewHandler(k8sRunner *k8sJobs.Runner) (*Handler, error) { } func (h *Handler) RegisterRoutes(mux *http.ServeMux) { - mux.HandleFunc("/static/", h.HandleGetStaticFiles) - mux.HandleFunc("/logs/", h.HandleGetLogs) - mux.HandleFunc("/tasks", h.HandleGetTasks) - mux.HandleFunc("/", h.HandleIndex) + mux.HandleFunc("/static/", h.handleGetStaticFiles) + mux.HandleFunc("/command/", h.handleGetCommand) + mux.HandleFunc("/logs/", h.handleGetLogs) + mux.HandleFunc("/tasks", h.handleGetTasks) + mux.HandleFunc("/", h.handleIndex) } -func (h *Handler) HandleIndex(w http.ResponseWriter, _ *http.Request) { +func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request) { err := h.templates.ExecuteTemplate(w, "index.html", nil) if err != nil { log.Error("failed to execute template", "error", err) @@ -45,7 +46,7 @@ func (h *Handler) HandleIndex(w http.ResponseWriter, _ *http.Request) { } } -func (h *Handler) HandleGetTasks(w http.ResponseWriter, _ *http.Request) { +func (h *Handler) handleGetTasks(w http.ResponseWriter, _ *http.Request) { // Fetch tasks directly as structs from the DB tasks, err := h.k8sRunner.GetAllTasksFromDB() if err != nil { @@ -62,13 +63,13 @@ func (h *Handler) HandleGetTasks(w http.ResponseWriter, _ *http.Request) { } } -func (h *Handler) HandleGetStaticFiles(w http.ResponseWriter, r *http.Request) { +func (h *Handler) handleGetStaticFiles(w http.ResponseWriter, r *http.Request) { log.Debug("serving static file", "path", r.URL.Path) staticFileHandler := http.FileServer(http.FS(FsStaticFiles)) staticFileHandler.ServeHTTP(w, r) } -func (h *Handler) HandleGetLogs(w http.ResponseWriter, r *http.Request) { +func (h *Handler) handleGetLogs(w http.ResponseWriter, r *http.Request) { taskID := strings.TrimPrefix(r.URL.Path, "/logs/") logs, err := h.k8sRunner.GetLogsForTaskFromDB(taskID) if err != nil { @@ -83,6 +84,21 @@ func (h *Handler) HandleGetLogs(w http.ResponseWriter, r *http.Request) { } } +func (h *Handler) handleGetCommand(w http.ResponseWriter, r *http.Request) { + taskID := strings.TrimPrefix(r.URL.Path, "/command/") + command, err := h.k8sRunner.GetCommandForTaskFromDB(taskID) + if err != nil { + http.Error(w, "Failed to fetch command", http.StatusInternalServerError) + return + } + _, err = w.Write([]byte("
" + html.EscapeString(join(command, " ")) + "
")) + if err != nil { + log.Error("failed to write command", "error", err) + http.Error(w, "Internal Error 500", http.StatusInternalServerError) + return + } +} + func join(strs []string, sep string) string { return strings.Join(strs, sep) } diff --git a/cmd/server/web/templates/tasks.html b/cmd/server/web/templates/tasks.html index 92cecbc..cfc48fd 100644 --- a/cmd/server/web/templates/tasks.html +++ b/cmd/server/web/templates/tasks.html @@ -1,10 +1,10 @@ + - @@ -14,10 +14,12 @@ {{ range . }} + - - + diff --git a/pkg/k8s/jobs/job.go b/pkg/k8s/jobs/job.go index d96944e..ac50198 100644 --- a/pkg/k8s/jobs/job.go +++ b/pkg/k8s/jobs/job.go @@ -32,22 +32,19 @@ func (r *Runner) createAndMonitorJob(ctx context.Context, namespace string, task RestartPolicy: corev1.RestartPolicyOnFailure, }, }, - ActiveDeadlineSeconds: ptr.To(int64(task.Timeout)), // Set the job timeout (in seconds), + ActiveDeadlineSeconds: ptr.To(int64(task.Timeout)), // Ensure jobs respect the task timeout }, } log.Debug("Creating job", "jobId", task.ID, "image", task.Image, "command", task.Command) - // Create the Kubernetes job job, err := r.cs.BatchV1().Jobs(namespace).Create(ctx, job, metav1.CreateOptions{}) if err != nil { log.Error("Failed to create job", "error", err) return nil, err } - // Monitor the job status until completion or failure log.Debug("Waiting for job to complete...", "jobId", task.ID, "image", task.Image, "command", task.Command) - err = r.waitForJobCompletion(ctx, job) - if err != nil { + if err = r.waitForJobCompletion(ctx, job); err != nil { log.Error("Failed to monitor job", "error", err) return nil, fmt.Errorf("job monitoring failed: %v", err) } @@ -64,47 +61,46 @@ func (r *Runner) waitForJobCompletion(ctx context.Context, job *batchv1.Job) err } defer watcher.Stop() - return r.processJobEvents(watcher) -} - -func (r *Runner) processJobEvents(watcher watch.Interface) error { - for event := range watcher.ResultChan() { - if err := r.handleJobEvent(event); err != nil { - return err + for { + select { + case event, ok := <-watcher.ResultChan(): + if !ok { + return fmt.Errorf("job watch channel closed") + } + switch event.Type { + case watch.Error: + return fmt.Errorf("error watching job: %v", event.Object) + case watch.Deleted: + return fmt.Errorf("job deleted unexpectedly") + case watch.Added, watch.Modified: + job, ok := event.Object.(*batchv1.Job) + if !ok { + return fmt.Errorf("unexpected object type") + } + if done, err := r.evaluateJobStatus(job); done { + return err + } + } + case <-ctx.Done(): + return ctx.Err() } } - return nil } -func (r *Runner) handleJobEvent(event watch.Event) error { - switch event.Type { - case watch.Added, watch.Modified: - return r.evaluateJobStatus(event) - case watch.Deleted: - return fmt.Errorf("job deleted") - } - return nil -} - -func (r *Runner) evaluateJobStatus(event watch.Event) error { - job, ok := event.Object.(*batchv1.Job) - if !ok { - return fmt.Errorf("unexpected type") - } - +func (r *Runner) evaluateJobStatus(job *batchv1.Job) (bool, error) { for _, condition := range job.Status.Conditions { switch condition.Type { case batchv1.JobComplete: if condition.Status == corev1.ConditionTrue { - return nil + return true, nil } case batchv1.JobFailed: if condition.Status == corev1.ConditionTrue { - return fmt.Errorf("job failed") + return true, fmt.Errorf("job failed: %s", condition.Message) } } } - return nil + return false, nil // Job has not yet reached a definitive state } // Delete deletes a pod in a Kubernetes cluster. diff --git a/pkg/k8s/jobs/queries.go b/pkg/k8s/jobs/queries.go index 408610b..a604f68 100644 --- a/pkg/k8s/jobs/queries.go +++ b/pkg/k8s/jobs/queries.go @@ -46,6 +46,26 @@ func (r *Runner) GetStatusForTaskFromDB(taskId string) (string, error) { return status, nil } +func (r *Runner) GetCommandForTaskFromDB(taskId string) ([]string, error) { + var command []string + err := r.db.View(func(tx *buntdb.Tx) error { + val, err := tx.Get(taskId) + if err != nil { + return err + } + var task Task + if err := json.Unmarshal([]byte(val), &task); err != nil { + return err + } + command = task.Command + return nil + }) + if err != nil { + return nil, err + } + return command, nil +} + func (r *Runner) GetTaskFromDB(taskId string) (*Task, error) { var task Task err := r.db.View(func(tx *buntdb.Tx) error { diff --git a/pkg/k8s/jobs/task.go b/pkg/k8s/jobs/task.go index 1692d8b..f5e5590 100644 --- a/pkg/k8s/jobs/task.go +++ b/pkg/k8s/jobs/task.go @@ -3,7 +3,6 @@ package k8sJobs import ( "context" "encoding/json" - "errors" "fmt" "github.com/charmbracelet/log" "github.com/tidwall/buntdb" @@ -54,11 +53,8 @@ func (r *Runner) processTasks(ctx context.Context) { select { case task := <-r.taskChan: r.wg.Add(1) - // Acquire a semaphore slot before processing a task - r.currentJobs <- struct{}{} // This will block if the limit is reached go func(t Task) { defer r.wg.Done() - defer func() { <-r.currentJobs }() // Release the semaphore slot after the task is processed r.handleTask(ctx, t) }(task) case <-r.quit: @@ -67,55 +63,43 @@ func (r *Runner) processTasks(ctx context.Context) { } } -func (r *Runner) handleTask(ctx context.Context, t Task) { - defer r.wg.Done() +// handleTask handles a task by creating a job, monitoring it, and updating the task status. +func (r *Runner) handleTask(parentCtx context.Context, t Task) { + // Acquire a semaphore slot before processing a task + r.currentJobs <- struct{}{} // This will block if the limit is reached + defer func() { <-r.currentJobs }() // Release the semaphore slot after the task is processed - // set default timeout if t.Timeout == 0 { t.Timeout = r.defaultJobTimeout } - t.StartedAt = time.Now() // Record when the task processing starts - log.Debug("Processing task", "jobId", t.ID, "image", t.Image, "command", t.Command) - - // Update task status to RUNNING - if !r.updateTaskStatus(t, "RUNNING", "") { - return - } + // Create a new context with a timeout for the task + taskCtx, cancel := context.WithTimeout(parentCtx, time.Duration(t.Timeout)*time.Second+5*time.Second) + defer cancel() - // Attempt to create and monitor the Kubernetes job - _, err := r.createAndMonitorJob(ctx, r.namespace, t) - if err != nil { - t.EndedAt = time.Now() // Set end time when task finishes or fails - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - r.handleTimeout(t) - } else { - log.Error("Failed to create and monitor job", "error", err, "jobId", t.ID) - r.updateTaskStatus(t, "FAILED", fmt.Sprintf("Job failed: %v", err)) + t.StartedAt = time.Now() + log.Debug("Processing task", "jobId", t.ID, "image", t.Image, "command", t.Command) + t.Status = "RUNNING" + r.updateTaskStatus(t) + + if _, err := r.createAndMonitorJob(taskCtx, r.namespace, t); err != nil { + log.Error("Failed to create or monitor job", "error", err, "jobId", t.ID) + t.Status = "FAILED" + t.Logs = fmt.Sprintf("Job monitoring failed: %v", err) + } else { + t.Status = "SUCCEEDED" + t.Logs, err = r.getLogs(taskCtx, t.ID) // Retrieve logs + if err != nil { + log.Error("Failed to get logs", "error", err, "jobId", t.ID) } - r.delete(ctx, t.ID) // Cleanup after failure or timeout - return } - // Successfully completed the task t.EndedAt = time.Now() - log.Debug("Retrieving logs", "jobId", t.ID) - logs, err := r.getLogs(ctx, t.ID) - finalStatus := "SUCCEEDED" - if err != nil { - finalStatus = "FAILED" - logs = fmt.Sprintf("Failed to get logs: %v", err) - log.Error("Failed to get logs", "error", err, "jobId", t.ID) - } - - r.updateTaskStatus(t, finalStatus, logs) - r.delete(ctx, t.ID) // Cleanup after successful completion + r.updateTaskStatus(t) + r.delete(taskCtx, t.ID) // Use task-specific context for deletion } -func (r *Runner) updateTaskStatus(t Task, status, logs string) bool { - t.Status = status - t.Logs = logs - +func (r *Runner) updateTaskStatus(t Task) bool { taskData, err := json.Marshal(t) if err != nil { log.Error("Failed to serialize task", "error", err, "jobId", t.ID) @@ -136,6 +120,8 @@ func (r *Runner) updateTaskStatus(t Task, status, logs string) bool { func (r *Runner) handleTimeout(t Task) { t.EndedAt = time.Now() // Set the timeout end time - r.updateTaskStatus(t, "TIMED OUT", "Task timed out") + t.Status = "TIMEOUT" + t.Logs = "Task timed out" + r.updateTaskStatus(t) r.delete(context.Background(), t.ID) // Ensure context.Background() to avoid passing a canceled context }
Status Job ID Image CommandStatus Logs Created At Started At
{{ .Status }} {{ .ID }} {{ .Image }}{{ join .Command " " }}{{ .Status }} + +