109 lines
2.6 KiB
Go
109 lines
2.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
const (
|
|
SessionCookieName = "session_token"
|
|
SessionDuration = 24 * time.Hour
|
|
)
|
|
|
|
type Session struct {
|
|
ID string
|
|
UserID string
|
|
Token string
|
|
ExpiresAt time.Time
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
type SessionStore struct {
|
|
DB *sql.DB
|
|
}
|
|
|
|
func NewSessionStore(db *sql.DB) *SessionStore {
|
|
return &SessionStore{DB: db}
|
|
}
|
|
|
|
// GenerateToken generates a secure random token
|
|
func GenerateToken() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
// CreateSession creates a new session for a user
|
|
func (s *SessionStore) CreateSession(ctx context.Context, userID string) (*Session, error) {
|
|
token, err := GenerateToken()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sessionID := uuid.New().String()
|
|
expiresAt := time.Now().Add(SessionDuration)
|
|
|
|
_, err = s.DB.ExecContext(ctx,
|
|
`INSERT INTO sessions (id, user_id, token, expires_at) VALUES (?, ?, ?, ?)`,
|
|
sessionID, userID, token, expiresAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Session{
|
|
ID: sessionID,
|
|
UserID: userID,
|
|
Token: token,
|
|
ExpiresAt: expiresAt,
|
|
CreatedAt: time.Now(),
|
|
}, nil
|
|
}
|
|
|
|
// GetSession retrieves a session by token
|
|
func (s *SessionStore) GetSession(ctx context.Context, token string) (*Session, error) {
|
|
var session Session
|
|
var expiresAtStr string
|
|
err := s.DB.QueryRowContext(ctx,
|
|
`SELECT id, user_id, token, expires_at, created_at FROM sessions WHERE token = ? AND expires_at > ?`,
|
|
token, time.Now()).Scan(&session.ID, &session.UserID, &session.Token, &expiresAtStr, &session.CreatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
session.ExpiresAt, err = time.Parse("2006-01-02 15:04:05", expiresAtStr)
|
|
if err != nil {
|
|
// Try with timezone
|
|
session.ExpiresAt, err = time.Parse(time.RFC3339, expiresAtStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
// DeleteSession deletes a session by token
|
|
func (s *SessionStore) DeleteSession(ctx context.Context, token string) error {
|
|
_, err := s.DB.ExecContext(ctx, `DELETE FROM sessions WHERE token = ?`, token)
|
|
return err
|
|
}
|
|
|
|
// DeleteUserSessions deletes all sessions for a user
|
|
func (s *SessionStore) DeleteUserSessions(ctx context.Context, userID string) error {
|
|
_, err := s.DB.ExecContext(ctx, `DELETE FROM sessions WHERE user_id = ?`, userID)
|
|
return err
|
|
}
|
|
|
|
// CleanupExpiredSessions removes expired sessions
|
|
func (s *SessionStore) CleanupExpiredSessions(ctx context.Context) error {
|
|
_, err := s.DB.ExecContext(ctx, `DELETE FROM sessions WHERE expires_at < ?`, time.Now())
|
|
return err
|
|
}
|