Sync from /srv/compose/unified-media-manager
This commit is contained in:
236
internal/worker/scheduler.go
Normal file
236
internal/worker/scheduler.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/TopherMayor/unified-media-manager/internal/db"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
type Worker interface {
|
||||
Name() string
|
||||
CronExpr() string
|
||||
Run(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
cron *cron.Cron
|
||||
database *db.DB
|
||||
workers map[string]Worker
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewScheduler(database *db.DB) *Scheduler {
|
||||
return &Scheduler{
|
||||
cron: cron.New(cron.WithSeconds()),
|
||||
database: database,
|
||||
workers: make(map[string]Worker),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) Register(w Worker) {
|
||||
s.workers[w.Name()] = w
|
||||
|
||||
_, err := s.database.Pool.Exec(context.Background(),
|
||||
`INSERT INTO scheduled_tasks (name, cron_expr, enabled)
|
||||
VALUES ($1, $2, true)
|
||||
ON CONFLICT (name) DO UPDATE SET cron_expr = EXCLUDED.cron_expr`,
|
||||
w.Name(), w.CronExpr())
|
||||
if err != nil {
|
||||
slog.Error("failed to seed scheduled task", "worker", w.Name(), "error", err)
|
||||
}
|
||||
|
||||
wrapper := s.runWithLogging(w)
|
||||
_, err = s.cron.AddFunc(w.CronExpr(), wrapper)
|
||||
if err != nil {
|
||||
slog.Error("failed to schedule worker", "worker", w.Name(), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) runWithLogging(w Worker) func() {
|
||||
return func() {
|
||||
// Check if task is enabled before running
|
||||
var enabled bool
|
||||
err := s.database.Pool.QueryRow(context.Background(),
|
||||
"SELECT enabled FROM scheduled_tasks WHERE name = $1", w.Name()).Scan(&enabled)
|
||||
if err != nil {
|
||||
slog.Error("failed to check task enabled status", "worker", w.Name(), "error", err)
|
||||
return
|
||||
}
|
||||
if !enabled {
|
||||
slog.Debug("skipping disabled task", "worker", w.Name())
|
||||
return
|
||||
}
|
||||
|
||||
var taskID int
|
||||
err = s.database.Pool.QueryRow(context.Background(),
|
||||
"SELECT id FROM scheduled_tasks WHERE name = $1", w.Name()).Scan(&taskID)
|
||||
if err != nil {
|
||||
slog.Error("failed to get task id", "worker", w.Name(), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
var execID int64
|
||||
err = s.database.Pool.QueryRow(context.Background(),
|
||||
"INSERT INTO task_executions (task_id, status, started_at) VALUES ($1, 'running', NOW()) RETURNING id",
|
||||
taskID).Scan(&execID)
|
||||
if err != nil {
|
||||
slog.Error("failed to create execution record", "worker", w.Name(), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
runErr := w.Run(s.ctx)
|
||||
duration := time.Since(start)
|
||||
|
||||
if runErr != nil {
|
||||
slog.Error("worker execution failed", "worker", w.Name(), "error", runErr, "duration_ms", duration.Milliseconds())
|
||||
_, _ = s.database.Pool.Exec(context.Background(),
|
||||
"UPDATE task_executions SET status = 'failed', ended_at = NOW(), duration_ms = $1, error = $2 WHERE id = $3",
|
||||
duration.Milliseconds(), runErr.Error(), execID)
|
||||
} else {
|
||||
slog.Info("worker execution completed", "worker", w.Name(), "duration_ms", duration.Milliseconds())
|
||||
_, _ = s.database.Pool.Exec(context.Background(),
|
||||
"UPDATE task_executions SET status = 'success', ended_at = NOW(), duration_ms = $1 WHERE id = $2",
|
||||
duration.Milliseconds(), execID)
|
||||
}
|
||||
|
||||
schedule, parseErr := cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow).Parse(w.CronExpr())
|
||||
var nextRunAt time.Time
|
||||
if parseErr == nil {
|
||||
nextRunAt = schedule.Next(time.Now())
|
||||
}
|
||||
|
||||
_, _ = s.database.Pool.Exec(context.Background(),
|
||||
"UPDATE scheduled_tasks SET last_run_at = NOW(), next_run_at = $1 WHERE id = $2",
|
||||
nextRunAt, taskID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) TriggerWorker(name string) error {
|
||||
w, ok := s.workers[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("worker not found: %s", name)
|
||||
}
|
||||
|
||||
go s.runWithLogging(w)()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) GetWorkers() []ScheduledTaskInfo {
|
||||
rows, err := s.database.Pool.Query(context.Background(),
|
||||
"SELECT id, name, cron_expr, enabled, last_run_at, next_run_at FROM scheduled_tasks ORDER BY name")
|
||||
if err != nil {
|
||||
slog.Error("failed to query scheduled tasks", "error", err)
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []ScheduledTaskInfo
|
||||
for rows.Next() {
|
||||
var t ScheduledTaskInfo
|
||||
var lastRunAt, nextRunAt *time.Time
|
||||
if err := rows.Scan(&t.ID, &t.Name, &t.CronExpr, &t.Enabled, &lastRunAt, &nextRunAt); err != nil {
|
||||
slog.Error("failed to scan scheduled task", "error", err)
|
||||
continue
|
||||
}
|
||||
if lastRunAt != nil {
|
||||
t.LastRunAt = lastRunAt
|
||||
}
|
||||
if nextRunAt != nil {
|
||||
t.NextRunAt = nextRunAt
|
||||
}
|
||||
tasks = append(tasks, t)
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
func (s *Scheduler) GetHistory(ctx context.Context, name string, page, pageSize int) ([]TaskExecution, int, error) {
|
||||
var total int
|
||||
err := s.database.Pool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM task_executions te JOIN scheduled_tasks st ON te.task_id = st.id WHERE st.name = $1`,
|
||||
name).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("count task executions: %w", err)
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
rows, err := s.database.Pool.Query(ctx,
|
||||
`SELECT te.id, te.status, te.started_at, te.ended_at, te.duration_ms, te.result, te.error
|
||||
FROM task_executions te JOIN scheduled_tasks st ON te.task_id = st.id
|
||||
WHERE st.name = $1 ORDER BY te.started_at DESC LIMIT $2 OFFSET $3`,
|
||||
name, pageSize, offset)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("query task executions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var executions []TaskExecution
|
||||
for rows.Next() {
|
||||
var e TaskExecution
|
||||
var result []byte
|
||||
var execError *string
|
||||
if err := rows.Scan(&e.ID, &e.Status, &e.StartedAt, &e.EndedAt, &e.DurationMS, &result, &execError); err != nil {
|
||||
slog.Error("failed to scan task execution", "error", err)
|
||||
continue
|
||||
}
|
||||
if result != nil {
|
||||
e.Result = json.RawMessage(result)
|
||||
}
|
||||
if execError != nil {
|
||||
e.Error = *execError
|
||||
}
|
||||
executions = append(executions, e)
|
||||
}
|
||||
return executions, total, nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) SetEnabled(name string, enabled bool) error {
|
||||
tag, err := s.database.Pool.Exec(context.Background(),
|
||||
"UPDATE scheduled_tasks SET enabled = $1 WHERE name = $2", enabled, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update scheduled task enabled: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("scheduled task not found: %s", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) Start(ctx context.Context) {
|
||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||
s.cron.Start()
|
||||
slog.Info("worker scheduler started")
|
||||
}
|
||||
|
||||
func (s *Scheduler) Stop() {
|
||||
s.cron.Stop()
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
slog.Info("worker scheduler stopped")
|
||||
}
|
||||
|
||||
type ScheduledTaskInfo struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
CronExpr string `json:"cron_expr"`
|
||||
Enabled bool `json:"enabled"`
|
||||
LastRunAt *time.Time `json:"last_run_at,omitempty"`
|
||||
NextRunAt *time.Time `json:"next_run_at,omitempty"`
|
||||
}
|
||||
|
||||
type TaskExecution struct {
|
||||
ID int64 `json:"id"`
|
||||
Status string `json:"status"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
DurationMS *int64 `json:"duration_ms,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user