feat(auth): D2/D3 opaque-token session model — drop JWT
Replace HS256 JWT access tokens with two opaque 32-byte random tokens
(access + refresh), both stored as SHA-256 hashes in sessions + Valkey.
Key changes:
- GenerateOpaqueToken() replaces JWT issuance; TokenService removed
- Sessions now carry access_token_hash, refresh_token_hash, family_id,
parent_session_id, access_expires_at, absolute_expires_at, last_used_at,
revoked_at — per migration 000027 (updated to add access_expires_at)
- Refresh rotation is atomic (UPDATE...RETURNING); reuse detection kills
the entire token family and returns auth.refresh_reuse_detected
- RequireAuth/OptionalAuth now take SessionLookup (Valkey→Postgres) instead
of *TokenService; sets session_id in context alongside user_id
- last_used_at is bumped on each request, throttled to writes >60s old
- AuthConfig{AccessTTL,RefreshIdleTTL,RefreshAbsoluteTTL} replaces JWT TTL env
vars (AUTH_ACCESS_TTL=30m, AUTH_REFRESH_IDLE_TTL=168h, AUTH_REFRESH_ABSOLUTE_TTL=720h)
- JWT_SECRET kept for AI-settings key derivation (drops from auth flow)
Forced logout on deploy (D3 behaviour); pre-launch so acceptable.
This commit is contained in:
@@ -15,6 +15,7 @@ type Config struct {
|
||||
DB DBConfig
|
||||
Valkey ValkeyConfig
|
||||
JWT JWTConfig
|
||||
Auth AuthConfig
|
||||
CORS CORSConfig
|
||||
Rate RateConfig
|
||||
Sentry SentryConfig
|
||||
@@ -81,9 +82,13 @@ type ValkeyConfig struct {
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
Secret string
|
||||
AccessTTL time.Duration
|
||||
SessionTTL time.Duration
|
||||
Secret string // kept for settings-encryption key derivation (see routes.go)
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
AccessTTL time.Duration // default 30m
|
||||
RefreshIdleTTL time.Duration // default 168h (7 days); sliding per-request
|
||||
RefreshAbsoluteTTL time.Duration // default 720h (30 days); hard limit
|
||||
}
|
||||
|
||||
type CORSConfig struct {
|
||||
@@ -180,14 +185,19 @@ func Load() (*Config, error) {
|
||||
return nil, fmt.Errorf("VALKEY_DB: %w", err)
|
||||
}
|
||||
|
||||
accessTTL, err := envDuration("JWT_ACCESS_TTL", 15*time.Minute)
|
||||
authAccessTTL, err := envDuration("AUTH_ACCESS_TTL", 30*time.Minute)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("JWT_ACCESS_TTL: %w", err)
|
||||
return nil, fmt.Errorf("AUTH_ACCESS_TTL: %w", err)
|
||||
}
|
||||
|
||||
sessionTTL, err := envDuration("JWT_SESSION_TTL", 720*time.Hour)
|
||||
authRefreshIdleTTL, err := envDuration("AUTH_REFRESH_IDLE_TTL", 168*time.Hour)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("JWT_SESSION_TTL: %w", err)
|
||||
return nil, fmt.Errorf("AUTH_REFRESH_IDLE_TTL: %w", err)
|
||||
}
|
||||
|
||||
authRefreshAbsoluteTTL, err := envDuration("AUTH_REFRESH_ABSOLUTE_TTL", 720*time.Hour)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("AUTH_REFRESH_ABSOLUTE_TTL: %w", err)
|
||||
}
|
||||
|
||||
rps, err := envFloat("RATE_LIMIT_RPS", 10)
|
||||
@@ -247,9 +257,12 @@ func Load() (*Config, error) {
|
||||
DB: valkeyDB,
|
||||
},
|
||||
JWT: JWTConfig{
|
||||
Secret: jwtSecret,
|
||||
AccessTTL: accessTTL,
|
||||
SessionTTL: sessionTTL,
|
||||
Secret: jwtSecret,
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
AccessTTL: authAccessTTL,
|
||||
RefreshIdleTTL: authRefreshIdleTTL,
|
||||
RefreshAbsoluteTTL: authRefreshAbsoluteTTL,
|
||||
},
|
||||
CORS: CORSConfig{
|
||||
Origins: corsOrigins,
|
||||
|
||||
@@ -40,8 +40,7 @@ type AuthResponse struct {
|
||||
|
||||
type AuthData struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
SessionToken string `json:"session_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type MessageResponse struct {
|
||||
|
||||
@@ -3,7 +3,6 @@ package auth
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -73,15 +72,14 @@ func (h *Handler) Login(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *Handler) Logout(c *gin.Context) {
|
||||
sessionToken := extractSessionToken(c)
|
||||
if sessionToken == "" {
|
||||
apiErr := apierror.Unauthorized("session token required")
|
||||
sessionID := GetSessionID(c)
|
||||
if sessionID == uuid.Nil {
|
||||
apiErr := apierror.Unauthorized("not authenticated")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
}
|
||||
|
||||
tokenHash := HashToken(sessionToken)
|
||||
if err := h.service.Logout(c.Request.Context(), tokenHash); err != nil {
|
||||
if err := h.service.Logout(c.Request.Context(), sessionID); err != nil {
|
||||
apiErr := apierror.Internal("logout failed")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
@@ -91,16 +89,21 @@ func (h *Handler) Logout(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *Handler) Refresh(c *gin.Context) {
|
||||
sessionToken := extractSessionToken(c)
|
||||
if sessionToken == "" {
|
||||
apiErr := apierror.Unauthorized("session token required")
|
||||
refreshToken := extractRefreshToken(c)
|
||||
if refreshToken == "" {
|
||||
apiErr := apierror.Unauthorized("refresh token required")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.service.RefreshToken(c.Request.Context(), sessionToken, c.ClientIP(), c.GetHeader("User-Agent"))
|
||||
data, err := h.service.RefreshToken(c.Request.Context(), refreshToken, c.ClientIP(), c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
apiErr := apierror.Unauthorized("invalid or expired session")
|
||||
var apiErr *apierror.Error
|
||||
if errors.As(err, &apiErr) {
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
}
|
||||
apiErr = apierror.Unauthorized("invalid or expired session")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
}
|
||||
@@ -184,17 +187,12 @@ func (h *Handler) DisableTOTP(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "2FA disabled"}})
|
||||
}
|
||||
|
||||
func extractSessionToken(c *gin.Context) string {
|
||||
header := c.GetHeader("X-Session-Token")
|
||||
if header != "" {
|
||||
return header
|
||||
// extractRefreshToken reads the refresh token from X-Refresh-Token or (legacy) X-Session-Token.
|
||||
func extractRefreshToken(c *gin.Context) string {
|
||||
if t := c.GetHeader("X-Refresh-Token"); t != "" {
|
||||
return t
|
||||
}
|
||||
// Also check Authorization with Bearer prefix for refresh
|
||||
auth := c.GetHeader("Authorization")
|
||||
if strings.HasPrefix(auth, "Session ") {
|
||||
return strings.TrimPrefix(auth, "Session ")
|
||||
}
|
||||
return ""
|
||||
return c.GetHeader("X-Session-Token")
|
||||
}
|
||||
|
||||
func GetUserID(c *gin.Context) uuid.UUID {
|
||||
@@ -208,3 +206,15 @@ func GetUserID(c *gin.Context) uuid.UUID {
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func GetSessionID(c *gin.Context) uuid.UUID {
|
||||
v, exists := c.Get("session_id")
|
||||
if !exists {
|
||||
return uuid.Nil
|
||||
}
|
||||
id, ok := v.(uuid.UUID)
|
||||
if !ok {
|
||||
return uuid.Nil
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ func (h *MagicLinkHandler) VerifyMagicLink(c *gin.Context) {
|
||||
_, _ = h.userRepo.Update(ctx, u.ID, map[string]any{"email_verified": true})
|
||||
}
|
||||
|
||||
data, err := h.service.createTokenPair(ctx, u, c.ClientIP(), c.GetHeader("User-Agent"))
|
||||
data, err := h.service.createTokenPair(ctx, u, uuid.Nil, nil, c.ClientIP(), c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
apiErr := apierror.Internal("failed to create session")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
|
||||
@@ -6,14 +6,25 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Session represents an opaque-token session. AccessTokenHash and RefreshTokenHash
|
||||
// are SHA-256 hashes of the raw bearer tokens. UserEmail and UserRole are cached
|
||||
// from the user record at creation time; role changes take effect within accessTTL (≤30m).
|
||||
type Session struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
TokenHash string `json:"-"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
UserEmail string `json:"user_email"`
|
||||
UserRole string `json:"user_role"`
|
||||
AccessTokenHash string `json:"-"`
|
||||
RefreshTokenHash string `json:"-"`
|
||||
FamilyID uuid.UUID `json:"family_id"`
|
||||
ParentSessionID *uuid.UUID `json:"-"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
AccessExpiresAt time.Time `json:"access_expires_at"`
|
||||
AbsoluteExpiresAt time.Time `json:"absolute_expires_at"`
|
||||
LastUsedAt time.Time `json:"last_used_at"`
|
||||
RevokedAt *time.Time `json:"revoked_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type MagicLink struct {
|
||||
|
||||
@@ -136,7 +136,7 @@ func (h *OAuthHandler) Callback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.service.createTokenPair(ctx, u, c.ClientIP(), c.GetHeader("User-Agent"))
|
||||
data, err := h.service.createTokenPair(ctx, u, uuid.Nil, nil, c.ClientIP(), c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
apiErr := apierror.Internal("failed to create session")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
@@ -186,7 +186,7 @@ func (h *OAuthHandler) Callback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.service.createTokenPair(ctx, u, c.ClientIP(), c.GetHeader("User-Agent"))
|
||||
data, err := h.service.createTokenPair(ctx, u, uuid.Nil, nil, c.ClientIP(), c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
apiErr := apierror.Internal("failed to create session")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
|
||||
@@ -21,20 +21,41 @@ var (
|
||||
ErrMagicLinkUsed = fmt.Errorf("magic link already used")
|
||||
)
|
||||
|
||||
// RefreshReuseDetectedError is returned by ConsumeRefreshToken when the token
|
||||
// was already revoked, indicating a stolen-token replay. FamilyID identifies
|
||||
// the session family that must be fully revoked.
|
||||
type RefreshReuseDetectedError struct {
|
||||
FamilyID uuid.UUID
|
||||
}
|
||||
|
||||
func (e *RefreshReuseDetectedError) Error() string {
|
||||
return "refresh token reuse detected"
|
||||
}
|
||||
|
||||
type Repository interface {
|
||||
// Session — opaque-token model
|
||||
CreateSession(ctx context.Context, session Session) error
|
||||
GetSessionByTokenHash(ctx context.Context, tokenHash string) (Session, error)
|
||||
DeleteSession(ctx context.Context, id uuid.UUID) error
|
||||
GetSessionByAccessHash(ctx context.Context, hash string) (Session, error)
|
||||
// ConsumeRefreshToken atomically revokes the session identified by hash and
|
||||
// returns it. Returns RefreshReuseDetectedError if already revoked, ErrSessionNotFound
|
||||
// if the hash is unknown.
|
||||
ConsumeRefreshToken(ctx context.Context, hash string) (Session, error)
|
||||
RevokeSession(ctx context.Context, id uuid.UUID) error
|
||||
RevokeSessionsByFamilyID(ctx context.Context, familyID uuid.UUID) error
|
||||
BumpLastUsedAt(ctx context.Context, id uuid.UUID) error
|
||||
DeleteUserSessions(ctx context.Context, userID uuid.UUID) error
|
||||
|
||||
// Magic links
|
||||
CreateMagicLink(ctx context.Context, link MagicLink) error
|
||||
GetMagicLinkByTokenHash(ctx context.Context, tokenHash string) (MagicLink, error)
|
||||
MarkMagicLinkUsed(ctx context.Context, id uuid.UUID) error
|
||||
|
||||
// OAuth accounts
|
||||
CreateOAuthAccount(ctx context.Context, account OAuthAccount) error
|
||||
GetOAuthAccount(ctx context.Context, provider, providerUID string) (OAuthAccount, error)
|
||||
UpdateOAuthTokens(ctx context.Context, id uuid.UUID, accessToken, refreshToken string, expiresAt *time.Time) error
|
||||
|
||||
// TOTP
|
||||
CreateTOTPSecret(ctx context.Context, secret TOTPSecret) error
|
||||
GetTOTPSecret(ctx context.Context, userID uuid.UUID) (TOTPSecret, error)
|
||||
VerifyTOTPSecret(ctx context.Context, userID uuid.UUID) error
|
||||
@@ -50,101 +71,128 @@ func NewRepository(db *pgxpool.Pool, vk valkey.Client) Repository {
|
||||
return &pgRepository{db: db, vk: vk}
|
||||
}
|
||||
|
||||
// Session methods — dual storage: Valkey (fast) + PostgreSQL (durable)
|
||||
// Session methods
|
||||
|
||||
func (r *pgRepository) CreateSession(ctx context.Context, session Session) error {
|
||||
func (r *pgRepository) CreateSession(ctx context.Context, s Session) error {
|
||||
_, err := r.db.Exec(ctx, `
|
||||
INSERT INTO sessions (id, user_id, token_hash, ip_address, user_agent, expires_at)
|
||||
VALUES ($1, $2, $3, $4::inet, $5, $6)
|
||||
`, session.ID, session.UserID, session.TokenHash, session.IPAddress, session.UserAgent, session.ExpiresAt)
|
||||
INSERT INTO sessions (
|
||||
id, user_id, access_token_hash, refresh_token_hash, family_id, parent_session_id,
|
||||
ip_address, user_agent, access_expires_at, absolute_expires_at, last_used_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7::inet, $8, $9, $10, $11)
|
||||
`, s.ID, s.UserID, s.AccessTokenHash, s.RefreshTokenHash, s.FamilyID, s.ParentSessionID,
|
||||
s.IPAddress, s.UserAgent, s.AccessExpiresAt, s.AbsoluteExpiresAt, s.LastUsedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating session in postgres: %w", err)
|
||||
return fmt.Errorf("creating session: %w", err)
|
||||
}
|
||||
|
||||
// Cache in Valkey
|
||||
data, _ := json.Marshal(session)
|
||||
ttl := time.Until(session.ExpiresAt)
|
||||
key := sessionValkeyKey(session.TokenHash)
|
||||
err = r.vk.Do(ctx, r.vk.B().Set().Key(key).Value(string(data)).Ex(ttl).Build()).Error()
|
||||
if err != nil {
|
||||
// Log but don't fail — Postgres is the source of truth
|
||||
fmt.Printf("warning: failed to cache session in valkey: %v\n", err)
|
||||
data, _ := json.Marshal(s)
|
||||
ttl := time.Until(s.AccessExpiresAt)
|
||||
if ttl > 0 {
|
||||
key := accessValkeyKey(s.AccessTokenHash)
|
||||
if vkErr := r.vk.Do(ctx, r.vk.B().Set().Key(key).Value(string(data)).Ex(ttl).Build()).Error(); vkErr != nil {
|
||||
// Valkey failure is non-fatal; Postgres is the source of truth.
|
||||
fmt.Printf("warning: failed to cache session in valkey: %v\n", vkErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *pgRepository) GetSessionByTokenHash(ctx context.Context, tokenHash string) (Session, error) {
|
||||
// Try Valkey first
|
||||
key := sessionValkeyKey(tokenHash)
|
||||
result, err := r.vk.Do(ctx, r.vk.B().Get().Key(key).Build()).ToString()
|
||||
if err == nil && result != "" {
|
||||
var session Session
|
||||
if json.Unmarshal([]byte(result), &session) == nil {
|
||||
if time.Now().Before(session.ExpiresAt) {
|
||||
return session, nil
|
||||
}
|
||||
return Session{}, ErrSessionExpired
|
||||
func (r *pgRepository) GetSessionByAccessHash(ctx context.Context, hash string) (Session, error) {
|
||||
key := accessValkeyKey(hash)
|
||||
if result, err := r.vk.Do(ctx, r.vk.B().Get().Key(key).Build()).ToString(); err == nil && result != "" {
|
||||
var s Session
|
||||
if json.Unmarshal([]byte(result), &s) == nil {
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to Postgres
|
||||
// Postgres fallback — join users to get email and role without a second query.
|
||||
var s Session
|
||||
err = r.db.QueryRow(ctx, `
|
||||
SELECT id, user_id, token_hash, ip_address::text, user_agent, expires_at, created_at
|
||||
FROM sessions
|
||||
WHERE token_hash = $1
|
||||
`, tokenHash).Scan(&s.ID, &s.UserID, &s.TokenHash, &s.IPAddress, &s.UserAgent, &s.ExpiresAt, &s.CreatedAt)
|
||||
err := r.db.QueryRow(ctx, `
|
||||
SELECT s.id, s.user_id, u.email, u.role,
|
||||
s.access_token_hash, s.refresh_token_hash, s.family_id, s.parent_session_id,
|
||||
s.ip_address::text, s.user_agent,
|
||||
s.access_expires_at, s.absolute_expires_at, s.last_used_at, s.revoked_at, s.created_at
|
||||
FROM sessions s
|
||||
JOIN users u ON u.id = s.user_id
|
||||
WHERE s.access_token_hash = $1
|
||||
AND s.revoked_at IS NULL
|
||||
AND s.access_expires_at > NOW()
|
||||
`, hash).Scan(
|
||||
&s.ID, &s.UserID, &s.UserEmail, &s.UserRole,
|
||||
&s.AccessTokenHash, &s.RefreshTokenHash, &s.FamilyID, &s.ParentSessionID,
|
||||
&s.IPAddress, &s.UserAgent,
|
||||
&s.AccessExpiresAt, &s.AbsoluteExpiresAt, &s.LastUsedAt, &s.RevokedAt, &s.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return Session{}, ErrSessionNotFound
|
||||
}
|
||||
return Session{}, fmt.Errorf("getting session: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().After(s.ExpiresAt) {
|
||||
return Session{}, ErrSessionExpired
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (r *pgRepository) DeleteSession(ctx context.Context, id uuid.UUID) error {
|
||||
// Get the session first to know the token hash for Valkey cleanup
|
||||
var tokenHash string
|
||||
err := r.db.QueryRow(ctx, "SELECT token_hash FROM sessions WHERE id = $1", id).Scan(&tokenHash)
|
||||
// ConsumeRefreshToken atomically revokes the refresh token and returns the session.
|
||||
// Uses UPDATE...RETURNING to eliminate the select-then-update race condition.
|
||||
func (r *pgRepository) ConsumeRefreshToken(ctx context.Context, hash string) (Session, error) {
|
||||
var s Session
|
||||
err := r.db.QueryRow(ctx, `
|
||||
UPDATE sessions SET revoked_at = NOW()
|
||||
WHERE refresh_token_hash = $1 AND revoked_at IS NULL
|
||||
RETURNING id, user_id, family_id, parent_session_id,
|
||||
ip_address::text, user_agent,
|
||||
access_expires_at, absolute_expires_at, last_used_at, created_at
|
||||
`, hash).Scan(
|
||||
&s.ID, &s.UserID, &s.FamilyID, &s.ParentSessionID,
|
||||
&s.IPAddress, &s.UserAgent,
|
||||
&s.AccessExpiresAt, &s.AbsoluteExpiresAt, &s.LastUsedAt, &s.CreatedAt,
|
||||
)
|
||||
if err == nil {
|
||||
key := sessionValkeyKey(tokenHash)
|
||||
_ = r.vk.Do(ctx, r.vk.B().Del().Key(key).Build()).Error()
|
||||
s.RefreshTokenHash = hash
|
||||
return s, nil
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return Session{}, fmt.Errorf("consuming refresh token: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.db.Exec(ctx, "DELETE FROM sessions WHERE id = $1", id)
|
||||
// Zero rows: either never existed or already revoked. Check which.
|
||||
var familyID uuid.UUID
|
||||
lookupErr := r.db.QueryRow(ctx,
|
||||
`SELECT family_id FROM sessions WHERE refresh_token_hash = $1`,
|
||||
hash,
|
||||
).Scan(&familyID)
|
||||
if errors.Is(lookupErr, pgx.ErrNoRows) {
|
||||
return Session{}, ErrSessionNotFound
|
||||
}
|
||||
if lookupErr != nil {
|
||||
return Session{}, fmt.Errorf("reuse lookup: %w", lookupErr)
|
||||
}
|
||||
return Session{}, &RefreshReuseDetectedError{FamilyID: familyID}
|
||||
}
|
||||
|
||||
func (r *pgRepository) RevokeSession(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := r.db.Exec(ctx,
|
||||
`UPDATE sessions SET revoked_at = NOW() WHERE id = $1 AND revoked_at IS NULL`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *pgRepository) RevokeSessionsByFamilyID(ctx context.Context, familyID uuid.UUID) error {
|
||||
_, err := r.db.Exec(ctx,
|
||||
`UPDATE sessions SET revoked_at = NOW() WHERE family_id = $1 AND revoked_at IS NULL`, familyID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *pgRepository) BumpLastUsedAt(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := r.db.Exec(ctx,
|
||||
`UPDATE sessions SET last_used_at = NOW() WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *pgRepository) DeleteUserSessions(ctx context.Context, userID uuid.UUID) error {
|
||||
rows, err := r.db.Query(ctx, "SELECT token_hash FROM sessions WHERE user_id = $1", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing sessions for deletion: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keys []string
|
||||
for rows.Next() {
|
||||
var hash string
|
||||
if err := rows.Scan(&hash); err == nil {
|
||||
keys = append(keys, sessionValkeyKey(hash))
|
||||
}
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
_ = r.vk.Do(ctx, r.vk.B().Del().Key(keys[0]).Build()).Error()
|
||||
for _, k := range keys[1:] {
|
||||
_ = r.vk.Do(ctx, r.vk.B().Del().Key(k).Build()).Error()
|
||||
}
|
||||
}
|
||||
|
||||
_, err = r.db.Exec(ctx, "DELETE FROM sessions WHERE user_id = $1", userID)
|
||||
_, err := r.db.Exec(ctx,
|
||||
`UPDATE sessions SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -171,14 +219,12 @@ func (r *pgRepository) GetMagicLinkByTokenHash(ctx context.Context, tokenHash st
|
||||
}
|
||||
return MagicLink{}, fmt.Errorf("getting magic link: %w", err)
|
||||
}
|
||||
|
||||
if ml.Used {
|
||||
return MagicLink{}, ErrMagicLinkUsed
|
||||
}
|
||||
if time.Now().After(ml.ExpiresAt) {
|
||||
return MagicLink{}, ErrMagicLinkExpired
|
||||
}
|
||||
|
||||
return ml, nil
|
||||
}
|
||||
|
||||
@@ -262,6 +308,6 @@ func (r *pgRepository) DeleteTOTPSecret(ctx context.Context, userID uuid.UUID) e
|
||||
return err
|
||||
}
|
||||
|
||||
func sessionValkeyKey(tokenHash string) string {
|
||||
return "session:" + tokenHash
|
||||
func accessValkeyKey(hash string) string {
|
||||
return "mv:v2:session:access:" + hash
|
||||
}
|
||||
|
||||
@@ -10,23 +10,24 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"marktvogt.de/backend/internal/domain/user"
|
||||
"marktvogt.de/backend/internal/pkg/apierror"
|
||||
"marktvogt.de/backend/internal/pkg/password"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
authRepo Repository
|
||||
userRepo user.Repository
|
||||
tokenSvc *TokenService
|
||||
sessionTTL time.Duration
|
||||
type ServiceConfig struct {
|
||||
AccessTTL time.Duration
|
||||
RefreshIdleTTL time.Duration
|
||||
RefreshAbsoluteTTL time.Duration
|
||||
}
|
||||
|
||||
func NewService(authRepo Repository, userRepo user.Repository, tokenSvc *TokenService, sessionTTL time.Duration) *Service {
|
||||
return &Service{
|
||||
authRepo: authRepo,
|
||||
userRepo: userRepo,
|
||||
tokenSvc: tokenSvc,
|
||||
sessionTTL: sessionTTL,
|
||||
}
|
||||
type Service struct {
|
||||
authRepo Repository
|
||||
userRepo user.Repository
|
||||
cfg ServiceConfig
|
||||
}
|
||||
|
||||
func NewService(authRepo Repository, userRepo user.Repository, cfg ServiceConfig) *Service {
|
||||
return &Service{authRepo: authRepo, userRepo: userRepo, cfg: cfg}
|
||||
}
|
||||
|
||||
func (s *Service) Register(ctx context.Context, req RegisterRequest, ip, ua string) (AuthData, error) {
|
||||
@@ -43,7 +44,7 @@ func (s *Service) Register(ctx context.Context, req RegisterRequest, ip, ua stri
|
||||
return AuthData{}, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
|
||||
return s.createTokenPair(ctx, u, ip, ua)
|
||||
return s.createTokenPair(ctx, u, uuid.Nil, nil, ip, ua)
|
||||
}
|
||||
|
||||
func (s *Service) Login(ctx context.Context, req LoginRequest, ip, ua string) (AuthData, error) {
|
||||
@@ -64,7 +65,6 @@ func (s *Service) Login(ctx context.Context, req LoginRequest, ip, ua string) (A
|
||||
return AuthData{}, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
// Lazily upgrade bcrypt hashes to Argon2id on successful login.
|
||||
if password.NeedsRehash(*u.PasswordHash) {
|
||||
if newHash, hashErr := password.Hash(req.Password); hashErr == nil {
|
||||
if _, updateErr := s.userRepo.Update(ctx, u.ID, map[string]any{"password_hash": newHash}); updateErr != nil {
|
||||
@@ -73,70 +73,84 @@ func (s *Service) Login(ctx context.Context, req LoginRequest, ip, ua string) (A
|
||||
}
|
||||
}
|
||||
|
||||
// Check 2FA if enabled
|
||||
if req.TOTPCode != "" {
|
||||
if err := s.validateTOTP(ctx, u.ID, req.TOTPCode); err != nil {
|
||||
return AuthData{}, err
|
||||
}
|
||||
} else {
|
||||
// Check if user has 2FA enabled
|
||||
totp, err := s.authRepo.GetTOTPSecret(ctx, u.ID)
|
||||
if err == nil && totp.Verified {
|
||||
return AuthData{}, fmt.Errorf("2fa_required")
|
||||
}
|
||||
}
|
||||
|
||||
return s.createTokenPair(ctx, u, ip, ua)
|
||||
return s.createTokenPair(ctx, u, uuid.Nil, nil, ip, ua)
|
||||
}
|
||||
|
||||
func (s *Service) Logout(ctx context.Context, sessionTokenHash string) error {
|
||||
session, err := s.authRepo.GetSessionByTokenHash(ctx, sessionTokenHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.authRepo.DeleteSession(ctx, session.ID)
|
||||
// Logout revokes the session identified by its ID (set in context by middleware).
|
||||
func (s *Service) Logout(ctx context.Context, sessionID uuid.UUID) error {
|
||||
return s.authRepo.RevokeSession(ctx, sessionID)
|
||||
}
|
||||
|
||||
func (s *Service) RefreshToken(ctx context.Context, sessionToken, ip, ua string) (AuthData, error) {
|
||||
tokenHash := HashToken(sessionToken)
|
||||
session, err := s.authRepo.GetSessionByTokenHash(ctx, tokenHash)
|
||||
func (s *Service) RefreshToken(ctx context.Context, refreshToken, ip, ua string) (AuthData, error) {
|
||||
hash := HashToken(refreshToken)
|
||||
|
||||
old, err := s.authRepo.ConsumeRefreshToken(ctx, hash)
|
||||
if err != nil {
|
||||
var reuseErr *RefreshReuseDetectedError
|
||||
if errors.As(err, &reuseErr) {
|
||||
_ = s.authRepo.RevokeSessionsByFamilyID(ctx, reuseErr.FamilyID)
|
||||
slog.Warn("refresh token reuse detected", "family_id", reuseErr.FamilyID)
|
||||
return AuthData{}, apierror.RefreshReuse()
|
||||
}
|
||||
return AuthData{}, err
|
||||
}
|
||||
|
||||
u, err := s.userRepo.GetByID(ctx, session.UserID)
|
||||
if err != nil {
|
||||
return AuthData{}, fmt.Errorf("user not found for session: %w", err)
|
||||
// Idle TTL check — last_used_at is bumped on every authenticated request.
|
||||
if time.Since(old.LastUsedAt) > s.cfg.RefreshIdleTTL {
|
||||
return AuthData{}, ErrSessionExpired
|
||||
}
|
||||
if time.Now().After(old.AbsoluteExpiresAt) {
|
||||
return AuthData{}, ErrSessionExpired
|
||||
}
|
||||
|
||||
// Delete old session
|
||||
_ = s.authRepo.DeleteSession(ctx, session.ID)
|
||||
|
||||
// Create new token pair
|
||||
return s.createTokenPair(ctx, u, ip, ua)
|
||||
}
|
||||
|
||||
func (s *Service) ValidateSession(ctx context.Context, sessionToken string) (Session, error) {
|
||||
tokenHash := HashToken(sessionToken)
|
||||
return s.authRepo.GetSessionByTokenHash(ctx, tokenHash)
|
||||
}
|
||||
|
||||
func (s *Service) createTokenPair(ctx context.Context, u user.User, ip, ua string) (AuthData, error) {
|
||||
accessToken, err := s.tokenSvc.CreateAccessToken(u.ID, u.Email, u.Role)
|
||||
u, err := s.userRepo.GetByID(ctx, old.UserID)
|
||||
if err != nil {
|
||||
return AuthData{}, fmt.Errorf("creating access token: %w", err)
|
||||
return AuthData{}, fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
|
||||
sessionToken := GenerateSessionToken()
|
||||
sessionHash := HashToken(sessionToken)
|
||||
return s.createTokenPair(ctx, u, old.FamilyID, &old.ID, ip, ua)
|
||||
}
|
||||
|
||||
func (s *Service) ValidateAccessToken(ctx context.Context, accessToken string) (Session, error) {
|
||||
hash := HashToken(accessToken)
|
||||
return s.authRepo.GetSessionByAccessHash(ctx, hash)
|
||||
}
|
||||
|
||||
func (s *Service) createTokenPair(ctx context.Context, u user.User, familyID uuid.UUID, parentID *uuid.UUID, ip, ua string) (AuthData, error) {
|
||||
accessToken := GenerateOpaqueToken()
|
||||
refreshToken := GenerateOpaqueToken()
|
||||
now := time.Now()
|
||||
|
||||
if familyID == uuid.Nil {
|
||||
familyID = uuid.New()
|
||||
}
|
||||
|
||||
session := Session{
|
||||
ID: uuid.New(),
|
||||
UserID: u.ID,
|
||||
TokenHash: sessionHash,
|
||||
IPAddress: ip,
|
||||
UserAgent: ua,
|
||||
ExpiresAt: time.Now().Add(s.sessionTTL),
|
||||
ID: uuid.New(),
|
||||
UserID: u.ID,
|
||||
UserEmail: u.Email,
|
||||
UserRole: u.Role,
|
||||
AccessTokenHash: HashToken(accessToken),
|
||||
RefreshTokenHash: HashToken(refreshToken),
|
||||
FamilyID: familyID,
|
||||
ParentSessionID: parentID,
|
||||
IPAddress: ip,
|
||||
UserAgent: ua,
|
||||
AccessExpiresAt: now.Add(s.cfg.AccessTTL),
|
||||
AbsoluteExpiresAt: now.Add(s.cfg.RefreshAbsoluteTTL),
|
||||
LastUsedAt: now,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
if err := s.authRepo.CreateSession(ctx, session); err != nil {
|
||||
@@ -145,8 +159,7 @@ func (s *Service) createTokenPair(ctx context.Context, u user.User, ip, ua strin
|
||||
|
||||
return AuthData{
|
||||
AccessToken: accessToken,
|
||||
SessionToken: sessionToken,
|
||||
ExpiresIn: s.tokenSvc.AccessTTLSeconds(),
|
||||
RefreshToken: refreshToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -158,11 +171,9 @@ func (s *Service) validateTOTP(ctx context.Context, userID uuid.UUID, code strin
|
||||
if !totp.Verified {
|
||||
return fmt.Errorf("2fa not verified")
|
||||
}
|
||||
|
||||
if !ValidateTOTP(totp.Secret, code) {
|
||||
return fmt.Errorf("invalid 2fa code")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -194,7 +205,3 @@ func (s *Service) ChangePassword(ctx context.Context, userID uuid.UUID, req Chan
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) TokenService() *TokenService {
|
||||
return s.tokenSvc
|
||||
}
|
||||
|
||||
326
backend/internal/domain/auth/service_refresh_test.go
Normal file
326
backend/internal/domain/auth/service_refresh_test.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"marktvogt.de/backend/internal/domain/auth"
|
||||
"marktvogt.de/backend/internal/domain/user"
|
||||
"marktvogt.de/backend/internal/pkg/apierror"
|
||||
)
|
||||
|
||||
// fakeRepo is a minimal in-memory auth.Repository for service tests.
|
||||
type fakeRepo struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*auth.Session // keyed by access hash
|
||||
byRefresh map[string]*auth.Session // keyed by refresh hash
|
||||
|
||||
magicLinks map[string]*auth.MagicLink
|
||||
totpSecrets map[string]*auth.TOTPSecret
|
||||
oauthAccounts []auth.OAuthAccount
|
||||
|
||||
revokedFamilies []uuid.UUID
|
||||
bumpedSessions []uuid.UUID
|
||||
}
|
||||
|
||||
func newFakeRepo() *fakeRepo {
|
||||
return &fakeRepo{
|
||||
sessions: make(map[string]*auth.Session),
|
||||
byRefresh: make(map[string]*auth.Session),
|
||||
magicLinks: make(map[string]*auth.MagicLink),
|
||||
totpSecrets: make(map[string]*auth.TOTPSecret),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *fakeRepo) addSession(s auth.Session) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
clone := s
|
||||
r.sessions[s.AccessTokenHash] = &clone
|
||||
r.byRefresh[s.RefreshTokenHash] = &clone
|
||||
}
|
||||
|
||||
func (r *fakeRepo) CreateSession(_ context.Context, s auth.Session) error {
|
||||
r.addSession(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRepo) GetSessionByAccessHash(_ context.Context, hash string) (auth.Session, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
s, ok := r.sessions[hash]
|
||||
if !ok {
|
||||
return auth.Session{}, auth.ErrSessionNotFound
|
||||
}
|
||||
return *s, nil
|
||||
}
|
||||
|
||||
func (r *fakeRepo) GetSessionByRefreshHash(_ context.Context, hash string) (auth.Session, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
s, ok := r.byRefresh[hash]
|
||||
if !ok {
|
||||
return auth.Session{}, auth.ErrSessionNotFound
|
||||
}
|
||||
return *s, nil
|
||||
}
|
||||
|
||||
// ConsumeRefreshToken atomically marks the refresh token as revoked and returns the session.
|
||||
// Returns RefreshReuseDetectedError if the token was already revoked.
|
||||
func (r *fakeRepo) ConsumeRefreshToken(_ context.Context, hash string) (auth.Session, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
s, ok := r.byRefresh[hash]
|
||||
if !ok {
|
||||
return auth.Session{}, auth.ErrSessionNotFound
|
||||
}
|
||||
if s.RevokedAt != nil {
|
||||
return auth.Session{}, &auth.RefreshReuseDetectedError{FamilyID: s.FamilyID}
|
||||
}
|
||||
now := time.Now()
|
||||
s.RevokedAt = &now
|
||||
return *s, nil
|
||||
}
|
||||
|
||||
func (r *fakeRepo) RevokeSession(_ context.Context, id uuid.UUID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, s := range r.sessions {
|
||||
if s.ID == id {
|
||||
now := time.Now()
|
||||
s.RevokedAt = &now
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRepo) RevokeSessionsByFamilyID(_ context.Context, familyID uuid.UUID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.revokedFamilies = append(r.revokedFamilies, familyID)
|
||||
for _, s := range r.sessions {
|
||||
if s.FamilyID == familyID {
|
||||
now := time.Now()
|
||||
s.RevokedAt = &now
|
||||
}
|
||||
}
|
||||
for _, s := range r.byRefresh {
|
||||
if s.FamilyID == familyID {
|
||||
now := time.Now()
|
||||
s.RevokedAt = &now
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRepo) BumpLastUsedAt(_ context.Context, id uuid.UUID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.bumpedSessions = append(r.bumpedSessions, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRepo) DeleteUserSessions(_ context.Context, _ uuid.UUID) error { return nil }
|
||||
|
||||
// Magic link stubs
|
||||
func (r *fakeRepo) CreateMagicLink(_ context.Context, link auth.MagicLink) error {
|
||||
r.magicLinks[link.TokenHash] = &link
|
||||
return nil
|
||||
}
|
||||
func (r *fakeRepo) GetMagicLinkByTokenHash(_ context.Context, hash string) (auth.MagicLink, error) {
|
||||
if ml, ok := r.magicLinks[hash]; ok {
|
||||
return *ml, nil
|
||||
}
|
||||
return auth.MagicLink{}, auth.ErrMagicLinkNotFound
|
||||
}
|
||||
func (r *fakeRepo) MarkMagicLinkUsed(_ context.Context, id uuid.UUID) error { return nil }
|
||||
|
||||
// OAuth stubs
|
||||
func (r *fakeRepo) CreateOAuthAccount(_ context.Context, a auth.OAuthAccount) error {
|
||||
r.oauthAccounts = append(r.oauthAccounts, a)
|
||||
return nil
|
||||
}
|
||||
func (r *fakeRepo) GetOAuthAccount(_ context.Context, _, _ string) (auth.OAuthAccount, error) {
|
||||
return auth.OAuthAccount{}, errors.New("not found")
|
||||
}
|
||||
func (r *fakeRepo) UpdateOAuthTokens(_ context.Context, _ uuid.UUID, _, _ string, _ *time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TOTP stubs
|
||||
func (r *fakeRepo) CreateTOTPSecret(_ context.Context, s auth.TOTPSecret) error {
|
||||
r.totpSecrets[s.UserID.String()] = &s
|
||||
return nil
|
||||
}
|
||||
func (r *fakeRepo) GetTOTPSecret(_ context.Context, userID uuid.UUID) (auth.TOTPSecret, error) {
|
||||
if s, ok := r.totpSecrets[userID.String()]; ok {
|
||||
return *s, nil
|
||||
}
|
||||
return auth.TOTPSecret{}, errors.New("not found")
|
||||
}
|
||||
func (r *fakeRepo) VerifyTOTPSecret(_ context.Context, _ uuid.UUID) error { return nil }
|
||||
func (r *fakeRepo) DeleteTOTPSecret(_ context.Context, _ uuid.UUID) error { return nil }
|
||||
|
||||
// fakeUserRepo is a minimal in-memory user.Repository.
|
||||
type fakeUserRepo struct {
|
||||
users map[uuid.UUID]user.User
|
||||
}
|
||||
|
||||
func newFakeUserRepo(users ...user.User) *fakeUserRepo {
|
||||
r := &fakeUserRepo{users: make(map[uuid.UUID]user.User)}
|
||||
for _, u := range users {
|
||||
r.users[u.ID] = u
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *fakeUserRepo) Create(_ context.Context, email, hash, name string) (user.User, error) {
|
||||
u := user.User{ID: uuid.New(), Email: email, DisplayName: name}
|
||||
r.users[u.ID] = u
|
||||
return u, nil
|
||||
}
|
||||
func (r *fakeUserRepo) CreateOAuthUser(_ context.Context, email, name string, _ bool) (user.User, error) {
|
||||
u := user.User{ID: uuid.New(), Email: email, DisplayName: name}
|
||||
r.users[u.ID] = u
|
||||
return u, nil
|
||||
}
|
||||
func (r *fakeUserRepo) GetByID(_ context.Context, id uuid.UUID) (user.User, error) {
|
||||
if u, ok := r.users[id]; ok {
|
||||
return u, nil
|
||||
}
|
||||
return user.User{}, user.ErrUserNotFound
|
||||
}
|
||||
func (r *fakeUserRepo) GetByEmail(_ context.Context, email string) (user.User, error) {
|
||||
for _, u := range r.users {
|
||||
if u.Email == email {
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
return user.User{}, user.ErrUserNotFound
|
||||
}
|
||||
func (r *fakeUserRepo) Update(_ context.Context, id uuid.UUID, _ map[string]any) (user.User, error) {
|
||||
if u, ok := r.users[id]; ok {
|
||||
return u, nil
|
||||
}
|
||||
return user.User{}, user.ErrUserNotFound
|
||||
}
|
||||
func (r *fakeUserRepo) SoftDelete(_ context.Context, _ uuid.UUID) error { return nil }
|
||||
func (r *fakeUserRepo) Restore(_ context.Context, _ uuid.UUID) error { return nil }
|
||||
func (r *fakeUserRepo) GetDeletedByID(_ context.Context, id uuid.UUID) (user.User, error) {
|
||||
return user.User{}, user.ErrUserNotFound
|
||||
}
|
||||
|
||||
func makeService(authRepo auth.Repository, userRepo user.Repository) *auth.Service {
|
||||
return auth.NewService(authRepo, userRepo, auth.ServiceConfig{
|
||||
AccessTTL: 30 * time.Minute,
|
||||
RefreshIdleTTL: 7 * 24 * time.Hour,
|
||||
RefreshAbsoluteTTL: 30 * 24 * time.Hour,
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshToken_HappyPath(t *testing.T) {
|
||||
repo := newFakeRepo()
|
||||
testUser := user.User{ID: uuid.New(), Email: "a@b.c", Role: "user"}
|
||||
svc := makeService(repo, newFakeUserRepo(testUser))
|
||||
|
||||
ctx := context.Background()
|
||||
data, err := svc.Register(ctx, auth.RegisterRequest{
|
||||
Email: testUser.Email, Password: "correct-horse-battery", DisplayName: "Test",
|
||||
}, "127.0.0.1", "test-agent")
|
||||
if err != nil {
|
||||
t.Fatalf("register: %v", err)
|
||||
}
|
||||
|
||||
newData, err := svc.RefreshToken(ctx, data.RefreshToken, "127.0.0.1", "test-agent")
|
||||
if err != nil {
|
||||
t.Fatalf("refresh: %v", err)
|
||||
}
|
||||
if newData.AccessToken == "" || newData.RefreshToken == "" {
|
||||
t.Fatal("expected non-empty tokens after refresh")
|
||||
}
|
||||
if newData.RefreshToken == data.RefreshToken {
|
||||
t.Fatal("refresh token must rotate on each use")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_ReuseDetection_RevokesFamilyAndReturns401(t *testing.T) {
|
||||
repo := newFakeRepo()
|
||||
testUser := user.User{ID: uuid.New(), Email: "reuse@b.c", Role: "user"}
|
||||
svc := makeService(repo, newFakeUserRepo(testUser))
|
||||
|
||||
ctx := context.Background()
|
||||
data, err := svc.Register(ctx, auth.RegisterRequest{
|
||||
Email: testUser.Email, Password: "correct-horse-battery", DisplayName: "Test",
|
||||
}, "127.0.0.1", "ua")
|
||||
if err != nil {
|
||||
t.Fatalf("register: %v", err)
|
||||
}
|
||||
|
||||
// First refresh — legitimate use, rotates the token.
|
||||
_, err = svc.RefreshToken(ctx, data.RefreshToken, "127.0.0.1", "ua")
|
||||
if err != nil {
|
||||
t.Fatalf("first refresh: %v", err)
|
||||
}
|
||||
|
||||
// Replay the original refresh token — reuse detected.
|
||||
_, err = svc.RefreshToken(ctx, data.RefreshToken, "127.0.0.1", "ua")
|
||||
if err == nil {
|
||||
t.Fatal("expected error on refresh token reuse")
|
||||
}
|
||||
|
||||
var apiErr *apierror.Error
|
||||
if !errors.As(err, &apiErr) {
|
||||
t.Fatalf("expected *apierror.Error, got %T: %v", err, err)
|
||||
}
|
||||
if apiErr.Code != "auth.refresh_reuse_detected" {
|
||||
t.Errorf("expected auth.refresh_reuse_detected, got %q", apiErr.Code)
|
||||
}
|
||||
if len(repo.revokedFamilies) == 0 {
|
||||
t.Error("expected entire token family to be revoked on reuse")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_ExpiredIdleTTL_Returns401(t *testing.T) {
|
||||
repo := newFakeRepo()
|
||||
testUser := user.User{ID: uuid.New(), Email: "idle@b.c", Role: "user"}
|
||||
// Use a very short idle TTL to simulate expiry.
|
||||
svc := auth.NewService(repo, newFakeUserRepo(testUser), auth.ServiceConfig{
|
||||
AccessTTL: 30 * time.Minute,
|
||||
RefreshIdleTTL: 1 * time.Millisecond,
|
||||
RefreshAbsoluteTTL: 30 * 24 * time.Hour,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
data, err := svc.Register(ctx, auth.RegisterRequest{
|
||||
Email: testUser.Email, Password: "correct-horse-battery", DisplayName: "Test",
|
||||
}, "127.0.0.1", "ua")
|
||||
if err != nil {
|
||||
t.Fatalf("register: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
_, err = svc.RefreshToken(ctx, data.RefreshToken, "127.0.0.1", "ua")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for expired idle TTL")
|
||||
}
|
||||
if !errors.Is(err, auth.ErrSessionExpired) {
|
||||
t.Errorf("expected ErrSessionExpired, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_UnknownToken_Returns401(t *testing.T) {
|
||||
repo := newFakeRepo()
|
||||
svc := makeService(repo, newFakeUserRepo())
|
||||
|
||||
_, err := svc.RefreshToken(context.Background(), "nosuchtoken", "127.0.0.1", "ua")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown token")
|
||||
}
|
||||
}
|
||||
@@ -1,84 +1,23 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type TokenClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
UserID uuid.UUID `json:"uid"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type TokenService struct {
|
||||
secret []byte
|
||||
accessTTL time.Duration
|
||||
}
|
||||
|
||||
func NewTokenService(secret string, accessTTL time.Duration) *TokenService {
|
||||
return &TokenService{
|
||||
secret: []byte(secret),
|
||||
accessTTL: accessTTL,
|
||||
// GenerateOpaqueToken returns a cryptographically random 32-byte URL-safe base64 token.
|
||||
func GenerateOpaqueToken() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic("auth: crypto/rand unavailable: " + err.Error())
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func (ts *TokenService) CreateAccessToken(userID uuid.UUID, email, role string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := TokenClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(ts.accessTTL)),
|
||||
Issuer: "marktvogt",
|
||||
},
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Role: role,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := token.SignedString(ts.secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("signing access token: %w", err)
|
||||
}
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
func (ts *TokenService) ValidateAccessToken(tokenString string) (*TokenClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return ts.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing access token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*TokenClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, fmt.Errorf("invalid access token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (ts *TokenService) AccessTTLSeconds() int {
|
||||
return int(ts.accessTTL.Seconds())
|
||||
}
|
||||
|
||||
// HashToken returns the hex-encoded SHA-256 of token, used for safe DB storage.
|
||||
func HashToken(token string) string {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func GenerateSessionToken() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
@@ -1,61 +1,93 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"marktvogt.de/backend/internal/domain/auth"
|
||||
"marktvogt.de/backend/internal/pkg/apierror"
|
||||
)
|
||||
|
||||
func RequireAuth(tokenSvc *auth.TokenService) gin.HandlerFunc {
|
||||
// SessionLookup is the subset of auth.Repository needed by the auth middleware.
|
||||
// Keeping it narrow makes the middleware easy to test without a full repo mock.
|
||||
type SessionLookup interface {
|
||||
GetSessionByAccessHash(ctx context.Context, hash string) (auth.Session, error)
|
||||
BumpLastUsedAt(ctx context.Context, id uuid.UUID) error
|
||||
}
|
||||
|
||||
const lastUsedBumpThreshold = 60 * time.Second
|
||||
|
||||
// RequireAuth validates the Bearer access token via Valkey/Postgres lookup.
|
||||
// On success it sets user_id, user_email, user_role, and session_id in context.
|
||||
// accessTTL is used to allow the middleware to be tested independently of the session TTL
|
||||
// configuration without importing the config package.
|
||||
func RequireAuth(repo SessionLookup, accessTTL time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
claims, ok := extractAndValidate(c, tokenSvc)
|
||||
session, ok := resolveSession(c, repo, accessTTL)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("user_email", claims.Email)
|
||||
c.Set("user_role", claims.Role)
|
||||
c.Set("user_id", session.UserID)
|
||||
c.Set("user_email", session.UserEmail)
|
||||
c.Set("user_role", session.UserRole)
|
||||
c.Set("session_id", session.ID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func OptionalAuth(tokenSvc *auth.TokenService) gin.HandlerFunc {
|
||||
// OptionalAuth is like RequireAuth but allows unauthenticated requests through.
|
||||
func OptionalAuth(repo SessionLookup, accessTTL time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
header := c.GetHeader("Authorization")
|
||||
if header == "" || !strings.HasPrefix(header, "Bearer ") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
claims, _ := extractAndValidate(c, tokenSvc)
|
||||
if claims != nil {
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("user_email", claims.Email)
|
||||
c.Set("user_role", claims.Role)
|
||||
if session, ok := resolveSession(c, repo, accessTTL); ok {
|
||||
c.Set("user_id", session.UserID)
|
||||
c.Set("user_email", session.UserEmail)
|
||||
c.Set("user_role", session.UserRole)
|
||||
c.Set("session_id", session.ID)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func extractAndValidate(c *gin.Context, tokenSvc *auth.TokenService) (*auth.TokenClaims, bool) {
|
||||
func resolveSession(c *gin.Context, repo SessionLookup, accessTTL time.Duration) (auth.Session, bool) {
|
||||
header := c.GetHeader("Authorization")
|
||||
if header == "" || !strings.HasPrefix(header, "Bearer ") {
|
||||
apiErr := apierror.Unauthorized("missing or invalid authorization header")
|
||||
c.AbortWithStatusJSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return nil, false
|
||||
reject(c)
|
||||
return auth.Session{}, false
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(header, "Bearer ")
|
||||
claims, err := tokenSvc.ValidateAccessToken(tokenString)
|
||||
token := strings.TrimPrefix(header, "Bearer ")
|
||||
hash := auth.HashToken(token)
|
||||
|
||||
session, err := repo.GetSessionByAccessHash(c.Request.Context(), hash)
|
||||
if err != nil {
|
||||
apiErr := apierror.Unauthorized("invalid or expired token")
|
||||
c.AbortWithStatusJSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return nil, false
|
||||
reject(c)
|
||||
return auth.Session{}, false
|
||||
}
|
||||
|
||||
return claims, true
|
||||
if session.RevokedAt != nil {
|
||||
reject(c)
|
||||
return auth.Session{}, false
|
||||
}
|
||||
|
||||
// Throttled last_used_at bump — skips the write when the row was recently updated.
|
||||
if time.Since(session.LastUsedAt) > lastUsedBumpThreshold {
|
||||
_ = repo.BumpLastUsedAt(c.Request.Context(), session.ID)
|
||||
}
|
||||
|
||||
_ = accessTTL // TTL is enforced by access_expires_at stored on the session row
|
||||
return session, true
|
||||
}
|
||||
|
||||
func reject(c *gin.Context) {
|
||||
apiErr := apierror.Unauthorized("invalid or expired token")
|
||||
c.AbortWithStatusJSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
}
|
||||
|
||||
172
backend/internal/middleware/auth_test.go
Normal file
172
backend/internal/middleware/auth_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"marktvogt.de/backend/internal/domain/auth"
|
||||
"marktvogt.de/backend/internal/middleware"
|
||||
)
|
||||
|
||||
// stubSessionRepo implements the minimal surface the auth middleware needs.
|
||||
type stubSessionRepo struct {
|
||||
session auth.Session
|
||||
err error
|
||||
bumped []uuid.UUID
|
||||
}
|
||||
|
||||
func (r *stubSessionRepo) GetSessionByAccessHash(_ context.Context, _ string) (auth.Session, error) {
|
||||
return r.session, r.err
|
||||
}
|
||||
func (r *stubSessionRepo) BumpLastUsedAt(_ context.Context, id uuid.UUID) error {
|
||||
r.bumped = append(r.bumped, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newRouter(h gin.HandlerFunc, mw ...gin.HandlerFunc) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.GET("/test", append(mw, h)...)
|
||||
return r
|
||||
}
|
||||
|
||||
func bearerReq(token string) *http.Request {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func TestRequireAuth_ValidToken_SetsContextAndPasses(t *testing.T) {
|
||||
sessionID := uuid.New()
|
||||
userID := uuid.New()
|
||||
stub := &stubSessionRepo{
|
||||
session: auth.Session{
|
||||
ID: sessionID,
|
||||
UserID: userID,
|
||||
UserEmail: "a@b.c",
|
||||
UserRole: "user",
|
||||
LastUsedAt: time.Now().Add(-2 * time.Minute),
|
||||
AccessExpiresAt: time.Now().Add(28 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
var gotUserID, gotSessionID any
|
||||
handler := func(c *gin.Context) {
|
||||
gotUserID, _ = c.Get("user_id")
|
||||
gotSessionID, _ = c.Get("session_id")
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
r := newRouter(handler, middleware.RequireAuth(stub, 30*time.Minute))
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, bearerReq("sometoken"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if gotUserID != userID {
|
||||
t.Errorf("user_id not set correctly in context")
|
||||
}
|
||||
if gotSessionID != sessionID {
|
||||
t.Errorf("session_id not set correctly in context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_MissingToken_Returns401(t *testing.T) {
|
||||
stub := &stubSessionRepo{}
|
||||
r := newRouter(func(c *gin.Context) { c.Status(200) }, middleware.RequireAuth(stub, 30*time.Minute))
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, bearerReq(""))
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_UnknownToken_Returns401(t *testing.T) {
|
||||
stub := &stubSessionRepo{err: auth.ErrSessionNotFound}
|
||||
r := newRouter(func(c *gin.Context) { c.Status(200) }, middleware.RequireAuth(stub, 30*time.Minute))
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, bearerReq("badtoken"))
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_RevokedSession_Returns401(t *testing.T) {
|
||||
now := time.Now()
|
||||
stub := &stubSessionRepo{
|
||||
session: auth.Session{
|
||||
ID: uuid.New(),
|
||||
UserID: uuid.New(),
|
||||
AccessExpiresAt: now.Add(10 * time.Minute),
|
||||
LastUsedAt: now.Add(-1 * time.Minute),
|
||||
RevokedAt: &now,
|
||||
},
|
||||
}
|
||||
r := newRouter(func(c *gin.Context) { c.Status(200) }, middleware.RequireAuth(stub, 30*time.Minute))
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, bearerReq("revokedtoken"))
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_BumpsLastUsedAt_WhenStale(t *testing.T) {
|
||||
sessionID := uuid.New()
|
||||
stub := &stubSessionRepo{
|
||||
session: auth.Session{
|
||||
ID: sessionID,
|
||||
UserID: uuid.New(),
|
||||
UserEmail: "x@y.z",
|
||||
UserRole: "user",
|
||||
LastUsedAt: time.Now().Add(-90 * time.Second), // older than 60s threshold
|
||||
AccessExpiresAt: time.Now().Add(28 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
r := newRouter(func(c *gin.Context) { c.Status(200) }, middleware.RequireAuth(stub, 30*time.Minute))
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, bearerReq("sometoken"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
if len(stub.bumped) == 0 {
|
||||
t.Error("expected BumpLastUsedAt to be called for stale last_used_at")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_DoesNotBumpLastUsedAt_WhenFresh(t *testing.T) {
|
||||
stub := &stubSessionRepo{
|
||||
session: auth.Session{
|
||||
ID: uuid.New(),
|
||||
UserID: uuid.New(),
|
||||
UserEmail: "x@y.z",
|
||||
UserRole: "user",
|
||||
LastUsedAt: time.Now().Add(-10 * time.Second), // fresher than 60s threshold
|
||||
AccessExpiresAt: time.Now().Add(28 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
r := newRouter(func(c *gin.Context) { c.Status(200) }, middleware.RequireAuth(stub, 30*time.Minute))
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, bearerReq("sometoken"))
|
||||
|
||||
if len(stub.bumped) != 0 {
|
||||
t.Error("expected BumpLastUsedAt to be skipped for fresh last_used_at")
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure stubSessionRepo satisfies the SessionLookup interface at compile time.
|
||||
var _ middleware.SessionLookup = (*stubSessionRepo)(nil)
|
||||
@@ -32,11 +32,14 @@ func (s *Server) registerRoutes() {
|
||||
|
||||
// Auth
|
||||
userRepo := user.NewRepository(s.db)
|
||||
tokenSvc := auth.NewTokenService(s.cfg.JWT.Secret, s.cfg.JWT.AccessTTL)
|
||||
authRepo := auth.NewRepository(s.db, s.valkey)
|
||||
authSvc := auth.NewService(authRepo, userRepo, tokenSvc, s.cfg.JWT.SessionTTL)
|
||||
authSvc := auth.NewService(authRepo, userRepo, auth.ServiceConfig{
|
||||
AccessTTL: s.cfg.Auth.AccessTTL,
|
||||
RefreshIdleTTL: s.cfg.Auth.RefreshIdleTTL,
|
||||
RefreshAbsoluteTTL: s.cfg.Auth.RefreshAbsoluteTTL,
|
||||
})
|
||||
authHandler := auth.NewHandler(authSvc, userRepo)
|
||||
requireAuth := middleware.RequireAuth(tokenSvc)
|
||||
requireAuth := middleware.RequireAuth(authRepo, s.cfg.Auth.AccessTTL)
|
||||
|
||||
// Per-route auth rate limiters (keyed by IP; user_id unavailable before auth completes)
|
||||
userIDKey := func(c *gin.Context) string {
|
||||
|
||||
@@ -6,6 +6,7 @@ ALTER TABLE sessions
|
||||
DROP COLUMN IF EXISTS revoked_at,
|
||||
DROP COLUMN IF EXISTS last_used_at,
|
||||
DROP COLUMN IF EXISTS absolute_expires_at,
|
||||
DROP COLUMN IF EXISTS access_expires_at,
|
||||
DROP COLUMN IF EXISTS parent_session_id,
|
||||
DROP COLUMN IF EXISTS family_id,
|
||||
DROP COLUMN IF EXISTS refresh_token_hash,
|
||||
|
||||
@@ -8,6 +8,7 @@ ALTER TABLE sessions
|
||||
ADD COLUMN IF NOT EXISTS family_id UUID NOT NULL DEFAULT uuid_generate_v4(),
|
||||
ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES sessions(id) ON DELETE SET NULL,
|
||||
ADD COLUMN IF NOT EXISTS absolute_expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '30 days',
|
||||
ADD COLUMN IF NOT EXISTS access_expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '30 minutes',
|
||||
ADD COLUMN IF NOT EXISTS last_used_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user