127 lines
3.1 KiB
Go
127 lines
3.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"starter-kit/internal/service"
|
|
)
|
|
|
|
const (
|
|
// ContextKeyUser là key dùng để lưu thông tin user trong context
|
|
ContextKeyUser = "user"
|
|
)
|
|
|
|
// AuthMiddleware xác thực JWT token
|
|
type AuthMiddleware struct {
|
|
authSvc service.AuthService
|
|
}
|
|
|
|
// NewAuthMiddleware tạo mới AuthMiddleware
|
|
func NewAuthMiddleware(authSvc service.AuthService) *AuthMiddleware {
|
|
return &AuthMiddleware{
|
|
authSvc: authSvc,
|
|
}
|
|
}
|
|
|
|
// Authenticate xác thực JWT token
|
|
func (m *AuthMiddleware) Authenticate() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Lấy token từ header
|
|
authHeader := c.GetHeader("Authorization")
|
|
if authHeader == "" {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"})
|
|
return
|
|
}
|
|
|
|
// Kiểm tra định dạng token
|
|
parts := strings.Split(authHeader, " ")
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid authorization header format"})
|
|
return
|
|
}
|
|
|
|
tokenString := parts[1]
|
|
|
|
// Check for empty token
|
|
if tokenString == "" {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token cannot be empty"})
|
|
return
|
|
}
|
|
|
|
// Xác thực token
|
|
claims, err := m.authSvc.ValidateToken(tokenString)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
|
|
return
|
|
}
|
|
|
|
// Lưu thông tin user vào context
|
|
c.Set(ContextKeyUser, claims)
|
|
|
|
// Tiếp tục xử lý request
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RequireRole kiểm tra user có vai trò được yêu cầu không
|
|
func (m *AuthMiddleware) RequireRole(roles ...string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Lấy thông tin user từ context
|
|
userValue, exists := c.Get(ContextKeyUser)
|
|
if !exists {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
|
|
return
|
|
}
|
|
|
|
// Ép kiểu về Claims
|
|
claims, ok := userValue.(*service.Claims)
|
|
if !ok {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Invalid user data"})
|
|
return
|
|
}
|
|
|
|
// Kiểm tra vai trò
|
|
for _, role := range roles {
|
|
for _, userRole := range claims.Roles {
|
|
if userRole == role {
|
|
// Có quyền, tiếp tục xử lý
|
|
c.Next()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Không có quyền
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
|
"error": fmt.Sprintf("Require one of these roles: %v", roles),
|
|
})
|
|
}
|
|
}
|
|
|
|
// GetUserFromContext lấy thông tin user từ context
|
|
func GetUserFromContext(c *gin.Context) (*service.Claims, error) {
|
|
userValue, exists := c.Get(ContextKeyUser)
|
|
if !exists {
|
|
return nil, fmt.Errorf("user not found in context")
|
|
}
|
|
|
|
claims, ok := userValue.(*service.Claims)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid user data in context")
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// GetUserIDFromContext lấy user ID từ context
|
|
func GetUserIDFromContext(c *gin.Context) (string, error) {
|
|
claims, err := GetUserFromContext(c)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return claims.UserID, nil
|
|
}
|