Skip to content

Commit

Permalink
🐛 fix semaphore concurrent jobs negative wait group bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zcubbs committed May 12, 2024
1 parent 053166c commit 323947e
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 86 deletions.
2 changes: 1 addition & 1 deletion cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func init() {
}
utils.CheckTimeZone()

log.Info("loaded configuration")
log.Info("loaded configuration", "path", *configPath)
}

func main() {
Expand Down
32 changes: 24 additions & 8 deletions cmd/server/web/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,22 @@ func NewHandler(k8sRunner *k8sJobs.Runner) (*Handler, error) {
}

func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/static/", h.HandleGetStaticFiles)
mux.HandleFunc("/logs/", h.HandleGetLogs)
mux.HandleFunc("/tasks", h.HandleGetTasks)
mux.HandleFunc("/", h.HandleIndex)
mux.HandleFunc("/static/", h.handleGetStaticFiles)
mux.HandleFunc("/command/", h.handleGetCommand)
mux.HandleFunc("/logs/", h.handleGetLogs)
mux.HandleFunc("/tasks", h.handleGetTasks)
mux.HandleFunc("/", h.handleIndex)
}

func (h *Handler) HandleIndex(w http.ResponseWriter, _ *http.Request) {
func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request) {
err := h.templates.ExecuteTemplate(w, "index.html", nil)
if err != nil {
log.Error("failed to execute template", "error", err)
http.Error(w, "Internal Error 500", http.StatusInternalServerError)
}
}

