130 lines
3.1 KiB
Go
130 lines
3.1 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"fmt"
|
|
"io/fs"
|
|
"log/slog"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
//go:embed migrations/*.sql
|
|
var MigrationsFS embed.FS
|
|
|
|
type DB struct {
|
|
Pool *pgxpool.Pool
|
|
}
|
|
|
|
func New(ctx context.Context, databaseURL string) (*DB, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgxpool.ParseConfig(databaseURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse database url: %w", err)
|
|
}
|
|
|
|
config.MaxConns = 25
|
|
config.MinConns = 3
|
|
config.MaxConnLifetime = 1 * time.Hour
|
|
config.MaxConnIdleTime = 30 * time.Minute
|
|
config.HealthCheckPeriod = 30 * time.Second
|
|
|
|
pool, err := pgxpool.NewWithConfig(ctx, config)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create connection pool: %w", err)
|
|
}
|
|
|
|
if err := pool.Ping(ctx); err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("ping database: %w", err)
|
|
}
|
|
|
|
return &DB{Pool: pool}, nil
|
|
}
|
|
|
|
func (d *DB) Ping(ctx context.Context) error {
|
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
defer cancel()
|
|
return d.Pool.Ping(ctx)
|
|
}
|
|
|
|
func (d *DB) Close() {
|
|
d.Pool.Close()
|
|
}
|
|
|
|
func (d *DB) RunMigrations(ctx context.Context, migrationsFS embed.FS) error {
|
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer cancel()
|
|
|
|
var count int
|
|
if err := d.Pool.QueryRow(ctx, "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'schema_migrations'").Scan(&count); err != nil {
|
|
return fmt.Errorf("check migrations table: %w", err)
|
|
}
|
|
|
|
if count == 0 {
|
|
if _, err := d.Pool.Exec(ctx, "CREATE TABLE schema_migrations (version TEXT PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW())"); err != nil {
|
|
return fmt.Errorf("create schema_migrations table: %w", err)
|
|
}
|
|
slog.Info("created schema_migrations table")
|
|
}
|
|
|
|
files, err := fs.ReadDir(migrationsFS, "migrations")
|
|
if err != nil {
|
|
return fmt.Errorf("read migrations directory: %w", err)
|
|
}
|
|
|
|
sort.Slice(files, func(i, j int) bool {
|
|
return files[i].Name() < files[j].Name()
|
|
})
|
|
|
|
for _, f := range files {
|
|
if f.IsDir() || !strings.HasSuffix(f.Name(), ".sql") {
|
|
continue
|
|
}
|
|
|
|
version := f.Name()
|
|
|
|
var applied bool
|
|
if err := d.Pool.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = $1)", version).Scan(&applied); err != nil {
|
|
return fmt.Errorf("check migration %s: %w", version, err)
|
|
}
|
|
if applied {
|
|
continue
|
|
}
|
|
|
|
content, err := migrationsFS.ReadFile("migrations/" + version)
|
|
if err != nil {
|
|
return fmt.Errorf("read migration %s: %w", version, err)
|
|
}
|
|
|
|
tx, err := d.Pool.Begin(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("begin transaction for %s: %w", version, err)
|
|
}
|
|
|
|
if _, err := tx.Exec(ctx, string(content)); err != nil {
|
|
tx.Rollback(ctx)
|
|
return fmt.Errorf("execute migration %s: %w", version, err)
|
|
}
|
|
|
|
if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations (version) VALUES ($1)", version); err != nil {
|
|
tx.Rollback(ctx)
|
|
return fmt.Errorf("record migration %s: %w", version, err)
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return fmt.Errorf("commit migration %s: %w", version, err)
|
|
}
|
|
|
|
slog.Info("applied migration", "file", version)
|
|
}
|
|
|
|
return nil
|
|
}
|