package tasks import ( "context" "database/sql" "encoding/json" "fmt" "github.com/atlasos/calypso/internal/common/database" "github.com/atlasos/calypso/internal/common/logger" "github.com/google/uuid" ) // Engine manages async task execution type Engine struct { db *database.DB logger *logger.Logger } // NewEngine creates a new task engine func NewEngine(db *database.DB, log *logger.Logger) *Engine { return &Engine{ db: db, logger: log, } } // TaskStatus represents the state of a task type TaskStatus string const ( TaskStatusPending TaskStatus = "pending" TaskStatusRunning TaskStatus = "running" TaskStatusCompleted TaskStatus = "completed" TaskStatusFailed TaskStatus = "failed" TaskStatusCancelled TaskStatus = "cancelled" ) // TaskType represents the type of task type TaskType string const ( TaskTypeInventory TaskType = "inventory" TaskTypeLoadUnload TaskType = "load_unload" TaskTypeRescan TaskType = "rescan" TaskTypeApplySCST TaskType = "apply_scst" TaskTypeSupportBundle TaskType = "support_bundle" ) // CreateTask creates a new task func (e *Engine) CreateTask(ctx context.Context, taskType TaskType, createdBy string, metadata map[string]interface{}) (string, error) { taskID := uuid.New().String() var metadataJSON *string if metadata != nil { bytes, err := json.Marshal(metadata) if err != nil { return "", fmt.Errorf("failed to marshal metadata: %w", err) } jsonStr := string(bytes) metadataJSON = &jsonStr } query := ` INSERT INTO tasks (id, type, status, progress, created_by, metadata) VALUES ($1, $2, $3, $4, $5, $6) ` _, err := e.db.ExecContext(ctx, query, taskID, string(taskType), string(TaskStatusPending), 0, createdBy, metadataJSON, ) if err != nil { return "", fmt.Errorf("failed to create task: %w", err) } e.logger.Info("Task created", "task_id", taskID, "type", taskType) return taskID, nil } // StartTask marks a task as running func (e *Engine) StartTask(ctx context.Context, taskID string) error { query := ` UPDATE tasks SET status = $1, progress = 0, started_at = NOW(), updated_at = NOW() WHERE id = $2 AND status = $3 ` result, err := e.db.ExecContext(ctx, query, string(TaskStatusRunning), taskID, string(TaskStatusPending)) if err != nil { return fmt.Errorf("failed to start task: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rows == 0 { return fmt.Errorf("task not found or already started") } e.logger.Info("Task started", "task_id", taskID) return nil } // UpdateProgress updates task progress func (e *Engine) UpdateProgress(ctx context.Context, taskID string, progress int, message string) error { if progress < 0 || progress > 100 { return fmt.Errorf("progress must be between 0 and 100") } query := ` UPDATE tasks SET progress = $1, message = $2, updated_at = NOW() WHERE id = $3 ` _, err := e.db.ExecContext(ctx, query, progress, message, taskID) if err != nil { return fmt.Errorf("failed to update progress: %w", err) } return nil } // CompleteTask marks a task as completed func (e *Engine) CompleteTask(ctx context.Context, taskID string, message string) error { query := ` UPDATE tasks SET status = $1, progress = 100, message = $2, completed_at = NOW(), updated_at = NOW() WHERE id = $3 ` result, err := e.db.ExecContext(ctx, query, string(TaskStatusCompleted), message, taskID) if err != nil { return fmt.Errorf("failed to complete task: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rows == 0 { return fmt.Errorf("task not found") } e.logger.Info("Task completed", "task_id", taskID) return nil } // FailTask marks a task as failed func (e *Engine) FailTask(ctx context.Context, taskID string, errorMessage string) error { query := ` UPDATE tasks SET status = $1, error_message = $2, completed_at = NOW(), updated_at = NOW() WHERE id = $3 ` result, err := e.db.ExecContext(ctx, query, string(TaskStatusFailed), errorMessage, taskID) if err != nil { return fmt.Errorf("failed to fail task: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rows == 0 { return fmt.Errorf("task not found") } e.logger.Error("Task failed", "task_id", taskID, "error", errorMessage) return nil } // GetTask retrieves a task by ID func (e *Engine) GetTask(ctx context.Context, taskID string) (*Task, error) { query := ` SELECT id, type, status, progress, message, error_message, created_by, started_at, completed_at, created_at, updated_at, metadata FROM tasks WHERE id = $1 ` var task Task var errorMsg, createdBy sql.NullString var startedAt, completedAt sql.NullTime var metadata sql.NullString err := e.db.QueryRowContext(ctx, query, taskID).Scan( &task.ID, &task.Type, &task.Status, &task.Progress, &task.Message, &errorMsg, &createdBy, &startedAt, &completedAt, &task.CreatedAt, &task.UpdatedAt, &metadata, ) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("task not found") } return nil, fmt.Errorf("failed to get task: %w", err) } if errorMsg.Valid { task.ErrorMessage = errorMsg.String } if createdBy.Valid { task.CreatedBy = createdBy.String } if startedAt.Valid { task.StartedAt = &startedAt.Time } if completedAt.Valid { task.CompletedAt = &completedAt.Time } if metadata.Valid && metadata.String != "" { json.Unmarshal([]byte(metadata.String), &task.Metadata) } return &task, nil }