func (h *Handler) HandleGetTasks(w http.ResponseWriter, _ *http.Request) {
func (h *Handler) handleGetTasks(w http.ResponseWriter, _ *http.Request) {
// Fetch tasks directly as structs from the DB
tasks, err := h.k8sRunner.GetAllTasksFromDB()
if err != nil {
Expand All @@ -62,13 +63,13 @@ func (h *Handler) HandleGetTasks(w http.ResponseWriter, _ *http.Request) {
}
}

func (h *Handler) HandleGetStaticFiles(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleGetStaticFiles(w http.ResponseWriter, r *http.Request) {
log.Debug("serving static file", "path", r.URL.Path)
staticFileHandler := http.FileServer(http.FS(FsStaticFiles))
staticFileHandler.ServeHTTP(w, r)
}

func (h *Handler) HandleGetLogs(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleGetLogs(w http.ResponseWriter, r *http.Request) {
taskID := strings.TrimPrefix(r.URL.Path, "/logs/")
logs, err := h.k8sRunner.GetLogsForTaskFromDB(taskID)
if err != nil {
Expand All @@ -83,6 +84,21 @@ func (h *Handler) HandleGetLogs(w http.ResponseWriter, r *http.Request) {
}
}

func (h *Handler) handleGetCommand(w http.ResponseWriter, r *http.Request) {
taskID := strings.TrimPrefix(r.URL.Path, "/command/")
command, err := h.k8sRunner.GetCommandForTaskFromDB(taskID)
if err != nil {
http.Error(w, "Failed to fetch command", http.StatusInternalServerError)
return
}
_, err = w.Write([]byte("<html><body><pre>" + html.EscapeString(join(command, " ")) + "</pre></body></html>"))
if err != nil {
log.Error("failed to write command", "error", err)
http.Error(w, "Internal Error 500", http.StatusInternalServerError)
return
}
}

func join(strs []string, sep string) string {
return strings.Join(strs, sep)
}
8 changes: 5 additions & 3 deletions cmd/server/web/templates/tasks.html
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
<table class="min-w-full divide-y divide-gray-700">
<thead class="bg-gray-700">
<tr>
<th class="px-4 py-3 text-left text-xs font-medium text-white uppercase tracking-wider">Status</th>
<th class="px-4 py-3 text-left text-xs font-medium text-white uppercase tracking-wider">Job ID</th>
<th class="px-4 py-3 text-left text-xs font-medium text-white uppercase tracking-wider">Image</th>
<th class="px-4 py-3 text-left text-xs font-medium text-white uppercase tracking-wider">Command</th>
<th class="px-4 py-3 text-left text-xs font-medium text-white uppercase tracking-wider">Status</th>
<th class="px-4 py-3 text-left text-xs font-medium text-white uppercase tracking-wider">Logs</th>
<th class="px-4 py-3 text-left text-xs font-medium text-white uppercase tracking-wider">Created At</th>
<th class="px-4 py-3 text-left text-xs font-medium text-white uppercase tracking-wider">Started At</th>
Expand All @@ -14,10 +14,12 @@
<tbody class="bg-gray-800 text-white divide-y divide-gray-700">
{{ range . }}
<tr>
<td class="px-4 py-3 whitespace-nowrap">{{ .Status }}</td>
<td class="px-4 py-3 whitespace-nowrap">{{ .ID }}</td>
<td class="px-4 py-3 whitespace-nowrap">{{ .Image }}</td>
<td class="px-4 py-3 whitespace-pre">{{ join .Command " " }}</td>
<td class="px-4 py-3 whitespace-nowrap">{{ .Status }}</td>
<td class="px-4 py-3 whitespace-pre-wrap">
<button onclick="window.open('/command/{{ .ID }}', '_blank')">View Command</button>
</td>
<td class="px-4 py-3 whitespace-pre-wrap">
<button onclick="window.open('/logs/{{ .ID }}', '_blank')">View Logs</button>
</td>
Expand Down
60 changes: 28 additions & 32 deletions pkg/k8s/jobs/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,19 @@ func (r *Runner) createAndMonitorJob(ctx context.Context, namespace string, task
RestartPolicy: corev1.RestartPolicyOnFailure,
},
},
ActiveDeadlineSeconds: ptr.To(int64(task.Timeout)), // Set the job timeout (in seconds),
ActiveDeadlineSeconds: ptr.To(int64(task.Timeout)), // Ensure jobs respect the task timeout
},
}

log.Debug("Creating job", "jobId", task.ID, "image", task.Image, "command", task.Command)
// Create the Kubernetes job
job, err := r.cs.BatchV1().Jobs(namespace).Create(ctx, job, metav1.CreateOptions{})
if err != nil {
log.Error("Failed to create job", "error", err)
return nil, err
}

// Monitor the job status until completion or failure
log.Debug("Waiting for job to complete...", "jobId", task.ID, "image", task.Image, "command", task.Command)
err = r.waitForJobCompletion(ctx, job)
if err != nil {
if err = r.waitForJobCompletion(ctx, job); err != nil {
log.Error("Failed to monitor job", "error", err)
return nil, fmt.Errorf("job monitoring failed: %v", err)
}
Expand All @@ -64,47 +61,46 @@ func (r *Runner) waitForJobCompletion(ctx context.Context, job *batchv1.Job) err
}
defer watcher.Stop()

return r.processJobEvents(watcher)
}

func (r *Runner) processJobEvents(watcher watch.Interface) error {
for event := range watcher.ResultChan() {
if err := r.handleJobEvent(event); err != nil {
return err
for {
select {
case event, ok := <-watcher.ResultChan():
if !ok {
return fmt.Errorf("job watch channel closed")
}
switch event.Type {
case watch.Error:
return fmt.Errorf("error watching job: %v", event.Object)
case watch.Deleted:
return fmt.Errorf("job deleted unexpectedly")
case watch.Added, watch.Modified:
job, ok := event.Object.(*batchv1.Job)
if !ok {
return fmt.Errorf("unexpected object type")
}
if done, err := r.evaluateJobStatus(job); done {
return err
}
}
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}

func (r *Runner) handleJobEvent(event watch.Event) error {
switch event.Type {
case watch.Added, watch.Modified:
return r.evaluateJobStatus(event)
case watch.Deleted:
return fmt.Errorf("job deleted")
}
return nil
}

func (r *Runner) evaluateJobStatus(event watch.Event) error {
job, ok := event.Object.(*batchv1.Job)
if !ok {
return fmt.Errorf("unexpected type")
}

func (r *Runner) evaluateJobStatus(job *batchv1.Job) (bool, error) {
for _, condition := range job.Status.Conditions {
switch condition.Type {
case batchv1.JobComplete:
if condition.Status == corev1.ConditionTrue {
return nil
return true, nil
}
case batchv1.JobFailed:
if condition.Status == corev1.ConditionTrue {
return fmt.Errorf("job failed")
return true, fmt.Errorf("job failed: %s", condition.Message)
}
}
}
return nil
return false, nil // Job has not yet reached a definitive state
}

// Delete deletes a pod in a Kubernetes cluster.
Expand Down
20 changes: 20 additions & 0 deletions pkg/k8s/jobs/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,26 @@ func (r *Runner) GetStatusForTaskFromDB(taskId string) (string, error) {
return status, nil
}

func (r *Runner) GetCommandForTaskFromDB(taskId string) ([]string, error) {
var command []string
err := r.db.View(func(tx *buntdb.Tx) error {
val, err := tx.Get(taskId)
if err != nil {
return err
}
var task Task
if err := json.Unmarshal([]byte(val), &task); err != nil {
return err
}
command = task.Command
return nil
})
if err != nil {
return nil, err
}
return command, nil
}

func (r *Runner) GetTaskFromDB(taskId string) (*Task, error) {
var task Task
err := r.db.View(func(tx *buntdb.Tx) error {
Expand Down
70 changes: 28 additions & 42 deletions pkg/k8s/jobs/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package k8sJobs
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/charmbracelet/log"
"github.com/tidwall/buntdb"
Expand Down Expand Up @@ -54,11 +53,8 @@ func (r *Runner) processTasks(ctx context.Context) {
select {
case task := <-r.taskChan:
r.wg.Add(1)
// Acquire a semaphore slot before processing a task
r.currentJobs <- struct{}{} // This will block if the limit is reached
go func(t Task) {
defer r.wg.Done()
defer func() { <-r.currentJobs }() // Release the semaphore slot after the task is processed
r.handleTask(ctx, t)
}(task)
case <-r.quit:
Expand All @@ -67,55 +63,43 @@ func (r *Runner) processTasks(ctx context.Context) {
}
}

func (r *Runner) handleTask(ctx context.Context, t Task) {
defer r.wg.Done()
// handleTask handles a task by creating a job, monitoring it, and updating the task status.
func (r *Runner) handleTask(parentCtx context.Context, t Task) {
// Acquire a semaphore slot before processing a task
r.currentJobs <- struct{}{} // This will block if the limit is reached
defer func() { <-r.currentJobs }() // Release the semaphore slot after the task is processed

// set default timeout
if t.Timeout == 0 {
t.Timeout = r.defaultJobTimeout
}

t.StartedAt = time.Now() // Record when the task processing starts
log.Debug("Processing task", "jobId", t.ID, "image", t.Image, "command", t.Command)

// Update task status to RUNNING
if !r.updateTaskStatus(t, "RUNNING", "") {
return
}
// Create a new context with a timeout for the task
taskCtx, cancel := context.WithTimeout(parentCtx, time.Duration(t.Timeout)*time.Second+5*time.Second)
defer cancel()

// Attempt to create and monitor the Kubernetes job
_, err := r.createAndMonitorJob(ctx, r.namespace, t)
if err != nil {
t.EndedAt = time.Now() // Set end time when task finishes or fails
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
r.handleTimeout(t)
} else {
log.Error("Failed to create and monitor job", "error", err, "jobId", t.ID)
r.updateTaskStatus(t, "FAILED", fmt.Sprintf("Job failed: %v", err))
t.StartedAt = time.Now()
log.Debug("Processing task", "jobId", t.ID, "image", t.Image, "command", t.Command)
t.Status = "RUNNING"
r.updateTaskStatus(t)

if _, err := r.createAndMonitorJob(taskCtx, r.namespace, t); err != nil {
log.Error("Failed to create or monitor job", "error", err, "jobId", t.ID)
t.Status = "FAILED"
t.Logs = fmt.Sprintf("Job monitoring failed: %v", err)
} else {
t.Status = "SUCCEEDED"
t.Logs, err = r.getLogs(taskCtx, t.ID) // Retrieve logs
if err != nil {
log.Error("Failed to get logs", "error", err, "jobId", t.ID)
}
r.delete(ctx, t.ID) // Cleanup after failure or timeout
return
}

// Successfully completed the task
t.EndedAt = time.Now()
log.Debug("Retrieving logs", "jobId", t.ID)
logs, err := r.getLogs(ctx, t.ID)
finalStatus := "SUCCEEDED"
if err != nil {
finalStatus = "FAILED"
logs = fmt.Sprintf("Failed to get logs: %v", err)
log.Error("Failed to get logs", "error", err, "jobId", t.ID)
}

r.updateTaskStatus(t, finalStatus, logs)
r.delete(ctx, t.ID) // Cleanup after successful completion
r.updateTaskStatus(t)
r.delete(taskCtx, t.ID) // Use task-specific context for deletion
}

func (r *Runner) updateTaskStatus(t Task, status, logs string) bool {
t.Status = status
t.Logs = logs

func (r *Runner) updateTaskStatus(t Task) bool {
taskData, err := json.Marshal(t)
if err != nil {
log.Error("Failed to serialize task", "error", err, "jobId", t.ID)
Expand All @@ -136,6 +120,8 @@ func (r *Runner) updateTaskStatus(t Task, status, logs string) bool {

func (r *Runner) handleTimeout(t Task) {

Check failure on line 121 in pkg/k8s/jobs/task.go

View workflow job for this annotation

GitHub Actions / lint

func `(*Runner).handleTimeout` is unused (unused)
t.EndedAt = time.Now() // Set the timeout end time
r.updateTaskStatus(t, "TIMED OUT", "Task timed out")
t.Status = "TIMEOUT"
t.Logs = "Task timed out"
r.updateTaskStatus(t)
r.delete(context.Background(), t.ID) // Ensure context.Background() to avoid passing a canceled context
}

0 comments on commit 323947e

Please sign in to comment.