Files
vikingowl 515a72e6e8 feat(auth): D4 TOTP backup codes + session management
- Backup codes: 10 × Crockford base32 (XXXXX-XXXXX), SHA-256 hashed,
  single-use; regenerate requires current TOTP code
- Login accepts BackupCode field alongside TOTPCode
- Session management: list, revoke-by-id (ownership-checked),
  revoke-all-except-current; password change revokes other sessions
- New routes: POST /auth/2fa/backup-codes/regenerate,
  GET /auth/sessions, DELETE /auth/sessions, DELETE /auth/sessions/:id
- fakeRepo extended with backup code + session management stubs
- Tests cover: code format/count, hash storage, regen invalidates old,
  login with valid/used code, session list isolation, revoke ownership,
  password change session revocation
2026-04-26 12:33:47 +02:00

315 lines
8.7 KiB
Go

package auth
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"marktvogt.de/backend/internal/domain/user"
"marktvogt.de/backend/internal/pkg/apierror"
"marktvogt.de/backend/internal/pkg/validate"
)
type Handler struct {
service *Service
userRepo user.Repository
}
func NewHandler(service *Service, userRepo user.Repository) *Handler {
return &Handler{service: service, userRepo: userRepo}
}
func (h *Handler) Register(c *gin.Context) {
var req RegisterRequest
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
data, err := h.service.Register(c.Request.Context(), req, c.ClientIP(), c.GetHeader("User-Agent"))
if err != nil {
if errors.Is(err, user.ErrEmailAlreadyTaken) {
apiErr := apierror.Conflict("email already registered")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
apiErr := apierror.Internal("registration failed")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusCreated, AuthResponse{Data: data})
}
func (h *Handler) Login(c *gin.Context) {
var req LoginRequest
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
data, err := h.service.Login(c.Request.Context(), req, c.ClientIP(), c.GetHeader("User-Agent"))
if err != nil {
msg := err.Error()
if msg == "invalid credentials" {
apiErr := apierror.Unauthorized("invalid email or password")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
if msg == "2fa_required" {
apiErr := apierror.BadRequest("2fa_required", "2FA code required")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
apiErr := apierror.Internal("login failed")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusOK, AuthResponse{Data: data})
}
func (h *Handler) Logout(c *gin.Context) {
sessionID := GetSessionID(c)
if sessionID == uuid.Nil {
apiErr := apierror.Unauthorized("not authenticated")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
if err := h.service.Logout(c.Request.Context(), sessionID); err != nil {
apiErr := apierror.Internal("logout failed")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "logged out"}})
}
func (h *Handler) Refresh(c *gin.Context) {
refreshToken := extractRefreshToken(c)
if refreshToken == "" {
apiErr := apierror.Unauthorized("refresh token required")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
data, err := h.service.RefreshToken(c.Request.Context(), refreshToken, c.ClientIP(), c.GetHeader("User-Agent"))
if err != nil {
var apiErr *apierror.Error
if errors.As(err, &apiErr) {
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
apiErr = apierror.Unauthorized("invalid or expired session")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusOK, AuthResponse{Data: data})
}
func (h *Handler) SetupTOTP(c *gin.Context) {
userID := GetUserID(c)
u, err := h.userRepo.GetByID(c.Request.Context(), userID)
if err != nil {
apiErr := apierror.Internal("failed to get user")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
data, err := h.service.SetupTOTP(c.Request.Context(), userID, u.Email)
if err != nil {
apiErr := apierror.Internal("failed to setup 2FA")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusOK, TOTPSetupResponse{Data: data})
}
func (h *Handler) VerifyTOTP(c *gin.Context) {
var req TOTPVerifyRequest
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
userID := GetUserID(c)
if err := h.service.VerifyTOTPSetup(c.Request.Context(), userID, req.Code); err != nil {
apiErr := apierror.BadRequest("invalid_code", err.Error())
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "2FA enabled"}})
}
func (h *Handler) ChangePassword(c *gin.Context) {
var req ChangePasswordRequest
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
userID := GetUserID(c)
sessionID := GetSessionID(c)
if err := h.service.ChangePassword(c.Request.Context(), userID, sessionID, req); err != nil {
msg := err.Error()
if msg == "current password required" || msg == "current password incorrect" {
apiErr := apierror.BadRequest("invalid_password", msg)
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
apiErr := apierror.Internal("failed to change password")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "password updated"}})
}
func (h *Handler) RegenerateBackupCodes(c *gin.Context) {
var req BackupCodeRegenerateRequest
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
userID := GetUserID(c)
if err := h.service.validateTOTP(c.Request.Context(), userID, req.TOTPCode); err != nil {
apiErr := apierror.BadRequest("invalid_code", "invalid TOTP code")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
codes, err := h.service.GenerateBackupCodes(c.Request.Context(), userID)
if err != nil {
apiErr := apierror.Internal("failed to generate backup codes")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusOK, BackupCodesResponse{Data: BackupCodesData{Codes: codes}})
}
func (h *Handler) ListSessions(c *gin.Context) {
userID := GetUserID(c)
currentSessionID := GetSessionID(c)
sessions, err := h.service.ListSessions(c.Request.Context(), userID)
if err != nil {
apiErr := apierror.Internal("failed to list sessions")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
infos := make([]SessionInfo, 0, len(sessions))
for _, s := range sessions {
infos = append(infos, SessionInfo{
ID: s.ID,
IPAddress: s.IPAddress,
UserAgent: s.UserAgent,
LastUsedAt: s.LastUsedAt,
CreatedAt: s.CreatedAt,
Current: s.ID == currentSessionID,
})
}
c.JSON(http.StatusOK, SessionListResponse{Data: infos})
}
func (h *Handler) RevokeSession(c *gin.Context) {
userID := GetUserID(c)
rawID := c.Param("id")
sessionID, err := uuid.Parse(rawID)
if err != nil {
apiErr := apierror.BadRequest("invalid_id", "invalid session id")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
if err := h.service.RevokeSessionByID(c.Request.Context(), userID, sessionID); err != nil {
msg := err.Error()
if msg == "session not found" {
apiErr := apierror.NotFound("session not found")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
if msg == "not your session" {
apiErr := apierror.Forbidden("not your session")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
apiErr := apierror.Internal("failed to revoke session")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "session revoked"}})
}
func (h *Handler) RevokeOtherSessions(c *gin.Context) {
userID := GetUserID(c)
sessionID := GetSessionID(c)
if err := h.service.RevokeOtherSessions(c.Request.Context(), userID, sessionID); err != nil {
apiErr := apierror.Internal("failed to revoke sessions")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "other sessions revoked"}})
}
func (h *Handler) DisableTOTP(c *gin.Context) {
var req TOTPVerifyRequest
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
userID := GetUserID(c)
if err := h.service.DisableTOTP(c.Request.Context(), userID, req.Code); err != nil {
apiErr := apierror.BadRequest("invalid_code", err.Error())
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "2FA disabled"}})
}
// extractRefreshToken reads the refresh token from X-Refresh-Token or (legacy) X-Session-Token.
func extractRefreshToken(c *gin.Context) string {
if t := c.GetHeader("X-Refresh-Token"); t != "" {
return t
}
return c.GetHeader("X-Session-Token")
}
func GetUserID(c *gin.Context) uuid.UUID {
v, exists := c.Get("user_id")
if !exists {
return uuid.Nil
}
id, ok := v.(uuid.UUID)
if !ok {
return uuid.Nil
}
return id
}
func GetSessionID(c *gin.Context) uuid.UUID {
v, exists := c.Get("session_id")
if !exists {
return uuid.Nil
}
id, ok := v.(uuid.UUID)
if !ok {
return uuid.Nil
}
return id
}