Files
vikingowl 5821547a73 feat(security): close audit waves 1-4 (C1-C6, H1, H2, H4, H11, H13, H14, H16)
Implements the remediation pass described in
planning/19-security-audit-2026-04-30.md. All Critical findings and the
Wave 1-4 High findings are closed; PoC tests added; full backend test
suite green; helm chart lints clean.

Wave 1 - Auth & identity
- C1 OAuth state nonce: PutOAuthState / ConsumeOAuthState (valkey,
  GETDEL single-use, 15min TTL); Callback rejects missing/forged/cross-
  provider state before token exchange.
- C2 OAuth identity linking: refuse silent linking to existing user
  unless info.EmailVerified is true. fetchGitHubUser now consults the
  /user/emails endpoint for the verified flag (no more hardcoded true);
  fetchFacebookUser sets EmailVerified=false (FB exposes no per-email
  verification flag).
- H1 Magic-link verify: replaced Get + MarkUsed with a single atomic
  UPDATE...RETURNING (ConsumeMagicLink) - TOCTOU-free.
- H2 TOTP code replay: MarkTOTPCodeConsumed (valkey SET NX, 120s TTL)
  prevents replay of a successfully validated code; fails closed on
  transient store errors.
- H3 Backup-code orphan: DisableTOTP now also wipes totp_backup_codes.

Wave 2 - Middleware & network
- C3 CORS/CSRF regex anchoring: NewCORSConfig wraps each pattern with
  \A...\z so substring spoofing of origins is impossible.
- H4 ClientIP: server reads APP_TRUSTED_PROXIES; gin SetTrustedProxies
  is called explicitly (empty default = no proxy trust).
- H11 Body limit + DisallowUnknownFields: BodyLimitBytes middleware
  (1 MiB default) wraps every request; validate.BindJSON now uses a
  json.Decoder with DisallowUnknownFields and rejects trailing tokens;
  413 envelope on body-limit overflow.
- H16 NetworkPolicy: backend.networkPolicy.enabled defaults to true;
  new web-networkpolicy.yaml restricts web pod ingress to nginx-gateway
  and egress to backend service + DNS + 443.

Wave 3 - Encryption at rest
- C4 TOTP secrets: CreateTOTPSecret writes encrypted secret_v2;
  GetTOTPSecret prefers v2 with legacy fallback.
- C5 OAuth tokens: migration 000033 adds *_v2 columns; CreateOAuthAccount
  and UpdateOAuthTokens write encrypted; GetOAuthAccount reads v2 with
  legacy fallback.
- M1 Domain separation: crypto.DeriveKeyFor(secret, purpose) replaces
  single-purpose DeriveKey; settings, totp, oauth each use a distinct
  HKDF-derived subkey. DeriveKey kept as back-compat alias for settings.

Wave 4 - Input & AI safety
- C6 SSRF: new pkg/safehttp refuses to dial RFC1918, loopback, link-
  local, ULA, multicast, unspecified, or cloud-metadata IPs; scheme
  allowlist (http/https). Wired into pkg/scrape, discovery LinkChecker,
  and imageURLReachable. NewForTesting opt-in for httptest.
- H13 PromptGuard German + Unicode: NFKC + Cf-class strip pre-pass
  closes zero-width and full-width-homoglyph bypasses; new German rules
  for ignoriere/missachte/vergiss/role-escalation/prompt-exfil/verbatim;
  Gemma-style and pipe-delimited chat-template tokens covered;
  source-fence rule prevents '=== Quelle:' splice in scraped text.
- H14 BudgetGate: new ai.BudgetGate interface; UsageRepo.CheckBudget
  reads today's SUM(estimated_cost_usd) (10s cache) and refuses calls
  when AI_DAILY_CAP_USD is exceeded; GeminiProvider.Chat checks the
  gate before contacting Gemini.

OAuth routes remain disabled in server/routes.go, so C1/C2 are not
actively reachable today; fixes ensure correctness when re-enabled.
2026-04-30 23:41:48 +02:00

204 lines
6.2 KiB
Go

