229 lines
5.5 KiB
Go
229 lines
5.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
"github.com/ajaxray/geek-life/model"
|
|
"github.com/ajaxray/geek-life/repository"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
|
ErrUserExists = errors.New("user already exists")
|
|
ErrTenantExists = errors.New("tenant already exists")
|
|
)
|
|
|
|
// AuthService handles authentication operations
|
|
type AuthService struct {
|
|
userRepo repository.UserRepository
|
|
tenantRepo repository.TenantRepository
|
|
sessionRepo repository.SessionRepository
|
|
sessionDuration int // in hours
|
|
}
|
|
|
|
// NewAuthService creates a new authentication service
|
|
func NewAuthService(userRepo repository.UserRepository, tenantRepo repository.TenantRepository, sessionRepo repository.SessionRepository, sessionDuration int) *AuthService {
|
|
return &AuthService{
|
|
userRepo: userRepo,
|
|
tenantRepo: tenantRepo,
|
|
sessionRepo: sessionRepo,
|
|
sessionDuration: sessionDuration,
|
|
}
|
|
}
|
|
|
|
// RegisterTenant creates a new tenant with an admin user
|
|
func (s *AuthService) RegisterTenant(tenantName, username, email, password string) (*model.UserContext, string, error) {
|
|
// Check if tenant already exists
|
|
if _, err := s.tenantRepo.GetByName(tenantName); err == nil {
|
|
return nil, "", ErrTenantExists
|
|
}
|
|
|
|
// Create tenant
|
|
tenant, err := s.tenantRepo.Create(tenantName)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Check if user already exists in this tenant
|
|
if _, err := s.userRepo.GetByUsername(tenant.ID, username); err == nil {
|
|
return nil, "", ErrUserExists
|
|
}
|
|
if _, err := s.userRepo.GetByEmail(tenant.ID, email); err == nil {
|
|
return nil, "", ErrUserExists
|
|
}
|
|
|
|
// Hash password
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Create user
|
|
user, err := s.userRepo.Create(tenant.ID, username, email, string(hashedPassword))
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Create session
|
|
token, err := s.generateToken()
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
expiresAt := time.Now().Add(time.Duration(s.sessionDuration) * time.Hour).Unix()
|
|
_, err = s.sessionRepo.Create(user.ID, token, expiresAt)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
ctx := &model.UserContext{
|
|
User: user,
|
|
Tenant: tenant,
|
|
}
|
|
|
|
return ctx, token, nil
|
|
}
|
|
|
|
// RegisterUser creates a new user in an existing tenant
|
|
func (s *AuthService) RegisterUser(tenantName, username, email, password string) (*model.UserContext, string, error) {
|
|
// Get tenant
|
|
tenant, err := s.tenantRepo.GetByName(tenantName)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Check if user already exists in this tenant
|
|
if _, err := s.userRepo.GetByUsername(tenant.ID, username); err == nil {
|
|
return nil, "", ErrUserExists
|
|
}
|
|
if _, err := s.userRepo.GetByEmail(tenant.ID, email); err == nil {
|
|
return nil, "", ErrUserExists
|
|
}
|
|
|
|
// Hash password
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Create user
|
|
user, err := s.userRepo.Create(tenant.ID, username, email, string(hashedPassword))
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Create session
|
|
token, err := s.generateToken()
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
expiresAt := time.Now().Add(time.Duration(s.sessionDuration) * time.Hour).Unix()
|
|
_, err = s.sessionRepo.Create(user.ID, token, expiresAt)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
ctx := &model.UserContext{
|
|
User: user,
|
|
Tenant: tenant,
|
|
}
|
|
|
|
return ctx, token, nil
|
|
}
|
|
|
|
// Login authenticates a user and returns a session token
|
|
func (s *AuthService) Login(tenantName, username, password string) (*model.UserContext, string, error) {
|
|
// Get tenant
|
|
tenant, err := s.tenantRepo.GetByName(tenantName)
|
|
if err != nil {
|
|
return nil, "", ErrInvalidCredentials
|
|
}
|
|
|
|
// Get user by username or email
|
|
var user *model.User
|
|
user, err = s.userRepo.GetByUsername(tenant.ID, username)
|
|
if err != nil {
|
|
// Try by email
|
|
user, err = s.userRepo.GetByEmail(tenant.ID, username)
|
|
if err != nil {
|
|
return nil, "", ErrInvalidCredentials
|
|
}
|
|
}
|
|
|
|
// Verify password
|
|
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
|
if err != nil {
|
|
return nil, "", ErrInvalidCredentials
|
|
}
|
|
|
|
// Create session
|
|
token, err := s.generateToken()
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
expiresAt := time.Now().Add(time.Duration(s.sessionDuration) * time.Hour).Unix()
|
|
_, err = s.sessionRepo.Create(user.ID, token, expiresAt)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
ctx := &model.UserContext{
|
|
User: user,
|
|
Tenant: tenant,
|
|
}
|
|
|
|
return ctx, token, nil
|
|
}
|
|
|
|
// ValidateSession validates a session token and returns user context
|
|
func (s *AuthService) ValidateSession(token string) (*model.UserContext, error) {
|
|
// Get session
|
|
session, err := s.sessionRepo.GetByToken(token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get user
|
|
user, err := s.userRepo.GetByID(session.UserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get tenant
|
|
tenant, err := s.tenantRepo.GetByID(user.TenantID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ctx := &model.UserContext{
|
|
User: user,
|
|
Tenant: tenant,
|
|
}
|
|
|
|
return ctx, nil
|
|
}
|
|
|
|
// Logout invalidates a session
|
|
func (s *AuthService) Logout(token string) error {
|
|
session, err := s.sessionRepo.GetByToken(token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return s.sessionRepo.Delete(session)
|
|
}
|
|
|
|
// generateToken generates a random session token
|
|
func (s *AuthService) generateToken() (string, error) {
|
|
bytes := make([]byte, 32)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(bytes), nil
|
|
} |