feat(auth): D1 non-breaking security foundations

- CORS: rewrite middleware with Vary: Origin, regex origin patterns,
  startup validation, and prod boot-fail on empty allowlist; shared
  CORSConfig exported for CSRF reuse
- CSRF: new Origin/Referer check middleware sharing CORS allowlist;
  Bearer-token clients exempt; mounts globally after CORS
- Argon2id: new password package with PHC format, bcrypt dispatch, and
  NeedsRehash; lazy upgrade on login in auth service
- Rate limiting: add RateLimitByKey with custom key function; apply
  per-route limits to /auth/login, /refresh, /2fa/verify,
  /auth/magic-link, and /auth/password
- apierror: add CSRFMismatch and RefreshReuse error constructors
- Migrations: 000027 (session model schema columns for D2/D3),
  000028 (TOTP secret_v2 column + totp_backup_codes table)
- cmd/totp-encrypt: one-shot job to encrypt existing TOTP secrets
This commit is contained in:
2026-04-26 11:54:37 +02:00
parent 49a31bca02
commit 0997d4befa
20 changed files with 957 additions and 42 deletions

View File

@@ -0,0 +1,127 @@
// totp-encrypt is a one-shot migration job that encrypts plaintext TOTP secrets
// in the totp_secrets table. Run once after deploying migration 000028.
// Set TOTP_ENCRYPTION_KEY (32-byte AES key, base64-encoded) before running.
//
// Usage: TOTP_ENCRYPTION_KEY=<b64> DB_HOST=... ./totp-encrypt [--dry-run]
package main
import (
"context"
"encoding/base64"
"errors"
"fmt"
"log/slog"
"os"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"marktvogt.de/backend/internal/pkg/crypto"
)
func main() {
if err := run(); err != nil {
slog.Error("totp-encrypt failed", "error", err)
os.Exit(1)
}
}
func run() error {
dryRun := len(os.Args) > 1 && os.Args[1] == "--dry-run"
keyB64 := os.Getenv("TOTP_ENCRYPTION_KEY")
if keyB64 == "" {
return fmt.Errorf("TOTP_ENCRYPTION_KEY is required")
}
keyBytes, err := base64.StdEncoding.DecodeString(keyB64)
if err != nil {
return fmt.Errorf("TOTP_ENCRYPTION_KEY: invalid base64: %w", err)
}
if len(keyBytes) != 32 {
return fmt.Errorf("TOTP_ENCRYPTION_KEY must be exactly 32 bytes (got %d)", len(keyBytes))
}
var key [32]byte
copy(key[:], keyBytes)
dsn := fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s?sslmode=%s",
getenv("DB_USER", "marktvogt"),
getenv("DB_PASSWORD", "marktvogt"),
getenv("DB_HOST", "localhost"),
getenv("DB_PORT", "5432"),
getenv("DB_NAME", "marktvogt"),
getenv("DB_SSLMODE", "disable"),
)
ctx := context.Background()
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
return fmt.Errorf("connecting to database: %w", err)
}
defer pool.Close()
rows, err := pool.Query(ctx,
`SELECT id, secret FROM totp_secrets WHERE secret_v2 IS NULL AND secret != ''`)
if err != nil {
return fmt.Errorf("querying totp_secrets: %w", err)
}
type row struct {
id string
secret string
}
var pending []row
for rows.Next() {
var r row
if scanErr := rows.Scan(&r.id, &r.secret); scanErr != nil {
return fmt.Errorf("scanning row: %w", scanErr)
}
pending = append(pending, r)
}
rows.Close()
slog.Info("rows to encrypt", "count", len(pending), "dry_run", dryRun)
if dryRun {
slog.Info("dry run — no changes written")
return nil
}
ok, failed := 0, 0
for _, r := range pending {
ciphertext, sealErr := crypto.Seal(key, []byte(r.secret))
if sealErr != nil {
slog.Error("encrypting secret", "id", r.id, "error", sealErr)
failed++
continue
}
encoded := "v1:" + base64.StdEncoding.EncodeToString(ciphertext)
_, execErr := pool.Exec(ctx,
`UPDATE totp_secrets SET secret_v2 = $1 WHERE id = $2`,
encoded, r.id)
if execErr != nil {
if errors.Is(execErr, pgx.ErrNoRows) {
slog.Warn("row disappeared during migration", "id", r.id)
} else {
slog.Error("updating secret_v2", "id", r.id, "error", execErr)
failed++
}
continue
}
ok++
}
slog.Info("encryption complete", "encrypted", ok, "failed", failed)
if failed > 0 {
return fmt.Errorf("%d rows failed to encrypt", failed)
}
return nil
}
func getenv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}

View File