package settings
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"marktvogt.de/backend/internal/pkg/ai"
)
// UsageRepo persists and queries AI call records.
type UsageRepo struct {
db *pgxpool.Pool
// budget caching (audit H14): the daily-cap check runs on every AI call,
// so we cache today's SUM(estimated_cost_usd) for capCacheTTL to avoid a
// hot Postgres path under bursts.
capUSD float64
capCacheTTL time.Duration
capCacheMu sync.RWMutex
cachedCost float64
cachedAtUnix int64
}
func NewUsageRepo(db *pgxpool.Pool) *UsageRepo {
return &UsageRepo{db: db, capCacheTTL: 10 * time.Second}
}
// SetDailyCap configures the per-day AI spend cap in USD. Zero disables the
// gate. Audit H14.
func (r *UsageRepo) SetDailyCap(usd float64) {
r.capCacheMu.Lock()
defer r.capCacheMu.Unlock()
r.capUSD = usd
}
// CheckBudget refuses calls when today's spend exceeds the configured cap.
// Implements ai.BudgetGate. The daily window is calendar-day in UTC.
func (r *UsageRepo) CheckBudget(ctx context.Context) error {
r.capCacheMu.RLock()
limit := r.capUSD
cached := r.cachedCost
cachedAt := r.cachedAtUnix
ttl := r.capCacheTTL
r.capCacheMu.RUnlock()
if limit <= 0 {
return nil
}
now := time.Now().Unix()
if now-cachedAt < int64(ttl.Seconds()) {
if cached >= limit {
return &ai.ProviderError{
Code: ai.ErrBudgetExceeded,
Message: fmt.Sprintf("daily AI budget exceeded: %.4f >= %.4f USD", cached, limit),
Retryable: false,
}
}
return nil
}
stats, err := r.Today(ctx)
if err != nil {
// Fail open on transient stat errors — refusing all AI calls because
// Postgres briefly hiccuped is a worse outcome than letting one
// over-cap call through. The same call's Record will catch up the
// counter on the next check. The error is logged so an operator can
// still notice when the gate is silently bypassed.
slog.Warn("budget gate: today query failed; allowing request", "error", err)
return nil
}
r.capCacheMu.Lock()
r.cachedCost = stats.EstimatedCostUSD
r.cachedAtUnix = now
r.capCacheMu.Unlock()
if stats.EstimatedCostUSD >= limit {
return &ai.ProviderError{
Code: ai.ErrBudgetExceeded,
Message: fmt.Sprintf("daily AI budget exceeded: %.4f >= %.4f USD", stats.EstimatedCostUSD, limit),
Retryable: false,
}
}
return nil
}
// Record writes a single usage event — implements ai.UsageRecorder.
func (r *UsageRepo) Record(ctx context.Context, e ai.UsageEvent) error {
var errStr *string
if e.Error != "" {
errStr = &e.Error
}
var promptVersion *string
if e.PromptVersion != "" {
promptVersion = &e.PromptVersion
}
_, err := r.db.Exec(ctx, `
INSERT INTO ai_usage
(provider, model, call_type, input_tokens, output_tokens, thinking_tokens,
grounded, duration_ms, estimated_cost_usd, error, prompt_version)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)
`, e.Provider, e.Model, e.CallType, e.InputTokens, e.OutputTokens, e.ThinkingTokens,
e.Grounded, e.DurationMs, e.EstimatedCostUSD, errStr, promptVersion)
if err != nil {
return fmt.Errorf("usage: record: %w", err)
}
return nil
}
// UsageStats is a rollup over a time window.
type UsageStats struct {
Calls int `json:"calls"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
ThinkingTokens int `json:"thinking_tokens"`
GroundingCalls int `json:"grounding_calls"`
EstimatedCostUSD float64 `json:"estimated_cost_usd"`
}
func (r *UsageRepo) Today(ctx context.Context) (UsageStats, error) {
return r.statsWindow(ctx, "1 day")
}
func (r *UsageRepo) Month(ctx context.Context) (UsageStats, error) {
return r.statsWindow(ctx, "30 days")
}
func (r *UsageRepo) GroundingToday(ctx context.Context) (int, error) {
row := r.db.QueryRow(ctx, `
SELECT COUNT(*) FROM ai_usage
WHERE grounded AND created_at >= now() - INTERVAL '1 day'
`)
var n int
return n, row.Scan(&n)
}
func (r *UsageRepo) statsWindow(ctx context.Context, interval string) (UsageStats, error) {
row := r.db.QueryRow(ctx, fmt.Sprintf(`
SELECT
COUNT(*) AS calls,
COALESCE(SUM(input_tokens),0) AS input_tokens,
COALESCE(SUM(output_tokens),0) AS output_tokens,
COALESCE(SUM(thinking_tokens),0) AS thinking_tokens,
COALESCE(SUM(CASE WHEN grounded THEN 1 ELSE 0 END),0) AS grounding_calls,
COALESCE(SUM(estimated_cost_usd),0) AS cost
FROM ai_usage
WHERE created_at >= now() - INTERVAL '%s'
`, interval))
var s UsageStats
if err := row.Scan(&s.Calls, &s.InputTokens, &s.OutputTokens, &s.ThinkingTokens, &s.GroundingCalls, &s.EstimatedCostUSD); err != nil {
return s, fmt.Errorf("usage: stats(%s): %w", interval, err)
}
return s, nil
}
// UsageEvent is a single entry from ai_usage.
type UsageEvent struct {
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
Provider string `json:"provider"`
Model string `json:"model"`
CallType string `json:"call_type"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
ThinkingTokens int `json:"thinking_tokens"`
Grounded bool `json:"grounded"`
DurationMs int `json:"duration_ms"`
EstimatedCostUSD float64 `json:"estimated_cost_usd"`
Error *string `json:"error,omitempty"`
PromptVersion *string `json:"prompt_version,omitempty"`
}
func (r *UsageRepo) Recent(ctx context.Context, limit int) ([]UsageEvent, error) {
rows, err := r.db.Query(ctx, `
SELECT id, created_at, provider, model, call_type,
input_tokens, output_tokens, thinking_tokens, grounded, duration_ms,
estimated_cost_usd, error, prompt_version
FROM ai_usage
ORDER BY created_at DESC
LIMIT $1
`, limit)
if err != nil {
return nil, fmt.Errorf("usage: recent: %w", err)
}
defer rows.Close()
var out []UsageEvent
for rows.Next() {
var e UsageEvent
if err := rows.Scan(&e.ID, &e.CreatedAt, &e.Provider, &e.Model, &e.CallType,
&e.InputTokens, &e.OutputTokens, &e.ThinkingTokens, &e.Grounded, &e.DurationMs,
&e.EstimatedCostUSD, &e.Error, &e.PromptVersion); err != nil {
return nil, fmt.Errorf("usage: scan: %w", err)
}
out = append(out, e)
}
return out, rows.Err()
}