package db import ( "context" "embed" "fmt" "io/fs" "log/slog" "sort" "strings" "time" "github.com/jackc/pgx/v5/pgxpool" ) //go:embed migrations/*.sql var MigrationsFS embed.FS type DB struct { Pool *pgxpool.Pool } func New(ctx context.Context, databaseURL string) (*DB, error) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() config, err := pgxpool.ParseConfig(databaseURL) if err != nil { return nil, fmt.Errorf("parse database url: %w", err) } config.MaxConns = 25 config.MinConns = 3 config.MaxConnLifetime = 1 * time.Hour config.MaxConnIdleTime = 30 * time.Minute config.HealthCheckPeriod = 30 * time.Second pool, err := pgxpool.NewWithConfig(ctx, config) if err != nil { return nil, fmt.Errorf("create connection pool: %w", err) } if err := pool.Ping(ctx); err != nil { pool.Close() return nil, fmt.Errorf("ping database: %w", err) } return &DB{Pool: pool}, nil } func (d *DB) Ping(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() return d.Pool.Ping(ctx) } func (d *DB) Close() { d.Pool.Close() } func (d *DB) RunMigrations(ctx context.Context, migrationsFS embed.FS) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() var count int if err := d.Pool.QueryRow(ctx, "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'schema_migrations'").Scan(&count); err != nil { return fmt.Errorf("check migrations table: %w", err) } if count == 0 { if _, err := d.Pool.Exec(ctx, "CREATE TABLE schema_migrations (version TEXT PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW())"); err != nil { return fmt.Errorf("create schema_migrations table: %w", err) } slog.Info("created schema_migrations table") } files, err := fs.ReadDir(migrationsFS, "migrations") if err != nil { return fmt.Errorf("read migrations directory: %w", err) } sort.Slice(files, func(i, j int) bool { return files[i].Name() < files[j].Name() }) for _, f := range files { if f.IsDir() || !strings.HasSuffix(f.Name(), ".sql") { continue } version := f.Name() var applied bool if err := d.Pool.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = $1)", version).Scan(&applied); err != nil { return fmt.Errorf("check migration %s: %w", version, err) } if applied { continue } content, err := migrationsFS.ReadFile("migrations/" + version) if err != nil { return fmt.Errorf("read migration %s: %w", version, err) } tx, err := d.Pool.Begin(ctx) if err != nil { return fmt.Errorf("begin transaction for %s: %w", version, err) } if _, err := tx.Exec(ctx, string(content)); err != nil { tx.Rollback(ctx) return fmt.Errorf("execute migration %s: %w", version, err) } if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations (version) VALUES ($1)", version); err != nil { tx.Rollback(ctx) return fmt.Errorf("record migration %s: %w", version, err) } if err := tx.Commit(ctx); err != nil { return fmt.Errorf("commit migration %s: %w", version, err) } slog.Info("applied migration", "file", version) } return nil }