@@ -4,7 +4,9 @@ import (
"fmt"
"log/slog"
"os"
"regexp"
"strconv"
"strings"
"time"
)
@@ -85,7 +87,8 @@ type JWTConfig struct {
}
type CORSConfig struct {
Origins string
Origins []string
OriginPatterns []string
}
type RateConfig struct {
@@ -132,7 +135,26 @@ type NotificationConfig struct {
FrontendURL string
}
const envDevelopment = "development"
func Load() (*Config, error) {
appEnv := envStr("APP_ENV", envDevelopment)
corsOrigins := envStrSlice("CORS_ORIGINS")
if len(corsOrigins) == 0 && appEnv == envDevelopment {
corsOrigins = []string{"http://localhost:5173"}
}
corsPatterns := envStrSlice("CORS_ORIGIN_PATTERNS")
if len(corsOrigins) == 0 && len(corsPatterns) == 0 && appEnv != envDevelopment {
return nil, fmt.Errorf("CORS_ORIGINS or CORS_ORIGIN_PATTERNS is required in non-dev environments")
}
for _, p := range corsPatterns {
if _, err := regexp.Compile(p); err != nil {
return nil, fmt.Errorf("CORS_ORIGIN_PATTERNS: invalid regex %q: %w", p, err)
}
}
slog.Info("cors allowlist loaded", "exact_origins", corsOrigins, "pattern_count", len(corsPatterns))
port, err := envInt("APP_PORT", 8080)
if err != nil {
return nil, fmt.Errorf("APP_PORT: %w", err)
@@ -205,7 +227,7 @@ func Load() (*Config, error) {
return &Config{
App: AppConfig{
Env: envStr("APP_ENV", "development"),
Env: appEnv,
Host: envStr("APP_HOST", "0.0.0.0"),
Port: port,
},
@@ -230,7 +252,8 @@ func Load() (*Config, error) {
SessionTTL: sessionTTL,
},
CORS: CORSConfig{
Origins: envStr("CORS_ORIGINS", "http://localhost:5173"),
Origins: corsOrigins,
OriginPatterns: corsPatterns,
},
Rate: RateConfig{
RPS: rps,
@@ -297,7 +320,7 @@ func (c *Config) Addr() string {
}
func (c *Config) IsDev() bool {
return c.App.Env == "development"
return c.App.Env == envDevelopment
}
func envStr(key, fallback string) string {
@@ -331,6 +354,21 @@ func envFloat(key string, fallback float64) (float64, error) {
return f, nil
}
func envStrSlice(key string) []string {
v := os.Getenv(key)
if v == "" {
return nil
}
var result []string
for _, s := range strings.Split(v, ",") {
s = strings.TrimSpace(s)
if s != "" {
result = append(result, s)
}
}
return result
}
func envDuration(key string, fallback time.Duration) (time.Duration, error) {
v := os.Getenv(key)
if v == "" {

View File

@@ -168,7 +168,7 @@ func (h *MagicLinkHandler) findOrCreateUser(ctx context.Context, email string) (
return h.userRepo.CreateOAuthUser(ctx, email, user.GenerateDisplayName(), true)
}
func RegisterMagicLinkRoutes(rg *gin.RouterGroup, h *MagicLinkHandler) {
rg.POST("/auth/magic-link", h.RequestMagicLink)
func RegisterMagicLinkRoutes(rg *gin.RouterGroup, h *MagicLinkHandler, requestLimit gin.HandlerFunc) {
rg.POST("/auth/magic-link", requestLimit, h.RequestMagicLink)
rg.GET("/auth/magic-link/verify", h.VerifyMagicLink)
}

View File

@@ -2,20 +2,20 @@ package auth
import "github.com/gin-gonic/gin"
func RegisterRoutes(rg *gin.RouterGroup, h *Handler, requireAuth gin.HandlerFunc) {
func RegisterRoutes(rg *gin.RouterGroup, h *Handler, requireAuth, loginLimit, refreshLimit, twoFALimit, passwordLimit gin.HandlerFunc) {
auth := rg.Group("/auth")
{
auth.POST("/register", h.Register)
auth.POST("/login", h.Login)
auth.POST("/login", loginLimit, h.Login)
auth.POST("/logout", requireAuth, h.Logout)
auth.POST("/refresh", h.Refresh)
auth.POST("/refresh", refreshLimit, h.Refresh)
// Password
auth.PUT("/password", requireAuth, h.ChangePassword)
auth.PUT("/password", requireAuth, passwordLimit, h.ChangePassword)
// 2FA
auth.POST("/2fa/setup", requireAuth, h.SetupTOTP)
auth.POST("/2fa/verify", requireAuth, h.VerifyTOTP)
auth.POST("/2fa/verify", requireAuth, twoFALimit, h.VerifyTOTP)
auth.DELETE("/2fa", requireAuth, h.DisableTOTP)
}
}

View File

@@ -4,12 +4,13 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"marktvogt.de/backend/internal/domain/user"
"marktvogt.de/backend/internal/pkg/password"
)
type Service struct {
@@ -29,12 +30,12 @@ func NewService(authRepo Repository, userRepo user.Repository, tokenSvc *TokenSe
}
func (s *Service) Register(ctx context.Context, req RegisterRequest, ip, ua string) (AuthData, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
hash, err := password.Hash(req.Password)
if err != nil {
return AuthData{}, fmt.Errorf("hashing password: %w", err)
}
u, err := s.userRepo.Create(ctx, req.Email, string(hash), req.DisplayName)
u, err := s.userRepo.Create(ctx, req.Email, hash, req.DisplayName)
if err != nil {
if errors.Is(err, user.ErrEmailAlreadyTaken) {
return AuthData{}, err
@@ -58,10 +59,20 @@ func (s *Service) Login(ctx context.Context, req LoginRequest, ip, ua string) (A
return AuthData{}, fmt.Errorf("invalid credentials")
}
if err := bcrypt.CompareHashAndPassword([]byte(*u.PasswordHash), []byte(req.Password)); err != nil {
ok, err := password.Verify(req.Password, *u.PasswordHash)
if err != nil || !ok {
return AuthData{}, fmt.Errorf("invalid credentials")
}
// Lazily upgrade bcrypt hashes to Argon2id on successful login.
if password.NeedsRehash(*u.PasswordHash) {
if newHash, hashErr := password.Hash(req.Password); hashErr == nil {
if _, updateErr := s.userRepo.Update(ctx, u.ID, map[string]any{"password_hash": newHash}); updateErr != nil {
slog.Warn("password rehash failed", "user_id", u.ID, "error", updateErr)
}
}
}
// Check 2FA if enabled
if req.TOTPCode != "" {
if err := s.validateTOTP(ctx, u.ID, req.TOTPCode); err != nil {
@@ -165,17 +176,18 @@ func (s *Service) ChangePassword(ctx context.Context, userID uuid.UUID, req Chan
if req.CurrentPassword == "" {
return fmt.Errorf("current password required")
}
if err := bcrypt.CompareHashAndPassword([]byte(*u.PasswordHash), []byte(req.CurrentPassword)); err != nil {
ok, err := password.Verify(req.CurrentPassword, *u.PasswordHash)
if err != nil || !ok {
return fmt.Errorf("current password incorrect")
}
}
hash, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
hash, err := password.Hash(req.NewPassword)
if err != nil {
return fmt.Errorf("hashing password: %w", err)
}
_, err = s.userRepo.Update(ctx, userID, map[string]any{"password_hash": string(hash)})
_, err = s.userRepo.Update(ctx, userID, map[string]any{"password_hash": hash})
if err != nil {
return fmt.Errorf("updating password: %w", err)
}

View File

@@ -1,36 +1,77 @@
package middleware
import (
"log/slog"
"net/http"
"strings"
"regexp"
"github.com/gin-gonic/gin"
)
func CORS(allowedOrigins string) gin.HandlerFunc {
origins := make(map[string]bool)
for _, o := range strings.Split(allowedOrigins, ",") {
o = strings.TrimSpace(o)
if o != "" {
origins[o] = true
// CORSConfig holds the parsed CORS allowlist. Build via NewCORSConfig to
// enable regex-pattern support; direct struct literals support exact Origins only.
type CORSConfig struct {
Origins []string
compiled []*regexp.Regexp
}
// NewCORSConfig compiles regex patterns and returns a ready CORSConfig.
// Returns an error if any pattern fails to compile.
func NewCORSConfig(origins []string, patterns []string) (CORSConfig, error) {
cfg := CORSConfig{Origins: origins}
for _, p := range patterns {
re, err := regexp.Compile(p)
if err != nil {
return CORSConfig{}, err
}
cfg.compiled = append(cfg.compiled, re)
}
return cfg, nil
}
// IsAllowedOrigin returns true if origin matches an exact entry or a compiled pattern.
func (c CORSConfig) IsAllowedOrigin(origin string) bool {
if origin == "" {
return false
}
for _, o := range c.Origins {
if o == origin {
return true
}
}
for _, re := range c.compiled {
if re.MatchString(origin) {
return true
}
}
return false
}
// CORS returns a middleware that sets CORS headers for allowed origins.
// Vary: Origin is always set so caches key on origin correctly.
func CORS(cfg CORSConfig) gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.GetHeader("Origin")
if origins[origin] {
c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, X-Request-ID")
c.Header("Access-Control-Expose-Headers", "X-Request-ID")
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Max-Age", "86400")
}
// Always set Vary: Origin so caches don't serve the wrong response across origins.
c.Header("Vary", "Origin")
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
if cfg.IsAllowedOrigin(origin) {
c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Expose-Headers", "X-Request-ID")
if c.Request.Method == http.MethodOptions {
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, X-Request-ID")
c.Header("Access-Control-Max-Age", "86400")
// Extend Vary for preflight so caches differentiate by request method/headers too.
c.Header("Vary", "Origin, Access-Control-Request-Method, Access-Control-Request-Headers")
c.AbortWithStatus(http.StatusNoContent)
return
}
} else if origin != "" {
slog.Warn("cors origin rejected", "origin", origin, "path", c.Request.URL.Path, "method", c.Request.Method)
}
c.Next()

View File

@@ -0,0 +1,155 @@
package middleware_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"marktvogt.de/backend/internal/middleware"
)
func init() {
gin.SetMode(gin.TestMode)
}
func corsRouter(cfg middleware.CORSConfig) *gin.Engine {
r := gin.New()
r.Use(middleware.CORS(cfg))
r.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
r.POST("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
return r
}
func TestCORS_AllowedOriginSetsHeaders(t *testing.T) {
r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Origin", "https://marktvogt.de")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "https://marktvogt.de" {
t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "https://marktvogt.de")
}
if got := w.Header().Get("Vary"); got == "" {
t.Error("Vary header missing")
}
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
}
func TestCORS_VaryOriginAlwaysPresent(t *testing.T) {
r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}})
// Even for requests without an Origin header, Vary: Origin must be set
// so caches know the response differs by origin.
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if got := w.Header().Get("Vary"); got == "" {
t.Error("Vary header should be set even when no Origin is present")
}
}
func TestCORS_DisallowedOriginNoAccessControlHeaders(t *testing.T) {
r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Origin", "https://evil.example")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
t.Errorf("Access-Control-Allow-Origin should be empty for disallowed origin, got %q", got)
}
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "" {
t.Errorf("Access-Control-Allow-Credentials should be empty for disallowed origin, got %q", got)
}
// Request still executes — CORS only controls browser behaviour
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
}
func TestCORS_PreflightReturns204WithVary(t *testing.T) {
r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}})
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
req.Header.Set("Origin", "https://marktvogt.de")
req.Header.Set("Access-Control-Request-Method", "POST")
req.Header.Set("Access-Control-Request-Headers", "Content-Type")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNoContent {
t.Errorf("preflight status = %d, want 204", w.Code)
}
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "https://marktvogt.de" {
t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "https://marktvogt.de")
}
vary := w.Header().Get("Vary")
if vary == "" {
t.Error("Vary header missing on preflight")
}
}
func TestCORS_RegexPatternAllowedOrigin(t *testing.T) {
cfg, err := middleware.NewCORSConfig(nil, []string{`^https://[a-z0-9-]+\.marktvogt\.de$`})
if err != nil {
t.Fatalf("NewCORSConfig error: %v", err)
}
r := corsRouter(cfg)
for _, origin := range []string{"https://admin.marktvogt.de", "https://staging.marktvogt.de"} {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Origin", origin)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != origin {
t.Errorf("origin %q: Access-Control-Allow-Origin = %q, want %q", origin, got, origin)
}
}
}
func TestCORS_RegexPatternRejectsNonMatch(t *testing.T) {
cfg, err := middleware.NewCORSConfig(nil, []string{`^https://[a-z0-9-]+\.marktvogt\.de$`})
if err != nil {
t.Fatalf("NewCORSConfig error: %v", err)
}
r := corsRouter(cfg)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Origin", "https://evil.marktvogt.de.example.com")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
t.Errorf("expected no ACAO header for non-matching origin, got %q", got)
}
}
func TestCORS_InvalidRegexPatternErrors(t *testing.T) {
_, err := middleware.NewCORSConfig(nil, []string{"[invalid-regex"})
if err == nil {
t.Error("expected error for invalid regex pattern, got nil")
}
}
func TestCORS_IsAllowedOriginHelper(t *testing.T) {
cfg := middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}}
if !cfg.IsAllowedOrigin("https://marktvogt.de") {
t.Error("expected true for allowed origin")
}
if cfg.IsAllowedOrigin("https://evil.example") {
t.Error("expected false for disallowed origin")
}
if cfg.IsAllowedOrigin("") {
t.Error("expected false for empty origin")
}
}

View File

@@ -0,0 +1,50 @@
package middleware
import (
"net/http"
"net/url"
"strings"
"github.com/gin-gonic/gin"
"marktvogt.de/backend/internal/pkg/apierror"
)
// CSRF returns middleware that validates the Origin (or Referer) header for
// state-changing cookie-authed requests. Bearer-token requests (mobile/API)
// are exempt — they're not CSRF-vulnerable.
func CSRF(cfg CORSConfig) gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
// Safe methods and preflight are not CSRF-vulnerable.
if method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions {
c.Next()
return
}
// Bearer-token clients (mobile, third-party API) are exempt.
if strings.HasPrefix(c.GetHeader("Authorization"), "Bearer ") {
c.Next()
return
}
origin := c.GetHeader("Origin")
if origin == "" {
// Fall back to Referer — extract scheme+host as origin.
if ref := c.GetHeader("Referer"); ref != "" {
if u, err := url.Parse(ref); err == nil && u.Host != "" {
origin = u.Scheme + "://" + u.Host
}
}
}
if !cfg.IsAllowedOrigin(origin) {
apiErr := apierror.CSRFMismatch()
c.AbortWithStatusJSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.Next()
}
}

View File

@@ -0,0 +1,141 @@
package middleware_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"marktvogt.de/backend/internal/middleware"
)
func csrfRouter(cfg middleware.CORSConfig) *gin.Engine {
r := gin.New()
r.Use(middleware.CSRF(cfg))
r.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
r.POST("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
r.PUT("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
r.DELETE("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
r.Handle(http.MethodOptions, "/test", func(c *gin.Context) { c.Status(http.StatusNoContent) })
return r
}
var csrfCfg = middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}}
func TestCSRF_AllowedOriginPasses(t *testing.T) {
r := csrfRouter(csrfCfg)
req := httptest.NewRequest(http.MethodPost, "/test", nil)
req.Header.Set("Origin", "https://marktvogt.de")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
}
func TestCSRF_DisallowedOriginReturns403(t *testing.T) {
r := csrfRouter(csrfCfg)
req := httptest.NewRequest(http.MethodPost, "/test", nil)
req.Header.Set("Origin", "https://evil.example")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("status = %d, want 403", w.Code)
}
}
func TestCSRF_NoOriginNoRefererReturns403(t *testing.T) {
r := csrfRouter(csrfCfg)
req := httptest.NewRequest(http.MethodPost, "/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("status = %d, want 403", w.Code)
}
}
func TestCSRF_AllowedRefererFallback(t *testing.T) {
r := csrfRouter(csrfCfg)
req := httptest.NewRequest(http.MethodPost, "/test", nil)
req.Header.Set("Referer", "https://marktvogt.de/some/page")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
}
func TestCSRF_DisallowedRefererReturns403(t *testing.T) {
r := csrfRouter(csrfCfg)
req := httptest.NewRequest(http.MethodPost, "/test", nil)
req.Header.Set("Referer", "https://evil.example/page")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("status = %d, want 403", w.Code)
}
}
func TestCSRF_GetAlwaysPasses(t *testing.T) {
r := csrfRouter(csrfCfg)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
// no Origin, no Referer — still passes
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("GET status = %d, want 200", w.Code)
}
}
func TestCSRF_OptionsAlwaysPasses(t *testing.T) {
r := csrfRouter(csrfCfg)
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// OPTIONS is handled by CORS; CSRF skips it
if w.Code == http.StatusForbidden {
t.Error("OPTIONS should not be blocked by CSRF")
}
}
func TestCSRF_BearerTokenExempt(t *testing.T) {
r := csrfRouter(csrfCfg)
req := httptest.NewRequest(http.MethodPost, "/test", nil)
req.Header.Set("Authorization", "Bearer some.token.here")
// no Origin — would normally fail CSRF
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Bearer request status = %d, want 200 (exempt from CSRF)", w.Code)
}
}
func TestCSRF_DeleteWithAllowedOriginPasses(t *testing.T) {
r := csrfRouter(csrfCfg)
req := httptest.NewRequest(http.MethodDelete, "/test", nil)
req.Header.Set("Origin", "https://marktvogt.de")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("DELETE status = %d, want 200", w.Code)
}
}

View File

@@ -47,12 +47,16 @@ func (l *ipLimiter) get(ip string) *rate.Limiter {
return limiter
}
func RateLimit(rps float64, burst int) gin.HandlerFunc {
// IPKey is a KeyFn that returns the client IP — the default key for RateLimitByKey.
func IPKey(c *gin.Context) string { return c.ClientIP() }
// RateLimitByKey creates a per-key rate limiter. keyFn extracts the rate-limit
// key from the request (e.g. IPKey, or a composite IP+email key).
func RateLimitByKey(rps float64, burst int, keyFn func(*gin.Context) string) gin.HandlerFunc {
limiter := newIPLimiter(rps, burst)
return func(c *gin.Context) {
ip := c.ClientIP()
if !limiter.get(ip).Allow() {
if !limiter.get(keyFn(c)).Allow() {
apiErr := apierror.TooManyRequests()
c.AbortWithStatusJSON(http.StatusTooManyRequests, apierror.NewResponse(apiErr))
return
@@ -60,3 +64,7 @@ func RateLimit(rps float64, burst int) gin.HandlerFunc {
c.Next()
}
}
func RateLimit(rps float64, burst int) gin.HandlerFunc {
return RateLimitByKey(rps, burst, IPKey)
}

View File

@@ -0,0 +1,79 @@
package middleware_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"marktvogt.de/backend/internal/middleware"
)
func rateLimitRouter(mw gin.HandlerFunc) *gin.Engine {
r := gin.New()
r.Use(mw)
r.POST("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
return r
}
func TestRateLimitByKey_BlocksAfterBurst(t *testing.T) {
// burst=2, zero refill rate — allows exactly 2 requests then blocks
keyConst := func(*gin.Context) string { return "same-key" }
r := rateLimitRouter(middleware.RateLimitByKey(0, 2, keyConst))
for i := 1; i <= 2; i++ {
req := httptest.NewRequest(http.MethodPost, "/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("request %d: status = %d, want 200 (within burst)", i, w.Code)
}
}
req := httptest.NewRequest(http.MethodPost, "/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("request 3: status = %d, want 429 (burst exhausted)", w.Code)
}
}
func TestRateLimitByKey_DifferentKeysAreIndependent(t *testing.T) {
// burst=1 — each unique key gets its own limiter with 1 token
counter := 0
keyFn := func(*gin.Context) string {
counter++
return string(rune('A' + counter - 1)) // A, B, C, ...
}
r := rateLimitRouter(middleware.RateLimitByKey(0, 1, keyFn))
// Each request uses a new key → each should get through
for i := 0; i < 3; i++ {
req := httptest.NewRequest(http.MethodPost, "/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("request %d: status = %d, want 200 (independent keys)", i+1, w.Code)
}
}
}
func TestRateLimitByKey_IPKeyFunctionMatchesRateLimit(t *testing.T) {
// IPKey is the same as standard RateLimit key function
r := rateLimitRouter(middleware.RateLimitByKey(0, 1, middleware.IPKey))
req := httptest.NewRequest(http.MethodPost, "/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("first request status = %d, want 200", w.Code)
}
req2 := httptest.NewRequest(http.MethodPost, "/test", nil)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req2)
if w2.Code != http.StatusTooManyRequests {
t.Errorf("second request status = %d, want 429", w2.Code)
}
}

View File

@@ -87,3 +87,11 @@ func Validation(message string) *Error {
func Gone(message string) *Error {
return &Error{Status: http.StatusGone, Code: "gone", Message: message}
}
func CSRFMismatch() *Error {
return &Error{Status: http.StatusForbidden, Code: "auth.csrf_mismatch", Message: "CSRF validation failed"}
}
func RefreshReuse() *Error {
return &Error{Status: http.StatusUnauthorized, Code: "auth.refresh_reuse_detected", Message: "session token reuse detected"}
}

View File

@@ -0,0 +1,94 @@
package password
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"strings"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/bcrypt"
)
// Argon2id parameters (RFC 9106 recommended minimum for interactive logins).
const (
argonTime = 3
argonMemory = 64 * 1024 // 64 MiB
argonThreads = 2
argonKeyLen = 32
argonSaltLen = 16
)
// Hash produces a PHC-format Argon2id hash of the plaintext password.
func Hash(plain string) (string, error) {
salt := make([]byte, argonSaltLen)
if _, err := rand.Read(salt); err != nil {
return "", fmt.Errorf("generating salt: %w", err)
}
key := argon2.IDKey([]byte(plain), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
encoded := fmt.Sprintf(
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version,
argonMemory, argonTime, argonThreads,
base64.RawStdEncoding.EncodeToString(salt),
base64.RawStdEncoding.EncodeToString(key),
)
return encoded, nil
}
// Verify checks plain against an encoded hash. Supports both Argon2id (PHC
// format) and bcrypt ($2a$/$2b$ prefix). Returns (false, nil) for wrong
// passwords, (false, err) for unrecognised formats.
func Verify(plain, encoded string) (bool, error) {
switch {
case strings.HasPrefix(encoded, "$argon2id$"):
return verifyArgon2id(plain, encoded)
case strings.HasPrefix(encoded, "$2a$"), strings.HasPrefix(encoded, "$2b$"), strings.HasPrefix(encoded, "$2y$"):
err := bcrypt.CompareHashAndPassword([]byte(encoded), []byte(plain))
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return false, nil
}
return err == nil, err
default:
return false, fmt.Errorf("unrecognised hash format")
}
}
// NeedsRehash returns true if the hash was produced by an older algorithm
// (currently bcrypt) and should be upgraded to Argon2id.
func NeedsRehash(encoded string) bool {
return strings.HasPrefix(encoded, "$2a$") ||
strings.HasPrefix(encoded, "$2b$") ||
strings.HasPrefix(encoded, "$2y$")
}
func verifyArgon2id(plain, encoded string) (bool, error) {
parts := strings.Split(encoded, "$")
// Expected: ["", "argon2id", "v=19", "m=...,t=...,p=...", "<salt>", "<hash>"]
if len(parts) != 6 {
return false, fmt.Errorf("invalid argon2id hash format")
}
var mem, t uint32
var p uint8
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &mem, &t, &p); err != nil {
return false, fmt.Errorf("parsing argon2id params: %w", err)
}
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return false, fmt.Errorf("decoding argon2id salt: %w", err)
}
hash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return false, fmt.Errorf("decoding argon2id hash: %w", err)
}
candidate := argon2.IDKey([]byte(plain), salt, t, mem, p, uint32(len(hash)))
if subtle.ConstantTimeCompare(candidate, hash) != 1 {
return false, nil
}
return true, nil
}

View File

@@ -0,0 +1,91 @@
package password_test
import (
"strings"
"testing"
"golang.org/x/crypto/bcrypt"
"marktvogt.de/backend/internal/pkg/password"
)
func TestHash_ProducesArgon2idPrefix(t *testing.T) {
h, err := password.Hash("hunter2")
if err != nil {
t.Fatalf("Hash error: %v", err)
}
if !strings.HasPrefix(h, "$argon2id$") {
t.Errorf("hash %q does not start with $argon2id$", h)
}
}
func TestHash_RoundTrip(t *testing.T) {
h, err := password.Hash("correct horse battery staple")
if err != nil {
t.Fatalf("Hash error: %v", err)
}
ok, err := password.Verify("correct horse battery staple", h)
if err != nil {
t.Fatalf("Verify error: %v", err)
}
if !ok {
t.Error("expected Verify to return true for matching password")
}
}
func TestVerify_WrongPassword(t *testing.T) {
h, err := password.Hash("correct")
if err != nil {
t.Fatalf("Hash error: %v", err)
}
ok, err := password.Verify("wrong", h)
if err != nil {
t.Fatalf("Verify error: %v", err)
}
if ok {
t.Error("expected Verify to return false for wrong password")
}
}
func TestHash_UniquePerCall(t *testing.T) {
h1, _ := password.Hash("same")
h2, _ := password.Hash("same")
if h1 == h2 {
t.Error("two hashes of the same password should differ (different salts)")
}
}
func TestNeedsRehash_Argon2idReturnsFalse(t *testing.T) {
h, _ := password.Hash("test")
if password.NeedsRehash(h) {
t.Error("NeedsRehash should return false for argon2id hashes")
}
}
func TestNeedsRehash_BcryptReturnsTrue(t *testing.T) {
h, _ := bcrypt.GenerateFromPassword([]byte("x"), bcrypt.MinCost)
if !password.NeedsRehash(string(h)) {
t.Error("NeedsRehash should return true for bcrypt hashes")
}
}
func TestVerify_BcryptHashVerifiesCorrectly(t *testing.T) {
h, err := bcrypt.GenerateFromPassword([]byte("test"), bcrypt.MinCost)
if err != nil {
t.Fatalf("bcrypt.GenerateFromPassword error: %v", err)
}
ok, verifyErr := password.Verify("test", string(h))
if verifyErr != nil {
t.Fatalf("Verify error: %v", verifyErr)
}
if !ok {
t.Error("expected Verify to return true for correct bcrypt password")
}
}
func TestVerify_UnknownFormatErrors(t *testing.T) {
_, err := password.Verify("test", "not-a-valid-hash")
if err == nil {
t.Error("expected error for unknown hash format")
}
}

View File

@@ -37,7 +37,21 @@ func (s *Server) registerRoutes() {
authSvc := auth.NewService(authRepo, userRepo, tokenSvc, s.cfg.JWT.SessionTTL)
authHandler := auth.NewHandler(authSvc, userRepo)
requireAuth := middleware.RequireAuth(tokenSvc)
auth.RegisterRoutes(v1, authHandler, requireAuth)
// Per-route auth rate limiters (keyed by IP; user_id unavailable before auth completes)
userIDKey := func(c *gin.Context) string {
if uid, ok := c.Get("user_id"); ok {
return fmt.Sprintf("%v", uid)
}
return c.ClientIP()
}
loginLimit := middleware.RateLimitByKey(0.2, 5, middleware.IPKey) // 1 per 5s, burst 5
refreshLimit := middleware.RateLimitByKey(1, 10, middleware.IPKey) // 1/s, burst 10
twoFALimit := middleware.RateLimitByKey(0.2, 5, middleware.IPKey) // 1 per 5s, burst 5
passwordLimit := middleware.RateLimitByKey(0.1, 3, userIDKey) // 1 per 10s, burst 3
magicLinkLimit := middleware.RateLimitByKey(0.1, 3, middleware.IPKey) // 1 per 10s, burst 3
auth.RegisterRoutes(v1, authHandler, requireAuth, loginLimit, refreshLimit, twoFALimit, passwordLimit)
// OAuth routes — disabled until provider apps are configured
// oauthHandler := auth.NewOAuthHandler(s.cfg.OAuth, authSvc, userRepo, authRepo)
@@ -51,7 +65,7 @@ func (s *Server) registerRoutes() {
// Magic link routes
magicLinkHandler := auth.NewMagicLinkHandler(authRepo, userRepo, authSvc, s.cfg.Magic, emailSender, s.cfg.Notification.FrontendURL)
auth.RegisterMagicLinkRoutes(v1, magicLinkHandler)
auth.RegisterMagicLinkRoutes(v1, magicLinkHandler, magicLinkLimit)
// User profile routes
userSvc := user.NewService(userRepo)

View File

@@ -29,11 +29,15 @@ func New(cfg *config.Config, db *pgxpool.Pool, vk valkey.Client) *Server {
router := gin.New()
// NewCORSConfig only errors on bad regexes; config.Load already validates them.
corsCfg, _ := middleware.NewCORSConfig(cfg.CORS.Origins, cfg.CORS.OriginPatterns)
router.Use(
middleware.Recovery(),
middleware.RequestID(),
middleware.Logging(),
middleware.CORS(cfg.CORS.Origins),
middleware.CORS(corsCfg),
middleware.CSRF(corsCfg),
middleware.RateLimit(cfg.Rate.RPS, cfg.Rate.Burst),
)

View File

@@ -0,0 +1,12 @@
DROP INDEX IF EXISTS idx_sessions_refresh_hash;
DROP INDEX IF EXISTS idx_sessions_access_hash;
DROP INDEX IF EXISTS idx_sessions_family_id;
ALTER TABLE sessions
DROP COLUMN IF EXISTS revoked_at,
DROP COLUMN IF EXISTS last_used_at,
DROP COLUMN IF EXISTS absolute_expires_at,
DROP COLUMN IF EXISTS parent_session_id,
DROP COLUMN IF EXISTS family_id,
DROP COLUMN IF EXISTS refresh_token_hash,
DROP COLUMN IF EXISTS access_token_hash;

View File

@@ -0,0 +1,16 @@
-- Workstream 1: extend sessions table for opaque-token session model.
-- Adds new columns needed for the upcoming JWT drop (D2/D3 deploy).
-- Old token_hash column is kept until cleanup migration 000029.
ALTER TABLE sessions
ADD COLUMN IF NOT EXISTS access_token_hash TEXT UNIQUE,
ADD COLUMN IF NOT EXISTS refresh_token_hash TEXT UNIQUE,
ADD COLUMN IF NOT EXISTS family_id UUID NOT NULL DEFAULT uuid_generate_v4(),
ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES sessions(id) ON DELETE SET NULL,
ADD COLUMN IF NOT EXISTS absolute_expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '30 days',
ADD COLUMN IF NOT EXISTS last_used_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ;
CREATE INDEX IF NOT EXISTS idx_sessions_family_id ON sessions (family_id);
CREATE INDEX IF NOT EXISTS idx_sessions_access_hash ON sessions (access_token_hash) WHERE access_token_hash IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_sessions_refresh_hash ON sessions (refresh_token_hash) WHERE refresh_token_hash IS NOT NULL;

View File

@@ -0,0 +1,4 @@
DROP TABLE IF EXISTS totp_backup_codes;
ALTER TABLE totp_secrets
DROP COLUMN IF EXISTS secret_v2;

View File

@@ -0,0 +1,21 @@
-- Workstream 7: TOTP encryption + backup codes schema.
-- secret_v2 stores AES-256-GCM encrypted secret (format: v1:<base64-nonce-ciphertext>).
-- The migration job (cmd/totp-encrypt) populates secret_v2 for existing rows.
-- Column secret is dropped in migration 000030 once all rows have secret_v2.
ALTER TABLE totp_secrets
ADD COLUMN IF NOT EXISTS secret_v2 TEXT;
CREATE TABLE IF NOT EXISTS totp_backup_codes (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
code_hash TEXT NOT NULL,
used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_totp_backup_codes_user_id
ON totp_backup_codes (user_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_totp_backup_codes_user_code
ON totp_backup_codes (user_id, code_hash);