feat(auth): D1 non-breaking security foundations
- CORS: rewrite middleware with Vary: Origin, regex origin patterns, startup validation, and prod boot-fail on empty allowlist; shared CORSConfig exported for CSRF reuse - CSRF: new Origin/Referer check middleware sharing CORS allowlist; Bearer-token clients exempt; mounts globally after CORS - Argon2id: new password package with PHC format, bcrypt dispatch, and NeedsRehash; lazy upgrade on login in auth service - Rate limiting: add RateLimitByKey with custom key function; apply per-route limits to /auth/login, /refresh, /2fa/verify, /auth/magic-link, and /auth/password - apierror: add CSRFMismatch and RefreshReuse error constructors - Migrations: 000027 (session model schema columns for D2/D3), 000028 (TOTP secret_v2 column + totp_backup_codes table) - cmd/totp-encrypt: one-shot job to encrypt existing TOTP secrets
This commit is contained in:
127
backend/cmd/totp-encrypt/main.go
Normal file
127
backend/cmd/totp-encrypt/main.go
Normal file
@@ -0,0 +1,127 @@
|
||||
// totp-encrypt is a one-shot migration job that encrypts plaintext TOTP secrets
|
||||
// in the totp_secrets table. Run once after deploying migration 000028.
|
||||
// Set TOTP_ENCRYPTION_KEY (32-byte AES key, base64-encoded) before running.
|
||||
//
|
||||
// Usage: TOTP_ENCRYPTION_KEY=<b64> DB_HOST=... ./totp-encrypt [--dry-run]
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"marktvogt.de/backend/internal/pkg/crypto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
slog.Error("totp-encrypt failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func run() error {
|
||||
dryRun := len(os.Args) > 1 && os.Args[1] == "--dry-run"
|
||||
|
||||
keyB64 := os.Getenv("TOTP_ENCRYPTION_KEY")
|
||||
if keyB64 == "" {
|
||||
return fmt.Errorf("TOTP_ENCRYPTION_KEY is required")
|
||||
}
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(keyB64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("TOTP_ENCRYPTION_KEY: invalid base64: %w", err)
|
||||
}
|
||||
if len(keyBytes) != 32 {
|
||||
return fmt.Errorf("TOTP_ENCRYPTION_KEY must be exactly 32 bytes (got %d)", len(keyBytes))
|
||||
}
|
||||
var key [32]byte
|
||||
copy(key[:], keyBytes)
|
||||
|
||||
dsn := fmt.Sprintf(
|
||||
"postgres://%s:%s@%s:%s/%s?sslmode=%s",
|
||||
getenv("DB_USER", "marktvogt"),
|
||||
getenv("DB_PASSWORD", "marktvogt"),
|
||||
getenv("DB_HOST", "localhost"),
|
||||
getenv("DB_PORT", "5432"),
|
||||
getenv("DB_NAME", "marktvogt"),
|
||||
getenv("DB_SSLMODE", "disable"),
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
pool, err := pgxpool.New(ctx, dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to database: %w", err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
rows, err := pool.Query(ctx,
|
||||
`SELECT id, secret FROM totp_secrets WHERE secret_v2 IS NULL AND secret != ''`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("querying totp_secrets: %w", err)
|
||||
}
|
||||
|
||||
type row struct {
|
||||
id string
|
||||
secret string
|
||||
}
|
||||
var pending []row
|
||||
for rows.Next() {
|
||||
var r row
|
||||
if scanErr := rows.Scan(&r.id, &r.secret); scanErr != nil {
|
||||
return fmt.Errorf("scanning row: %w", scanErr)
|
||||
}
|
||||
pending = append(pending, r)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
slog.Info("rows to encrypt", "count", len(pending), "dry_run", dryRun)
|
||||
|
||||
if dryRun {
|
||||
slog.Info("dry run — no changes written")
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, failed := 0, 0
|
||||
for _, r := range pending {
|
||||
ciphertext, sealErr := crypto.Seal(key, []byte(r.secret))
|
||||
if sealErr != nil {
|
||||
slog.Error("encrypting secret", "id", r.id, "error", sealErr)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
encoded := "v1:" + base64.StdEncoding.EncodeToString(ciphertext)
|
||||
|
||||
_, execErr := pool.Exec(ctx,
|
||||
`UPDATE totp_secrets SET secret_v2 = $1 WHERE id = $2`,
|
||||
encoded, r.id)
|
||||
if execErr != nil {
|
||||
if errors.Is(execErr, pgx.ErrNoRows) {
|
||||
slog.Warn("row disappeared during migration", "id", r.id)
|
||||
} else {
|
||||
slog.Error("updating secret_v2", "id", r.id, "error", execErr)
|
||||
failed++
|
||||
}
|
||||
continue
|
||||
}
|
||||
ok++
|
||||
}
|
||||
|
||||
slog.Info("encryption complete", "encrypted", ok, "failed", failed)
|
||||
if failed > 0 {
|
||||
return fmt.Errorf("%d rows failed to encrypt", failed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getenv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -85,7 +87,8 @@ type JWTConfig struct {
|
||||
}
|
||||
|
||||
type CORSConfig struct {
|
||||
Origins string
|
||||
Origins []string
|
||||
OriginPatterns []string
|
||||
}
|
||||
|
||||
type RateConfig struct {
|
||||
@@ -132,7 +135,26 @@ type NotificationConfig struct {
|
||||
FrontendURL string
|
||||
}
|
||||
|
||||
const envDevelopment = "development"
|
||||
|
||||
func Load() (*Config, error) {
|
||||
appEnv := envStr("APP_ENV", envDevelopment)
|
||||
|
||||
corsOrigins := envStrSlice("CORS_ORIGINS")
|
||||
if len(corsOrigins) == 0 && appEnv == envDevelopment {
|
||||
corsOrigins = []string{"http://localhost:5173"}
|
||||
}
|
||||
corsPatterns := envStrSlice("CORS_ORIGIN_PATTERNS")
|
||||
if len(corsOrigins) == 0 && len(corsPatterns) == 0 && appEnv != envDevelopment {
|
||||
return nil, fmt.Errorf("CORS_ORIGINS or CORS_ORIGIN_PATTERNS is required in non-dev environments")
|
||||
}
|
||||
for _, p := range corsPatterns {
|
||||
if _, err := regexp.Compile(p); err != nil {
|
||||
return nil, fmt.Errorf("CORS_ORIGIN_PATTERNS: invalid regex %q: %w", p, err)
|
||||
}
|
||||
}
|
||||
slog.Info("cors allowlist loaded", "exact_origins", corsOrigins, "pattern_count", len(corsPatterns))
|
||||
|
||||
port, err := envInt("APP_PORT", 8080)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("APP_PORT: %w", err)
|
||||
@@ -205,7 +227,7 @@ func Load() (*Config, error) {
|
||||
|
||||
return &Config{
|
||||
App: AppConfig{
|
||||
Env: envStr("APP_ENV", "development"),
|
||||
Env: appEnv,
|
||||
Host: envStr("APP_HOST", "0.0.0.0"),
|
||||
Port: port,
|
||||
},
|
||||
@@ -230,7 +252,8 @@ func Load() (*Config, error) {
|
||||
SessionTTL: sessionTTL,
|
||||
},
|
||||
CORS: CORSConfig{
|
||||
Origins: envStr("CORS_ORIGINS", "http://localhost:5173"),
|
||||
Origins: corsOrigins,
|
||||
OriginPatterns: corsPatterns,
|
||||
},
|
||||
Rate: RateConfig{
|
||||
RPS: rps,
|
||||
@@ -297,7 +320,7 @@ func (c *Config) Addr() string {
|
||||
}
|
||||
|
||||
func (c *Config) IsDev() bool {
|
||||
return c.App.Env == "development"
|
||||
return c.App.Env == envDevelopment
|
||||
}
|
||||
|
||||
func envStr(key, fallback string) string {
|
||||
@@ -331,6 +354,21 @@ func envFloat(key string, fallback float64) (float64, error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func envStrSlice(key string) []string {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
var result []string
|
||||
for _, s := range strings.Split(v, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func envDuration(key string, fallback time.Duration) (time.Duration, error) {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
|
||||
@@ -168,7 +168,7 @@ func (h *MagicLinkHandler) findOrCreateUser(ctx context.Context, email string) (
|
||||
return h.userRepo.CreateOAuthUser(ctx, email, user.GenerateDisplayName(), true)
|
||||
}
|
||||
|
||||
func RegisterMagicLinkRoutes(rg *gin.RouterGroup, h *MagicLinkHandler) {
|
||||
rg.POST("/auth/magic-link", h.RequestMagicLink)
|
||||
func RegisterMagicLinkRoutes(rg *gin.RouterGroup, h *MagicLinkHandler, requestLimit gin.HandlerFunc) {
|
||||
rg.POST("/auth/magic-link", requestLimit, h.RequestMagicLink)
|
||||
rg.GET("/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
}
|
||||
|
||||
@@ -2,20 +2,20 @@ package auth
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func RegisterRoutes(rg *gin.RouterGroup, h *Handler, requireAuth gin.HandlerFunc) {
|
||||
func RegisterRoutes(rg *gin.RouterGroup, h *Handler, requireAuth, loginLimit, refreshLimit, twoFALimit, passwordLimit gin.HandlerFunc) {
|
||||
auth := rg.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Register)
|
||||
auth.POST("/login", h.Login)
|
||||
auth.POST("/login", loginLimit, h.Login)
|
||||
auth.POST("/logout", requireAuth, h.Logout)
|
||||
auth.POST("/refresh", h.Refresh)
|
||||
auth.POST("/refresh", refreshLimit, h.Refresh)
|
||||
|
||||
// Password
|
||||
auth.PUT("/password", requireAuth, h.ChangePassword)
|
||||
auth.PUT("/password", requireAuth, passwordLimit, h.ChangePassword)
|
||||
|
||||
// 2FA
|
||||
auth.POST("/2fa/setup", requireAuth, h.SetupTOTP)
|
||||
auth.POST("/2fa/verify", requireAuth, h.VerifyTOTP)
|
||||
auth.POST("/2fa/verify", requireAuth, twoFALimit, h.VerifyTOTP)
|
||||
auth.DELETE("/2fa", requireAuth, h.DisableTOTP)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"marktvogt.de/backend/internal/domain/user"
|
||||
"marktvogt.de/backend/internal/pkg/password"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
@@ -29,12 +30,12 @@ func NewService(authRepo Repository, userRepo user.Repository, tokenSvc *TokenSe
|
||||
}
|
||||
|
||||
func (s *Service) Register(ctx context.Context, req RegisterRequest, ip, ua string) (AuthData, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
hash, err := password.Hash(req.Password)
|
||||
if err != nil {
|
||||
return AuthData{}, fmt.Errorf("hashing password: %w", err)
|
||||
}
|
||||
|
||||
u, err := s.userRepo.Create(ctx, req.Email, string(hash), req.DisplayName)
|
||||
u, err := s.userRepo.Create(ctx, req.Email, hash, req.DisplayName)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrEmailAlreadyTaken) {
|
||||
return AuthData{}, err
|
||||
@@ -58,10 +59,20 @@ func (s *Service) Login(ctx context.Context, req LoginRequest, ip, ua string) (A
|
||||
return AuthData{}, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(*u.PasswordHash), []byte(req.Password)); err != nil {
|
||||
ok, err := password.Verify(req.Password, *u.PasswordHash)
|
||||
if err != nil || !ok {
|
||||
return AuthData{}, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
// Lazily upgrade bcrypt hashes to Argon2id on successful login.
|
||||
if password.NeedsRehash(*u.PasswordHash) {
|
||||
if newHash, hashErr := password.Hash(req.Password); hashErr == nil {
|
||||
if _, updateErr := s.userRepo.Update(ctx, u.ID, map[string]any{"password_hash": newHash}); updateErr != nil {
|
||||
slog.Warn("password rehash failed", "user_id", u.ID, "error", updateErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check 2FA if enabled
|
||||
if req.TOTPCode != "" {
|
||||
if err := s.validateTOTP(ctx, u.ID, req.TOTPCode); err != nil {
|
||||
@@ -165,17 +176,18 @@ func (s *Service) ChangePassword(ctx context.Context, userID uuid.UUID, req Chan
|
||||
if req.CurrentPassword == "" {
|
||||
return fmt.Errorf("current password required")
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(*u.PasswordHash), []byte(req.CurrentPassword)); err != nil {
|
||||
ok, err := password.Verify(req.CurrentPassword, *u.PasswordHash)
|
||||
if err != nil || !ok {
|
||||
return fmt.Errorf("current password incorrect")
|
||||
}
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
|
||||
hash, err := password.Hash(req.NewPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hashing password: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.userRepo.Update(ctx, userID, map[string]any{"password_hash": string(hash)})
|
||||
_, err = s.userRepo.Update(ctx, userID, map[string]any{"password_hash": hash})
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating password: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,36 +1,77 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"regexp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CORS(allowedOrigins string) gin.HandlerFunc {
|
||||
origins := make(map[string]bool)
|
||||
for _, o := range strings.Split(allowedOrigins, ",") {
|
||||
o = strings.TrimSpace(o)
|
||||
if o != "" {
|
||||
origins[o] = true
|
||||
// CORSConfig holds the parsed CORS allowlist. Build via NewCORSConfig to
|
||||
// enable regex-pattern support; direct struct literals support exact Origins only.
|
||||
type CORSConfig struct {
|
||||
Origins []string
|
||||
compiled []*regexp.Regexp
|
||||
}
|
||||
|
||||
// NewCORSConfig compiles regex patterns and returns a ready CORSConfig.
|
||||
// Returns an error if any pattern fails to compile.
|
||||
func NewCORSConfig(origins []string, patterns []string) (CORSConfig, error) {
|
||||
cfg := CORSConfig{Origins: origins}
|
||||
for _, p := range patterns {
|
||||
re, err := regexp.Compile(p)
|
||||
if err != nil {
|
||||
return CORSConfig{}, err
|
||||
}
|
||||
cfg.compiled = append(cfg.compiled, re)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// IsAllowedOrigin returns true if origin matches an exact entry or a compiled pattern.
|
||||
func (c CORSConfig) IsAllowedOrigin(origin string) bool {
|
||||
if origin == "" {
|
||||
return false
|
||||
}
|
||||
for _, o := range c.Origins {
|
||||
if o == origin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, re := range c.compiled {
|
||||
if re.MatchString(origin) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CORS returns a middleware that sets CORS headers for allowed origins.
|
||||
// Vary: Origin is always set so caches key on origin correctly.
|
||||
func CORS(cfg CORSConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.GetHeader("Origin")
|
||||
|
||||
if origins[origin] {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, X-Request-ID")
|
||||
c.Header("Access-Control-Expose-Headers", "X-Request-ID")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
}
|
||||
// Always set Vary: Origin so caches don't serve the wrong response across origins.
|
||||
c.Header("Vary", "Origin")
|
||||
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
if cfg.IsAllowedOrigin(origin) {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Expose-Headers", "X-Request-ID")
|
||||
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, X-Request-ID")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
// Extend Vary for preflight so caches differentiate by request method/headers too.
|
||||
c.Header("Vary", "Origin, Access-Control-Request-Method, Access-Control-Request-Headers")
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
} else if origin != "" {
|
||||
slog.Warn("cors origin rejected", "origin", origin, "path", c.Request.URL.Path, "method", c.Request.Method)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
|
||||
155
backend/internal/middleware/cors_test.go
Normal file
155
backend/internal/middleware/cors_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"marktvogt.de/backend/internal/middleware"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func corsRouter(cfg middleware.CORSConfig) *gin.Engine {
|
||||
r := gin.New()
|
||||
r.Use(middleware.CORS(cfg))
|
||||
r.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
r.POST("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
return r
|
||||
}
|
||||
|
||||
func TestCORS_AllowedOriginSetsHeaders(t *testing.T) {
|
||||
r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Origin", "https://marktvogt.de")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "https://marktvogt.de" {
|
||||
t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "https://marktvogt.de")
|
||||
}
|
||||
if got := w.Header().Get("Vary"); got == "" {
|
||||
t.Error("Vary header missing")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_VaryOriginAlwaysPresent(t *testing.T) {
|
||||
r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}})
|
||||
|
||||
// Even for requests without an Origin header, Vary: Origin must be set
|
||||
// so caches know the response differs by origin.
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Vary"); got == "" {
|
||||
t.Error("Vary header should be set even when no Origin is present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_DisallowedOriginNoAccessControlHeaders(t *testing.T) {
|
||||
r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Origin", "https://evil.example")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
|
||||
t.Errorf("Access-Control-Allow-Origin should be empty for disallowed origin, got %q", got)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "" {
|
||||
t.Errorf("Access-Control-Allow-Credentials should be empty for disallowed origin, got %q", got)
|
||||
}
|
||||
// Request still executes — CORS only controls browser behaviour
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_PreflightReturns204WithVary(t *testing.T) {
|
||||
r := corsRouter(middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
|
||||
req.Header.Set("Origin", "https://marktvogt.de")
|
||||
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||
req.Header.Set("Access-Control-Request-Headers", "Content-Type")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("preflight status = %d, want 204", w.Code)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "https://marktvogt.de" {
|
||||
t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "https://marktvogt.de")
|
||||
}
|
||||
vary := w.Header().Get("Vary")
|
||||
if vary == "" {
|
||||
t.Error("Vary header missing on preflight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_RegexPatternAllowedOrigin(t *testing.T) {
|
||||
cfg, err := middleware.NewCORSConfig(nil, []string{`^https://[a-z0-9-]+\.marktvogt\.de$`})
|
||||
if err != nil {
|
||||
t.Fatalf("NewCORSConfig error: %v", err)
|
||||
}
|
||||
r := corsRouter(cfg)
|
||||
|
||||
for _, origin := range []string{"https://admin.marktvogt.de", "https://staging.marktvogt.de"} {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Origin", origin)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != origin {
|
||||
t.Errorf("origin %q: Access-Control-Allow-Origin = %q, want %q", origin, got, origin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_RegexPatternRejectsNonMatch(t *testing.T) {
|
||||
cfg, err := middleware.NewCORSConfig(nil, []string{`^https://[a-z0-9-]+\.marktvogt\.de$`})
|
||||
if err != nil {
|
||||
t.Fatalf("NewCORSConfig error: %v", err)
|
||||
}
|
||||
r := corsRouter(cfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Origin", "https://evil.marktvogt.de.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
|
||||
t.Errorf("expected no ACAO header for non-matching origin, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_InvalidRegexPatternErrors(t *testing.T) {
|
||||
_, err := middleware.NewCORSConfig(nil, []string{"[invalid-regex"})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid regex pattern, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_IsAllowedOriginHelper(t *testing.T) {
|
||||
cfg := middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}}
|
||||
|
||||
if !cfg.IsAllowedOrigin("https://marktvogt.de") {
|
||||
t.Error("expected true for allowed origin")
|
||||
}
|
||||
if cfg.IsAllowedOrigin("https://evil.example") {
|
||||
t.Error("expected false for disallowed origin")
|
||||
}
|
||||
if cfg.IsAllowedOrigin("") {
|
||||
t.Error("expected false for empty origin")
|
||||
}
|
||||
}
|
||||
50
backend/internal/middleware/csrf.go
Normal file
50
backend/internal/middleware/csrf.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"marktvogt.de/backend/internal/pkg/apierror"
|
||||
)
|
||||
|
||||
// CSRF returns middleware that validates the Origin (or Referer) header for
|
||||
// state-changing cookie-authed requests. Bearer-token requests (mobile/API)
|
||||
// are exempt — they're not CSRF-vulnerable.
|
||||
func CSRF(cfg CORSConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
method := c.Request.Method
|
||||
|
||||
// Safe methods and preflight are not CSRF-vulnerable.
|
||||
if method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Bearer-token clients (mobile, third-party API) are exempt.
|
||||
if strings.HasPrefix(c.GetHeader("Authorization"), "Bearer ") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
origin := c.GetHeader("Origin")
|
||||
if origin == "" {
|
||||
// Fall back to Referer — extract scheme+host as origin.
|
||||
if ref := c.GetHeader("Referer"); ref != "" {
|
||||
if u, err := url.Parse(ref); err == nil && u.Host != "" {
|
||||
origin = u.Scheme + "://" + u.Host
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !cfg.IsAllowedOrigin(origin) {
|
||||
apiErr := apierror.CSRFMismatch()
|
||||
c.AbortWithStatusJSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
141
backend/internal/middleware/csrf_test.go
Normal file
141
backend/internal/middleware/csrf_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"marktvogt.de/backend/internal/middleware"
|
||||
)
|
||||
|
||||
func csrfRouter(cfg middleware.CORSConfig) *gin.Engine {
|
||||
r := gin.New()
|
||||
r.Use(middleware.CSRF(cfg))
|
||||
r.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
r.POST("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
r.PUT("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
r.DELETE("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
r.Handle(http.MethodOptions, "/test", func(c *gin.Context) { c.Status(http.StatusNoContent) })
|
||||
return r
|
||||
}
|
||||
|
||||
var csrfCfg = middleware.CORSConfig{Origins: []string{"https://marktvogt.de"}}
|
||||
|
||||
func TestCSRF_AllowedOriginPasses(t *testing.T) {
|
||||
r := csrfRouter(csrfCfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Origin", "https://marktvogt.de")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF_DisallowedOriginReturns403(t *testing.T) {
|
||||
r := csrfRouter(csrfCfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Origin", "https://evil.example")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status = %d, want 403", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF_NoOriginNoRefererReturns403(t *testing.T) {
|
||||
r := csrfRouter(csrfCfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status = %d, want 403", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF_AllowedRefererFallback(t *testing.T) {
|
||||
r := csrfRouter(csrfCfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Referer", "https://marktvogt.de/some/page")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF_DisallowedRefererReturns403(t *testing.T) {
|
||||
r := csrfRouter(csrfCfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Referer", "https://evil.example/page")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status = %d, want 403", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF_GetAlwaysPasses(t *testing.T) {
|
||||
r := csrfRouter(csrfCfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
// no Origin, no Referer — still passes
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("GET status = %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF_OptionsAlwaysPasses(t *testing.T) {
|
||||
r := csrfRouter(csrfCfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// OPTIONS is handled by CORS; CSRF skips it
|
||||
if w.Code == http.StatusForbidden {
|
||||
t.Error("OPTIONS should not be blocked by CSRF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF_BearerTokenExempt(t *testing.T) {
|
||||
r := csrfRouter(csrfCfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer some.token.here")
|
||||
// no Origin — would normally fail CSRF
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Bearer request status = %d, want 200 (exempt from CSRF)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF_DeleteWithAllowedOriginPasses(t *testing.T) {
|
||||
r := csrfRouter(csrfCfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/test", nil)
|
||||
req.Header.Set("Origin", "https://marktvogt.de")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("DELETE status = %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -47,12 +47,16 @@ func (l *ipLimiter) get(ip string) *rate.Limiter {
|
||||
return limiter
|
||||
}
|
||||
|
||||
func RateLimit(rps float64, burst int) gin.HandlerFunc {
|
||||
// IPKey is a KeyFn that returns the client IP — the default key for RateLimitByKey.
|
||||
func IPKey(c *gin.Context) string { return c.ClientIP() }
|
||||
|
||||
// RateLimitByKey creates a per-key rate limiter. keyFn extracts the rate-limit
|
||||
// key from the request (e.g. IPKey, or a composite IP+email key).
|
||||
func RateLimitByKey(rps float64, burst int, keyFn func(*gin.Context) string) gin.HandlerFunc {
|
||||
limiter := newIPLimiter(rps, burst)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
if !limiter.get(ip).Allow() {
|
||||
if !limiter.get(keyFn(c)).Allow() {
|
||||
apiErr := apierror.TooManyRequests()
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, apierror.NewResponse(apiErr))
|
||||
return
|
||||
@@ -60,3 +64,7 @@ func RateLimit(rps float64, burst int) gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimit(rps float64, burst int) gin.HandlerFunc {
|
||||
return RateLimitByKey(rps, burst, IPKey)
|
||||
}
|
||||
|
||||
79
backend/internal/middleware/ratelimit_test.go
Normal file
79
backend/internal/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"marktvogt.de/backend/internal/middleware"
|
||||
)
|
||||
|
||||
func rateLimitRouter(mw gin.HandlerFunc) *gin.Engine {
|
||||
r := gin.New()
|
||||
r.Use(mw)
|
||||
r.POST("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
return r
|
||||
}
|
||||
|
||||
func TestRateLimitByKey_BlocksAfterBurst(t *testing.T) {
|
||||
// burst=2, zero refill rate — allows exactly 2 requests then blocks
|
||||
keyConst := func(*gin.Context) string { return "same-key" }
|
||||
r := rateLimitRouter(middleware.RateLimitByKey(0, 2, keyConst))
|
||||
|
||||
for i := 1; i <= 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("request %d: status = %d, want 200 (within burst)", i, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("request 3: status = %d, want 429 (burst exhausted)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitByKey_DifferentKeysAreIndependent(t *testing.T) {
|
||||
// burst=1 — each unique key gets its own limiter with 1 token
|
||||
counter := 0
|
||||
keyFn := func(*gin.Context) string {
|
||||
counter++
|
||||
return string(rune('A' + counter - 1)) // A, B, C, ...
|
||||
}
|
||||
r := rateLimitRouter(middleware.RateLimitByKey(0, 1, keyFn))
|
||||
|
||||
// Each request uses a new key → each should get through
|
||||
for i := 0; i < 3; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("request %d: status = %d, want 200 (independent keys)", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitByKey_IPKeyFunctionMatchesRateLimit(t *testing.T) {
|
||||
// IPKey is the same as standard RateLimit key function
|
||||
r := rateLimitRouter(middleware.RateLimitByKey(0, 1, middleware.IPKey))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("first request status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w2, req2)
|
||||
if w2.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("second request status = %d, want 429", w2.Code)
|
||||
}
|
||||
}
|
||||
@@ -87,3 +87,11 @@ func Validation(message string) *Error {
|
||||
func Gone(message string) *Error {
|
||||
return &Error{Status: http.StatusGone, Code: "gone", Message: message}
|
||||
}
|
||||
|
||||
func CSRFMismatch() *Error {
|
||||
return &Error{Status: http.StatusForbidden, Code: "auth.csrf_mismatch", Message: "CSRF validation failed"}
|
||||
}
|
||||
|
||||
func RefreshReuse() *Error {
|
||||
return &Error{Status: http.StatusUnauthorized, Code: "auth.refresh_reuse_detected", Message: "session token reuse detected"}
|
||||
}
|
||||
|
||||
94
backend/internal/pkg/password/password.go
Normal file
94
backend/internal/pkg/password/password.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package password
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Argon2id parameters (RFC 9106 recommended minimum for interactive logins).
|
||||
const (
|
||||
argonTime = 3
|
||||
argonMemory = 64 * 1024 // 64 MiB
|
||||
argonThreads = 2
|
||||
argonKeyLen = 32
|
||||
argonSaltLen = 16
|
||||
)
|
||||
|
||||
// Hash produces a PHC-format Argon2id hash of the plaintext password.
|
||||
func Hash(plain string) (string, error) {
|
||||
salt := make([]byte, argonSaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", fmt.Errorf("generating salt: %w", err)
|
||||
}
|
||||
key := argon2.IDKey([]byte(plain), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
|
||||
encoded := fmt.Sprintf(
|
||||
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version,
|
||||
argonMemory, argonTime, argonThreads,
|
||||
base64.RawStdEncoding.EncodeToString(salt),
|
||||
base64.RawStdEncoding.EncodeToString(key),
|
||||
)
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// Verify checks plain against an encoded hash. Supports both Argon2id (PHC
|
||||
// format) and bcrypt ($2a$/$2b$ prefix). Returns (false, nil) for wrong
|
||||
// passwords, (false, err) for unrecognised formats.
|
||||
func Verify(plain, encoded string) (bool, error) {
|
||||
switch {
|
||||
case strings.HasPrefix(encoded, "$argon2id$"):
|
||||
return verifyArgon2id(plain, encoded)
|
||||
case strings.HasPrefix(encoded, "$2a$"), strings.HasPrefix(encoded, "$2b$"), strings.HasPrefix(encoded, "$2y$"):
|
||||
err := bcrypt.CompareHashAndPassword([]byte(encoded), []byte(plain))
|
||||
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||
return false, nil
|
||||
}
|
||||
return err == nil, err
|
||||
default:
|
||||
return false, fmt.Errorf("unrecognised hash format")
|
||||
}
|
||||
}
|
||||
|
||||
// NeedsRehash returns true if the hash was produced by an older algorithm
|
||||
// (currently bcrypt) and should be upgraded to Argon2id.
|
||||
func NeedsRehash(encoded string) bool {
|
||||
return strings.HasPrefix(encoded, "$2a$") ||
|
||||
strings.HasPrefix(encoded, "$2b$") ||
|
||||
strings.HasPrefix(encoded, "$2y$")
|
||||
}
|
||||
|
||||
func verifyArgon2id(plain, encoded string) (bool, error) {
|
||||
parts := strings.Split(encoded, "$")
|
||||
// Expected: ["", "argon2id", "v=19", "m=...,t=...,p=...", "<salt>", "<hash>"]
|
||||
if len(parts) != 6 {
|
||||
return false, fmt.Errorf("invalid argon2id hash format")
|
||||
}
|
||||
|
||||
var mem, t uint32
|
||||
var p uint8
|
||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &mem, &t, &p); err != nil {
|
||||
return false, fmt.Errorf("parsing argon2id params: %w", err)
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("decoding argon2id salt: %w", err)
|
||||
}
|
||||
hash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("decoding argon2id hash: %w", err)
|
||||
}
|
||||
|
||||
candidate := argon2.IDKey([]byte(plain), salt, t, mem, p, uint32(len(hash)))
|
||||
if subtle.ConstantTimeCompare(candidate, hash) != 1 {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
91
backend/internal/pkg/password/password_test.go
Normal file
91
backend/internal/pkg/password/password_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package password_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"marktvogt.de/backend/internal/pkg/password"
|
||||
)
|
||||
|
||||
func TestHash_ProducesArgon2idPrefix(t *testing.T) {
|
||||
h, err := password.Hash("hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("Hash error: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(h, "$argon2id$") {
|
||||
t.Errorf("hash %q does not start with $argon2id$", h)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHash_RoundTrip(t *testing.T) {
|
||||
h, err := password.Hash("correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("Hash error: %v", err)
|
||||
}
|
||||
ok, err := password.Verify("correct horse battery staple", h)
|
||||
if err != nil {
|
||||
t.Fatalf("Verify error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Error("expected Verify to return true for matching password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerify_WrongPassword(t *testing.T) {
|
||||
h, err := password.Hash("correct")
|
||||
if err != nil {
|
||||
t.Fatalf("Hash error: %v", err)
|
||||
}
|
||||
ok, err := password.Verify("wrong", h)
|
||||
if err != nil {
|
||||
t.Fatalf("Verify error: %v", err)
|
||||
}
|
||||
if ok {
|
||||
t.Error("expected Verify to return false for wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHash_UniquePerCall(t *testing.T) {
|
||||
h1, _ := password.Hash("same")
|
||||
h2, _ := password.Hash("same")
|
||||
if h1 == h2 {
|
||||
t.Error("two hashes of the same password should differ (different salts)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsRehash_Argon2idReturnsFalse(t *testing.T) {
|
||||
h, _ := password.Hash("test")
|
||||
if password.NeedsRehash(h) {
|
||||
t.Error("NeedsRehash should return false for argon2id hashes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsRehash_BcryptReturnsTrue(t *testing.T) {
|
||||
h, _ := bcrypt.GenerateFromPassword([]byte("x"), bcrypt.MinCost)
|
||||
if !password.NeedsRehash(string(h)) {
|
||||
t.Error("NeedsRehash should return true for bcrypt hashes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerify_BcryptHashVerifiesCorrectly(t *testing.T) {
|
||||
h, err := bcrypt.GenerateFromPassword([]byte("test"), bcrypt.MinCost)
|
||||
if err != nil {
|
||||
t.Fatalf("bcrypt.GenerateFromPassword error: %v", err)
|
||||
}
|
||||
ok, verifyErr := password.Verify("test", string(h))
|
||||
if verifyErr != nil {
|
||||
t.Fatalf("Verify error: %v", verifyErr)
|
||||
}
|
||||
if !ok {
|
||||
t.Error("expected Verify to return true for correct bcrypt password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerify_UnknownFormatErrors(t *testing.T) {
|
||||
_, err := password.Verify("test", "not-a-valid-hash")
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown hash format")
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,21 @@ func (s *Server) registerRoutes() {
|
||||
authSvc := auth.NewService(authRepo, userRepo, tokenSvc, s.cfg.JWT.SessionTTL)
|
||||
authHandler := auth.NewHandler(authSvc, userRepo)
|
||||
requireAuth := middleware.RequireAuth(tokenSvc)
|
||||
auth.RegisterRoutes(v1, authHandler, requireAuth)
|
||||
|
||||
// Per-route auth rate limiters (keyed by IP; user_id unavailable before auth completes)
|
||||
userIDKey := func(c *gin.Context) string {
|
||||
if uid, ok := c.Get("user_id"); ok {
|
||||
return fmt.Sprintf("%v", uid)
|
||||
}
|
||||
return c.ClientIP()
|
||||
}
|
||||
loginLimit := middleware.RateLimitByKey(0.2, 5, middleware.IPKey) // 1 per 5s, burst 5
|
||||
refreshLimit := middleware.RateLimitByKey(1, 10, middleware.IPKey) // 1/s, burst 10
|
||||
twoFALimit := middleware.RateLimitByKey(0.2, 5, middleware.IPKey) // 1 per 5s, burst 5
|
||||
passwordLimit := middleware.RateLimitByKey(0.1, 3, userIDKey) // 1 per 10s, burst 3
|
||||
magicLinkLimit := middleware.RateLimitByKey(0.1, 3, middleware.IPKey) // 1 per 10s, burst 3
|
||||
|
||||
auth.RegisterRoutes(v1, authHandler, requireAuth, loginLimit, refreshLimit, twoFALimit, passwordLimit)
|
||||
|
||||
// OAuth routes — disabled until provider apps are configured
|
||||
// oauthHandler := auth.NewOAuthHandler(s.cfg.OAuth, authSvc, userRepo, authRepo)
|
||||
@@ -51,7 +65,7 @@ func (s *Server) registerRoutes() {
|
||||
|
||||
// Magic link routes
|
||||
magicLinkHandler := auth.NewMagicLinkHandler(authRepo, userRepo, authSvc, s.cfg.Magic, emailSender, s.cfg.Notification.FrontendURL)
|
||||
auth.RegisterMagicLinkRoutes(v1, magicLinkHandler)
|
||||
auth.RegisterMagicLinkRoutes(v1, magicLinkHandler, magicLinkLimit)
|
||||
|
||||
// User profile routes
|
||||
userSvc := user.NewService(userRepo)
|
||||
|
||||
@@ -29,11 +29,15 @@ func New(cfg *config.Config, db *pgxpool.Pool, vk valkey.Client) *Server {
|
||||
|
||||
router := gin.New()
|
||||
|
||||
// NewCORSConfig only errors on bad regexes; config.Load already validates them.
|
||||
corsCfg, _ := middleware.NewCORSConfig(cfg.CORS.Origins, cfg.CORS.OriginPatterns)
|
||||
|
||||
router.Use(
|
||||
middleware.Recovery(),
|
||||
middleware.RequestID(),
|
||||
middleware.Logging(),
|
||||
middleware.CORS(cfg.CORS.Origins),
|
||||
middleware.CORS(corsCfg),
|
||||
middleware.CSRF(corsCfg),
|
||||
middleware.RateLimit(cfg.Rate.RPS, cfg.Rate.Burst),
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
DROP INDEX IF EXISTS idx_sessions_refresh_hash;
|
||||
DROP INDEX IF EXISTS idx_sessions_access_hash;
|
||||
DROP INDEX IF EXISTS idx_sessions_family_id;
|
||||
|
||||
ALTER TABLE sessions
|
||||
DROP COLUMN IF EXISTS revoked_at,
|
||||
DROP COLUMN IF EXISTS last_used_at,
|
||||
DROP COLUMN IF EXISTS absolute_expires_at,
|
||||
DROP COLUMN IF EXISTS parent_session_id,
|
||||
DROP COLUMN IF EXISTS family_id,
|
||||
DROP COLUMN IF EXISTS refresh_token_hash,
|
||||
DROP COLUMN IF EXISTS access_token_hash;
|
||||
@@ -0,0 +1,16 @@
|
||||
-- Workstream 1: extend sessions table for opaque-token session model.
|
||||
-- Adds new columns needed for the upcoming JWT drop (D2/D3 deploy).
|
||||
-- Old token_hash column is kept until cleanup migration 000029.
|
||||
|
||||
ALTER TABLE sessions
|
||||
ADD COLUMN IF NOT EXISTS access_token_hash TEXT UNIQUE,
|
||||
ADD COLUMN IF NOT EXISTS refresh_token_hash TEXT UNIQUE,
|
||||
ADD COLUMN IF NOT EXISTS family_id UUID NOT NULL DEFAULT uuid_generate_v4(),
|
||||
ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES sessions(id) ON DELETE SET NULL,
|
||||
ADD COLUMN IF NOT EXISTS absolute_expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '30 days',
|
||||
ADD COLUMN IF NOT EXISTS last_used_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_family_id ON sessions (family_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_access_hash ON sessions (access_token_hash) WHERE access_token_hash IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_refresh_hash ON sessions (refresh_token_hash) WHERE refresh_token_hash IS NOT NULL;
|
||||
@@ -0,0 +1,4 @@
|
||||
DROP TABLE IF EXISTS totp_backup_codes;
|
||||
|
||||
ALTER TABLE totp_secrets
|
||||
DROP COLUMN IF EXISTS secret_v2;
|
||||
@@ -0,0 +1,21 @@
|
||||
-- Workstream 7: TOTP encryption + backup codes schema.
|
||||
-- secret_v2 stores AES-256-GCM encrypted secret (format: v1:<base64-nonce-ciphertext>).
|
||||
-- The migration job (cmd/totp-encrypt) populates secret_v2 for existing rows.
|
||||
-- Column secret is dropped in migration 000030 once all rows have secret_v2.
|
||||
|
||||
ALTER TABLE totp_secrets
|
||||
ADD COLUMN IF NOT EXISTS secret_v2 TEXT;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS totp_backup_codes (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
code_hash TEXT NOT NULL,
|
||||
used_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_totp_backup_codes_user_id
|
||||
ON totp_backup_codes (user_id);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_totp_backup_codes_user_code
|
||||
ON totp_backup_codes (user_id, code_hash);
|
||||
Reference in New Issue
Block a user