Files
2026-04-24 10:45:19 -07:00

130 lines
3.1 KiB
Go

package db
import (
"context"
"embed"
"fmt"
"io/fs"
"log/slog"
"sort"
"strings"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
//go:embed migrations/*.sql
var MigrationsFS embed.FS
type DB struct {
Pool *pgxpool.Pool
}
func New(ctx context.Context, databaseURL string) (*DB, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
config, err := pgxpool.ParseConfig(databaseURL)
if err != nil {
return nil, fmt.Errorf("parse database url: %w", err)
}
config.MaxConns = 25
config.MinConns = 3
config.MaxConnLifetime = 1 * time.Hour
config.MaxConnIdleTime = 30 * time.Minute
config.HealthCheckPeriod = 30 * time.Second
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("create connection pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("ping database: %w", err)
}
return &DB{Pool: pool}, nil
}
func (d *DB) Ping(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
return d.Pool.Ping(ctx)
}
func (d *DB) Close() {
d.Pool.Close()
}
func (d *DB) RunMigrations(ctx context.Context, migrationsFS embed.FS) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
var count int
if err := d.Pool.QueryRow(ctx, "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'schema_migrations'").Scan(&count); err != nil {
return fmt.Errorf("check migrations table: %w", err)
}
if count == 0 {
if _, err := d.Pool.Exec(ctx, "CREATE TABLE schema_migrations (version TEXT PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW())"); err != nil {
return fmt.Errorf("create schema_migrations table: %w", err)
}
slog.Info("created schema_migrations table")
}
files, err := fs.ReadDir(migrationsFS, "migrations")
if err != nil {
return fmt.Errorf("read migrations directory: %w", err)
}
sort.Slice(files, func(i, j int) bool {
return files[i].Name() < files[j].Name()
})
for _, f := range files {
if f.IsDir() || !strings.HasSuffix(f.Name(), ".sql") {
continue
}
version := f.Name()
var applied bool
if err := d.Pool.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = $1)", version).Scan(&applied); err != nil {
return fmt.Errorf("check migration %s: %w", version, err)
}
if applied {
continue
}
content, err := migrationsFS.ReadFile("migrations/" + version)
if err != nil {
return fmt.Errorf("read migration %s: %w", version, err)
}
tx, err := d.Pool.Begin(ctx)
if err != nil {
return fmt.Errorf("begin transaction for %s: %w", version, err)
}
if _, err := tx.Exec(ctx, string(content)); err != nil {
tx.Rollback(ctx)
return fmt.Errorf("execute migration %s: %w", version, err)
}
if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations (version) VALUES ($1)", version); err != nil {
tx.Rollback(ctx)
return fmt.Errorf("record migration %s: %w", version, err)
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("commit migration %s: %w", version, err)
}
slog.Info("applied migration", "file", version)
}
return nil
}