- Backup codes: 10 × Crockford base32 (XXXXX-XXXXX), SHA-256 hashed, single-use; regenerate requires current TOTP code - Login accepts BackupCode field alongside TOTPCode - Session management: list, revoke-by-id (ownership-checked), revoke-all-except-current; password change revokes other sessions - New routes: POST /auth/2fa/backup-codes/regenerate, GET /auth/sessions, DELETE /auth/sessions, DELETE /auth/sessions/:id - fakeRepo extended with backup code + session management stubs - Tests cover: code format/count, hash storage, regen invalidates old, login with valid/used code, session list isolation, revoke ownership, password change session revocation
315 lines
8.7 KiB
Go
315 lines
8.7 KiB
Go
package auth
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
|
|
"marktvogt.de/backend/internal/domain/user"
|
|
"marktvogt.de/backend/internal/pkg/apierror"
|
|
"marktvogt.de/backend/internal/pkg/validate"
|
|
)
|
|
|
|
type Handler struct {
|
|
service *Service
|
|
userRepo user.Repository
|
|
}
|
|
|
|
func NewHandler(service *Service, userRepo user.Repository) *Handler {
|
|
return &Handler{service: service, userRepo: userRepo}
|
|
}
|
|
|
|
func (h *Handler) Register(c *gin.Context) {
|
|
var req RegisterRequest
|
|
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
data, err := h.service.Register(c.Request.Context(), req, c.ClientIP(), c.GetHeader("User-Agent"))
|
|
if err != nil {
|
|
if errors.Is(err, user.ErrEmailAlreadyTaken) {
|
|
apiErr := apierror.Conflict("email already registered")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
apiErr := apierror.Internal("registration failed")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusCreated, AuthResponse{Data: data})
|
|
}
|
|
|
|
func (h *Handler) Login(c *gin.Context) {
|
|
var req LoginRequest
|
|
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
data, err := h.service.Login(c.Request.Context(), req, c.ClientIP(), c.GetHeader("User-Agent"))
|
|
if err != nil {
|
|
msg := err.Error()
|
|
if msg == "invalid credentials" {
|
|
apiErr := apierror.Unauthorized("invalid email or password")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
if msg == "2fa_required" {
|
|
apiErr := apierror.BadRequest("2fa_required", "2FA code required")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
apiErr := apierror.Internal("login failed")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, AuthResponse{Data: data})
|
|
}
|
|
|
|
func (h *Handler) Logout(c *gin.Context) {
|
|
sessionID := GetSessionID(c)
|
|
if sessionID == uuid.Nil {
|
|
apiErr := apierror.Unauthorized("not authenticated")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
if err := h.service.Logout(c.Request.Context(), sessionID); err != nil {
|
|
apiErr := apierror.Internal("logout failed")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "logged out"}})
|
|
}
|
|
|
|
func (h *Handler) Refresh(c *gin.Context) {
|
|
refreshToken := extractRefreshToken(c)
|
|
if refreshToken == "" {
|
|
apiErr := apierror.Unauthorized("refresh token required")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
data, err := h.service.RefreshToken(c.Request.Context(), refreshToken, c.ClientIP(), c.GetHeader("User-Agent"))
|
|
if err != nil {
|
|
var apiErr *apierror.Error
|
|
if errors.As(err, &apiErr) {
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
apiErr = apierror.Unauthorized("invalid or expired session")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, AuthResponse{Data: data})
|
|
}
|
|
|
|
func (h *Handler) SetupTOTP(c *gin.Context) {
|
|
userID := GetUserID(c)
|
|
u, err := h.userRepo.GetByID(c.Request.Context(), userID)
|
|
if err != nil {
|
|
apiErr := apierror.Internal("failed to get user")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
data, err := h.service.SetupTOTP(c.Request.Context(), userID, u.Email)
|
|
if err != nil {
|
|
apiErr := apierror.Internal("failed to setup 2FA")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, TOTPSetupResponse{Data: data})
|
|
}
|
|
|
|
func (h *Handler) VerifyTOTP(c *gin.Context) {
|
|
var req TOTPVerifyRequest
|
|
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
userID := GetUserID(c)
|
|
if err := h.service.VerifyTOTPSetup(c.Request.Context(), userID, req.Code); err != nil {
|
|
apiErr := apierror.BadRequest("invalid_code", err.Error())
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "2FA enabled"}})
|
|
}
|
|
|
|
func (h *Handler) ChangePassword(c *gin.Context) {
|
|
var req ChangePasswordRequest
|
|
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
userID := GetUserID(c)
|
|
sessionID := GetSessionID(c)
|
|
if err := h.service.ChangePassword(c.Request.Context(), userID, sessionID, req); err != nil {
|
|
msg := err.Error()
|
|
if msg == "current password required" || msg == "current password incorrect" {
|
|
apiErr := apierror.BadRequest("invalid_password", msg)
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
apiErr := apierror.Internal("failed to change password")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "password updated"}})
|
|
}
|
|
|
|
func (h *Handler) RegenerateBackupCodes(c *gin.Context) {
|
|
var req BackupCodeRegenerateRequest
|
|
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
userID := GetUserID(c)
|
|
if err := h.service.validateTOTP(c.Request.Context(), userID, req.TOTPCode); err != nil {
|
|
apiErr := apierror.BadRequest("invalid_code", "invalid TOTP code")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
codes, err := h.service.GenerateBackupCodes(c.Request.Context(), userID)
|
|
if err != nil {
|
|
apiErr := apierror.Internal("failed to generate backup codes")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, BackupCodesResponse{Data: BackupCodesData{Codes: codes}})
|
|
}
|
|
|
|
func (h *Handler) ListSessions(c *gin.Context) {
|
|
userID := GetUserID(c)
|
|
currentSessionID := GetSessionID(c)
|
|
|
|
sessions, err := h.service.ListSessions(c.Request.Context(), userID)
|
|
if err != nil {
|
|
apiErr := apierror.Internal("failed to list sessions")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
infos := make([]SessionInfo, 0, len(sessions))
|
|
for _, s := range sessions {
|
|
infos = append(infos, SessionInfo{
|
|
ID: s.ID,
|
|
IPAddress: s.IPAddress,
|
|
UserAgent: s.UserAgent,
|
|
LastUsedAt: s.LastUsedAt,
|
|
CreatedAt: s.CreatedAt,
|
|
Current: s.ID == currentSessionID,
|
|
})
|
|
}
|
|
c.JSON(http.StatusOK, SessionListResponse{Data: infos})
|
|
}
|
|
|
|
func (h *Handler) RevokeSession(c *gin.Context) {
|
|
userID := GetUserID(c)
|
|
rawID := c.Param("id")
|
|
|
|
sessionID, err := uuid.Parse(rawID)
|
|
if err != nil {
|
|
apiErr := apierror.BadRequest("invalid_id", "invalid session id")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
if err := h.service.RevokeSessionByID(c.Request.Context(), userID, sessionID); err != nil {
|
|
msg := err.Error()
|
|
if msg == "session not found" {
|
|
apiErr := apierror.NotFound("session not found")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
if msg == "not your session" {
|
|
apiErr := apierror.Forbidden("not your session")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
apiErr := apierror.Internal("failed to revoke session")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "session revoked"}})
|
|
}
|
|
|
|
func (h *Handler) RevokeOtherSessions(c *gin.Context) {
|
|
userID := GetUserID(c)
|
|
sessionID := GetSessionID(c)
|
|
|
|
if err := h.service.RevokeOtherSessions(c.Request.Context(), userID, sessionID); err != nil {
|
|
apiErr := apierror.Internal("failed to revoke sessions")
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "other sessions revoked"}})
|
|
}
|
|
|
|
func (h *Handler) DisableTOTP(c *gin.Context) {
|
|
var req TOTPVerifyRequest
|
|
if apiErr := validate.BindJSON(c, &req); apiErr != nil {
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
userID := GetUserID(c)
|
|
if err := h.service.DisableTOTP(c.Request.Context(), userID, req.Code); err != nil {
|
|
apiErr := apierror.BadRequest("invalid_code", err.Error())
|
|
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, MessageResponse{Data: MessageData{Message: "2FA disabled"}})
|
|
}
|
|
|
|
// extractRefreshToken reads the refresh token from X-Refresh-Token or (legacy) X-Session-Token.
|
|
func extractRefreshToken(c *gin.Context) string {
|
|
if t := c.GetHeader("X-Refresh-Token"); t != "" {
|
|
return t
|
|
}
|
|
return c.GetHeader("X-Session-Token")
|
|
}
|
|
|
|
func GetUserID(c *gin.Context) uuid.UUID {
|
|
v, exists := c.Get("user_id")
|
|
if !exists {
|
|
return uuid.Nil
|
|
}
|
|
id, ok := v.(uuid.UUID)
|
|
if !ok {
|
|
return uuid.Nil
|
|
}
|
|
return id
|
|
}
|
|
|
|
func GetSessionID(c *gin.Context) uuid.UUID {
|
|
v, exists := c.Get("session_id")
|
|
if !exists {
|
|
return uuid.Nil
|
|
}
|
|
id, ok := v.(uuid.UUID)
|
|
if !ok {
|
|
return uuid.Nil
|
|
}
|
|
return id
|
|
}
|