Add RBAC support with roles, permissions, and session management. Implement middleware for authentication and CSRF protection. Enhance audit logging with additional fields. Update HTTP handlers and routes for new features.
This commit is contained in:
89
internal/auth/password.go
Normal file
89
internal/auth/password.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
const (
|
||||
// Argon2id parameters
|
||||
argon2Memory = 64 * 1024 // 64 MB
|
||||
argon2Iterations = 3
|
||||
argon2Parallelism = 2
|
||||
argon2SaltLength = 16
|
||||
argon2KeyLength = 32
|
||||
)
|
||||
|
||||
// HashPassword hashes a password using Argon2id
|
||||
func HashPassword(password string) (string, error) {
|
||||
// Generate a random salt
|
||||
salt := make([]byte, argon2SaltLength)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Hash the password
|
||||
hash := argon2.IDKey([]byte(password), salt, argon2Iterations, argon2Memory, argon2Parallelism, argon2KeyLength)
|
||||
|
||||
// Encode the hash and salt
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Return the encoded hash in the format: $argon2id$v=19$m=65536,t=3,p=2$salt$hash
|
||||
return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, argon2Memory, argon2Iterations, argon2Parallelism, b64Salt, b64Hash), nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies a password against a hash
|
||||
func VerifyPassword(password, encodedHash string) (bool, error) {
|
||||
// Parse the encoded hash
|
||||
parts := strings.Split(encodedHash, "$")
|
||||
if len(parts) != 6 {
|
||||
return false, errors.New("invalid hash format")
|
||||
}
|
||||
|
||||
if parts[1] != "argon2id" {
|
||||
return false, errors.New("unsupported hash algorithm")
|
||||
}
|
||||
|
||||
// Parse version
|
||||
var version int
|
||||
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return false, errors.New("incompatible version")
|
||||
}
|
||||
|
||||
// Parse parameters
|
||||
var memory, iterations, parallelism int
|
||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &iterations, ¶llelism); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Decode salt and hash
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
hash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Compute the hash of the password
|
||||
otherHash := argon2.IDKey([]byte(password), salt, uint32(iterations), uint32(memory), uint8(parallelism), uint32(len(hash)))
|
||||
|
||||
// Compare hashes in constant time
|
||||
if subtle.ConstantTimeCompare(hash, otherHash) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
183
internal/auth/rbac.go
Normal file
183
internal/auth/rbac.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type Permission struct {
|
||||
ID string
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
ID string
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
type RBACStore struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
func NewRBACStore(db *sql.DB) *RBACStore {
|
||||
return &RBACStore{DB: db}
|
||||
}
|
||||
|
||||
// GetUserRoles retrieves all roles for a user
|
||||
func (s *RBACStore) GetUserRoles(ctx context.Context, userID string) ([]Role, error) {
|
||||
rows, err := s.DB.QueryContext(ctx,
|
||||
`SELECT r.id, r.name, r.description FROM roles r
|
||||
INNER JOIN user_roles ur ON r.id = ur.role_id
|
||||
WHERE ur.user_id = ?`,
|
||||
userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var roles []Role
|
||||
for rows.Next() {
|
||||
var role Role
|
||||
if err := rows.Scan(&role.ID, &role.Name, &role.Description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roles = append(roles, role)
|
||||
}
|
||||
return roles, rows.Err()
|
||||
}
|
||||
|
||||
// GetRolePermissions retrieves all permissions for a role
|
||||
func (s *RBACStore) GetRolePermissions(ctx context.Context, roleID string) ([]Permission, error) {
|
||||
rows, err := s.DB.QueryContext(ctx,
|
||||
`SELECT p.id, p.name, p.description FROM permissions p
|
||||
INNER JOIN role_permissions rp ON p.id = rp.permission_id
|
||||
WHERE rp.role_id = ?`,
|
||||
roleID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var permissions []Permission
|
||||
for rows.Next() {
|
||||
var perm Permission
|
||||
if err := rows.Scan(&perm.ID, &perm.Name, &perm.Description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
return permissions, rows.Err()
|
||||
}
|
||||
|
||||
// GetUserPermissions retrieves all permissions for a user (through their roles)
|
||||
func (s *RBACStore) GetUserPermissions(ctx context.Context, userID string) ([]Permission, error) {
|
||||
rows, err := s.DB.QueryContext(ctx,
|
||||
`SELECT DISTINCT p.id, p.name, p.description FROM permissions p
|
||||
INNER JOIN role_permissions rp ON p.id = rp.permission_id
|
||||
INNER JOIN user_roles ur ON rp.role_id = ur.role_id
|
||||
WHERE ur.user_id = ?`,
|
||||
userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var permissions []Permission
|
||||
for rows.Next() {
|
||||
var perm Permission
|
||||
if err := rows.Scan(&perm.ID, &perm.Name, &perm.Description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
return permissions, rows.Err()
|
||||
}
|
||||
|
||||
// UserHasPermission checks if a user has a specific permission
|
||||
func (s *RBACStore) UserHasPermission(ctx context.Context, userID, permission string) (bool, error) {
|
||||
var count int
|
||||
err := s.DB.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM permissions p
|
||||
INNER JOIN role_permissions rp ON p.id = rp.permission_id
|
||||
INNER JOIN user_roles ur ON rp.role_id = ur.role_id
|
||||
WHERE ur.user_id = ? AND p.name = ?`,
|
||||
userID, permission).Scan(&count)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// AssignRoleToUser assigns a role to a user
|
||||
func (s *RBACStore) AssignRoleToUser(ctx context.Context, userID, roleID string) error {
|
||||
_, err := s.DB.ExecContext(ctx,
|
||||
`INSERT OR IGNORE INTO user_roles (user_id, role_id) VALUES (?, ?)`,
|
||||
userID, roleID)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveRoleFromUser removes a role from a user
|
||||
func (s *RBACStore) RemoveRoleFromUser(ctx context.Context, userID, roleID string) error {
|
||||
_, err := s.DB.ExecContext(ctx,
|
||||
`DELETE FROM user_roles WHERE user_id = ? AND role_id = ?`,
|
||||
userID, roleID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAllRoles retrieves all roles
|
||||
func (s *RBACStore) GetAllRoles(ctx context.Context) ([]Role, error) {
|
||||
rows, err := s.DB.QueryContext(ctx,
|
||||
`SELECT id, name, description FROM roles ORDER BY name`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var roles []Role
|
||||
for rows.Next() {
|
||||
var role Role
|
||||
if err := rows.Scan(&role.ID, &role.Name, &role.Description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roles = append(roles, role)
|
||||
}
|
||||
return roles, rows.Err()
|
||||
}
|
||||
|
||||
// GetAllPermissions retrieves all permissions
|
||||
func (s *RBACStore) GetAllPermissions(ctx context.Context) ([]Permission, error) {
|
||||
rows, err := s.DB.QueryContext(ctx,
|
||||
`SELECT id, name, description FROM permissions ORDER BY name`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var permissions []Permission
|
||||
for rows.Next() {
|
||||
var perm Permission
|
||||
if err := rows.Scan(&perm.ID, &perm.Name, &perm.Description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
return permissions, rows.Err()
|
||||
}
|
||||
|
||||
// AssignPermissionToRole assigns a permission to a role
|
||||
func (s *RBACStore) AssignPermissionToRole(ctx context.Context, roleID, permissionID string) error {
|
||||
_, err := s.DB.ExecContext(ctx,
|
||||
`INSERT OR IGNORE INTO role_permissions (role_id, permission_id) VALUES (?, ?)`,
|
||||
roleID, permissionID)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemovePermissionFromRole removes a permission from a role
|
||||
func (s *RBACStore) RemovePermissionFromRole(ctx context.Context, roleID, permissionID string) error {
|
||||
_, err := s.DB.ExecContext(ctx,
|
||||
`DELETE FROM role_permissions WHERE role_id = ? AND permission_id = ?`,
|
||||
roleID, permissionID)
|
||||
return err
|
||||
}
|
||||
108
internal/auth/session.go
Normal file
108
internal/auth/session.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
SessionCookieName = "session_token"
|
||||
SessionDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ID string
|
||||
UserID string
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type SessionStore struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
func NewSessionStore(db *sql.DB) *SessionStore {
|
||||
return &SessionStore{DB: db}
|
||||
}
|
||||
|
||||
// GenerateToken generates a secure random token
|
||||
func GenerateToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// CreateSession creates a new session for a user
|
||||
func (s *SessionStore) CreateSession(ctx context.Context, userID string) (*Session, error) {
|
||||
token, err := GenerateToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionID := uuid.New().String()
|
||||
expiresAt := time.Now().Add(SessionDuration)
|
||||
|
||||
_, err = s.DB.ExecContext(ctx,
|
||||
`INSERT INTO sessions (id, user_id, token, expires_at) VALUES (?, ?, ?, ?)`,
|
||||
sessionID, userID, token, expiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Session{
|
||||
ID: sessionID,
|
||||
UserID: userID,
|
||||
Token: token,
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by token
|
||||
func (s *SessionStore) GetSession(ctx context.Context, token string) (*Session, error) {
|
||||
var session Session
|
||||
var expiresAtStr string
|
||||
err := s.DB.QueryRowContext(ctx,
|
||||
`SELECT id, user_id, token, expires_at, created_at FROM sessions WHERE token = ? AND expires_at > ?`,
|
||||
token, time.Now()).Scan(&session.ID, &session.UserID, &session.Token, &expiresAtStr, &session.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.ExpiresAt, err = time.Parse("2006-01-02 15:04:05", expiresAtStr)
|
||||
if err != nil {
|
||||
// Try with timezone
|
||||
session.ExpiresAt, err = time.Parse(time.RFC3339, expiresAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session by token
|
||||
func (s *SessionStore) DeleteSession(ctx context.Context, token string) error {
|
||||
_, err := s.DB.ExecContext(ctx, `DELETE FROM sessions WHERE token = ?`, token)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUserSessions deletes all sessions for a user
|
||||
func (s *SessionStore) DeleteUserSessions(ctx context.Context, userID string) error {
|
||||
_, err := s.DB.ExecContext(ctx, `DELETE FROM sessions WHERE user_id = ?`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// CleanupExpiredSessions removes expired sessions
|
||||
func (s *SessionStore) CleanupExpiredSessions(ctx context.Context) error {
|
||||
_, err := s.DB.ExecContext(ctx, `DELETE FROM sessions WHERE expires_at < ?`, time.Now())
|
||||
return err
|
||||
}
|
||||
102
internal/auth/user.go
Normal file
102
internal/auth/user.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
Username string
|
||||
PasswordHash string
|
||||
Role string // Legacy field, kept for backward compatibility
|
||||
CreatedAt string
|
||||
}
|
||||
|
||||
type UserStore struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
func NewUserStore(db *sql.DB) *UserStore {
|
||||
return &UserStore{DB: db}
|
||||
}
|
||||
|
||||
// GetUserByUsername retrieves a user by username
|
||||
func (s *UserStore) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
||||
var user User
|
||||
err := s.DB.QueryRowContext(ctx,
|
||||
`SELECT id, username, password_hash, role, created_at FROM users WHERE username = ?`,
|
||||
username).Scan(&user.ID, &user.Username, &user.PasswordHash, &user.Role, &user.CreatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by ID
|
||||
func (s *UserStore) GetUserByID(ctx context.Context, userID string) (*User, error) {
|
||||
var user User
|
||||
err := s.DB.QueryRowContext(ctx,
|
||||
`SELECT id, username, password_hash, role, created_at FROM users WHERE id = ?`,
|
||||
userID).Scan(&user.ID, &user.Username, &user.PasswordHash, &user.Role, &user.CreatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user
|
||||
func (s *UserStore) CreateUser(ctx context.Context, username, password string) (*User, error) {
|
||||
passwordHash, err := HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userID := username // Using username as ID for simplicity, could use UUID
|
||||
_, err = s.DB.ExecContext(ctx,
|
||||
`INSERT INTO users (id, username, password_hash) VALUES (?, ?, ?)`,
|
||||
userID, username, passwordHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.GetUserByID(ctx, userID)
|
||||
}
|
||||
|
||||
// UpdatePassword updates a user's password
|
||||
func (s *UserStore) UpdatePassword(ctx context.Context, userID, newPassword string) error {
|
||||
passwordHash, err := HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.DB.ExecContext(ctx,
|
||||
`UPDATE users SET password_hash = ? WHERE id = ?`,
|
||||
passwordHash, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// Authenticate verifies username and password
|
||||
func (s *UserStore) Authenticate(ctx context.Context, username, password string) (*User, error) {
|
||||
user, err := s.GetUserByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
valid, err := VerifyPassword(password, user.PasswordHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("invalid password")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
Reference in New Issue
Block a user