Files
unified-media-manager/internal/service/notification.go
2026-04-24 10:45:19 -07:00

674 lines
20 KiB
Go

package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"math"
"net/http"
"net/url"
"strings"
"time"
"github.com/TopherMayor/unified-media-manager/internal/db"
"github.com/jackc/pgx/v5"
)
type NotificationChannel struct {
ID int64 `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Enabled bool `json:"enabled"`
Config json.RawMessage `json:"config"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// EventTypes populated by JOIN (not a DB column)
EventTypes []string `json:"event_types,omitempty"`
}
type QueueEntry struct {
ID int64 `json:"id"`
ChannelID int64 `json:"channel_id"`
EventType string `json:"event_type"`
Title string `json:"title"`
Message json.RawMessage `json:"message"`
Status string `json:"status"`
Attempts int `json:"attempts"`
MaxAttempts int `json:"max_attempts"`
LastError *string `json:"last_error,omitempty"`
NextRetryAt *time.Time `json:"next_retry_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
DeliveredAt *time.Time `json:"delivered_at,omitempty"`
}
type NotificationService struct {
db *db.DB
http *http.Client
telegramBaseURL string // override for testing
done chan struct{}
}
func NewNotificationService(database *db.DB) *NotificationService {
return &NotificationService{
db: database,
http: &http.Client{Timeout: 10 * time.Second},
telegramBaseURL: "https://api.telegram.org",
done: make(chan struct{}),
}
}
// ValidateChannelConfig checks config has required fields for the channel type.
func (s *NotificationService) ValidateChannelConfig(channelType string, config json.RawMessage) error {
var m map[string]interface{}
if err := json.Unmarshal(config, &m); err != nil {
return fmt.Errorf("invalid config JSON: %w", err)
}
switch channelType {
case "webhook":
urlStr, _ := m["url"].(string)
if urlStr == "" {
return fmt.Errorf("webhook config requires 'url' field")
}
u, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("webhook url is invalid: %w", err)
}
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("webhook url must use http or https scheme")
}
case "telegram":
botToken, _ := m["bot_token"].(string)
chatID, _ := m["chat_id"].(string)
if botToken == "" {
return fmt.Errorf("telegram config requires 'bot_token' field")
}
if chatID == "" {
return fmt.Errorf("telegram config requires 'chat_id' field")
}
default:
return fmt.Errorf("unknown channel type: %s", channelType)
}
return nil
}
// ListChannels returns all channels with masked configs and their event subscriptions.
func (s *NotificationService) ListChannels(ctx context.Context) ([]NotificationChannel, error) {
rows, err := s.db.Pool.Query(ctx,
`SELECT c.id, c.name, c.type, c.enabled, c.config, c.created_at, c.updated_at,
COALESCE(json_agg(s.event_type) FILTER (WHERE s.event_type IS NOT NULL), '[]') AS event_types
FROM notification_channels c
LEFT JOIN notification_subscriptions s ON c.id = s.channel_id
GROUP BY c.id, c.name, c.type, c.enabled, c.config, c.created_at, c.updated_at
ORDER BY c.name`)
if err != nil {
return nil, fmt.Errorf("list notification channels: %w", err)
}
defer rows.Close()
var channels []NotificationChannel
for rows.Next() {
var ch NotificationChannel
var eventTypesJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Type, &ch.Enabled, &ch.Config,
&ch.CreatedAt, &ch.UpdatedAt, &eventTypesJSON); err != nil {
slog.Error("failed to scan notification channel", "error", err)
continue
}
ch.Config = maskConfig(ch.Type, ch.Config)
var types []string
if err := json.Unmarshal(eventTypesJSON, &types); err == nil {
ch.EventTypes = types
}
channels = append(channels, ch)
}
return channels, nil
}
// CreateChannel creates a new notification channel and returns its ID.
func (s *NotificationService) CreateChannel(ctx context.Context, name, channelType string, config json.RawMessage) (int64, error) {
if err := s.ValidateChannelConfig(channelType, config); err != nil {
return 0, err
}
var id int64
err := s.db.Pool.QueryRow(ctx,
`INSERT INTO notification_channels (name, type, config) VALUES ($1, $2, $3) RETURNING id`,
name, channelType, config).Scan(&id)
if err != nil {
return 0, fmt.Errorf("create notification channel: %w", err)
}
slog.Info("created notification channel", "name", name, "type", channelType)
return id, nil
}
// UpdateChannel updates a notification channel's fields.
func (s *NotificationService) UpdateChannel(ctx context.Context, id int64, name *string, enabled *bool, config json.RawMessage) error {
qb := NewQueryBuilder(1)
setClauses := []string{}
if name != nil {
setClauses = append(setClauses, fmt.Sprintf("name = $%d", qb.Idx()))
qb.Add("", *name)
}
if enabled != nil {
setClauses = append(setClauses, fmt.Sprintf("enabled = $%d", qb.Idx()))
qb.Add("", *enabled)
}
if config != nil {
if err := s.ValidateChannelConfig("", config); err != nil {
// Skip type-specific validation on update since we don't know the type here
// The channel type doesn't change, just validate JSON is valid
}
setClauses = append(setClauses, fmt.Sprintf("config = $%d", qb.Idx()))
qb.Add("", config)
}
if len(setClauses) == 0 {
return nil
}
setClauses = append(setClauses, "updated_at = NOW()")
query := fmt.Sprintf("UPDATE notification_channels SET %s WHERE id = $%d",
strings.Join(setClauses, ", "), qb.Idx())
qb.Add("", id)
tag, err := s.db.Pool.Exec(ctx, query, qb.Args()...)
if err != nil {
return fmt.Errorf("update notification channel: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("notification channel not found: %d", id)
}
return nil
}
// DeleteChannel removes a notification channel and its subscriptions.
func (s *NotificationService) DeleteChannel(ctx context.Context, id int64) error {
tag, err := s.db.Pool.Exec(ctx, `DELETE FROM notification_channels WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete notification channel: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("notification channel not found: %d", id)
}
slog.Info("deleted notification channel", "id", id)
return nil
}
// GetChannelWithConfig returns a channel with full unmasked config for delivery.
func (s *NotificationService) GetChannelWithConfig(ctx context.Context, id int64) (*NotificationChannel, error) {
var ch NotificationChannel
err := s.db.Pool.QueryRow(ctx,
`SELECT id, name, type, enabled, config, created_at, updated_at FROM notification_channels WHERE id = $1`, id).
Scan(&ch.ID, &ch.Name, &ch.Type, &ch.Enabled, &ch.Config, &ch.CreatedAt, &ch.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("get notification channel: %w", err)
}
return &ch, nil
}
// ListSubscriptions returns event types subscribed by a channel.
func (s *NotificationService) ListSubscriptions(ctx context.Context, channelID int64) ([]string, error) {
rows, err := s.db.Pool.Query(ctx,
`SELECT event_type FROM notification_subscriptions WHERE channel_id = $1`, channelID)
if err != nil {
return nil, fmt.Errorf("list subscriptions: %w", err)
}
defer rows.Close()
var types []string
for rows.Next() {
var t string
if err := rows.Scan(&t); err != nil {
continue
}
types = append(types, t)
}
return types, nil
}
// UpdateSubscriptions replaces all subscriptions for a channel.
func (s *NotificationService) UpdateSubscriptions(ctx context.Context, channelID int64, eventTypes []string) error {
tx, err := s.db.Pool.Begin(ctx)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback(ctx)
if _, err := tx.Exec(ctx, `DELETE FROM notification_subscriptions WHERE channel_id = $1`, channelID); err != nil {
return fmt.Errorf("delete subscriptions: %w", err)
}
for _, et := range eventTypes {
if _, err := tx.Exec(ctx,
`INSERT INTO notification_subscriptions (channel_id, event_type) VALUES ($1, $2)`,
channelID, et); err != nil {
return fmt.Errorf("insert subscription: %w", err)
}
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("commit subscriptions: %w", err)
}
return nil
}
// GetSubscribersForEvent returns enabled channels subscribed to an event type.
func (s *NotificationService) GetSubscribersForEvent(ctx context.Context, eventType string) ([]NotificationChannel, error) {
rows, err := s.db.Pool.Query(ctx,
`SELECT DISTINCT c.id, c.name, c.type, c.enabled, c.config, c.created_at, c.updated_at
FROM notification_channels c
JOIN notification_subscriptions s ON c.id = s.channel_id
WHERE s.event_type = $1 AND c.enabled = true`, eventType)
if err != nil {
return nil, fmt.Errorf("get subscribers for event: %w", err)
}
defer rows.Close()
var channels []NotificationChannel
for rows.Next() {
var ch NotificationChannel
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Type, &ch.Enabled, &ch.Config,
&ch.CreatedAt, &ch.UpdatedAt); err != nil {
continue
}
channels = append(channels, ch)
}
return channels, nil
}
// DeliverWebhook sends an HTTP POST with JSON payload.
func (s *NotificationService) DeliverWebhook(ctx context.Context, webhookURL string, payload map[string]interface{}) error {
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal webhook payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", webhookURL, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("create webhook request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.http.Do(req)
if err != nil {
return fmt.Errorf("webhook delivery failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook returned status %d", resp.StatusCode)
}
return nil
}
// DeliverTelegram sends a message via the Telegram Bot API.
func (s *NotificationService) DeliverTelegram(ctx context.Context, botToken, chatID, text string) error {
apiURL := fmt.Sprintf("%s/bot%s/sendMessage", s.telegramBaseURL, botToken)
payload := map[string]interface{}{
"chat_id": chatID,
"text": text,
"parse_mode": "HTML",
}
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal telegram payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("create telegram request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.http.Do(req)
if err != nil {
return fmt.Errorf("telegram delivery failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("telegram returned status %d", resp.StatusCode)
}
return nil
}
// calculateBackoff returns exponential backoff: 30s * 2^attempt, capped at 480s.
func calculateBackoff(attempt int) time.Duration {
d := 30 * time.Second * time.Duration(math.Pow(2, float64(attempt)))
if d > 480*time.Second {
return 480 * time.Second
}
return d
}
// maskConfig masks sensitive fields in channel config for API responses.
func maskConfig(channelType string, raw json.RawMessage) json.RawMessage {
var m map[string]interface{}
if err := json.Unmarshal(raw, &m); err != nil {
return raw
}
switch channelType {
case "telegram":
if _, ok := m["bot_token"]; ok {
m["bot_token"] = "***"
}
}
masked, err := json.Marshal(m)
if err != nil {
return raw
}
return masked
}
// ListQueue returns paginated notification queue entries.
func (s *NotificationService) ListQueue(ctx context.Context, status string, page, pageSize int) ([]QueueEntry, int, error) {
qb := NewQueryBuilder(1)
if status != "" {
qb.Add("status = $%d", status)
}
where := qb.Where()
var total int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM notification_queue%s", where)
if err := s.db.Pool.QueryRow(ctx, countQuery, qb.Args()...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("count notification queue: %w", err)
}
offset := (page - 1) * pageSize
dataQuery := fmt.Sprintf(
`SELECT id, channel_id, event_type, title, message, status, attempts, max_attempts,
last_error, next_retry_at, created_at, delivered_at
FROM notification_queue%s ORDER BY created_at DESC LIMIT $%d OFFSET $%d`,
where, qb.Idx(), qb.Idx()+1)
args := append(qb.Args(), pageSize, offset)
rows, err := s.db.Pool.Query(ctx, dataQuery, args...)
if err != nil {
return nil, 0, fmt.Errorf("list notification queue: %w", err)
}
defer rows.Close()
var entries []QueueEntry
for rows.Next() {
var e QueueEntry
var lastError, nextRetry, deliveredAt interface{} // nullable
if err := rows.Scan(&e.ID, &e.ChannelID, &e.EventType, &e.Title, &e.Message,
&e.Status, &e.Attempts, &e.MaxAttempts,
&lastError, &nextRetry, &e.CreatedAt, &deliveredAt); err != nil {
continue
}
if le, ok := lastError.(*string); ok && le != nil {
e.LastError = le
} else if le, ok := lastError.(string); ok && le != "" {
e.LastError = &le
}
if nr, ok := nextRetry.(*time.Time); ok && nr != nil {
e.NextRetryAt = nr
}
if da, ok := deliveredAt.(*time.Time); ok && da != nil {
e.DeliveredAt = da
}
entries = append(entries, e)
}
return entries, total, nil
}
// StartDispatcher launches the notification dispatcher goroutines.
func (s *NotificationService) StartDispatcher(ctx context.Context) {
go s.pollActivityEvents(ctx)
go s.processQueue(ctx)
}
// StopDispatcher signals both dispatcher goroutines to stop.
func (s *NotificationService) StopDispatcher() {
close(s.done)
}
func (s *NotificationService) pollActivityEvents(ctx context.Context) {
slog.Info("notification event poller started")
defer slog.Info("notification event poller stopped")
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.done:
return
case <-ctx.Done():
return
case <-ticker.C:
s.pollOnce(ctx)
}
}
}
func (s *NotificationService) pollOnce(ctx context.Context) {
pollCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Get cursor
var lastID int64
var lastCreatedAt *time.Time
err := s.db.Pool.QueryRow(pollCtx,
`SELECT last_event_id, last_event_created_at FROM notification_state WHERE id = 1`).
Scan(&lastID, &lastCreatedAt)
if err != nil {
slog.Error("failed to read notification state", "error", err)
return
}
type eventRow struct {
ID int64
EventType string
Title string
CreatedAt time.Time
}
var events []eventRow
var rows pgx.Rows
if lastCreatedAt != nil {
rows, err = s.db.Pool.Query(pollCtx,
`SELECT id, event_type, title, created_at FROM activity_events
WHERE (created_at, id) > ($1, $2) ORDER BY created_at, id LIMIT 100`,
*lastCreatedAt, lastID)
} else {
rows, err = s.db.Pool.Query(pollCtx,
`SELECT id, event_type, title, created_at FROM activity_events
WHERE id > $1 ORDER BY created_at, id LIMIT 100`,
lastID)
}
if err != nil {
slog.Error("failed to poll activity events", "error", err)
return
}
defer rows.Close()
for rows.Next() {
var e eventRow
if err := rows.Scan(&e.ID, &e.EventType, &e.Title, &e.CreatedAt); err != nil {
continue
}
events = append(events, e)
}
if len(events) == 0 {
return
}
// For each event, find subscribers and create queue entries
var maxID int64
var maxCreatedAt time.Time
for _, e := range events {
subscribers, subErr := s.GetSubscribersForEvent(pollCtx, e.EventType)
if subErr != nil {
slog.Error("failed to get subscribers", "error", subErr, "event_type", e.EventType)
continue
}
message, _ := json.Marshal(map[string]interface{}{
"event_type": e.EventType,
"title": e.Title,
})
for _, ch := range subscribers {
_, qErr := s.db.Pool.Exec(pollCtx,
`INSERT INTO notification_queue (channel_id, event_type, title, message)
VALUES ($1, $2, $3, $4)`, ch.ID, e.EventType, e.Title, message)
if qErr != nil {
slog.Error("failed to enqueue notification", "error", qErr, "channel", ch.Name)
}
}
if e.ID > maxID {
maxID = e.ID
maxCreatedAt = e.CreatedAt
}
}
// Update cursor
_, err = s.db.Pool.Exec(pollCtx,
`UPDATE notification_state SET last_event_id = $1, last_event_created_at = $2 WHERE id = 1`,
maxID, maxCreatedAt)
if err != nil {
slog.Error("failed to update notification state", "error", err)
}
}
func (s *NotificationService) processQueue(ctx context.Context) {
slog.Info("notification queue processor started")
defer slog.Info("notification queue processor stopped")
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.done:
return
case <-ctx.Done():
return
case <-ticker.C:
s.processQueueBatch(ctx)
}
}
}
func (s *NotificationService) processQueueBatch(ctx context.Context) {
batchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
rows, err := s.db.Pool.Query(batchCtx,
`SELECT q.id, q.channel_id, q.event_type, q.title, q.message, q.status, q.attempts,
c.name AS channel_name, c.type AS channel_type, c.config AS channel_config
FROM notification_queue q
JOIN notification_channels c ON q.channel_id = c.id
WHERE q.status IN ('pending', 'failed')
AND (q.next_retry_at IS NULL OR q.next_retry_at <= NOW())
AND q.attempts < q.max_attempts
LIMIT 50`)
if err != nil {
slog.Error("failed to query notification queue", "error", err)
return
}
defer rows.Close()
type queueItem struct {
ID int64
ChannelID int64
EventType string
Title string
Message json.RawMessage
Status string
Attempts int
ChannelName string
ChannelType string
ChannelConfig json.RawMessage
}
var items []queueItem
for rows.Next() {
var q queueItem
if err := rows.Scan(&q.ID, &q.ChannelID, &q.EventType, &q.Title, &q.Message,
&q.Status, &q.Attempts, &q.ChannelName, &q.ChannelType, &q.ChannelConfig); err != nil {
continue
}
items = append(items, q)
}
for _, q := range items {
deliverCtx, deliverCancel := context.WithTimeout(batchCtx, 15*time.Second)
var deliverErr error
switch q.ChannelType {
case "webhook":
var configMap map[string]interface{}
json.Unmarshal(q.ChannelConfig, &configMap)
webhookURL, _ := configMap["url"].(string)
var payload map[string]interface{}
json.Unmarshal(q.Message, &payload)
deliverErr = s.DeliverWebhook(deliverCtx, webhookURL, payload)
case "telegram":
var configMap map[string]interface{}
json.Unmarshal(q.ChannelConfig, &configMap)
botToken, _ := configMap["bot_token"].(string)
chatID, _ := configMap["chat_id"].(string)
var msg map[string]interface{}
json.Unmarshal(q.Message, &msg)
title, _ := msg["title"].(string)
text := fmt.Sprintf("<b>%s</b>\n%s", q.EventType, title)
deliverErr = s.DeliverTelegram(deliverCtx, botToken, chatID, text)
}
deliverCancel()
newAttempts := q.Attempts + 1
if deliverErr == nil {
_, err := s.db.Pool.Exec(batchCtx,
`UPDATE notification_queue SET status = 'delivered', attempts = $1, delivered_at = NOW() WHERE id = $2`,
newAttempts, q.ID)
if err != nil {
slog.Error("failed to update queue entry", "error", err)
}
slog.Info("notification delivered", "channel", q.ChannelName, "event", q.EventType)
} else {
var nextRetry *time.Time
var newStatus string = "failed"
errMsg := deliverErr.Error()
if newAttempts >= 5 {
newStatus = "dead"
slog.Warn("notification dead-lettered", "channel", q.ChannelName, "attempts", newAttempts)
} else {
backoff := calculateBackoff(newAttempts)
t := time.Now().Add(backoff)
nextRetry = &t
}
_, err := s.db.Pool.Exec(batchCtx,
`UPDATE notification_queue SET status = $1, attempts = $2, last_error = $3, next_retry_at = $4 WHERE id = $5`,
newStatus, newAttempts, errMsg, nextRetry, q.ID)
if err != nil {
slog.Error("failed to update queue entry", "error", err)
}
slog.Error("notification delivery failed", "channel", q.ChannelName, "type", q.ChannelType, "attempts", newAttempts)
}
}
}