Implements the remediation pass described in planning/19-security-audit-2026-04-30.md. All Critical findings and the Wave 1-4 High findings are closed; PoC tests added; full backend test suite green; helm chart lints clean. Wave 1 - Auth & identity - C1 OAuth state nonce: PutOAuthState / ConsumeOAuthState (valkey, GETDEL single-use, 15min TTL); Callback rejects missing/forged/cross- provider state before token exchange. - C2 OAuth identity linking: refuse silent linking to existing user unless info.EmailVerified is true. fetchGitHubUser now consults the /user/emails endpoint for the verified flag (no more hardcoded true); fetchFacebookUser sets EmailVerified=false (FB exposes no per-email verification flag). - H1 Magic-link verify: replaced Get + MarkUsed with a single atomic UPDATE...RETURNING (ConsumeMagicLink) - TOCTOU-free. - H2 TOTP code replay: MarkTOTPCodeConsumed (valkey SET NX, 120s TTL) prevents replay of a successfully validated code; fails closed on transient store errors. - H3 Backup-code orphan: DisableTOTP now also wipes totp_backup_codes. Wave 2 - Middleware & network - C3 CORS/CSRF regex anchoring: NewCORSConfig wraps each pattern with \A...\z so substring spoofing of origins is impossible. - H4 ClientIP: server reads APP_TRUSTED_PROXIES; gin SetTrustedProxies is called explicitly (empty default = no proxy trust). - H11 Body limit + DisallowUnknownFields: BodyLimitBytes middleware (1 MiB default) wraps every request; validate.BindJSON now uses a json.Decoder with DisallowUnknownFields and rejects trailing tokens; 413 envelope on body-limit overflow. - H16 NetworkPolicy: backend.networkPolicy.enabled defaults to true; new web-networkpolicy.yaml restricts web pod ingress to nginx-gateway and egress to backend service + DNS + 443. Wave 3 - Encryption at rest - C4 TOTP secrets: CreateTOTPSecret writes encrypted secret_v2; GetTOTPSecret prefers v2 with legacy fallback. - C5 OAuth tokens: migration 000033 adds *_v2 columns; CreateOAuthAccount and UpdateOAuthTokens write encrypted; GetOAuthAccount reads v2 with legacy fallback. - M1 Domain separation: crypto.DeriveKeyFor(secret, purpose) replaces single-purpose DeriveKey; settings, totp, oauth each use a distinct HKDF-derived subkey. DeriveKey kept as back-compat alias for settings. Wave 4 - Input & AI safety - C6 SSRF: new pkg/safehttp refuses to dial RFC1918, loopback, link- local, ULA, multicast, unspecified, or cloud-metadata IPs; scheme allowlist (http/https). Wired into pkg/scrape, discovery LinkChecker, and imageURLReachable. NewForTesting opt-in for httptest. - H13 PromptGuard German + Unicode: NFKC + Cf-class strip pre-pass closes zero-width and full-width-homoglyph bypasses; new German rules for ignoriere/missachte/vergiss/role-escalation/prompt-exfil/verbatim; Gemma-style and pipe-delimited chat-template tokens covered; source-fence rule prevents '=== Quelle:' splice in scraped text. - H14 BudgetGate: new ai.BudgetGate interface; UsageRepo.CheckBudget reads today's SUM(estimated_cost_usd) (10s cache) and refuses calls when AI_DAILY_CAP_USD is exceeded; GeminiProvider.Chat checks the gate before contacting Gemini. OAuth routes remain disabled in server/routes.go, so C1/C2 are not actively reachable today; fixes ensure correctness when re-enabled.
204 lines
6.2 KiB
Go
204 lines
6.2 KiB
Go
package settings
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"marktvogt.de/backend/internal/pkg/ai"
|
|
)
|
|
|
|
// UsageRepo persists and queries AI call records.
|
|
type UsageRepo struct {
|
|
db *pgxpool.Pool
|
|
|
|
// budget caching (audit H14): the daily-cap check runs on every AI call,
|
|
// so we cache today's SUM(estimated_cost_usd) for capCacheTTL to avoid a
|
|
// hot Postgres path under bursts.
|
|
capUSD float64
|
|
capCacheTTL time.Duration
|
|
capCacheMu sync.RWMutex
|
|
cachedCost float64
|
|
cachedAtUnix int64
|
|
}
|
|
|
|
func NewUsageRepo(db *pgxpool.Pool) *UsageRepo {
|
|
return &UsageRepo{db: db, capCacheTTL: 10 * time.Second}
|
|
}
|
|
|
|
// SetDailyCap configures the per-day AI spend cap in USD. Zero disables the
|
|
// gate. Audit H14.
|
|
func (r *UsageRepo) SetDailyCap(usd float64) {
|
|
r.capCacheMu.Lock()
|
|
defer r.capCacheMu.Unlock()
|
|
r.capUSD = usd
|
|
}
|
|
|
|
// CheckBudget refuses calls when today's spend exceeds the configured cap.
|
|
// Implements ai.BudgetGate. The daily window is calendar-day in UTC.
|
|
func (r *UsageRepo) CheckBudget(ctx context.Context) error {
|
|
r.capCacheMu.RLock()
|
|
limit := r.capUSD
|
|
cached := r.cachedCost
|
|
cachedAt := r.cachedAtUnix
|
|
ttl := r.capCacheTTL
|
|
r.capCacheMu.RUnlock()
|
|
if limit <= 0 {
|
|
return nil
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
if now-cachedAt < int64(ttl.Seconds()) {
|
|
if cached >= limit {
|
|
return &ai.ProviderError{
|
|
Code: ai.ErrBudgetExceeded,
|
|
Message: fmt.Sprintf("daily AI budget exceeded: %.4f >= %.4f USD", cached, limit),
|
|
Retryable: false,
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
stats, err := r.Today(ctx)
|
|
if err != nil {
|
|
// Fail open on transient stat errors — refusing all AI calls because
|
|
// Postgres briefly hiccuped is a worse outcome than letting one
|
|
// over-cap call through. The same call's Record will catch up the
|
|
// counter on the next check. The error is logged so an operator can
|
|
// still notice when the gate is silently bypassed.
|
|
slog.Warn("budget gate: today query failed; allowing request", "error", err)
|
|
return nil
|
|
}
|
|
|
|
r.capCacheMu.Lock()
|
|
r.cachedCost = stats.EstimatedCostUSD
|
|
r.cachedAtUnix = now
|
|
r.capCacheMu.Unlock()
|
|
|
|
if stats.EstimatedCostUSD >= limit {
|
|
return &ai.ProviderError{
|
|
Code: ai.ErrBudgetExceeded,
|
|
Message: fmt.Sprintf("daily AI budget exceeded: %.4f >= %.4f USD", stats.EstimatedCostUSD, limit),
|
|
Retryable: false,
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Record writes a single usage event — implements ai.UsageRecorder.
|
|
func (r *UsageRepo) Record(ctx context.Context, e ai.UsageEvent) error {
|
|
var errStr *string
|
|
if e.Error != "" {
|
|
errStr = &e.Error
|
|
}
|
|
var promptVersion *string
|
|
if e.PromptVersion != "" {
|
|
promptVersion = &e.PromptVersion
|
|
}
|
|
_, err := r.db.Exec(ctx, `
|
|
INSERT INTO ai_usage
|
|
(provider, model, call_type, input_tokens, output_tokens, thinking_tokens,
|
|
grounded, duration_ms, estimated_cost_usd, error, prompt_version)
|
|
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)
|
|
`, e.Provider, e.Model, e.CallType, e.InputTokens, e.OutputTokens, e.ThinkingTokens,
|
|
e.Grounded, e.DurationMs, e.EstimatedCostUSD, errStr, promptVersion)
|
|
if err != nil {
|
|
return fmt.Errorf("usage: record: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UsageStats is a rollup over a time window.
|
|
type UsageStats struct {
|
|
Calls int `json:"calls"`
|
|
InputTokens int `json:"input_tokens"`
|
|
OutputTokens int `json:"output_tokens"`
|
|
ThinkingTokens int `json:"thinking_tokens"`
|
|
GroundingCalls int `json:"grounding_calls"`
|
|
EstimatedCostUSD float64 `json:"estimated_cost_usd"`
|
|
}
|
|
|
|
func (r *UsageRepo) Today(ctx context.Context) (UsageStats, error) {
|
|
return r.statsWindow(ctx, "1 day")
|
|
}
|
|
|
|
func (r *UsageRepo) Month(ctx context.Context) (UsageStats, error) {
|
|
return r.statsWindow(ctx, "30 days")
|
|
}
|
|
|
|
func (r *UsageRepo) GroundingToday(ctx context.Context) (int, error) {
|
|
row := r.db.QueryRow(ctx, `
|
|
SELECT COUNT(*) FROM ai_usage
|
|
WHERE grounded AND created_at >= now() - INTERVAL '1 day'
|
|
`)
|
|
var n int
|
|
return n, row.Scan(&n)
|
|
}
|
|
|
|
func (r *UsageRepo) statsWindow(ctx context.Context, interval string) (UsageStats, error) {
|
|
row := r.db.QueryRow(ctx, fmt.Sprintf(`
|
|
SELECT
|
|
COUNT(*) AS calls,
|
|
COALESCE(SUM(input_tokens),0) AS input_tokens,
|
|
COALESCE(SUM(output_tokens),0) AS output_tokens,
|
|
COALESCE(SUM(thinking_tokens),0) AS thinking_tokens,
|
|
COALESCE(SUM(CASE WHEN grounded THEN 1 ELSE 0 END),0) AS grounding_calls,
|
|
COALESCE(SUM(estimated_cost_usd),0) AS cost
|
|
FROM ai_usage
|
|
WHERE created_at >= now() - INTERVAL '%s'
|
|
`, interval))
|
|
var s UsageStats
|
|
if err := row.Scan(&s.Calls, &s.InputTokens, &s.OutputTokens, &s.ThinkingTokens, &s.GroundingCalls, &s.EstimatedCostUSD); err != nil {
|
|
return s, fmt.Errorf("usage: stats(%s): %w", interval, err)
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
// UsageEvent is a single entry from ai_usage.
|
|
type UsageEvent struct {
|
|
ID int64 `json:"id"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
Provider string `json:"provider"`
|
|
Model string `json:"model"`
|
|
CallType string `json:"call_type"`
|
|
InputTokens int `json:"input_tokens"`
|
|
OutputTokens int `json:"output_tokens"`
|
|
ThinkingTokens int `json:"thinking_tokens"`
|
|
Grounded bool `json:"grounded"`
|
|
DurationMs int `json:"duration_ms"`
|
|
EstimatedCostUSD float64 `json:"estimated_cost_usd"`
|
|
Error *string `json:"error,omitempty"`
|
|
PromptVersion *string `json:"prompt_version,omitempty"`
|
|
}
|
|
|
|
func (r *UsageRepo) Recent(ctx context.Context, limit int) ([]UsageEvent, error) {
|
|
rows, err := r.db.Query(ctx, `
|
|
SELECT id, created_at, provider, model, call_type,
|
|
input_tokens, output_tokens, thinking_tokens, grounded, duration_ms,
|
|
estimated_cost_usd, error, prompt_version
|
|
FROM ai_usage
|
|
ORDER BY created_at DESC
|
|
LIMIT $1
|
|
`, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("usage: recent: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []UsageEvent
|
|
for rows.Next() {
|
|
var e UsageEvent
|
|
if err := rows.Scan(&e.ID, &e.CreatedAt, &e.Provider, &e.Model, &e.CallType,
|
|
&e.InputTokens, &e.OutputTokens, &e.ThinkingTokens, &e.Grounded, &e.DurationMs,
|
|
&e.EstimatedCostUSD, &e.Error, &e.PromptVersion); err != nil {
|
|
return nil, fmt.Errorf("usage: scan: %w", err)
|
|
}
|
|
out = append(out, e)
|
|
}
|
|
return out, rows.Err()
|
|
}
|