From 492bbb350e6f87019bfff827acc19305bf0bd336 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Sun, 26 Apr 2026 12:15:57 +0200 Subject: [PATCH] =?UTF-8?q?feat(auth):=20D2/D3=20opaque-token=20session=20?= =?UTF-8?q?model=20=E2=80=94=20drop=20JWT?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace HS256 JWT access tokens with two opaque 32-byte random tokens (access + refresh), both stored as SHA-256 hashes in sessions + Valkey. Key changes: - GenerateOpaqueToken() replaces JWT issuance; TokenService removed - Sessions now carry access_token_hash, refresh_token_hash, family_id, parent_session_id, access_expires_at, absolute_expires_at, last_used_at, revoked_at — per migration 000027 (updated to add access_expires_at) - Refresh rotation is atomic (UPDATE...RETURNING); reuse detection kills the entire token family and returns auth.refresh_reuse_detected - RequireAuth/OptionalAuth now take SessionLookup (Valkey→Postgres) instead of *TokenService; sets session_id in context alongside user_id - last_used_at is bumped on each request, throttled to writes >60s old - AuthConfig{AccessTTL,RefreshIdleTTL,RefreshAbsoluteTTL} replaces JWT TTL env vars (AUTH_ACCESS_TTL=30m, AUTH_REFRESH_IDLE_TTL=168h, AUTH_REFRESH_ABSOLUTE_TTL=720h) - JWT_SECRET kept for AI-settings key derivation (drops from auth flow) Forced logout on deploy (D3 behaviour); pre-launch so acceptable. --- backend/internal/config/config.go | 33 +- backend/internal/domain/auth/dto.go | 3 +- backend/internal/domain/auth/handler.go | 52 +-- backend/internal/domain/auth/magiclink.go | 2 +- backend/internal/domain/auth/model.go | 25 +- backend/internal/domain/auth/oauth.go | 4 +- backend/internal/domain/auth/repository.go | 188 ++++++---- backend/internal/domain/auth/service.go | 127 +++---- .../domain/auth/service_refresh_test.go | 326 ++++++++++++++++++ backend/internal/domain/auth/token.go | 79 +---- backend/internal/middleware/auth.go | 78 +++-- backend/internal/middleware/auth_test.go | 172 +++++++++ backend/internal/server/routes.go | 9 +- ...027_sessions_session_model_rework.down.sql | 1 + ...00027_sessions_session_model_rework.up.sql | 1 + 15 files changed, 830 insertions(+), 270 deletions(-) create mode 100644 backend/internal/domain/auth/service_refresh_test.go create mode 100644 backend/internal/middleware/auth_test.go diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3ec0f0e..bd914fb 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -15,6 +15,7 @@ type Config struct { DB DBConfig Valkey ValkeyConfig JWT JWTConfig + Auth AuthConfig CORS CORSConfig Rate RateConfig Sentry SentryConfig @@ -81,9 +82,13 @@ type ValkeyConfig struct { } type JWTConfig struct { - Secret string - AccessTTL time.Duration - SessionTTL time.Duration + Secret string // kept for settings-encryption key derivation (see routes.go) +} + +type AuthConfig struct { + AccessTTL time.Duration // default 30m + RefreshIdleTTL time.Duration // default 168h (7 days); sliding per-request + RefreshAbsoluteTTL time.Duration // default 720h (30 days); hard limit } type CORSConfig struct { @@ -180,14 +185,19 @@ func Load() (*Config, error) { return nil, fmt.Errorf("VALKEY_DB: %w", err) } - accessTTL, err := envDuration("JWT_ACCESS_TTL", 15*time.Minute) + authAccessTTL, err := envDuration("AUTH_ACCESS_TTL", 30*time.Minute) if err != nil { - return nil, fmt.Errorf("JWT_ACCESS_TTL: %w", err) + return nil, fmt.Errorf("AUTH_ACCESS_TTL: %w", err) } - sessionTTL, err := envDuration("JWT_SESSION_TTL", 720*time.Hour) + authRefreshIdleTTL, err := envDuration("AUTH_REFRESH_IDLE_TTL", 168*time.Hour) if err != nil { - return nil, fmt.Errorf("JWT_SESSION_TTL: %w", err) + return nil, fmt.Errorf("AUTH_REFRESH_IDLE_TTL: %w", err) + } + + authRefreshAbsoluteTTL, err := envDuration("AUTH_REFRESH_ABSOLUTE_TTL", 720*time.Hour) + if err != nil { + return nil, fmt.Errorf("AUTH_REFRESH_ABSOLUTE_TTL: %w", err) } rps, err := envFloat("RATE_LIMIT_RPS", 10) @@ -247,9 +257,12 @@ func Load() (*Config, error) { DB: valkeyDB, }, JWT: JWTConfig{ - Secret: jwtSecret, - AccessTTL: accessTTL, - SessionTTL: sessionTTL, + Secret: jwtSecret, + }, + Auth: AuthConfig{ + AccessTTL: authAccessTTL, + RefreshIdleTTL: authRefreshIdleTTL, + RefreshAbsoluteTTL: authRefreshAbsoluteTTL, }, CORS: CORSConfig{ Origins: corsOrigins, diff --git a/backend/internal/domain/auth/dto.go b/backend/internal/domain/auth/dto.go index 96b0bef..bcdc70f 100644 --- a/backend/internal/domain/auth/dto.go +++ b/backend/internal/domain/auth/dto.go @@ -40,8 +40,7 @@ type AuthResponse struct { type AuthData struct { AccessToken string `json:"access_token"` - SessionToken string `json:"session_token"` - ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` } type MessageResponse struct { diff --git a/backend/internal/domain/auth/handler.go b/backend/internal/domain/auth/handler.go index 9fae2a7..32ee2fd 100644 --- a/backend/internal/domain/auth/handler.go +++ b/backend/internal/domain/auth/handler.go @@ -3,7 +3,6 @@ package auth import ( "errors" "net/http" - "strings" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -73,15 +72,14 @@ func (h *Handler) Login(c *gin.Context) { } func (h *Handler) Logout(c *gin.Context) { - sessionToken := extractSessionToken(c) - if sessionToken == "" { - apiErr := apierror.Unauthorized("session token required") + sessionID := GetSessionID(c) + if sessionID == uuid.Nil { + apiErr := apierror.Unauthorized("not authenticated") c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) return } - tokenHash := HashToken(sessionToken) - if err := h.service.Logout(c.Request.Context(), tokenHash); err != nil { + if err := h.service.Logout(c.Request.Context(), sessionID); err != nil { apiErr := apierror.Internal("logout failed") c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) return @@ -91,16 +89,21 @@ func (h *Handler) Logout(c *gin.Context) { } func (h *Handler) Refresh(c *gin.Context) { - sessionToken := extractSessionToken(c) - if sessionToken == "" { - apiErr := apierror.Unauthorized("session token required") + refreshToken := extractRefreshToken(c) + if refreshToken == "" { + apiErr := apierror.Unauthorized("refresh token required") c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) return } - data, err := h.service.RefreshToken(c.Request.Context(), sessionToken, c.ClientIP(), c.GetHeader("User-Agent")) + data, err := h.service.RefreshToken(c.Request.Context(), refreshToken, c.ClientIP(), c.GetHeader("User-Agent")) if err != nil { - apiErr := apierror.Unauthorized("invalid or expired session") + var apiErr *apierror.Error + if errors.As(err, &apiErr) { + c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) + return + } + apiErr = apierror.Unauthorized("invalid or expired session") c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) return } @@ -184,17 +187,12 @@ func (h *Handler) DisableTOTP(c *gin.Context) { c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "2FA disabled"}}) } -func extractSessionToken(c *gin.Context) string { - header := c.GetHeader("X-Session-Token") - if header != "" { - return header +// extractRefreshToken reads the refresh token from X-Refresh-Token or (legacy) X-Session-Token. +func extractRefreshToken(c *gin.Context) string { + if t := c.GetHeader("X-Refresh-Token"); t != "" { + return t } - // Also check Authorization with Bearer prefix for refresh - auth := c.GetHeader("Authorization") - if strings.HasPrefix(auth, "Session ") { - return strings.TrimPrefix(auth, "Session ") - } - return "" + return c.GetHeader("X-Session-Token") } func GetUserID(c *gin.Context) uuid.UUID { @@ -208,3 +206,15 @@ func GetUserID(c *gin.Context) uuid.UUID { } return id } + +func GetSessionID(c *gin.Context) uuid.UUID { + v, exists := c.Get("session_id") + if !exists { + return uuid.Nil + } + id, ok := v.(uuid.UUID) + if !ok { + return uuid.Nil + } + return id +} diff --git a/backend/internal/domain/auth/magiclink.go b/backend/internal/domain/auth/magiclink.go index a58730e..3261691 100644 --- a/backend/internal/domain/auth/magiclink.go +++ b/backend/internal/domain/auth/magiclink.go @@ -145,7 +145,7 @@ func (h *MagicLinkHandler) VerifyMagicLink(c *gin.Context) { _, _ = h.userRepo.Update(ctx, u.ID, map[string]any{"email_verified": true}) } - data, err := h.service.createTokenPair(ctx, u, c.ClientIP(), c.GetHeader("User-Agent")) + data, err := h.service.createTokenPair(ctx, u, uuid.Nil, nil, c.ClientIP(), c.GetHeader("User-Agent")) if err != nil { apiErr := apierror.Internal("failed to create session") c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) diff --git a/backend/internal/domain/auth/model.go b/backend/internal/domain/auth/model.go index 3ba474a..f69c43e 100644 --- a/backend/internal/domain/auth/model.go +++ b/backend/internal/domain/auth/model.go @@ -6,14 +6,25 @@ import ( "github.com/google/uuid" ) +// Session represents an opaque-token session. AccessTokenHash and RefreshTokenHash +// are SHA-256 hashes of the raw bearer tokens. UserEmail and UserRole are cached +// from the user record at creation time; role changes take effect within accessTTL (≤30m). type Session struct { - ID uuid.UUID `json:"id"` - UserID uuid.UUID `json:"user_id"` - TokenHash string `json:"-"` - IPAddress string `json:"ip_address"` - UserAgent string `json:"user_agent"` - ExpiresAt time.Time `json:"expires_at"` - CreatedAt time.Time `json:"created_at"` + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + UserEmail string `json:"user_email"` + UserRole string `json:"user_role"` + AccessTokenHash string `json:"-"` + RefreshTokenHash string `json:"-"` + FamilyID uuid.UUID `json:"family_id"` + ParentSessionID *uuid.UUID `json:"-"` + IPAddress string `json:"ip_address"` + UserAgent string `json:"user_agent"` + AccessExpiresAt time.Time `json:"access_expires_at"` + AbsoluteExpiresAt time.Time `json:"absolute_expires_at"` + LastUsedAt time.Time `json:"last_used_at"` + RevokedAt *time.Time `json:"revoked_at,omitempty"` + CreatedAt time.Time `json:"created_at"` } type MagicLink struct { diff --git a/backend/internal/domain/auth/oauth.go b/backend/internal/domain/auth/oauth.go index ea84d2c..73084b9 100644 --- a/backend/internal/domain/auth/oauth.go +++ b/backend/internal/domain/auth/oauth.go @@ -136,7 +136,7 @@ func (h *OAuthHandler) Callback(c *gin.Context) { return } - data, err := h.service.createTokenPair(ctx, u, c.ClientIP(), c.GetHeader("User-Agent")) + data, err := h.service.createTokenPair(ctx, u, uuid.Nil, nil, c.ClientIP(), c.GetHeader("User-Agent")) if err != nil { apiErr := apierror.Internal("failed to create session") c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) @@ -186,7 +186,7 @@ func (h *OAuthHandler) Callback(c *gin.Context) { return } - data, err := h.service.createTokenPair(ctx, u, c.ClientIP(), c.GetHeader("User-Agent")) + data, err := h.service.createTokenPair(ctx, u, uuid.Nil, nil, c.ClientIP(), c.GetHeader("User-Agent")) if err != nil { apiErr := apierror.Internal("failed to create session") c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) diff --git a/backend/internal/domain/auth/repository.go b/backend/internal/domain/auth/repository.go index d192d6a..9d5be78 100644 --- a/backend/internal/domain/auth/repository.go +++ b/backend/internal/domain/auth/repository.go @@ -21,20 +21,41 @@ var ( ErrMagicLinkUsed = fmt.Errorf("magic link already used") ) +// RefreshReuseDetectedError is returned by ConsumeRefreshToken when the token +// was already revoked, indicating a stolen-token replay. FamilyID identifies +// the session family that must be fully revoked. +type RefreshReuseDetectedError struct { + FamilyID uuid.UUID +} + +func (e *RefreshReuseDetectedError) Error() string { + return "refresh token reuse detected" +} + type Repository interface { + // Session — opaque-token model CreateSession(ctx context.Context, session Session) error - GetSessionByTokenHash(ctx context.Context, tokenHash string) (Session, error) - DeleteSession(ctx context.Context, id uuid.UUID) error + GetSessionByAccessHash(ctx context.Context, hash string) (Session, error) + // ConsumeRefreshToken atomically revokes the session identified by hash and + // returns it. Returns RefreshReuseDetectedError if already revoked, ErrSessionNotFound + // if the hash is unknown. + ConsumeRefreshToken(ctx context.Context, hash string) (Session, error) + RevokeSession(ctx context.Context, id uuid.UUID) error + RevokeSessionsByFamilyID(ctx context.Context, familyID uuid.UUID) error + BumpLastUsedAt(ctx context.Context, id uuid.UUID) error DeleteUserSessions(ctx context.Context, userID uuid.UUID) error + // Magic links CreateMagicLink(ctx context.Context, link MagicLink) error GetMagicLinkByTokenHash(ctx context.Context, tokenHash string) (MagicLink, error) MarkMagicLinkUsed(ctx context.Context, id uuid.UUID) error + // OAuth accounts CreateOAuthAccount(ctx context.Context, account OAuthAccount) error GetOAuthAccount(ctx context.Context, provider, providerUID string) (OAuthAccount, error) UpdateOAuthTokens(ctx context.Context, id uuid.UUID, accessToken, refreshToken string, expiresAt *time.Time) error + // TOTP CreateTOTPSecret(ctx context.Context, secret TOTPSecret) error GetTOTPSecret(ctx context.Context, userID uuid.UUID) (TOTPSecret, error) VerifyTOTPSecret(ctx context.Context, userID uuid.UUID) error @@ -50,101 +71,128 @@ func NewRepository(db *pgxpool.Pool, vk valkey.Client) Repository { return &pgRepository{db: db, vk: vk} } -// Session methods — dual storage: Valkey (fast) + PostgreSQL (durable) +// Session methods -func (r *pgRepository) CreateSession(ctx context.Context, session Session) error { +func (r *pgRepository) CreateSession(ctx context.Context, s Session) error { _, err := r.db.Exec(ctx, ` - INSERT INTO sessions (id, user_id, token_hash, ip_address, user_agent, expires_at) - VALUES ($1, $2, $3, $4::inet, $5, $6) - `, session.ID, session.UserID, session.TokenHash, session.IPAddress, session.UserAgent, session.ExpiresAt) + INSERT INTO sessions ( + id, user_id, access_token_hash, refresh_token_hash, family_id, parent_session_id, + ip_address, user_agent, access_expires_at, absolute_expires_at, last_used_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7::inet, $8, $9, $10, $11) + `, s.ID, s.UserID, s.AccessTokenHash, s.RefreshTokenHash, s.FamilyID, s.ParentSessionID, + s.IPAddress, s.UserAgent, s.AccessExpiresAt, s.AbsoluteExpiresAt, s.LastUsedAt) if err != nil { - return fmt.Errorf("creating session in postgres: %w", err) + return fmt.Errorf("creating session: %w", err) } - // Cache in Valkey - data, _ := json.Marshal(session) - ttl := time.Until(session.ExpiresAt) - key := sessionValkeyKey(session.TokenHash) - err = r.vk.Do(ctx, r.vk.B().Set().Key(key).Value(string(data)).Ex(ttl).Build()).Error() - if err != nil { - // Log but don't fail — Postgres is the source of truth - fmt.Printf("warning: failed to cache session in valkey: %v\n", err) + data, _ := json.Marshal(s) + ttl := time.Until(s.AccessExpiresAt) + if ttl > 0 { + key := accessValkeyKey(s.AccessTokenHash) + if vkErr := r.vk.Do(ctx, r.vk.B().Set().Key(key).Value(string(data)).Ex(ttl).Build()).Error(); vkErr != nil { + // Valkey failure is non-fatal; Postgres is the source of truth. + fmt.Printf("warning: failed to cache session in valkey: %v\n", vkErr) + } } return nil } -func (r *pgRepository) GetSessionByTokenHash(ctx context.Context, tokenHash string) (Session, error) { - // Try Valkey first - key := sessionValkeyKey(tokenHash) - result, err := r.vk.Do(ctx, r.vk.B().Get().Key(key).Build()).ToString() - if err == nil && result != "" { - var session Session - if json.Unmarshal([]byte(result), &session) == nil { - if time.Now().Before(session.ExpiresAt) { - return session, nil - } - return Session{}, ErrSessionExpired +func (r *pgRepository) GetSessionByAccessHash(ctx context.Context, hash string) (Session, error) { + key := accessValkeyKey(hash) + if result, err := r.vk.Do(ctx, r.vk.B().Get().Key(key).Build()).ToString(); err == nil && result != "" { + var s Session + if json.Unmarshal([]byte(result), &s) == nil { + return s, nil } } - // Fall back to Postgres + // Postgres fallback — join users to get email and role without a second query. var s Session - err = r.db.QueryRow(ctx, ` - SELECT id, user_id, token_hash, ip_address::text, user_agent, expires_at, created_at - FROM sessions - WHERE token_hash = $1 - `, tokenHash).Scan(&s.ID, &s.UserID, &s.TokenHash, &s.IPAddress, &s.UserAgent, &s.ExpiresAt, &s.CreatedAt) + err := r.db.QueryRow(ctx, ` + SELECT s.id, s.user_id, u.email, u.role, + s.access_token_hash, s.refresh_token_hash, s.family_id, s.parent_session_id, + s.ip_address::text, s.user_agent, + s.access_expires_at, s.absolute_expires_at, s.last_used_at, s.revoked_at, s.created_at + FROM sessions s + JOIN users u ON u.id = s.user_id + WHERE s.access_token_hash = $1 + AND s.revoked_at IS NULL + AND s.access_expires_at > NOW() + `, hash).Scan( + &s.ID, &s.UserID, &s.UserEmail, &s.UserRole, + &s.AccessTokenHash, &s.RefreshTokenHash, &s.FamilyID, &s.ParentSessionID, + &s.IPAddress, &s.UserAgent, + &s.AccessExpiresAt, &s.AbsoluteExpiresAt, &s.LastUsedAt, &s.RevokedAt, &s.CreatedAt, + ) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return Session{}, ErrSessionNotFound } return Session{}, fmt.Errorf("getting session: %w", err) } - - if time.Now().After(s.ExpiresAt) { - return Session{}, ErrSessionExpired - } - return s, nil } -func (r *pgRepository) DeleteSession(ctx context.Context, id uuid.UUID) error { - // Get the session first to know the token hash for Valkey cleanup - var tokenHash string - err := r.db.QueryRow(ctx, "SELECT token_hash FROM sessions WHERE id = $1", id).Scan(&tokenHash) +// ConsumeRefreshToken atomically revokes the refresh token and returns the session. +// Uses UPDATE...RETURNING to eliminate the select-then-update race condition. +func (r *pgRepository) ConsumeRefreshToken(ctx context.Context, hash string) (Session, error) { + var s Session + err := r.db.QueryRow(ctx, ` + UPDATE sessions SET revoked_at = NOW() + WHERE refresh_token_hash = $1 AND revoked_at IS NULL + RETURNING id, user_id, family_id, parent_session_id, + ip_address::text, user_agent, + access_expires_at, absolute_expires_at, last_used_at, created_at + `, hash).Scan( + &s.ID, &s.UserID, &s.FamilyID, &s.ParentSessionID, + &s.IPAddress, &s.UserAgent, + &s.AccessExpiresAt, &s.AbsoluteExpiresAt, &s.LastUsedAt, &s.CreatedAt, + ) if err == nil { - key := sessionValkeyKey(tokenHash) - _ = r.vk.Do(ctx, r.vk.B().Del().Key(key).Build()).Error() + s.RefreshTokenHash = hash + return s, nil + } + if !errors.Is(err, pgx.ErrNoRows) { + return Session{}, fmt.Errorf("consuming refresh token: %w", err) } - _, err = r.db.Exec(ctx, "DELETE FROM sessions WHERE id = $1", id) + // Zero rows: either never existed or already revoked. Check which. + var familyID uuid.UUID + lookupErr := r.db.QueryRow(ctx, + `SELECT family_id FROM sessions WHERE refresh_token_hash = $1`, + hash, + ).Scan(&familyID) + if errors.Is(lookupErr, pgx.ErrNoRows) { + return Session{}, ErrSessionNotFound + } + if lookupErr != nil { + return Session{}, fmt.Errorf("reuse lookup: %w", lookupErr) + } + return Session{}, &RefreshReuseDetectedError{FamilyID: familyID} +} + +func (r *pgRepository) RevokeSession(ctx context.Context, id uuid.UUID) error { + _, err := r.db.Exec(ctx, + `UPDATE sessions SET revoked_at = NOW() WHERE id = $1 AND revoked_at IS NULL`, id) + return err +} + +func (r *pgRepository) RevokeSessionsByFamilyID(ctx context.Context, familyID uuid.UUID) error { + _, err := r.db.Exec(ctx, + `UPDATE sessions SET revoked_at = NOW() WHERE family_id = $1 AND revoked_at IS NULL`, familyID) + return err +} + +func (r *pgRepository) BumpLastUsedAt(ctx context.Context, id uuid.UUID) error { + _, err := r.db.Exec(ctx, + `UPDATE sessions SET last_used_at = NOW() WHERE id = $1`, id) return err } func (r *pgRepository) DeleteUserSessions(ctx context.Context, userID uuid.UUID) error { - rows, err := r.db.Query(ctx, "SELECT token_hash FROM sessions WHERE user_id = $1", userID) - if err != nil { - return fmt.Errorf("listing sessions for deletion: %w", err) - } - defer rows.Close() - - var keys []string - for rows.Next() { - var hash string - if err := rows.Scan(&hash); err == nil { - keys = append(keys, sessionValkeyKey(hash)) - } - } - - if len(keys) > 0 { - _ = r.vk.Do(ctx, r.vk.B().Del().Key(keys[0]).Build()).Error() - for _, k := range keys[1:] { - _ = r.vk.Do(ctx, r.vk.B().Del().Key(k).Build()).Error() - } - } - - _, err = r.db.Exec(ctx, "DELETE FROM sessions WHERE user_id = $1", userID) + _, err := r.db.Exec(ctx, + `UPDATE sessions SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL`, userID) return err } @@ -171,14 +219,12 @@ func (r *pgRepository) GetMagicLinkByTokenHash(ctx context.Context, tokenHash st } return MagicLink{}, fmt.Errorf("getting magic link: %w", err) } - if ml.Used { return MagicLink{}, ErrMagicLinkUsed } if time.Now().After(ml.ExpiresAt) { return MagicLink{}, ErrMagicLinkExpired } - return ml, nil } @@ -262,6 +308,6 @@ func (r *pgRepository) DeleteTOTPSecret(ctx context.Context, userID uuid.UUID) e return err } -func sessionValkeyKey(tokenHash string) string { - return "session:" + tokenHash +func accessValkeyKey(hash string) string { + return "mv:v2:session:access:" + hash } diff --git a/backend/internal/domain/auth/service.go b/backend/internal/domain/auth/service.go index 6c3e1f5..9d16ca3 100644 --- a/backend/internal/domain/auth/service.go +++ b/backend/internal/domain/auth/service.go @@ -10,23 +10,24 @@ import ( "github.com/google/uuid" "marktvogt.de/backend/internal/domain/user" + "marktvogt.de/backend/internal/pkg/apierror" "marktvogt.de/backend/internal/pkg/password" ) -type Service struct { - authRepo Repository - userRepo user.Repository - tokenSvc *TokenService - sessionTTL time.Duration +type ServiceConfig struct { + AccessTTL time.Duration + RefreshIdleTTL time.Duration + RefreshAbsoluteTTL time.Duration } -func NewService(authRepo Repository, userRepo user.Repository, tokenSvc *TokenService, sessionTTL time.Duration) *Service { - return &Service{ - authRepo: authRepo, - userRepo: userRepo, - tokenSvc: tokenSvc, - sessionTTL: sessionTTL, - } +type Service struct { + authRepo Repository + userRepo user.Repository + cfg ServiceConfig +} + +func NewService(authRepo Repository, userRepo user.Repository, cfg ServiceConfig) *Service { + return &Service{authRepo: authRepo, userRepo: userRepo, cfg: cfg} } func (s *Service) Register(ctx context.Context, req RegisterRequest, ip, ua string) (AuthData, error) { @@ -43,7 +44,7 @@ func (s *Service) Register(ctx context.Context, req RegisterRequest, ip, ua stri return AuthData{}, fmt.Errorf("creating user: %w", err) } - return s.createTokenPair(ctx, u, ip, ua) + return s.createTokenPair(ctx, u, uuid.Nil, nil, ip, ua) } func (s *Service) Login(ctx context.Context, req LoginRequest, ip, ua string) (AuthData, error) { @@ -64,7 +65,6 @@ func (s *Service) Login(ctx context.Context, req LoginRequest, ip, ua string) (A return AuthData{}, fmt.Errorf("invalid credentials") } - // Lazily upgrade bcrypt hashes to Argon2id on successful login. if password.NeedsRehash(*u.PasswordHash) { if newHash, hashErr := password.Hash(req.Password); hashErr == nil { if _, updateErr := s.userRepo.Update(ctx, u.ID, map[string]any{"password_hash": newHash}); updateErr != nil { @@ -73,70 +73,84 @@ func (s *Service) Login(ctx context.Context, req LoginRequest, ip, ua string) (A } } - // Check 2FA if enabled if req.TOTPCode != "" { if err := s.validateTOTP(ctx, u.ID, req.TOTPCode); err != nil { return AuthData{}, err } } else { - // Check if user has 2FA enabled totp, err := s.authRepo.GetTOTPSecret(ctx, u.ID) if err == nil && totp.Verified { return AuthData{}, fmt.Errorf("2fa_required") } } - return s.createTokenPair(ctx, u, ip, ua) + return s.createTokenPair(ctx, u, uuid.Nil, nil, ip, ua) } -func (s *Service) Logout(ctx context.Context, sessionTokenHash string) error { - session, err := s.authRepo.GetSessionByTokenHash(ctx, sessionTokenHash) - if err != nil { - return err - } - return s.authRepo.DeleteSession(ctx, session.ID) +// Logout revokes the session identified by its ID (set in context by middleware). +func (s *Service) Logout(ctx context.Context, sessionID uuid.UUID) error { + return s.authRepo.RevokeSession(ctx, sessionID) } -func (s *Service) RefreshToken(ctx context.Context, sessionToken, ip, ua string) (AuthData, error) { - tokenHash := HashToken(sessionToken) - session, err := s.authRepo.GetSessionByTokenHash(ctx, tokenHash) +func (s *Service) RefreshToken(ctx context.Context, refreshToken, ip, ua string) (AuthData, error) { + hash := HashToken(refreshToken) + + old, err := s.authRepo.ConsumeRefreshToken(ctx, hash) if err != nil { + var reuseErr *RefreshReuseDetectedError + if errors.As(err, &reuseErr) { + _ = s.authRepo.RevokeSessionsByFamilyID(ctx, reuseErr.FamilyID) + slog.Warn("refresh token reuse detected", "family_id", reuseErr.FamilyID) + return AuthData{}, apierror.RefreshReuse() + } return AuthData{}, err } - u, err := s.userRepo.GetByID(ctx, session.UserID) - if err != nil { - return AuthData{}, fmt.Errorf("user not found for session: %w", err) + // Idle TTL check — last_used_at is bumped on every authenticated request. + if time.Since(old.LastUsedAt) > s.cfg.RefreshIdleTTL { + return AuthData{}, ErrSessionExpired + } + if time.Now().After(old.AbsoluteExpiresAt) { + return AuthData{}, ErrSessionExpired } - // Delete old session - _ = s.authRepo.DeleteSession(ctx, session.ID) - - // Create new token pair - return s.createTokenPair(ctx, u, ip, ua) -} - -func (s *Service) ValidateSession(ctx context.Context, sessionToken string) (Session, error) { - tokenHash := HashToken(sessionToken) - return s.authRepo.GetSessionByTokenHash(ctx, tokenHash) -} - -func (s *Service) createTokenPair(ctx context.Context, u user.User, ip, ua string) (AuthData, error) { - accessToken, err := s.tokenSvc.CreateAccessToken(u.ID, u.Email, u.Role) + u, err := s.userRepo.GetByID(ctx, old.UserID) if err != nil { - return AuthData{}, fmt.Errorf("creating access token: %w", err) + return AuthData{}, fmt.Errorf("user not found: %w", err) } - sessionToken := GenerateSessionToken() - sessionHash := HashToken(sessionToken) + return s.createTokenPair(ctx, u, old.FamilyID, &old.ID, ip, ua) +} + +func (s *Service) ValidateAccessToken(ctx context.Context, accessToken string) (Session, error) { + hash := HashToken(accessToken) + return s.authRepo.GetSessionByAccessHash(ctx, hash) +} + +func (s *Service) createTokenPair(ctx context.Context, u user.User, familyID uuid.UUID, parentID *uuid.UUID, ip, ua string) (AuthData, error) { + accessToken := GenerateOpaqueToken() + refreshToken := GenerateOpaqueToken() + now := time.Now() + + if familyID == uuid.Nil { + familyID = uuid.New() + } session := Session{ - ID: uuid.New(), - UserID: u.ID, - TokenHash: sessionHash, - IPAddress: ip, - UserAgent: ua, - ExpiresAt: time.Now().Add(s.sessionTTL), + ID: uuid.New(), + UserID: u.ID, + UserEmail: u.Email, + UserRole: u.Role, + AccessTokenHash: HashToken(accessToken), + RefreshTokenHash: HashToken(refreshToken), + FamilyID: familyID, + ParentSessionID: parentID, + IPAddress: ip, + UserAgent: ua, + AccessExpiresAt: now.Add(s.cfg.AccessTTL), + AbsoluteExpiresAt: now.Add(s.cfg.RefreshAbsoluteTTL), + LastUsedAt: now, + CreatedAt: now, } if err := s.authRepo.CreateSession(ctx, session); err != nil { @@ -145,8 +159,7 @@ func (s *Service) createTokenPair(ctx context.Context, u user.User, ip, ua strin return AuthData{ AccessToken: accessToken, - SessionToken: sessionToken, - ExpiresIn: s.tokenSvc.AccessTTLSeconds(), + RefreshToken: refreshToken, }, nil } @@ -158,11 +171,9 @@ func (s *Service) validateTOTP(ctx context.Context, userID uuid.UUID, code strin if !totp.Verified { return fmt.Errorf("2fa not verified") } - if !ValidateTOTP(totp.Secret, code) { return fmt.Errorf("invalid 2fa code") } - return nil } @@ -194,7 +205,3 @@ func (s *Service) ChangePassword(ctx context.Context, userID uuid.UUID, req Chan return nil } - -func (s *Service) TokenService() *TokenService { - return s.tokenSvc -} diff --git a/backend/internal/domain/auth/service_refresh_test.go b/backend/internal/domain/auth/service_refresh_test.go new file mode 100644 index 0000000..2c94ebc --- /dev/null +++ b/backend/internal/domain/auth/service_refresh_test.go @@ -0,0 +1,326 @@ +package auth_test + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/google/uuid" + + "marktvogt.de/backend/internal/domain/auth" + "marktvogt.de/backend/internal/domain/user" + "marktvogt.de/backend/internal/pkg/apierror" +) + +// fakeRepo is a minimal in-memory auth.Repository for service tests. +type fakeRepo struct { + mu sync.Mutex + sessions map[string]*auth.Session // keyed by access hash + byRefresh map[string]*auth.Session // keyed by refresh hash + + magicLinks map[string]*auth.MagicLink + totpSecrets map[string]*auth.TOTPSecret + oauthAccounts []auth.OAuthAccount + + revokedFamilies []uuid.UUID + bumpedSessions []uuid.UUID +} + +func newFakeRepo() *fakeRepo { + return &fakeRepo{ + sessions: make(map[string]*auth.Session), + byRefresh: make(map[string]*auth.Session), + magicLinks: make(map[string]*auth.MagicLink), + totpSecrets: make(map[string]*auth.TOTPSecret), + } +} + +func (r *fakeRepo) addSession(s auth.Session) { + r.mu.Lock() + defer r.mu.Unlock() + clone := s + r.sessions[s.AccessTokenHash] = &clone + r.byRefresh[s.RefreshTokenHash] = &clone +} + +func (r *fakeRepo) CreateSession(_ context.Context, s auth.Session) error { + r.addSession(s) + return nil +} + +func (r *fakeRepo) GetSessionByAccessHash(_ context.Context, hash string) (auth.Session, error) { + r.mu.Lock() + defer r.mu.Unlock() + s, ok := r.sessions[hash] + if !ok { + return auth.Session{}, auth.ErrSessionNotFound + } + return *s, nil +} + +func (r *fakeRepo) GetSessionByRefreshHash(_ context.Context, hash string) (auth.Session, error) { + r.mu.Lock() + defer r.mu.Unlock() + s, ok := r.byRefresh[hash] + if !ok { + return auth.Session{}, auth.ErrSessionNotFound + } + return *s, nil +} + +// ConsumeRefreshToken atomically marks the refresh token as revoked and returns the session. +// Returns RefreshReuseDetectedError if the token was already revoked. +func (r *fakeRepo) ConsumeRefreshToken(_ context.Context, hash string) (auth.Session, error) { + r.mu.Lock() + defer r.mu.Unlock() + s, ok := r.byRefresh[hash] + if !ok { + return auth.Session{}, auth.ErrSessionNotFound + } + if s.RevokedAt != nil { + return auth.Session{}, &auth.RefreshReuseDetectedError{FamilyID: s.FamilyID} + } + now := time.Now() + s.RevokedAt = &now + return *s, nil +} + +func (r *fakeRepo) RevokeSession(_ context.Context, id uuid.UUID) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, s := range r.sessions { + if s.ID == id { + now := time.Now() + s.RevokedAt = &now + return nil + } + } + return nil +} + +func (r *fakeRepo) RevokeSessionsByFamilyID(_ context.Context, familyID uuid.UUID) error { + r.mu.Lock() + defer r.mu.Unlock() + r.revokedFamilies = append(r.revokedFamilies, familyID) + for _, s := range r.sessions { + if s.FamilyID == familyID { + now := time.Now() + s.RevokedAt = &now + } + } + for _, s := range r.byRefresh { + if s.FamilyID == familyID { + now := time.Now() + s.RevokedAt = &now + } + } + return nil +} + +func (r *fakeRepo) BumpLastUsedAt(_ context.Context, id uuid.UUID) error { + r.mu.Lock() + defer r.mu.Unlock() + r.bumpedSessions = append(r.bumpedSessions, id) + return nil +} + +func (r *fakeRepo) DeleteUserSessions(_ context.Context, _ uuid.UUID) error { return nil } + +// Magic link stubs +func (r *fakeRepo) CreateMagicLink(_ context.Context, link auth.MagicLink) error { + r.magicLinks[link.TokenHash] = &link + return nil +} +func (r *fakeRepo) GetMagicLinkByTokenHash(_ context.Context, hash string) (auth.MagicLink, error) { + if ml, ok := r.magicLinks[hash]; ok { + return *ml, nil + } + return auth.MagicLink{}, auth.ErrMagicLinkNotFound +} +func (r *fakeRepo) MarkMagicLinkUsed(_ context.Context, id uuid.UUID) error { return nil } + +// OAuth stubs +func (r *fakeRepo) CreateOAuthAccount(_ context.Context, a auth.OAuthAccount) error { + r.oauthAccounts = append(r.oauthAccounts, a) + return nil +} +func (r *fakeRepo) GetOAuthAccount(_ context.Context, _, _ string) (auth.OAuthAccount, error) { + return auth.OAuthAccount{}, errors.New("not found") +} +func (r *fakeRepo) UpdateOAuthTokens(_ context.Context, _ uuid.UUID, _, _ string, _ *time.Time) error { + return nil +} + +// TOTP stubs +func (r *fakeRepo) CreateTOTPSecret(_ context.Context, s auth.TOTPSecret) error { + r.totpSecrets[s.UserID.String()] = &s + return nil +} +func (r *fakeRepo) GetTOTPSecret(_ context.Context, userID uuid.UUID) (auth.TOTPSecret, error) { + if s, ok := r.totpSecrets[userID.String()]; ok { + return *s, nil + } + return auth.TOTPSecret{}, errors.New("not found") +} +func (r *fakeRepo) VerifyTOTPSecret(_ context.Context, _ uuid.UUID) error { return nil } +func (r *fakeRepo) DeleteTOTPSecret(_ context.Context, _ uuid.UUID) error { return nil } + +// fakeUserRepo is a minimal in-memory user.Repository. +type fakeUserRepo struct { + users map[uuid.UUID]user.User +} + +func newFakeUserRepo(users ...user.User) *fakeUserRepo { + r := &fakeUserRepo{users: make(map[uuid.UUID]user.User)} + for _, u := range users { + r.users[u.ID] = u + } + return r +} + +func (r *fakeUserRepo) Create(_ context.Context, email, hash, name string) (user.User, error) { + u := user.User{ID: uuid.New(), Email: email, DisplayName: name} + r.users[u.ID] = u + return u, nil +} +func (r *fakeUserRepo) CreateOAuthUser(_ context.Context, email, name string, _ bool) (user.User, error) { + u := user.User{ID: uuid.New(), Email: email, DisplayName: name} + r.users[u.ID] = u + return u, nil +} +func (r *fakeUserRepo) GetByID(_ context.Context, id uuid.UUID) (user.User, error) { + if u, ok := r.users[id]; ok { + return u, nil + } + return user.User{}, user.ErrUserNotFound +} +func (r *fakeUserRepo) GetByEmail(_ context.Context, email string) (user.User, error) { + for _, u := range r.users { + if u.Email == email { + return u, nil + } + } + return user.User{}, user.ErrUserNotFound +} +func (r *fakeUserRepo) Update(_ context.Context, id uuid.UUID, _ map[string]any) (user.User, error) { + if u, ok := r.users[id]; ok { + return u, nil + } + return user.User{}, user.ErrUserNotFound +} +func (r *fakeUserRepo) SoftDelete(_ context.Context, _ uuid.UUID) error { return nil } +func (r *fakeUserRepo) Restore(_ context.Context, _ uuid.UUID) error { return nil } +func (r *fakeUserRepo) GetDeletedByID(_ context.Context, id uuid.UUID) (user.User, error) { + return user.User{}, user.ErrUserNotFound +} + +func makeService(authRepo auth.Repository, userRepo user.Repository) *auth.Service { + return auth.NewService(authRepo, userRepo, auth.ServiceConfig{ + AccessTTL: 30 * time.Minute, + RefreshIdleTTL: 7 * 24 * time.Hour, + RefreshAbsoluteTTL: 30 * 24 * time.Hour, + }) +} + +func TestRefreshToken_HappyPath(t *testing.T) { + repo := newFakeRepo() + testUser := user.User{ID: uuid.New(), Email: "a@b.c", Role: "user"} + svc := makeService(repo, newFakeUserRepo(testUser)) + + ctx := context.Background() + data, err := svc.Register(ctx, auth.RegisterRequest{ + Email: testUser.Email, Password: "correct-horse-battery", DisplayName: "Test", + }, "127.0.0.1", "test-agent") + if err != nil { + t.Fatalf("register: %v", err) + } + + newData, err := svc.RefreshToken(ctx, data.RefreshToken, "127.0.0.1", "test-agent") + if err != nil { + t.Fatalf("refresh: %v", err) + } + if newData.AccessToken == "" || newData.RefreshToken == "" { + t.Fatal("expected non-empty tokens after refresh") + } + if newData.RefreshToken == data.RefreshToken { + t.Fatal("refresh token must rotate on each use") + } +} + +func TestRefreshToken_ReuseDetection_RevokesFamilyAndReturns401(t *testing.T) { + repo := newFakeRepo() + testUser := user.User{ID: uuid.New(), Email: "reuse@b.c", Role: "user"} + svc := makeService(repo, newFakeUserRepo(testUser)) + + ctx := context.Background() + data, err := svc.Register(ctx, auth.RegisterRequest{ + Email: testUser.Email, Password: "correct-horse-battery", DisplayName: "Test", + }, "127.0.0.1", "ua") + if err != nil { + t.Fatalf("register: %v", err) + } + + // First refresh — legitimate use, rotates the token. + _, err = svc.RefreshToken(ctx, data.RefreshToken, "127.0.0.1", "ua") + if err != nil { + t.Fatalf("first refresh: %v", err) + } + + // Replay the original refresh token — reuse detected. + _, err = svc.RefreshToken(ctx, data.RefreshToken, "127.0.0.1", "ua") + if err == nil { + t.Fatal("expected error on refresh token reuse") + } + + var apiErr *apierror.Error + if !errors.As(err, &apiErr) { + t.Fatalf("expected *apierror.Error, got %T: %v", err, err) + } + if apiErr.Code != "auth.refresh_reuse_detected" { + t.Errorf("expected auth.refresh_reuse_detected, got %q", apiErr.Code) + } + if len(repo.revokedFamilies) == 0 { + t.Error("expected entire token family to be revoked on reuse") + } +} + +func TestRefreshToken_ExpiredIdleTTL_Returns401(t *testing.T) { + repo := newFakeRepo() + testUser := user.User{ID: uuid.New(), Email: "idle@b.c", Role: "user"} + // Use a very short idle TTL to simulate expiry. + svc := auth.NewService(repo, newFakeUserRepo(testUser), auth.ServiceConfig{ + AccessTTL: 30 * time.Minute, + RefreshIdleTTL: 1 * time.Millisecond, + RefreshAbsoluteTTL: 30 * 24 * time.Hour, + }) + + ctx := context.Background() + data, err := svc.Register(ctx, auth.RegisterRequest{ + Email: testUser.Email, Password: "correct-horse-battery", DisplayName: "Test", + }, "127.0.0.1", "ua") + if err != nil { + t.Fatalf("register: %v", err) + } + + time.Sleep(5 * time.Millisecond) + + _, err = svc.RefreshToken(ctx, data.RefreshToken, "127.0.0.1", "ua") + if err == nil { + t.Fatal("expected error for expired idle TTL") + } + if !errors.Is(err, auth.ErrSessionExpired) { + t.Errorf("expected ErrSessionExpired, got %v", err) + } +} + +func TestRefreshToken_UnknownToken_Returns401(t *testing.T) { + repo := newFakeRepo() + svc := makeService(repo, newFakeUserRepo()) + + _, err := svc.RefreshToken(context.Background(), "nosuchtoken", "127.0.0.1", "ua") + if err == nil { + t.Fatal("expected error for unknown token") + } +} diff --git a/backend/internal/domain/auth/token.go b/backend/internal/domain/auth/token.go index 0a05378..73d5e47 100644 --- a/backend/internal/domain/auth/token.go +++ b/backend/internal/domain/auth/token.go @@ -1,84 +1,23 @@ package auth import ( + "crypto/rand" "crypto/sha256" + "encoding/base64" "encoding/hex" - "fmt" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" ) -type TokenClaims struct { - jwt.RegisteredClaims - UserID uuid.UUID `json:"uid"` - Email string `json:"email"` - Role string `json:"role"` -} - -type TokenService struct { - secret []byte - accessTTL time.Duration -} - -func NewTokenService(secret string, accessTTL time.Duration) *TokenService { - return &TokenService{ - secret: []byte(secret), - accessTTL: accessTTL, +// GenerateOpaqueToken returns a cryptographically random 32-byte URL-safe base64 token. +func GenerateOpaqueToken() string { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + panic("auth: crypto/rand unavailable: " + err.Error()) } + return base64.RawURLEncoding.EncodeToString(b) } -func (ts *TokenService) CreateAccessToken(userID uuid.UUID, email, role string) (string, error) { - now := time.Now() - claims := TokenClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID.String(), - IssuedAt: jwt.NewNumericDate(now), - ExpiresAt: jwt.NewNumericDate(now.Add(ts.accessTTL)), - Issuer: "marktvogt", - }, - UserID: userID, - Email: email, - Role: role, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - signed, err := token.SignedString(ts.secret) - if err != nil { - return "", fmt.Errorf("signing access token: %w", err) - } - return signed, nil -} - -func (ts *TokenService) ValidateAccessToken(tokenString string) (*TokenClaims, error) { - token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(t *jwt.Token) (any, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) - } - return ts.secret, nil - }) - if err != nil { - return nil, fmt.Errorf("parsing access token: %w", err) - } - - claims, ok := token.Claims.(*TokenClaims) - if !ok || !token.Valid { - return nil, fmt.Errorf("invalid access token") - } - - return claims, nil -} - -func (ts *TokenService) AccessTTLSeconds() int { - return int(ts.accessTTL.Seconds()) -} - +// HashToken returns the hex-encoded SHA-256 of token, used for safe DB storage. func HashToken(token string) string { h := sha256.Sum256([]byte(token)) return hex.EncodeToString(h[:]) } - -func GenerateSessionToken() string { - return uuid.New().String() -} diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index a586586..b69bb8b 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -1,61 +1,93 @@ package middleware import ( + "context" "strings" + "time" "github.com/gin-gonic/gin" + "github.com/google/uuid" "marktvogt.de/backend/internal/domain/auth" "marktvogt.de/backend/internal/pkg/apierror" ) -func RequireAuth(tokenSvc *auth.TokenService) gin.HandlerFunc { +// SessionLookup is the subset of auth.Repository needed by the auth middleware. +// Keeping it narrow makes the middleware easy to test without a full repo mock. +type SessionLookup interface { + GetSessionByAccessHash(ctx context.Context, hash string) (auth.Session, error) + BumpLastUsedAt(ctx context.Context, id uuid.UUID) error +} + +const lastUsedBumpThreshold = 60 * time.Second + +// RequireAuth validates the Bearer access token via Valkey/Postgres lookup. +// On success it sets user_id, user_email, user_role, and session_id in context. +// accessTTL is used to allow the middleware to be tested independently of the session TTL +// configuration without importing the config package. +func RequireAuth(repo SessionLookup, accessTTL time.Duration) gin.HandlerFunc { return func(c *gin.Context) { - claims, ok := extractAndValidate(c, tokenSvc) + session, ok := resolveSession(c, repo, accessTTL) if !ok { return } - - c.Set("user_id", claims.UserID) - c.Set("user_email", claims.Email) - c.Set("user_role", claims.Role) + c.Set("user_id", session.UserID) + c.Set("user_email", session.UserEmail) + c.Set("user_role", session.UserRole) + c.Set("session_id", session.ID) c.Next() } } -func OptionalAuth(tokenSvc *auth.TokenService) gin.HandlerFunc { +// OptionalAuth is like RequireAuth but allows unauthenticated requests through. +func OptionalAuth(repo SessionLookup, accessTTL time.Duration) gin.HandlerFunc { return func(c *gin.Context) { header := c.GetHeader("Authorization") if header == "" || !strings.HasPrefix(header, "Bearer ") { c.Next() return } - - claims, _ := extractAndValidate(c, tokenSvc) - if claims != nil { - c.Set("user_id", claims.UserID) - c.Set("user_email", claims.Email) - c.Set("user_role", claims.Role) + if session, ok := resolveSession(c, repo, accessTTL); ok { + c.Set("user_id", session.UserID) + c.Set("user_email", session.UserEmail) + c.Set("user_role", session.UserRole) + c.Set("session_id", session.ID) } c.Next() } } -func extractAndValidate(c *gin.Context, tokenSvc *auth.TokenService) (*auth.TokenClaims, bool) { +func resolveSession(c *gin.Context, repo SessionLookup, accessTTL time.Duration) (auth.Session, bool) { header := c.GetHeader("Authorization") if header == "" || !strings.HasPrefix(header, "Bearer ") { - apiErr := apierror.Unauthorized("missing or invalid authorization header") - c.AbortWithStatusJSON(apiErr.Status, apierror.NewResponse(apiErr)) - return nil, false + reject(c) + return auth.Session{}, false } - tokenString := strings.TrimPrefix(header, "Bearer ") - claims, err := tokenSvc.ValidateAccessToken(tokenString) + token := strings.TrimPrefix(header, "Bearer ") + hash := auth.HashToken(token) + + session, err := repo.GetSessionByAccessHash(c.Request.Context(), hash) if err != nil { - apiErr := apierror.Unauthorized("invalid or expired token") - c.AbortWithStatusJSON(apiErr.Status, apierror.NewResponse(apiErr)) - return nil, false + reject(c) + return auth.Session{}, false } - return claims, true + if session.RevokedAt != nil { + reject(c) + return auth.Session{}, false + } + + // Throttled last_used_at bump — skips the write when the row was recently updated. + if time.Since(session.LastUsedAt) > lastUsedBumpThreshold { + _ = repo.BumpLastUsedAt(c.Request.Context(), session.ID) + } + + _ = accessTTL // TTL is enforced by access_expires_at stored on the session row + return session, true +} + +func reject(c *gin.Context) { + apiErr := apierror.Unauthorized("invalid or expired token") + c.AbortWithStatusJSON(apiErr.Status, apierror.NewResponse(apiErr)) } diff --git a/backend/internal/middleware/auth_test.go b/backend/internal/middleware/auth_test.go new file mode 100644 index 0000000..4e144fe --- /dev/null +++ b/backend/internal/middleware/auth_test.go @@ -0,0 +1,172 @@ +package middleware_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + + "marktvogt.de/backend/internal/domain/auth" + "marktvogt.de/backend/internal/middleware" +) + +// stubSessionRepo implements the minimal surface the auth middleware needs. +type stubSessionRepo struct { + session auth.Session + err error + bumped []uuid.UUID +} + +func (r *stubSessionRepo) GetSessionByAccessHash(_ context.Context, _ string) (auth.Session, error) { + return r.session, r.err +} +func (r *stubSessionRepo) BumpLastUsedAt(_ context.Context, id uuid.UUID) error { + r.bumped = append(r.bumped, id) + return nil +} + +func newRouter(h gin.HandlerFunc, mw ...gin.HandlerFunc) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + r.GET("/test", append(mw, h)...) + return r +} + +func bearerReq(token string) *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + return req +} + +func TestRequireAuth_ValidToken_SetsContextAndPasses(t *testing.T) { + sessionID := uuid.New() + userID := uuid.New() + stub := &stubSessionRepo{ + session: auth.Session{ + ID: sessionID, + UserID: userID, + UserEmail: "a@b.c", + UserRole: "user", + LastUsedAt: time.Now().Add(-2 * time.Minute), + AccessExpiresAt: time.Now().Add(28 * time.Minute), + }, + } + + var gotUserID, gotSessionID any + handler := func(c *gin.Context) { + gotUserID, _ = c.Get("user_id") + gotSessionID, _ = c.Get("session_id") + c.Status(http.StatusOK) + } + + r := newRouter(handler, middleware.RequireAuth(stub, 30*time.Minute)) + w := httptest.NewRecorder() + r.ServeHTTP(w, bearerReq("sometoken")) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if gotUserID != userID { + t.Errorf("user_id not set correctly in context") + } + if gotSessionID != sessionID { + t.Errorf("session_id not set correctly in context") + } +} + +func TestRequireAuth_MissingToken_Returns401(t *testing.T) { + stub := &stubSessionRepo{} + r := newRouter(func(c *gin.Context) { c.Status(200) }, middleware.RequireAuth(stub, 30*time.Minute)) + w := httptest.NewRecorder() + r.ServeHTTP(w, bearerReq("")) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestRequireAuth_UnknownToken_Returns401(t *testing.T) { + stub := &stubSessionRepo{err: auth.ErrSessionNotFound} + r := newRouter(func(c *gin.Context) { c.Status(200) }, middleware.RequireAuth(stub, 30*time.Minute)) + w := httptest.NewRecorder() + r.ServeHTTP(w, bearerReq("badtoken")) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestRequireAuth_RevokedSession_Returns401(t *testing.T) { + now := time.Now() + stub := &stubSessionRepo{ + session: auth.Session{ + ID: uuid.New(), + UserID: uuid.New(), + AccessExpiresAt: now.Add(10 * time.Minute), + LastUsedAt: now.Add(-1 * time.Minute), + RevokedAt: &now, + }, + } + r := newRouter(func(c *gin.Context) { c.Status(200) }, middleware.RequireAuth(stub, 30*time.Minute)) + w := httptest.NewRecorder() + r.ServeHTTP(w, bearerReq("revokedtoken")) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestRequireAuth_BumpsLastUsedAt_WhenStale(t *testing.T) { + sessionID := uuid.New() + stub := &stubSessionRepo{ + session: auth.Session{ + ID: sessionID, + UserID: uuid.New(), + UserEmail: "x@y.z", + UserRole: "user", + LastUsedAt: time.Now().Add(-90 * time.Second), // older than 60s threshold + AccessExpiresAt: time.Now().Add(28 * time.Minute), + }, + } + + r := newRouter(func(c *gin.Context) { c.Status(200) }, middleware.RequireAuth(stub, 30*time.Minute)) + w := httptest.NewRecorder() + r.ServeHTTP(w, bearerReq("sometoken")) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if len(stub.bumped) == 0 { + t.Error("expected BumpLastUsedAt to be called for stale last_used_at") + } +} + +func TestRequireAuth_DoesNotBumpLastUsedAt_WhenFresh(t *testing.T) { + stub := &stubSessionRepo{ + session: auth.Session{ + ID: uuid.New(), + UserID: uuid.New(), + UserEmail: "x@y.z", + UserRole: "user", + LastUsedAt: time.Now().Add(-10 * time.Second), // fresher than 60s threshold + AccessExpiresAt: time.Now().Add(28 * time.Minute), + }, + } + + r := newRouter(func(c *gin.Context) { c.Status(200) }, middleware.RequireAuth(stub, 30*time.Minute)) + w := httptest.NewRecorder() + r.ServeHTTP(w, bearerReq("sometoken")) + + if len(stub.bumped) != 0 { + t.Error("expected BumpLastUsedAt to be skipped for fresh last_used_at") + } +} + +// Ensure stubSessionRepo satisfies the SessionLookup interface at compile time. +var _ middleware.SessionLookup = (*stubSessionRepo)(nil) diff --git a/backend/internal/server/routes.go b/backend/internal/server/routes.go index fb3b077..7a68e57 100644 --- a/backend/internal/server/routes.go +++ b/backend/internal/server/routes.go @@ -32,11 +32,14 @@ func (s *Server) registerRoutes() { // Auth userRepo := user.NewRepository(s.db) - tokenSvc := auth.NewTokenService(s.cfg.JWT.Secret, s.cfg.JWT.AccessTTL) authRepo := auth.NewRepository(s.db, s.valkey) - authSvc := auth.NewService(authRepo, userRepo, tokenSvc, s.cfg.JWT.SessionTTL) + authSvc := auth.NewService(authRepo, userRepo, auth.ServiceConfig{ + AccessTTL: s.cfg.Auth.AccessTTL, + RefreshIdleTTL: s.cfg.Auth.RefreshIdleTTL, + RefreshAbsoluteTTL: s.cfg.Auth.RefreshAbsoluteTTL, + }) authHandler := auth.NewHandler(authSvc, userRepo) - requireAuth := middleware.RequireAuth(tokenSvc) + requireAuth := middleware.RequireAuth(authRepo, s.cfg.Auth.AccessTTL) // Per-route auth rate limiters (keyed by IP; user_id unavailable before auth completes) userIDKey := func(c *gin.Context) string { diff --git a/backend/migrations/000027_sessions_session_model_rework.down.sql b/backend/migrations/000027_sessions_session_model_rework.down.sql index bd1b0a6..c0fd02d 100644 --- a/backend/migrations/000027_sessions_session_model_rework.down.sql +++ b/backend/migrations/000027_sessions_session_model_rework.down.sql @@ -6,6 +6,7 @@ ALTER TABLE sessions DROP COLUMN IF EXISTS revoked_at, DROP COLUMN IF EXISTS last_used_at, DROP COLUMN IF EXISTS absolute_expires_at, + DROP COLUMN IF EXISTS access_expires_at, DROP COLUMN IF EXISTS parent_session_id, DROP COLUMN IF EXISTS family_id, DROP COLUMN IF EXISTS refresh_token_hash, diff --git a/backend/migrations/000027_sessions_session_model_rework.up.sql b/backend/migrations/000027_sessions_session_model_rework.up.sql index 37c36cb..6ad1d87 100644 --- a/backend/migrations/000027_sessions_session_model_rework.up.sql +++ b/backend/migrations/000027_sessions_session_model_rework.up.sql @@ -8,6 +8,7 @@ ALTER TABLE sessions ADD COLUMN IF NOT EXISTS family_id UUID NOT NULL DEFAULT uuid_generate_v4(), ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES sessions(id) ON DELETE SET NULL, ADD COLUMN IF NOT EXISTS absolute_expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '30 days', + ADD COLUMN IF NOT EXISTS access_expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '30 minutes', ADD COLUMN IF NOT EXISTS last_used_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ;