338 lines
9.8 KiB
Go
338 lines
9.8 KiB
Go
package middleware_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"zee/internal/resource/user"
|
|
"zee/internal/service"
|
|
"zee/internal/transport/http/middleware"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
)
|
|
|
|
// MockAuthService is a mock implementation of AuthService
|
|
type MockAuthService struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockAuthService) Register(ctx context.Context, req service.RegisterRequest) (*user.User, error) {
|
|
args := m.Called(ctx, req)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*user.User), args.Error(1)
|
|
}
|
|
|
|
func (m *MockAuthService) Login(ctx context.Context, username, password string) (string, string, error) {
|
|
args := m.Called(ctx, username, password)
|
|
return args.String(0), args.String(1), args.Error(2)
|
|
}
|
|
|
|
func (m *MockAuthService) RefreshToken(refreshToken string) (string, string, error) {
|
|
args := m.Called(refreshToken)
|
|
return args.String(0), args.String(1), args.Error(2)
|
|
}
|
|
|
|
func (m *MockAuthService) ValidateToken(tokenString string) (*service.Claims, error) {
|
|
args := m.Called(tokenString)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*service.Claims), args.Error(1)
|
|
}
|
|
|
|
func TestNewAuthMiddleware(t *testing.T) {
|
|
mockAuthSvc := new(MockAuthService)
|
|
middleware := middleware.NewAuthMiddleware(mockAuthSvc)
|
|
assert.NotNil(t, middleware)
|
|
}
|
|
|
|
func TestAuthenticate_Success(t *testing.T) {
|
|
// Setup
|
|
mockAuthSvc := new(MockAuthService)
|
|
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
|
|
|
|
// Mock token validation
|
|
claims := &service.Claims{
|
|
UserID: "user123",
|
|
Username: "testuser",
|
|
Roles: []string{"user"},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
|
|
},
|
|
}
|
|
mockAuthSvc.On("ValidateToken", "valid.token.here").Return(claims, nil)
|
|
|
|
// Create test router
|
|
r := gin.New()
|
|
r.GET("/protected", authMiddleware.Authenticate(), func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
})
|
|
|
|
// Create test request with valid token
|
|
req, _ := http.NewRequest("GET", "/protected", nil)
|
|
req.Header.Set("Authorization", "Bearer valid.token.here")
|
|
|
|
// Execute request
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
// Assert
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
mockAuthSvc.AssertExpectations(t)
|
|
}
|
|
|
|
func TestAuthenticate_NoAuthHeader(t *testing.T) {
|
|
mockAuthSvc := new(MockAuthService)
|
|
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
|
|
|
|
r := gin.New()
|
|
r.GET("/protected", authMiddleware.Authenticate())
|
|
|
|
req, _ := http.NewRequest("GET", "/protected", nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
assert.Contains(t, w.Body.String(), "Authorization header is required")
|
|
}
|
|
|
|
func TestAuthenticate_InvalidTokenFormat(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
authHeader string
|
|
expectedError string
|
|
shouldCallValidate bool
|
|
}{
|
|
{
|
|
name: "no bearer",
|
|
authHeader: "invalid",
|
|
expectedError: "Invalid authorization header format",
|
|
shouldCallValidate: false,
|
|
},
|
|
{
|
|
name: "empty token",
|
|
authHeader: "Bearer ",
|
|
expectedError: "Token cannot be empty",
|
|
shouldCallValidate: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockAuthSvc := new(MockAuthService)
|
|
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
|
|
|
|
// Create a test server with the middleware and a simple handler
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// This handler should not be called for invalid token formats
|
|
t.Error("Handler should not be called for invalid token formats")
|
|
w.WriteHeader(http.StatusOK)
|
|
if _, err := w.Write([]byte("should not be called")); err != nil {
|
|
t.Errorf("failed to write response in unexpected handler call: %v", err)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create a request with the test auth header
|
|
req, _ := http.NewRequest("GET", server.URL, nil)
|
|
req.Header.Set("Authorization", tt.authHeader)
|
|
|
|
// Create a response recorder
|
|
w := httptest.NewRecorder()
|
|
|
|
// Create a Gin context with the request and response
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = req
|
|
|
|
// Call the middleware directly
|
|
authMiddleware.Authenticate()(c)
|
|
|
|
// Check if the response has the expected status code and error message
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, w.Code)
|
|
}
|
|
|
|
var resp map[string]string
|
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("Failed to unmarshal response: %v", err)
|
|
}
|
|
|
|
if resp["error"] != tt.expectedError {
|
|
t.Errorf("Expected error message '%s', got '%s'", tt.expectedError, resp["error"])
|
|
}
|
|
|
|
// Verify that ValidateToken was not called when it shouldn't be
|
|
if !tt.shouldCallValidate {
|
|
mockAuthSvc.AssertNotCalled(t, "ValidateToken")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthenticate_InvalidToken(t *testing.T) {
|
|
mockAuthSvc := new(MockAuthService)
|
|
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
|
|
|
|
// Mock token validation to fail
|
|
errInvalidToken := errors.New("invalid token")
|
|
mockAuthSvc.On("ValidateToken", "invalid.token").Return((*service.Claims)(nil), errInvalidToken)
|
|
|
|
r := gin.New()
|
|
r.GET("/protected", authMiddleware.Authenticate())
|
|
|
|
req, _ := http.NewRequest("GET", "/protected", nil)
|
|
req.Header.Set("Authorization", "Bearer invalid.token")
|
|
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
assert.Contains(t, w.Body.String(), "Invalid or expired token")
|
|
mockAuthSvc.AssertExpectations(t)
|
|
}
|
|
|
|
func TestRequireRole_Success(t *testing.T) {
|
|
mockAuthSvc := new(MockAuthService)
|
|
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
|
|
|
|
// Create a test router with role-based auth
|
|
r := gin.New()
|
|
// Add a route that requires admin role
|
|
r.GET("/admin", authMiddleware.Authenticate(), authMiddleware.RequireRole("admin"), func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"status": "admin access granted"})
|
|
})
|
|
|
|
// Create a request with a valid token that has admin role
|
|
req, _ := http.NewRequest("GET", "/admin", nil)
|
|
req.Header.Set("Authorization", "Bearer admin.token")
|
|
|
|
// Mock the token validation to return a user with admin role
|
|
claims := &service.Claims{
|
|
UserID: "admin123",
|
|
Username: "adminuser",
|
|
Roles: []string{"admin"},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
|
|
},
|
|
}
|
|
mockAuthSvc.On("ValidateToken", "admin.token").Return(claims, nil)
|
|
|
|
// Execute request
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
// Assert
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "admin access granted")
|
|
mockAuthSvc.AssertExpectations(t)
|
|
}
|
|
|
|
func TestRequireRole_Unauthenticated(t *testing.T) {
|
|
mockAuthSvc := new(MockAuthService)
|
|
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
|
|
|
|
r := gin.New()
|
|
r.GET("/admin", authMiddleware.RequireRole("admin"))
|
|
|
|
req, _ := http.NewRequest("GET", "/admin", nil)
|
|
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
assert.Contains(t, w.Body.String(), "Authentication required")
|
|
}
|
|
|
|
func TestRequireRole_Forbidden(t *testing.T) {
|
|
mockAuthSvc := new(MockAuthService)
|
|
authMiddleware := middleware.NewAuthMiddleware(mockAuthSvc)
|
|
|
|
r := gin.New()
|
|
r.GET("/admin", authMiddleware.Authenticate(), authMiddleware.RequireRole("admin"))
|
|
|
|
// Create a request with a valid token that doesn't have admin role
|
|
req, _ := http.NewRequest("GET", "/admin", nil)
|
|
req.Header.Set("Authorization", "Bearer user.token")
|
|
|
|
// Mock the token validation to return a user without admin role
|
|
claims := &service.Claims{
|
|
UserID: "user123",
|
|
Username: "regularuser",
|
|
Roles: []string{"user"},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
|
|
},
|
|
}
|
|
mockAuthSvc.On("ValidateToken", "user.token").Return(claims, nil)
|
|
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusForbidden, w.Code)
|
|
assert.Contains(t, w.Body.String(), "Require one of these roles: [admin]")
|
|
mockAuthSvc.AssertExpectations(t)
|
|
}
|
|
|
|
func TestGetUserFromContext(t *testing.T) {
|
|
// Setup test context with user
|
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
claims := &service.Claims{
|
|
UserID: "user123",
|
|
Username: "testuser",
|
|
Roles: []string{"user"},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
|
|
},
|
|
}
|
|
c.Set(middleware.ContextKeyUser, claims)
|
|
|
|
// Test GetUserFromContext
|
|
user, err := middleware.GetUserFromContext(c)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "user123", user.UserID)
|
|
assert.Equal(t, "testuser", user.Username)
|
|
}
|
|
|
|
func TestGetUserFromContext_NotFound(t *testing.T) {
|
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
_, err := middleware.GetUserFromContext(c)
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestGetUserIDFromContext(t *testing.T) {
|
|
// Setup test context with user
|
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
claims := &service.Claims{
|
|
UserID: "user123",
|
|
Username: "testuser",
|
|
Roles: []string{"user"},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
|
|
},
|
|
}
|
|
c.Set(middleware.ContextKeyUser, claims)
|
|
|
|
// Test GetUserIDFromContext
|
|
userID, err := middleware.GetUserIDFromContext(c)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "user123", userID)
|
|
}
|
|
|
|
func TestGetUserIDFromContext_InvalidType(t *testing.T) {
|
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
c.Set(middleware.ContextKeyUser, "not a claims object")
|
|
|
|
_, err := middleware.GetUserIDFromContext(c)
|
|
assert.Error(t, err)
|
|
}
|