From 0997d4befa96fccc0325b6220d505d978276a15b Mon Sep 17 00:00:00 2001 From: vikingowl Date: Sun, 26 Apr 2026 11:54:37 +0200 Subject: [PATCH] feat(auth): D1 non-breaking security foundations - CORS: rewrite middleware with Vary: Origin, regex origin patterns, startup validation, and prod boot-fail on empty allowlist; shared CORSConfig exported for CSRF reuse - CSRF: new Origin/Referer check middleware sharing CORS allowlist; Bearer-token clients exempt; mounts globally after CORS - Argon2id: new password package with PHC format, bcrypt dispatch, and NeedsRehash; lazy upgrade on login in auth service - Rate limiting: add RateLimitByKey with custom key function; apply per-route limits to /auth/login, /refresh, /2fa/verify, /auth/magic-link, and /auth/password - apierror: add CSRFMismatch and RefreshReuse error constructors - Migrations: 000027 (session model schema columns for D2/D3), 000028 (TOTP secret_v2 column + totp_backup_codes table) - cmd/totp-encrypt: one-shot job to encrypt existing TOTP secrets --- backend/cmd/totp-encrypt/main.go | 127 ++++++++++++++ backend/internal/config/config.go | 46 +++++- backend/internal/domain/auth/magiclink.go | 4 +- backend/internal/domain/auth/routes.go | 10 +- backend/internal/domain/auth/service.go | 26 ++- backend/internal/middleware/cors.go | 77 +++++++-- backend/internal/middleware/cors_test.go | 155 ++++++++++++++++++ backend/internal/middleware/csrf.go | 50 ++++++ backend/internal/middleware/csrf_test.go | 141 ++++++++++++++++ backend/internal/middleware/ratelimit.go | 14 +- backend/internal/middleware/ratelimit_test.go | 79 +++++++++ backend/internal/pkg/apierror/error.go | 8 + backend/internal/pkg/password/password.go | 94 +++++++++++ .../internal/pkg/password/password_test.go | 91 ++++++++++ backend/internal/server/routes.go | 18 +- backend/internal/server/server.go | 6 +- ...027_sessions_session_model_rework.down.sql | 12 ++ ...00027_sessions_session_model_rework.up.sql | 16 ++ ..._encrypt_secrets_and_backup_codes.down.sql | 4 + ...tp_encrypt_secrets_and_backup_codes.up.sql | 21 +++ 20 files changed, 957 insertions(+), 42 deletions(-) create mode 100644 backend/cmd/totp-encrypt/main.go create mode 100644 backend/internal/middleware/cors_test.go create mode 100644 backend/internal/middleware/csrf.go create mode 100644 backend/internal/middleware/csrf_test.go create mode 100644 backend/internal/middleware/ratelimit_test.go create mode 100644 backend/internal/pkg/password/password.go create mode 100644 backend/internal/pkg/password/password_test.go create mode 100644 backend/migrations/000027_sessions_session_model_rework.down.sql create mode 100644 backend/migrations/000027_sessions_session_model_rework.up.sql create mode 100644 backend/migrations/000028_totp_encrypt_secrets_and_backup_codes.down.sql create mode 100644 backend/migrations/000028_totp_encrypt_secrets_and_backup_codes.up.sql diff --git a/backend/cmd/totp-encrypt/main.go b/backend/cmd/totp-encrypt/main.go new file mode 100644 index 0000000..65a50f2 --- /dev/null +++ b/backend/cmd/totp-encrypt/main.go @@ -0,0 +1,127 @@ +// totp-encrypt is a one-shot migration job that encrypts plaintext TOTP secrets +// in the totp_secrets table. Run once after deploying migration 000028. +// Set TOTP_ENCRYPTION_KEY (32-byte AES key, base64-encoded) before running. +// +// Usage: TOTP_ENCRYPTION_KEY= DB_HOST=... ./totp-encrypt [--dry-run] +package main + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "log/slog" + "os" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "marktvogt.de/backend/internal/pkg/crypto" +) + +func main() { + if err := run(); err != nil { + slog.Error("totp-encrypt failed", "error", err) + os.Exit(1) + } +} + +func run() error { + dryRun := len(os.Args) > 1 && os.Args[1] == "--dry-run" + + keyB64 := os.Getenv("TOTP_ENCRYPTION_KEY") + if keyB64 == "" { + return fmt.Errorf("TOTP_ENCRYPTION_KEY is required") + } + keyBytes, err := base64.StdEncoding.DecodeString(keyB64) + if err != nil { + return fmt.Errorf("TOTP_ENCRYPTION_KEY: invalid base64: %w", err) + } + if len(keyBytes) != 32 { + return fmt.Errorf("TOTP_ENCRYPTION_KEY must be exactly 32 bytes (got %d)", len(keyBytes)) + } + var key [32]byte + copy(key[:], keyBytes) + + dsn := fmt.Sprintf( + "postgres://%s:%s@%s:%s/%s?sslmode=%s", + getenv("DB_USER", "marktvogt"), + getenv("DB_PASSWORD", "marktvogt"), + getenv("DB_HOST", "localhost"), + getenv("DB_PORT", "5432"), + getenv("DB_NAME", "marktvogt"), + getenv("DB_SSLMODE", "disable"), + ) + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + return fmt.Errorf("connecting to database: %w", err) + } + defer pool.Close() + + rows, err := pool.Query(ctx, + `SELECT id, secret FROM totp_secrets WHERE secret_v2 IS NULL AND secret != ''`) + if err != nil { + return fmt.Errorf("querying totp_secrets: %w", err) + } + + type row struct { + id string + secret string + } + var pending []row + for rows.Next() { + var r row + if scanErr := rows.Scan(&r.id, &r.secret); scanErr != nil { + return fmt.Errorf("scanning row: %w", scanErr) + } + pending = append(pending, r) + } + rows.Close() + + slog.Info("rows to encrypt", "count", len(pending), "dry_run", dryRun) + + if dryRun { + slog.Info("dry run — no changes written") + return nil + } + + ok, failed := 0, 0 + for _, r := range pending { + ciphertext, sealErr := crypto.Seal(key, []byte(r.secret)) + if sealErr != nil { + slog.Error("encrypting secret", "id", r.id, "error", sealErr) + failed++ + continue + } + encoded := "v1:" + base64.StdEncoding.EncodeToString(ciphertext) + + _, execErr := pool.Exec(ctx, + `UPDATE totp_secrets SET secret_v2 = $1 WHERE id = $2`, + encoded, r.id) + if execErr != nil { + if errors.Is(execErr, pgx.ErrNoRows) { + slog.Warn("row disappeared during migration", "id", r.id) + } else { + slog.Error("updating secret_v2", "id", r.id, "error", execErr) + failed++ + } + continue + } + ok++ + } + + slog.Info("encryption complete", "encrypted", ok, "failed", failed) + if failed > 0 { + return fmt.Errorf("%d rows failed to encrypt", failed) + } + return nil +} + +func getenv(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 4453e50..3ec0f0e 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -4,7 +4,9 @@ import ( "fmt" "log/slog" "os" + "regexp" "strconv" + "strings" "time" ) @@ -85,7 +87,8 @@ type JWTConfig struct { } type CORSConfig struct { - Origins string + Origins []string + OriginPatterns []string } type RateConfig struct { @@ -132,7 +135,26 @@ type NotificationConfig struct { FrontendURL string } +const envDevelopment = "development" + func Load() (*Config, error) { + appEnv := envStr("APP_ENV", envDevelopment) + + corsOrigins := envStrSlice("CORS_ORIGINS") + if len(corsOrigins) == 0 && appEnv == envDevelopment { + corsOrigins = []string{"http://localhost:5173"} + } + corsPatterns := envStrSlice("CORS_ORIGIN_PATTERNS") + if len(corsOrigins) == 0 && len(corsPatterns) == 0 && appEnv != envDevelopment { + return nil, fmt.Errorf("CORS_ORIGINS or CORS_ORIGIN_PATTERNS is required in non-dev environments") + } + for _, p := range corsPatterns { + if _, err := regexp.Compile(p); err != nil { + return nil, fmt.Errorf("CORS_ORIGIN_PATTERNS: invalid regex %q: %w", p, err) + } + } + slog.Info("cors allowlist loaded", "exact_origins", corsOrigins, "pattern_count", len(corsPatterns)) + port, err := envInt("APP_PORT", 8080) if err != nil { return nil, fmt.Errorf("APP_PORT: %w", err) @@ -205,7 +227,7 @@ func Load() (*Config, error) { return &Config{ App: AppConfig{ - Env: envStr("APP_ENV", "development"), + Env: appEnv, Host: envStr("APP_HOST", "0.0.0.0"), Port: port, }, @@ -230,7 +252,8 @@ func Load() (*Config, error) { SessionTTL: sessionTTL, }, CORS: CORSConfig{ - Origins: envStr("CORS_ORIGINS", "http://localhost:5173"), + Origins: corsOrigins, + OriginPatterns: corsPatterns, }, Rate: RateConfig{ RPS: rps, @@ -297,7 +320,7 @@ func (c *Config) Addr() string { } func (c *Config) IsDev() bool { - return c.App.Env == "development" + return c.App.Env == envDevelopment } func envStr(key, fallback string) string { @@ -331,6 +354,21 @@ func envFloat(key string, fallback float64) (float64, error) { return f, nil } +func envStrSlice(key string) []string { + v := os.Getenv(key) + if v == "" { + return nil + } + var result []string + for _, s := range strings.Split(v, ",") { + s = strings.TrimSpace(s) + if s != "" { + result = append(result, s) + } + } + return result +} + func envDuration(key string, fallback time.Duration) (time.Duration, error) { v := os.Getenv(key) if v == "" { diff --git a/backend/internal/domain/auth/magiclink.go b/backend/internal/domain/auth/magiclink.go index f6ad1a4..a58730e 100644 --- a/backend/internal/domain/auth/magiclink.go +++ b/backend/internal/domain/auth/magiclink.go @@ -168,7 +168,7 @@ func (h *MagicLinkHandler) findOrCreateUser(ctx context.Context, email string) ( return h.userRepo.CreateOAuthUser(ctx, email, user.GenerateDisplayName(), true) } -func RegisterMagicLinkRoutes(rg *gin.RouterGroup, h *MagicLinkHandler) { - rg.POST("/auth/magic-link", h.RequestMagicLink) +func RegisterMagicLinkRoutes(rg *gin.RouterGroup, h *MagicLinkHandler, requestLimit gin.HandlerFunc) { + rg.POST("/auth/magic-link", requestLimit, h.RequestMagicLink) rg.GET("/auth/magic-link/verify", h.VerifyMagicLink) } diff --git a/backend/internal/domain/auth/routes.go b/backend/internal/domain/auth/routes.go index daf5d49..91427d9 100644 --- a/backend/internal/domain/auth/routes.go +++ b/backend/internal/domain/auth/routes.go @@ -2,20 +2,20 @@ package auth import "github.com/gin-gonic/gin" -func RegisterRoutes(rg *gin.RouterGroup, h *Handler, requireAuth gin.HandlerFunc) { +func RegisterRoutes(rg *gin.RouterGroup, h *Handler, requireAuth, loginLimit, refreshLimit, twoFALimit, passwordLimit gin.HandlerFunc) { auth := rg.Group("/auth") { auth.POST("/register", h.Register) - auth.POST("/login", h.Login) + auth.POST("/login", loginLimit, h.Login) auth.POST("/logout", requireAuth, h.Logout) - auth.POST("/refresh", h.Refresh) + auth.POST("/refresh", refreshLimit, h.Refresh) // Password - auth.PUT("/password", requireAuth, h.ChangePassword) + auth.PUT("/password", requireAuth, passwordLimit, h.ChangePassword) // 2FA auth.POST("/2fa/setup", requireAuth, h.SetupTOTP) - auth.POST("/2fa/verify", requireAuth, h.VerifyTOTP) + auth.POST("/2fa/verify", requireAuth, twoFALimit, h.VerifyTOTP) auth.DELETE("/2fa", requireAuth, h.DisableTOTP) } } diff --git a/backend/internal/domain/auth/service.go b/backend/internal/domain/auth/service.go index 340f9a6..6c3e1f5 100644 --- a/backend/internal/domain/auth/service.go +++ b/backend/internal/domain/auth/service.go @@ -4,12 +4,13 @@ import ( "context" "errors" "fmt" + "log/slog" "time" "github.com/google/uuid" - "golang.org/x/crypto/bcrypt" "marktvogt.de/backend/internal/domain/user" + "marktvogt.de/backend/internal/pkg/password" ) type Service struct { @@ -29,12 +30,12 @@ func NewService(authRepo Repository, userRepo user.Repository, tokenSvc *TokenSe } func (s *Service) Register(ctx context.Context, req RegisterRequest, ip, ua string) (AuthData, error) { - hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) + hash, err := password.Hash(req.Password) if err != nil { return AuthData{}, fmt.Errorf("hashing password: %w", err) } - u, err := s.userRepo.Create(ctx, req.Email, string(hash), req.DisplayName) + u, err := s.userRepo.Create(ctx, req.Email, hash, req.DisplayName) if err != nil { if errors.Is(err, user.ErrEmailAlreadyTaken) { return AuthData{}, err @@ -58,10 +59,20 @@ func (s *Service) Login(ctx context.Context, req LoginRequest, ip, ua string) (A return AuthData{}, fmt.Errorf("invalid credentials") } - if err := bcrypt.CompareHashAndPassword([]byte(*u.PasswordHash), []byte(req.Password)); err != nil { + ok, err := password.Verify(req.Password, *u.PasswordHash) + if err != nil || !ok { return AuthData{}, fmt.Errorf("invalid credentials") } + // Lazily upgrade bcrypt hashes to Argon2id on successful login. + if password.NeedsRehash(*u.PasswordHash) { + if newHash, hashErr := password.Hash(req.Password); hashErr == nil { + if _, updateErr := s.userRepo.Update(ctx, u.ID, map[string]any{"password_hash": newHash}); updateErr != nil { + slog.Warn("password rehash failed", "user_id", u.ID, "error", updateErr) + } + } + } + // Check 2FA if enabled if req.TOTPCode != "" { if err := s.validateTOTP(ctx, u.ID, req.TOTPCode); err != nil { @@ -165,17 +176,18 @@ func (s *Service) ChangePassword(ctx context.Context, userID uuid.UUID, req Chan if req.CurrentPassword == "" { return fmt.Errorf("current password required") } - if err := bcrypt.CompareHashAndPassword([]byte(*u.PasswordHash), []byte(req.CurrentPassword)); err != nil { + ok, err := password.Verify(req.CurrentPassword, *u.PasswordHash) + if err != nil || !ok { return fmt.Errorf("current password incorrect") } } - hash, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) + hash, err := password.Hash(req.NewPassword) if err != nil { return fmt.Errorf("hashing password: %w", err) } - _, err = s.userRepo.Update(ctx, userID, map[string]any{"password_hash": string(hash)}) + _, err = s.userRepo.Update(ctx, userID, map[string]any{"password_hash": hash}) if err != nil { return fmt.Errorf("updating password: %w", err) } diff --git a/backend/internal/middleware/cors.go b/backend/internal/middleware/cors.go index a069159..5734e13 100644 --- a/backend/internal/middleware/cors.go +++ b/backend/internal/middleware/cors.go @@ -1,36 +1,77 @@ package middleware import ( + "log/slog" "net/http" - "strings" + "regexp" "github.com/gin-gonic/gin" ) -func CORS(allowedOrigins string) gin.HandlerFunc { - origins := make(map[string]bool) - for _, o := range strings.Split(allowedOrigins, ",") { - o = strings.TrimSpace(o) - if o != "" { - origins[o] = true +// CORSConfig holds the parsed CORS allowlist. Build via NewCORSConfig to +// enable regex-pattern support; direct struct literals support exact Origins only. +type CORSConfig struct { + Origins []string + compiled []*regexp.Regexp +} + +// NewCORSConfig compiles regex patterns and returns a ready CORSConfig. +// Returns an error if any pattern fails to compile. +func NewCORSConfig(origins []string, patterns []string) (CORSConfig, error) { + cfg := CORSConfig{Origins: origins} + for _, p := range patterns { + re, err := regexp.Compile(p) + if err != nil { + return CORSConfig{}, err + } + cfg.compiled = append(cfg.compiled, re) + } + return cfg, nil +} + +// IsAllowedOrigin returns true if origin matches an exact entry or a compiled pattern. +func (c CORSConfig) IsAllowedOrigin(origin string) bool { + if origin == "" { + return false + } + for _, o := range c.Origins { + if o == origin { + return true } } + for _, re := range c.compiled { + if re.MatchString(origin) { + return true + } + } + return false +} +// CORS returns a middleware that sets CORS headers for allowed origins. +// Vary: Origin is always set so caches key on origin correctly. +func CORS(cfg CORSConfig) gin.HandlerFunc { return func(c *gin.Context) { origin := c.GetHeader("Origin") - if origins[origin] { - c.Header("Access-Control-Allow-Origin", origin) - c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") - c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, X-Request-ID") - c.Header("Access-Control-Expose-Headers", "X-Request-ID") - c.Header("Access-Control-Allow-Credentials", "true") - c.Header("Access-Control-Max-Age", "86400") - } + // Always set Vary: Origin so caches don't serve the wrong response across origins. + c.Header("Vary", "Origin") - if c.Request.Method == http.MethodOptions { - c.AbortWithStatus(http.StatusNoContent) - return + if cfg.IsAllowedOrigin(origin) { + c.Header("Access-Control-Allow-Origin", origin) + c.Header("Access-Control-Allow-Credentials", "true") + c.Header("Access-Control-Expose-Headers", "X-Request-ID") + + if c.Request.Method == http.MethodOptions { + c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") + c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, X-Request-ID") + c.Header("Access-Control-Max-Age", "86400") + // Extend Vary for preflight so caches differentiate by request method/headers too. + c.Header("Vary", "Origin, Access-Control-Request-Method, Access-Control-Request-Headers") + c.AbortWithStatus(http.StatusNoContent) + return + } + } else if origin != "" { + slog.Warn("cors origin rejected", "origin", origin, "path", c.Request.URL.Path, "method", c.Request.Method) } c.Next() diff --git a/backend/internal/middleware/cors_test.go b/backend/internal/middleware/cors_test.go new file mode 100644 index 0000000..c145710 --- /dev/null +++ b/backend/internal/middleware/cors_test.go @@ -0,0 +1,155 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + "marktvogt.de/backend/internal/middleware" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func corsRouter(cfg middleware.CORSConfig) *gin.Engine { + r := gin.New() + r.Use(middleware.CORS(cfg)) + r.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + r.POST("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + return r +} + +func TestCORS_AllowedOriginSetsHeaders(t *testing.T) { + r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}}) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Origin", "https://marktvogt.de") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "https://marktvogt.de" { + t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "https://marktvogt.de") + } + if got := w.Header().Get("Vary"); got == "" { + t.Error("Vary header missing") + } + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200", w.Code) + } +} + +func TestCORS_VaryOriginAlwaysPresent(t *testing.T) { + r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}}) + + // Even for requests without an Origin header, Vary: Origin must be set + // so caches know the response differs by origin. + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if got := w.Header().Get("Vary"); got == "" { + t.Error("Vary header should be set even when no Origin is present") + } +} + +func TestCORS_DisallowedOriginNoAccessControlHeaders(t *testing.T) { + r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}}) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Origin", "https://evil.example") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("Access-Control-Allow-Origin should be empty for disallowed origin, got %q", got) + } + if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "" { + t.Errorf("Access-Control-Allow-Credentials should be empty for disallowed origin, got %q", got) + } + // Request still executes — CORS only controls browser behaviour + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200", w.Code) + } +} + +func TestCORS_PreflightReturns204WithVary(t *testing.T) { + r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}}) + + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + req.Header.Set("Origin", "https://marktvogt.de") + req.Header.Set("Access-Control-Request-Method", "POST") + req.Header.Set("Access-Control-Request-Headers", "Content-Type") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent { + t.Errorf("preflight status = %d, want 204", w.Code) + } + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "https://marktvogt.de" { + t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "https://marktvogt.de") + } + vary := w.Header().Get("Vary") + if vary == "" { + t.Error("Vary header missing on preflight") + } +} + +func TestCORS_RegexPatternAllowedOrigin(t *testing.T) { + cfg, err := middleware.NewCORSConfig(nil, []string{`^https://[a-z0-9-]+\.marktvogt\.de$`}) + if err != nil { + t.Fatalf("NewCORSConfig error: %v", err) + } + r := corsRouter(cfg) + + for _, origin := range []string{"https://admin.marktvogt.de", "https://staging.marktvogt.de"} { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Origin", origin) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != origin { + t.Errorf("origin %q: Access-Control-Allow-Origin = %q, want %q", origin, got, origin) + } + } +} + +func TestCORS_RegexPatternRejectsNonMatch(t *testing.T) { + cfg, err := middleware.NewCORSConfig(nil, []string{`^https://[a-z0-9-]+\.marktvogt\.de$`}) + if err != nil { + t.Fatalf("NewCORSConfig error: %v", err) + } + r := corsRouter(cfg) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Origin", "https://evil.marktvogt.de.example.com") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("expected no ACAO header for non-matching origin, got %q", got) + } +} + +func TestCORS_InvalidRegexPatternErrors(t *testing.T) { + _, err := middleware.NewCORSConfig(nil, []string{"[invalid-regex"}) + if err == nil { + t.Error("expected error for invalid regex pattern, got nil") + } +} + +func TestCORS_IsAllowedOriginHelper(t *testing.T) { + cfg := middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}} + + if !cfg.IsAllowedOrigin("https://marktvogt.de") { + t.Error("expected true for allowed origin") + } + if cfg.IsAllowedOrigin("https://evil.example") { + t.Error("expected false for disallowed origin") + } + if cfg.IsAllowedOrigin("") { + t.Error("expected false for empty origin") + } +} diff --git a/backend/internal/middleware/csrf.go b/backend/internal/middleware/csrf.go new file mode 100644 index 0000000..e300eac --- /dev/null +++ b/backend/internal/middleware/csrf.go @@ -0,0 +1,50 @@ +package middleware + +import ( + "net/http" + "net/url" + "strings" + + "github.com/gin-gonic/gin" + + "marktvogt.de/backend/internal/pkg/apierror" +) + +// CSRF returns middleware that validates the Origin (or Referer) header for +// state-changing cookie-authed requests. Bearer-token requests (mobile/API) +// are exempt — they're not CSRF-vulnerable. +func CSRF(cfg CORSConfig) gin.HandlerFunc { + return func(c *gin.Context) { + method := c.Request.Method + + // Safe methods and preflight are not CSRF-vulnerable. + if method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions { + c.Next() + return + } + + // Bearer-token clients (mobile, third-party API) are exempt. + if strings.HasPrefix(c.GetHeader("Authorization"), "Bearer ") { + c.Next() + return + } + + origin := c.GetHeader("Origin") + if origin == "" { + // Fall back to Referer — extract scheme+host as origin. + if ref := c.GetHeader("Referer"); ref != "" { + if u, err := url.Parse(ref); err == nil && u.Host != "" { + origin = u.Scheme + "://" + u.Host + } + } + } + + if !cfg.IsAllowedOrigin(origin) { + apiErr := apierror.CSRFMismatch() + c.AbortWithStatusJSON(apiErr.Status, apierror.NewResponse(apiErr)) + return + } + + c.Next() + } +} diff --git a/backend/internal/middleware/csrf_test.go b/backend/internal/middleware/csrf_test.go new file mode 100644 index 0000000..b061a71 --- /dev/null +++ b/backend/internal/middleware/csrf_test.go @@ -0,0 +1,141 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + "marktvogt.de/backend/internal/middleware" +) + +func csrfRouter(cfg middleware.CORSConfig) *gin.Engine { + r := gin.New() + r.Use(middleware.CSRF(cfg)) + r.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + r.POST("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + r.PUT("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + r.DELETE("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + r.Handle(http.MethodOptions, "/test", func(c *gin.Context) { c.Status(http.StatusNoContent) }) + return r +} + +var csrfCfg = middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}} + +func TestCSRF_AllowedOriginPasses(t *testing.T) { + r := csrfRouter(csrfCfg) + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.Header.Set("Origin", "https://marktvogt.de") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200", w.Code) + } +} + +func TestCSRF_DisallowedOriginReturns403(t *testing.T) { + r := csrfRouter(csrfCfg) + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.Header.Set("Origin", "https://evil.example") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status = %d, want 403", w.Code) + } +} + +func TestCSRF_NoOriginNoRefererReturns403(t *testing.T) { + r := csrfRouter(csrfCfg) + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status = %d, want 403", w.Code) + } +} + +func TestCSRF_AllowedRefererFallback(t *testing.T) { + r := csrfRouter(csrfCfg) + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.Header.Set("Referer", "https://marktvogt.de/some/page") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200", w.Code) + } +} + +func TestCSRF_DisallowedRefererReturns403(t *testing.T) { + r := csrfRouter(csrfCfg) + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.Header.Set("Referer", "https://evil.example/page") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status = %d, want 403", w.Code) + } +} + +func TestCSRF_GetAlwaysPasses(t *testing.T) { + r := csrfRouter(csrfCfg) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + // no Origin, no Referer — still passes + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("GET status = %d, want 200", w.Code) + } +} + +func TestCSRF_OptionsAlwaysPasses(t *testing.T) { + r := csrfRouter(csrfCfg) + + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // OPTIONS is handled by CORS; CSRF skips it + if w.Code == http.StatusForbidden { + t.Error("OPTIONS should not be blocked by CSRF") + } +} + +func TestCSRF_BearerTokenExempt(t *testing.T) { + r := csrfRouter(csrfCfg) + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.Header.Set("Authorization", "Bearer some.token.here") + // no Origin — would normally fail CSRF + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Bearer request status = %d, want 200 (exempt from CSRF)", w.Code) + } +} + +func TestCSRF_DeleteWithAllowedOriginPasses(t *testing.T) { + r := csrfRouter(csrfCfg) + + req := httptest.NewRequest(http.MethodDelete, "/test", nil) + req.Header.Set("Origin", "https://marktvogt.de") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("DELETE status = %d, want 200", w.Code) + } +} diff --git a/backend/internal/middleware/ratelimit.go b/backend/internal/middleware/ratelimit.go index 7e2d50b..eeec153 100644 --- a/backend/internal/middleware/ratelimit.go +++ b/backend/internal/middleware/ratelimit.go @@ -47,12 +47,16 @@ func (l *ipLimiter) get(ip string) *rate.Limiter { return limiter } -func RateLimit(rps float64, burst int) gin.HandlerFunc { +// IPKey is a KeyFn that returns the client IP — the default key for RateLimitByKey. +func IPKey(c *gin.Context) string { return c.ClientIP() } + +// RateLimitByKey creates a per-key rate limiter. keyFn extracts the rate-limit +// key from the request (e.g. IPKey, or a composite IP+email key). +func RateLimitByKey(rps float64, burst int, keyFn func(*gin.Context) string) gin.HandlerFunc { limiter := newIPLimiter(rps, burst) return func(c *gin.Context) { - ip := c.ClientIP() - if !limiter.get(ip).Allow() { + if !limiter.get(keyFn(c)).Allow() { apiErr := apierror.TooManyRequests() c.AbortWithStatusJSON(http.StatusTooManyRequests, apierror.NewResponse(apiErr)) return @@ -60,3 +64,7 @@ func RateLimit(rps float64, burst int) gin.HandlerFunc { c.Next() } } + +func RateLimit(rps float64, burst int) gin.HandlerFunc { + return RateLimitByKey(rps, burst, IPKey) +} diff --git a/backend/internal/middleware/ratelimit_test.go b/backend/internal/middleware/ratelimit_test.go new file mode 100644 index 0000000..68390e1 --- /dev/null +++ b/backend/internal/middleware/ratelimit_test.go @@ -0,0 +1,79 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + "marktvogt.de/backend/internal/middleware" +) + +func rateLimitRouter(mw gin.HandlerFunc) *gin.Engine { + r := gin.New() + r.Use(mw) + r.POST("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + return r +} + +func TestRateLimitByKey_BlocksAfterBurst(t *testing.T) { + // burst=2, zero refill rate — allows exactly 2 requests then blocks + keyConst := func(*gin.Context) string { return "same-key" } + r := rateLimitRouter(middleware.RateLimitByKey(0, 2, keyConst)) + + for i := 1; i <= 2; i++ { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("request %d: status = %d, want 200 (within burst)", i, w.Code) + } + } + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Errorf("request 3: status = %d, want 429 (burst exhausted)", w.Code) + } +} + +func TestRateLimitByKey_DifferentKeysAreIndependent(t *testing.T) { + // burst=1 — each unique key gets its own limiter with 1 token + counter := 0 + keyFn := func(*gin.Context) string { + counter++ + return string(rune('A' + counter - 1)) // A, B, C, ... + } + r := rateLimitRouter(middleware.RateLimitByKey(0, 1, keyFn)) + + // Each request uses a new key → each should get through + for i := 0; i < 3; i++ { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("request %d: status = %d, want 200 (independent keys)", i+1, w.Code) + } + } +} + +func TestRateLimitByKey_IPKeyFunctionMatchesRateLimit(t *testing.T) { + // IPKey is the same as standard RateLimit key function + r := rateLimitRouter(middleware.RateLimitByKey(0, 1, middleware.IPKey)) + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("first request status = %d, want 200", w.Code) + } + + req2 := httptest.NewRequest(http.MethodPost, "/test", nil) + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req2) + if w2.Code != http.StatusTooManyRequests { + t.Errorf("second request status = %d, want 429", w2.Code) + } +} diff --git a/backend/internal/pkg/apierror/error.go b/backend/internal/pkg/apierror/error.go index d80d58e..11a3691 100644 --- a/backend/internal/pkg/apierror/error.go +++ b/backend/internal/pkg/apierror/error.go @@ -87,3 +87,11 @@ func Validation(message string) *Error { func Gone(message string) *Error { return &Error{Status: http.StatusGone, Code: "gone", Message: message} } + +func CSRFMismatch() *Error { + return &Error{Status: http.StatusForbidden, Code: "auth.csrf_mismatch", Message: "CSRF validation failed"} +} + +func RefreshReuse() *Error { + return &Error{Status: http.StatusUnauthorized, Code: "auth.refresh_reuse_detected", Message: "session token reuse detected"} +} diff --git a/backend/internal/pkg/password/password.go b/backend/internal/pkg/password/password.go new file mode 100644 index 0000000..e9d5308 --- /dev/null +++ b/backend/internal/pkg/password/password.go @@ -0,0 +1,94 @@ +package password + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" + "golang.org/x/crypto/bcrypt" +) + +// Argon2id parameters (RFC 9106 recommended minimum for interactive logins). +const ( + argonTime = 3 + argonMemory = 64 * 1024 // 64 MiB + argonThreads = 2 + argonKeyLen = 32 + argonSaltLen = 16 +) + +// Hash produces a PHC-format Argon2id hash of the plaintext password. +func Hash(plain string) (string, error) { + salt := make([]byte, argonSaltLen) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("generating salt: %w", err) + } + key := argon2.IDKey([]byte(plain), salt, argonTime, argonMemory, argonThreads, argonKeyLen) + encoded := fmt.Sprintf( + "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", + argon2.Version, + argonMemory, argonTime, argonThreads, + base64.RawStdEncoding.EncodeToString(salt), + base64.RawStdEncoding.EncodeToString(key), + ) + return encoded, nil +} + +// Verify checks plain against an encoded hash. Supports both Argon2id (PHC +// format) and bcrypt ($2a$/$2b$ prefix). Returns (false, nil) for wrong +// passwords, (false, err) for unrecognised formats. +func Verify(plain, encoded string) (bool, error) { + switch { + case strings.HasPrefix(encoded, "$argon2id$"): + return verifyArgon2id(plain, encoded) + case strings.HasPrefix(encoded, "$2a$"), strings.HasPrefix(encoded, "$2b$"), strings.HasPrefix(encoded, "$2y$"): + err := bcrypt.CompareHashAndPassword([]byte(encoded), []byte(plain)) + if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { + return false, nil + } + return err == nil, err + default: + return false, fmt.Errorf("unrecognised hash format") + } +} + +// NeedsRehash returns true if the hash was produced by an older algorithm +// (currently bcrypt) and should be upgraded to Argon2id. +func NeedsRehash(encoded string) bool { + return strings.HasPrefix(encoded, "$2a$") || + strings.HasPrefix(encoded, "$2b$") || + strings.HasPrefix(encoded, "$2y$") +} + +func verifyArgon2id(plain, encoded string) (bool, error) { + parts := strings.Split(encoded, "$") + // Expected: ["", "argon2id", "v=19", "m=...,t=...,p=...", "", ""] + if len(parts) != 6 { + return false, fmt.Errorf("invalid argon2id hash format") + } + + var mem, t uint32 + var p uint8 + if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &mem, &t, &p); err != nil { + return false, fmt.Errorf("parsing argon2id params: %w", err) + } + + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return false, fmt.Errorf("decoding argon2id salt: %w", err) + } + hash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return false, fmt.Errorf("decoding argon2id hash: %w", err) + } + + candidate := argon2.IDKey([]byte(plain), salt, t, mem, p, uint32(len(hash))) + if subtle.ConstantTimeCompare(candidate, hash) != 1 { + return false, nil + } + return true, nil +} diff --git a/backend/internal/pkg/password/password_test.go b/backend/internal/pkg/password/password_test.go new file mode 100644 index 0000000..5664c4a --- /dev/null +++ b/backend/internal/pkg/password/password_test.go @@ -0,0 +1,91 @@ +package password_test + +import ( + "strings" + "testing" + + "golang.org/x/crypto/bcrypt" + + "marktvogt.de/backend/internal/pkg/password" +) + +func TestHash_ProducesArgon2idPrefix(t *testing.T) { + h, err := password.Hash("hunter2") + if err != nil { + t.Fatalf("Hash error: %v", err) + } + if !strings.HasPrefix(h, "$argon2id$") { + t.Errorf("hash %q does not start with $argon2id$", h) + } +} + +func TestHash_RoundTrip(t *testing.T) { + h, err := password.Hash("correct horse battery staple") + if err != nil { + t.Fatalf("Hash error: %v", err) + } + ok, err := password.Verify("correct horse battery staple", h) + if err != nil { + t.Fatalf("Verify error: %v", err) + } + if !ok { + t.Error("expected Verify to return true for matching password") + } +} + +func TestVerify_WrongPassword(t *testing.T) { + h, err := password.Hash("correct") + if err != nil { + t.Fatalf("Hash error: %v", err) + } + ok, err := password.Verify("wrong", h) + if err != nil { + t.Fatalf("Verify error: %v", err) + } + if ok { + t.Error("expected Verify to return false for wrong password") + } +} + +func TestHash_UniquePerCall(t *testing.T) { + h1, _ := password.Hash("same") + h2, _ := password.Hash("same") + if h1 == h2 { + t.Error("two hashes of the same password should differ (different salts)") + } +} + +func TestNeedsRehash_Argon2idReturnsFalse(t *testing.T) { + h, _ := password.Hash("test") + if password.NeedsRehash(h) { + t.Error("NeedsRehash should return false for argon2id hashes") + } +} + +func TestNeedsRehash_BcryptReturnsTrue(t *testing.T) { + h, _ := bcrypt.GenerateFromPassword([]byte("x"), bcrypt.MinCost) + if !password.NeedsRehash(string(h)) { + t.Error("NeedsRehash should return true for bcrypt hashes") + } +} + +func TestVerify_BcryptHashVerifiesCorrectly(t *testing.T) { + h, err := bcrypt.GenerateFromPassword([]byte("test"), bcrypt.MinCost) + if err != nil { + t.Fatalf("bcrypt.GenerateFromPassword error: %v", err) + } + ok, verifyErr := password.Verify("test", string(h)) + if verifyErr != nil { + t.Fatalf("Verify error: %v", verifyErr) + } + if !ok { + t.Error("expected Verify to return true for correct bcrypt password") + } +} + +func TestVerify_UnknownFormatErrors(t *testing.T) { + _, err := password.Verify("test", "not-a-valid-hash") + if err == nil { + t.Error("expected error for unknown hash format") + } +} diff --git a/backend/internal/server/routes.go b/backend/internal/server/routes.go index 8985324..fb3b077 100644 --- a/backend/internal/server/routes.go +++ b/backend/internal/server/routes.go @@ -37,7 +37,21 @@ func (s *Server) registerRoutes() { authSvc := auth.NewService(authRepo, userRepo, tokenSvc, s.cfg.JWT.SessionTTL) authHandler := auth.NewHandler(authSvc, userRepo) requireAuth := middleware.RequireAuth(tokenSvc) - auth.RegisterRoutes(v1, authHandler, requireAuth) + + // Per-route auth rate limiters (keyed by IP; user_id unavailable before auth completes) + userIDKey := func(c *gin.Context) string { + if uid, ok := c.Get("user_id"); ok { + return fmt.Sprintf("%v", uid) + } + return c.ClientIP() + } + loginLimit := middleware.RateLimitByKey(0.2, 5, middleware.IPKey) // 1 per 5s, burst 5 + refreshLimit := middleware.RateLimitByKey(1, 10, middleware.IPKey) // 1/s, burst 10 + twoFALimit := middleware.RateLimitByKey(0.2, 5, middleware.IPKey) // 1 per 5s, burst 5 + passwordLimit := middleware.RateLimitByKey(0.1, 3, userIDKey) // 1 per 10s, burst 3 + magicLinkLimit := middleware.RateLimitByKey(0.1, 3, middleware.IPKey) // 1 per 10s, burst 3 + + auth.RegisterRoutes(v1, authHandler, requireAuth, loginLimit, refreshLimit, twoFALimit, passwordLimit) // OAuth routes — disabled until provider apps are configured // oauthHandler := auth.NewOAuthHandler(s.cfg.OAuth, authSvc, userRepo, authRepo) @@ -51,7 +65,7 @@ func (s *Server) registerRoutes() { // Magic link routes magicLinkHandler := auth.NewMagicLinkHandler(authRepo, userRepo, authSvc, s.cfg.Magic, emailSender, s.cfg.Notification.FrontendURL) - auth.RegisterMagicLinkRoutes(v1, magicLinkHandler) + auth.RegisterMagicLinkRoutes(v1, magicLinkHandler, magicLinkLimit) // User profile routes userSvc := user.NewService(userRepo) diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go index 6b76c92..e149f35 100644 --- a/backend/internal/server/server.go +++ b/backend/internal/server/server.go @@ -29,11 +29,15 @@ func New(cfg *config.Config, db *pgxpool.Pool, vk valkey.Client) *Server { router := gin.New() + // NewCORSConfig only errors on bad regexes; config.Load already validates them. + corsCfg, _ := middleware.NewCORSConfig(cfg.CORS.Origins, cfg.CORS.OriginPatterns) + router.Use( middleware.Recovery(), middleware.RequestID(), middleware.Logging(), - middleware.CORS(cfg.CORS.Origins), + middleware.CORS(corsCfg), + middleware.CSRF(corsCfg), middleware.RateLimit(cfg.Rate.RPS, cfg.Rate.Burst), ) diff --git a/backend/migrations/000027_sessions_session_model_rework.down.sql b/backend/migrations/000027_sessions_session_model_rework.down.sql new file mode 100644 index 0000000..bd1b0a6 --- /dev/null +++ b/backend/migrations/000027_sessions_session_model_rework.down.sql @@ -0,0 +1,12 @@ +DROP INDEX IF EXISTS idx_sessions_refresh_hash; +DROP INDEX IF EXISTS idx_sessions_access_hash; +DROP INDEX IF EXISTS idx_sessions_family_id; + +ALTER TABLE sessions + DROP COLUMN IF EXISTS revoked_at, + DROP COLUMN IF EXISTS last_used_at, + DROP COLUMN IF EXISTS absolute_expires_at, + DROP COLUMN IF EXISTS parent_session_id, + DROP COLUMN IF EXISTS family_id, + DROP COLUMN IF EXISTS refresh_token_hash, + DROP COLUMN IF EXISTS access_token_hash; diff --git a/backend/migrations/000027_sessions_session_model_rework.up.sql b/backend/migrations/000027_sessions_session_model_rework.up.sql new file mode 100644 index 0000000..37c36cb --- /dev/null +++ b/backend/migrations/000027_sessions_session_model_rework.up.sql @@ -0,0 +1,16 @@ +-- Workstream 1: extend sessions table for opaque-token session model. +-- Adds new columns needed for the upcoming JWT drop (D2/D3 deploy). +-- Old token_hash column is kept until cleanup migration 000029. + +ALTER TABLE sessions + ADD COLUMN IF NOT EXISTS access_token_hash TEXT UNIQUE, + ADD COLUMN IF NOT EXISTS refresh_token_hash TEXT UNIQUE, + ADD COLUMN IF NOT EXISTS family_id UUID NOT NULL DEFAULT uuid_generate_v4(), + ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES sessions(id) ON DELETE SET NULL, + ADD COLUMN IF NOT EXISTS absolute_expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '30 days', + ADD COLUMN IF NOT EXISTS last_used_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ; + +CREATE INDEX IF NOT EXISTS idx_sessions_family_id ON sessions (family_id); +CREATE INDEX IF NOT EXISTS idx_sessions_access_hash ON sessions (access_token_hash) WHERE access_token_hash IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_sessions_refresh_hash ON sessions (refresh_token_hash) WHERE refresh_token_hash IS NOT NULL; diff --git a/backend/migrations/000028_totp_encrypt_secrets_and_backup_codes.down.sql b/backend/migrations/000028_totp_encrypt_secrets_and_backup_codes.down.sql new file mode 100644 index 0000000..2dd0832 --- /dev/null +++ b/backend/migrations/000028_totp_encrypt_secrets_and_backup_codes.down.sql @@ -0,0 +1,4 @@ +DROP TABLE IF EXISTS totp_backup_codes; + +ALTER TABLE totp_secrets + DROP COLUMN IF EXISTS secret_v2; diff --git a/backend/migrations/000028_totp_encrypt_secrets_and_backup_codes.up.sql b/backend/migrations/000028_totp_encrypt_secrets_and_backup_codes.up.sql new file mode 100644 index 0000000..99d42e7 --- /dev/null +++ b/backend/migrations/000028_totp_encrypt_secrets_and_backup_codes.up.sql @@ -0,0 +1,21 @@ +-- Workstream 7: TOTP encryption + backup codes schema. +-- secret_v2 stores AES-256-GCM encrypted secret (format: v1:). +-- The migration job (cmd/totp-encrypt) populates secret_v2 for existing rows. +-- Column secret is dropped in migration 000030 once all rows have secret_v2. + +ALTER TABLE totp_secrets + ADD COLUMN IF NOT EXISTS secret_v2 TEXT; + +CREATE TABLE IF NOT EXISTS totp_backup_codes ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + code_hash TEXT NOT NULL, + used_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_totp_backup_codes_user_id + ON totp_backup_codes (user_id); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_totp_backup_codes_user_code + ON totp_backup_codes (user_id, code_hash);