339 lines
8.0 KiB
Go
339 lines
8.0 KiB
Go
package service_test
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"zee/internal/resource/role"
|
|
"zee/internal/resource/user"
|
|
"zee/internal/service"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// MockUserRepo là mock cho user.Repository
|
|
type MockUserRepo struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockUserRepo) Create(ctx context.Context, user *user.User) error {
|
|
args := m.Called(ctx, user)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockUserRepo) GetByID(ctx context.Context, id string) (*user.User, error) {
|
|
args := m.Called(ctx, id)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*user.User), args.Error(1)
|
|
}
|
|
|
|
func (m *MockUserRepo) GetByUsername(ctx context.Context, username string) (*user.User, error) {
|
|
args := m.Called(ctx, username)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*user.User), args.Error(1)
|
|
}
|
|
|
|
func (m *MockUserRepo) GetByEmail(ctx context.Context, email string) (*user.User, error) {
|
|
args := m.Called(ctx, email)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*user.User), args.Error(1)
|
|
}
|
|
|
|
func (m *MockUserRepo) UpdateLastLogin(ctx context.Context, id string) error {
|
|
args := m.Called(ctx, id)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockUserRepo) AddRole(ctx context.Context, userID string, roleID int) error {
|
|
args := m.Called(ctx, userID, roleID)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockUserRepo) RemoveRole(ctx context.Context, userID string, roleID int) error {
|
|
args := m.Called(ctx, userID, roleID)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockUserRepo) HasRole(ctx context.Context, userID string, roleID int) (bool, error) {
|
|
args := m.Called(ctx, userID, roleID)
|
|
return args.Bool(0), args.Error(1)
|
|
}
|
|
|
|
func (m *MockUserRepo) Update(ctx context.Context, user *user.User) error {
|
|
args := m.Called(ctx, user)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockUserRepo) Delete(ctx context.Context, id string) error {
|
|
args := m.Called(ctx, id)
|
|
return args.Error(0)
|
|
}
|
|
|
|
// MockRoleRepo là mock cho role.Repository
|
|
type MockRoleRepo struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockRoleRepo) GetByName(ctx context.Context, name string) (*role.Role, error) {
|
|
args := m.Called(ctx, name)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*role.Role), args.Error(1)
|
|
}
|
|
|
|
func (m *MockRoleRepo) Create(ctx context.Context, role *role.Role) error {
|
|
args := m.Called(ctx, role)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockRoleRepo) GetByID(ctx context.Context, id int) (*role.Role, error) {
|
|
args := m.Called(ctx, id)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(*role.Role), args.Error(1)
|
|
}
|
|
|
|
func (m *MockRoleRepo) List(ctx context.Context) ([]*role.Role, error) {
|
|
args := m.Called(ctx)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).([]*role.Role), args.Error(1)
|
|
}
|
|
|
|
func (m *MockRoleRepo) Update(ctx context.Context, role *role.Role) error {
|
|
args := m.Called(ctx, role)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockRoleRepo) Delete(ctx context.Context, id int) error {
|
|
args := m.Called(ctx, id)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func TestAuthService_Register(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setup func(*MockUserRepo, *MockRoleRepo)
|
|
req service.RegisterRequest
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "successful registration",
|
|
setup: func(mu *MockUserRepo, mr *MockRoleRepo) {
|
|
// Mock GetByUsername - user not exists
|
|
mu.On("GetByUsername", mock.Anything, "testuser").
|
|
Return((*user.User)(nil), nil)
|
|
|
|
// Mock GetByEmail - email not exists
|
|
mu.On("GetByEmail", mock.Anything, "test@example.com").
|
|
Return((*user.User)(nil), nil)
|
|
|
|
// Mock GetByName - role exists
|
|
mr.On("GetByName", mock.Anything, role.User).
|
|
Return(&role.Role{ID: 1, Name: role.User}, nil)
|
|
|
|
// Mock AddRole
|
|
mu.On("AddRole", mock.Anything, mock.Anything, mock.Anything).
|
|
Return(nil)
|
|
|
|
// Mock Create - success
|
|
mu.On("Create", mock.Anything, mock.AnythingOfType("*user.User")).
|
|
Return(nil)
|
|
|
|
// Mock GetByID - return created user
|
|
mu.On("GetByID", mock.Anything, mock.Anything).
|
|
Return(&user.User{
|
|
ID: "123",
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
FullName: "Test User",
|
|
}, nil)
|
|
|
|
},
|
|
req: service.RegisterRequest{
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
Password: "password123",
|
|
FullName: "Test User",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Setup mocks
|
|
mockUserRepo := new(MockUserRepo)
|
|
mockRoleRepo := new(MockRoleRepo)
|
|
tt.setup(mockUserRepo, mockRoleRepo)
|
|
|
|
// Create service with mocks
|
|
svc := service.NewAuthService(
|
|
mockUserRepo,
|
|
mockRoleRepo,
|
|
"test-secret",
|
|
time.Hour,
|
|
)
|
|
|
|
// Call method
|
|
_, err := svc.Register(context.Background(), tt.req)
|
|
|
|
// Assertions
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Verify all expectations were met
|
|
mockUserRepo.AssertExpectations(t)
|
|
mockRoleRepo.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthService_Login(t *testing.T) {
|
|
// Create a test user with hashed password
|
|
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
|
testUser := &user.User{
|
|
ID: "123",
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
PasswordHash: string(hashedPassword),
|
|
IsActive: true,
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
setup func(*MockUserRepo)
|
|
username string
|
|
password string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "successful login",
|
|
setup: func(mu *MockUserRepo) {
|
|
// Mock GetByUsername - user exists
|
|
mu.On("GetByUsername", mock.Anything, "testuser").
|
|
Return(testUser, nil)
|
|
|
|
// Mock UpdateLastLogin
|
|
mu.On("UpdateLastLogin", mock.Anything, "123").
|
|
Return(nil)
|
|
},
|
|
username: "testuser",
|
|
password: "password123",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid password",
|
|
setup: func(mu *MockUserRepo) {
|
|
// Mock GetByUsername - user exists
|
|
mu.On("GetByUsername", mock.Anything, "testuser").
|
|
Return(testUser, nil)
|
|
},
|
|
username: "testuser",
|
|
password: "wrongpassword",
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Setup mocks
|
|
mockUserRepo := new(MockUserRepo)
|
|
tt.setup(mockUserRepo)
|
|
|
|
// Create service with mocks
|
|
svc := service.NewAuthService(
|
|
mockUserRepo,
|
|
nil, // Role repo not needed for login
|
|
"test-secret",
|
|
time.Hour,
|
|
)
|
|
|
|
// Call method
|
|
_, _, err := svc.Login(context.Background(), tt.username, tt.password)
|
|
|
|
// Assertions
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Verify all expectations were met
|
|
mockUserRepo.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthService_ValidateToken(t *testing.T) {
|
|
// Create a test service
|
|
svc := service.NewAuthService(
|
|
nil, // Repos not needed for this test
|
|
nil,
|
|
"test-secret",
|
|
time.Hour,
|
|
)
|
|
|
|
// Create a valid token
|
|
claims := &service.Claims{
|
|
UserID: "123",
|
|
Username: "testuser",
|
|
Roles: []string{"user"},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
},
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tokenString, _ := token.SignedString([]byte("test-secret"))
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
wantClaims *service.Claims
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "valid token",
|
|
token: tokenString,
|
|
wantClaims: claims,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid signature",
|
|
token: tokenString[:len(tokenString)-2] + "xx", // Corrupt the signature
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
claims, err := svc.ValidateToken(tt.token)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.wantClaims.UserID, claims.UserID)
|
|
assert.Equal(t, tt.wantClaims.Username, claims.Username)
|
|
}
|
|
})
|
|
}
|
|
}
|