ulflow_phattt2901 23ec4d7bd2
Some checks failed
CI Pipeline / Lint (push) Failing after 5m30s
CI Pipeline / Test (push) Has been skipped
CI Pipeline / Security Scan (push) Successful in 6m6s
CI Pipeline / Build (push) Has been skipped
CI Pipeline / Notification (push) Successful in 2s
feat: implement auth middleware and unit tests with JWT validation
2025-06-03 21:31:18 +07:00

335 lines
9.7 KiB
Go

package middleware_test
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"starter-kit/internal/domain/user"
"starter-kit/internal/service"
"starter-kit/internal/transport/http/middleware"
)
// MockAuthService is a mock implementation of AuthService
type MockAuthService struct {
mock.Mock
}
func (m *MockAuthService) Register(ctx context.Context, req service.RegisterRequest) (*user.User, error) {
args := m.Called(ctx, req)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*user.User), args.Error(1)
}
func (m *MockAuthService) Login(ctx context.Context, username, password string) (string, string, error) {
args := m.Called(ctx, username, password)
return args.String(0), args.String(1), args.Error(2)
}
func (m *MockAuthService) RefreshToken(refreshToken string) (string, string, error) {
args := m.Called(refreshToken)
return args.String(0), args.String(1), args.Error(2)
}
func (m *MockAuthService) ValidateToken(tokenString string) (*service.Claims, error) {
args := m.Called(tokenString)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*service.Claims), args.Error(1)
}
func TestNewAuthMiddleware(t *testing.T) {
mockAuthSvc := new(MockAuthService)
middleware := middleware.NewAuthMiddleware(mockAuthSvc)
assert.NotNil(t, middleware)
}
func TestAuthenticate_Success(t *testing.T) {
// Setup
mockAuthSvc := new(MockAuthService)
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
// Mock token validation
claims := &service.Claims{
UserID: "user123",
Username: "testuser",
Roles: []string{"user"},
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
},
}
mockAuthSvc.On("ValidateToken", "valid.token.here").Return(claims, nil)
// Create test router
r := gin.New()
r.GET("/protected", authMiddleware.Authenticate(), func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Create test request with valid token
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", "Bearer valid.token.here")
// Execute request
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Assert
assert.Equal(t, http.StatusOK, w.Code)
mockAuthSvc.AssertExpectations(t)
}
func TestAuthenticate_NoAuthHeader(t *testing.T) {
mockAuthSvc := new(MockAuthService)
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
r := gin.New()
r.GET("/protected", authMiddleware.Authenticate())
req, _ := http.NewRequest("GET", "/protected", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "Authorization header is required")
}
func TestAuthenticate_InvalidTokenFormat(t *testing.T) {
tests := []struct {
name string
authHeader string
expectedError string
shouldCallValidate bool
}{
{
name: "no bearer",
authHeader: "invalid",
expectedError: "Invalid authorization header format",
shouldCallValidate: false,
},
{
name: "empty token",
authHeader: "Bearer ",
expectedError: "Token cannot be empty",
shouldCallValidate: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockAuthSvc := new(MockAuthService)
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
// Create a test server with the middleware and a simple handler
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// This handler should not be called for invalid token formats
t.Error("Handler should not be called for invalid token formats")
w.WriteHeader(http.StatusOK)
w.Write([]byte("should not be called"))
}))
defer server.Close()
// Create a request with the test auth header
req, _ := http.NewRequest("GET", server.URL, nil)
req.Header.Set("Authorization", tt.authHeader)
// Create a response recorder
w := httptest.NewRecorder()
// Create a Gin context with the request and response
c, _ := gin.CreateTestContext(w)
c.Request = req
// Call the middleware directly
authMiddleware.Authenticate()(c)
// Check if the response has the expected status code and error message
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, w.Code)
}
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if resp["error"] != tt.expectedError {
t.Errorf("Expected error message '%s', got '%s'", tt.expectedError, resp["error"])
}
// Verify that ValidateToken was not called when it shouldn't be
if !tt.shouldCallValidate {
mockAuthSvc.AssertNotCalled(t, "ValidateToken")
}
})
}
}
func TestAuthenticate_InvalidToken(t *testing.T) {
mockAuthSvc := new(MockAuthService)
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
// Mock token validation to fail
errInvalidToken := errors.New("invalid token")
mockAuthSvc.On("ValidateToken", "invalid.token").Return((*service.Claims)(nil), errInvalidToken)
r := gin.New()
r.GET("/protected", authMiddleware.Authenticate())
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", "Bearer invalid.token")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "Invalid or expired token")
mockAuthSvc.AssertExpectations(t)
}
func TestRequireRole_Success(t *testing.T) {
mockAuthSvc := new(MockAuthService)
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
// Create a test router with role-based auth
r := gin.New()
// Add a route that requires admin role
r.GET("/admin", authMiddleware.Authenticate(), authMiddleware.RequireRole("admin"), func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "admin access granted"})
})
// Create a request with a valid token that has admin role
req, _ := http.NewRequest("GET", "/admin", nil)
req.Header.Set("Authorization", "Bearer admin.token")
// Mock the token validation to return a user with admin role
claims := &service.Claims{
UserID: "admin123",
Username: "adminuser",
Roles: []string{"admin"},
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
},
}
mockAuthSvc.On("ValidateToken", "admin.token").Return(claims, nil)
// Execute request
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Assert
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "admin access granted")
mockAuthSvc.AssertExpectations(t)
}
func TestRequireRole_Unauthenticated(t *testing.T) {
mockAuthSvc := new(MockAuthService)
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
r := gin.New()
r.GET("/admin", authMiddleware.RequireRole("admin"))
req, _ := http.NewRequest("GET", "/admin", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "Authentication required")
}
func TestRequireRole_Forbidden(t *testing.T) {
mockAuthSvc := new(MockAuthService)
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
r := gin.New()
r.GET("/admin", authMiddleware.Authenticate(), authMiddleware.RequireRole("admin"))
// Create a request with a valid token that doesn't have admin role
req, _ := http.NewRequest("GET", "/admin", nil)
req.Header.Set("Authorization", "Bearer user.token")
// Mock the token validation to return a user without admin role
claims := &service.Claims{
UserID: "user123",
Username: "regularuser",
Roles: []string{"user"},
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
},
}
mockAuthSvc.On("ValidateToken", "user.token").Return(claims, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.Contains(t, w.Body.String(), "Require one of these roles: [admin]")
mockAuthSvc.AssertExpectations(t)
}
func TestGetUserFromContext(t *testing.T) {
// Setup test context with user
c, _ := gin.CreateTestContext(httptest.NewRecorder())
claims := &service.Claims{
UserID: "user123",
Username: "testuser",
Roles: []string{"user"},
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
},
}
c.Set(middleware.ContextKeyUser, claims)
// Test GetUserFromContext
user, err := middleware.GetUserFromContext(c)
assert.NoError(t, err)
assert.Equal(t, "user123", user.UserID)
assert.Equal(t, "testuser", user.Username)
}
func TestGetUserFromContext_NotFound(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
_, err := middleware.GetUserFromContext(c)
assert.Error(t, err)
}
func TestGetUserIDFromContext(t *testing.T) {
// Setup test context with user
c, _ := gin.CreateTestContext(httptest.NewRecorder())
claims := &service.Claims{
UserID: "user123",
Username: "testuser",
Roles: []string{"user"},
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
},
}
c.Set(middleware.ContextKeyUser, claims)
// Test GetUserIDFromContext
userID, err := middleware.GetUserIDFromContext(c)
assert.NoError(t, err)
assert.Equal(t, "user123", userID)
}
func TestGetUserIDFromContext_InvalidType(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Set(middleware.ContextKeyUser, "not a claims object")
_, err := middleware.GetUserIDFromContext(c)
assert.Error(t, err)
}