2025-06-05 21:21:28 +07:00

127 lines
3.1 KiB
Go

package middleware
import (
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"starter-kit/internal/service"
)
const (
// ContextKeyUser là key dùng để lưu thông tin user trong context
ContextKeyUser = "user"
)
// AuthMiddleware xác thực JWT token
type AuthMiddleware struct {
authSvc service.AuthService
}
// NewAuthMiddleware tạo mới AuthMiddleware
func NewAuthMiddleware(authSvc service.AuthService) *AuthMiddleware {
return &AuthMiddleware{
authSvc: authSvc,
}
}
// Authenticate xác thực JWT token
func (m *AuthMiddleware) Authenticate() gin.HandlerFunc {
return func(c *gin.Context) {
// Lấy token từ header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"})
return
}
// Kiểm tra định dạng token
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid authorization header format"})
return
}
tokenString := parts[1]
// Check for empty token
if tokenString == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token cannot be empty"})
return
}
// Xác thực token
claims, err := m.authSvc.ValidateToken(tokenString)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
return
}
// Lưu thông tin user vào context
c.Set(ContextKeyUser, claims)
// Tiếp tục xử lý request
c.Next()
}
}
// RequireRole kiểm tra user có vai trò được yêu cầu không
func (m *AuthMiddleware) RequireRole(roles ...string) gin.HandlerFunc {
return func(c *gin.Context) {
// Lấy thông tin user từ context
userValue, exists := c.Get(ContextKeyUser)
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
return
}
// Ép kiểu về Claims
claims, ok := userValue.(*service.Claims)
if !ok {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Invalid user data"})
return
}
// Kiểm tra vai trò
for _, role := range roles {
for _, userRole := range claims.Roles {
if userRole == role {
// Có quyền, tiếp tục xử lý
c.Next()
return
}
}
}
// Không có quyền
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": fmt.Sprintf("Require one of these roles: %v", roles),
})
}
}
// GetUserFromContext lấy thông tin user từ context
func GetUserFromContext(c *gin.Context) (*service.Claims, error) {
userValue, exists := c.Get(ContextKeyUser)
if !exists {
return nil, fmt.Errorf("user not found in context")
}
claims, ok := userValue.(*service.Claims)
if !ok {
return nil, fmt.Errorf("invalid user data in context")
}
return claims, nil
}
// GetUserIDFromContext lấy user ID từ context
func GetUserIDFromContext(c *gin.Context) (string, error) {
claims, err := GetUserFromContext(c)
if err != nil {
return "", err
}
return claims.UserID, nil
}