add feature license management
This commit is contained in:
@@ -13,24 +13,30 @@ import (
|
||||
// authMiddleware validates JWT tokens and sets user context
|
||||
func authMiddleware(authHandler *auth.Handler) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Extract token from Authorization header
|
||||
var token string
|
||||
|
||||
// Try to extract token from Authorization header first
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
|
||||
if authHeader != "" {
|
||||
// Parse Bearer token
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
token = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
// If no token from header, try query parameter (for WebSocket)
|
||||
if token == "" {
|
||||
token = c.Query("token")
|
||||
}
|
||||
|
||||
// If still no token, return error
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Parse Bearer token
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid authorization header format"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
|
||||
// Validate token and get user
|
||||
user, err := authHandler.ValidateToken(token)
|
||||
if err != nil {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/atlasos/calypso/internal/iam"
|
||||
"github.com/atlasos/calypso/internal/monitoring"
|
||||
"github.com/atlasos/calypso/internal/scst"
|
||||
"github.com/atlasos/calypso/internal/shares"
|
||||
"github.com/atlasos/calypso/internal/storage"
|
||||
"github.com/atlasos/calypso/internal/system"
|
||||
"github.com/atlasos/calypso/internal/tape_physical"
|
||||
@@ -198,6 +199,18 @@ func NewRouter(cfg *config.Config, db *database.DB, log *logger.Logger) *gin.Eng
|
||||
storageGroup.GET("/zfs/arc/stats", storageHandler.GetARCStats)
|
||||
}
|
||||
|
||||
// Shares (CIFS/NFS)
|
||||
sharesHandler := shares.NewHandler(db, log)
|
||||
sharesGroup := protected.Group("/shares")
|
||||
sharesGroup.Use(requirePermission("storage", "read"))
|
||||
{
|
||||
sharesGroup.GET("", sharesHandler.ListShares)
|
||||
sharesGroup.GET("/:id", sharesHandler.GetShare)
|
||||
sharesGroup.POST("", requirePermission("storage", "write"), sharesHandler.CreateShare)
|
||||
sharesGroup.PUT("/:id", requirePermission("storage", "write"), sharesHandler.UpdateShare)
|
||||
sharesGroup.DELETE("/:id", requirePermission("storage", "write"), sharesHandler.DeleteShare)
|
||||
}
|
||||
|
||||
// SCST
|
||||
scstHandler := scst.NewHandler(db, log)
|
||||
scstGroup := protected.Group("/scst")
|
||||
@@ -232,6 +245,9 @@ func NewRouter(cfg *config.Config, db *database.DB, log *logger.Logger) *gin.Eng
|
||||
scstGroup.PUT("/initiator-groups/:id", requirePermission("iscsi", "write"), scstHandler.UpdateInitiatorGroup)
|
||||
scstGroup.DELETE("/initiator-groups/:id", requirePermission("iscsi", "write"), scstHandler.DeleteInitiatorGroup)
|
||||
scstGroup.POST("/initiator-groups/:id/initiators", requirePermission("iscsi", "write"), scstHandler.AddInitiatorToGroup)
|
||||
// Config file management
|
||||
scstGroup.GET("/config/file", requirePermission("iscsi", "read"), scstHandler.GetConfigFile)
|
||||
scstGroup.PUT("/config/file", requirePermission("iscsi", "write"), scstHandler.UpdateConfigFile)
|
||||
}
|
||||
|
||||
// Physical Tape Libraries
|
||||
@@ -295,6 +311,7 @@ func NewRouter(cfg *config.Config, db *database.DB, log *logger.Logger) *gin.Eng
|
||||
systemGroup.PUT("/interfaces/:name", systemHandler.UpdateNetworkInterface)
|
||||
systemGroup.GET("/ntp", systemHandler.GetNTPSettings)
|
||||
systemGroup.POST("/ntp", systemHandler.SaveNTPSettings)
|
||||
systemGroup.POST("/execute", requirePermission("system", "write"), systemHandler.ExecuteCommand)
|
||||
}
|
||||
|
||||
// IAM routes - GetUser can be accessed by user viewing own profile or admin
|
||||
|
||||
@@ -745,3 +745,49 @@ func (h *Handler) ListAllInitiatorGroups(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"groups": groups})
|
||||
}
|
||||
|
||||
// GetConfigFile reads the SCST configuration file content
|
||||
func (h *Handler) GetConfigFile(c *gin.Context) {
|
||||
configPath := c.DefaultQuery("path", "/etc/scst.conf")
|
||||
|
||||
content, err := h.service.ReadConfigFile(c.Request.Context(), configPath)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to read config file", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"content": content,
|
||||
"path": configPath,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateConfigFile writes content to SCST configuration file
|
||||
func (h *Handler) UpdateConfigFile(c *gin.Context) {
|
||||
var req struct {
|
||||
Content string `json:"content" binding:"required"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
configPath := req.Path
|
||||
if configPath == "" {
|
||||
configPath = "/etc/scst.conf"
|
||||
}
|
||||
|
||||
if err := h.service.WriteConfigFile(c.Request.Context(), configPath, req.Content); err != nil {
|
||||
h.logger.Error("Failed to write config file", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Configuration file updated successfully",
|
||||
"path": configPath,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1830,6 +1830,59 @@ func (s *Service) WriteConfig(ctx context.Context, configPath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadConfigFile reads the SCST configuration file content
|
||||
func (s *Service) ReadConfigFile(ctx context.Context, configPath string) (string, error) {
|
||||
// First, write current config to temp file to get the actual config
|
||||
tempPath := "/tmp/scst_config_read.conf"
|
||||
cmd := exec.CommandContext(ctx, "sudo", "scstadmin", "-write_config", tempPath)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to write SCST config: %s: %w", string(output), err)
|
||||
}
|
||||
|
||||
// Read the config file
|
||||
configData, err := os.ReadFile(tempPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
return string(configData), nil
|
||||
}
|
||||
|
||||
// WriteConfigFile writes content to SCST configuration file
|
||||
func (s *Service) WriteConfigFile(ctx context.Context, configPath string, content string) error {
|
||||
// Write content to temp file first
|
||||
tempPath := "/tmp/scst_config_write.conf"
|
||||
if err := os.WriteFile(tempPath, []byte(content), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write temp config file: %w", err)
|
||||
}
|
||||
|
||||
// Use scstadmin to load the config (this validates and applies it)
|
||||
cmd := exec.CommandContext(ctx, "sudo", "scstadmin", "-config", tempPath)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load SCST config: %s: %w", string(output), err)
|
||||
}
|
||||
|
||||
// Write to the actual config path using sudo
|
||||
if configPath != tempPath {
|
||||
// Use sudo cp to copy temp file to actual config path
|
||||
cpCmd := exec.CommandContext(ctx, "sudo", "cp", tempPath, configPath)
|
||||
cpOutput, cpErr := cpCmd.CombinedOutput()
|
||||
if cpErr != nil {
|
||||
return fmt.Errorf("failed to copy config file: %s: %w", string(cpOutput), cpErr)
|
||||
}
|
||||
// Set proper permissions
|
||||
chmodCmd := exec.CommandContext(ctx, "sudo", "chmod", "644", configPath)
|
||||
if chmodErr := chmodCmd.Run(); chmodErr != nil {
|
||||
s.logger.Warn("Failed to set config file permissions", "error", chmodErr)
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info("SCST configuration file written", "path", configPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandlerInfo represents SCST handler information
|
||||
type HandlerInfo struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
147
backend/internal/shares/handler.go
Normal file
147
backend/internal/shares/handler.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package shares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/atlasos/calypso/internal/common/database"
|
||||
"github.com/atlasos/calypso/internal/common/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// Handler handles Shares-related API requests
|
||||
type Handler struct {
|
||||
service *Service
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// NewHandler creates a new Shares handler
|
||||
func NewHandler(db *database.DB, log *logger.Logger) *Handler {
|
||||
return &Handler{
|
||||
service: NewService(db, log),
|
||||
logger: log,
|
||||
}
|
||||
}
|
||||
|
||||
// ListShares lists all shares
|
||||
func (h *Handler) ListShares(c *gin.Context) {
|
||||
shares, err := h.service.ListShares(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to list shares", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list shares"})
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return an empty array instead of null
|
||||
if shares == nil {
|
||||
shares = []*Share{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"shares": shares})
|
||||
}
|
||||
|
||||
// GetShare retrieves a share by ID
|
||||
func (h *Handler) GetShare(c *gin.Context) {
|
||||
shareID := c.Param("id")
|
||||
|
||||
share, err := h.service.GetShare(c.Request.Context(), shareID)
|
||||
if err != nil {
|
||||
if err.Error() == "share not found" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "share not found"})
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to get share", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get share"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, share)
|
||||
}
|
||||
|
||||
// CreateShare creates a new share
|
||||
func (h *Handler) CreateShare(c *gin.Context) {
|
||||
var req CreateShareRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Error("Invalid create share request", "error", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate request
|
||||
validate := validator.New()
|
||||
if err := validate.Struct(req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "validation failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user ID from context (set by auth middleware)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
share, err := h.service.CreateShare(c.Request.Context(), &req, userID.(string))
|
||||
if err != nil {
|
||||
if err.Error() == "dataset not found" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "dataset not found"})
|
||||
return
|
||||
}
|
||||
if err.Error() == "only filesystem datasets can be shared (not volumes)" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err.Error() == "at least one protocol (NFS or SMB) must be enabled" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to create share", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, share)
|
||||
}
|
||||
|
||||
// UpdateShare updates an existing share
|
||||
func (h *Handler) UpdateShare(c *gin.Context) {
|
||||
shareID := c.Param("id")
|
||||
|
||||
var req UpdateShareRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Error("Invalid update share request", "error", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
share, err := h.service.UpdateShare(c.Request.Context(), shareID, &req)
|
||||
if err != nil {
|
||||
if err.Error() == "share not found" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "share not found"})
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to update share", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, share)
|
||||
}
|
||||
|
||||
// DeleteShare deletes a share
|
||||
func (h *Handler) DeleteShare(c *gin.Context) {
|
||||
shareID := c.Param("id")
|
||||
|
||||
err := h.service.DeleteShare(c.Request.Context(), shareID)
|
||||
if err != nil {
|
||||
if err.Error() == "share not found" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "share not found"})
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to delete share", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "share deleted successfully"})
|
||||
}
|
||||
806
backend/internal/shares/service.go
Normal file
806
backend/internal/shares/service.go
Normal file
@@ -0,0 +1,806 @@
|
||||
package shares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/atlasos/calypso/internal/common/database"
|
||||
"github.com/atlasos/calypso/internal/common/logger"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Service handles Shares (CIFS/NFS) operations
|
||||
type Service struct {
|
||||
db *database.DB
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// NewService creates a new Shares service
|
||||
func NewService(db *database.DB, log *logger.Logger) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
logger: log,
|
||||
}
|
||||
}
|
||||
|
||||
// Share represents a filesystem share (NFS/SMB)
|
||||
type Share struct {
|
||||
ID string `json:"id"`
|
||||
DatasetID string `json:"dataset_id"`
|
||||
DatasetName string `json:"dataset_name"`
|
||||
MountPoint string `json:"mount_point"`
|
||||
ShareType string `json:"share_type"` // 'nfs', 'smb', 'both'
|
||||
NFSEnabled bool `json:"nfs_enabled"`
|
||||
NFSOptions string `json:"nfs_options,omitempty"`
|
||||
NFSClients []string `json:"nfs_clients,omitempty"`
|
||||
SMBEnabled bool `json:"smb_enabled"`
|
||||
SMBShareName string `json:"smb_share_name,omitempty"`
|
||||
SMBPath string `json:"smb_path,omitempty"`
|
||||
SMBComment string `json:"smb_comment,omitempty"`
|
||||
SMBGuestOK bool `json:"smb_guest_ok"`
|
||||
SMBReadOnly bool `json:"smb_read_only"`
|
||||
SMBBrowseable bool `json:"smb_browseable"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
}
|
||||
|
||||
// ListShares lists all shares
|
||||
func (s *Service) ListShares(ctx context.Context) ([]*Share, error) {
|
||||
query := `
|
||||
SELECT
|
||||
zs.id, zs.dataset_id, zd.name as dataset_name, zd.mount_point,
|
||||
zs.share_type, zs.nfs_enabled, zs.nfs_options, zs.nfs_clients,
|
||||
zs.smb_enabled, zs.smb_share_name, zs.smb_path, zs.smb_comment,
|
||||
zs.smb_guest_ok, zs.smb_read_only, zs.smb_browseable,
|
||||
zs.is_active, zs.created_at, zs.updated_at, zs.created_by
|
||||
FROM zfs_shares zs
|
||||
JOIN zfs_datasets zd ON zs.dataset_id = zd.id
|
||||
ORDER BY zd.name
|
||||
`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "does not exist") {
|
||||
s.logger.Warn("zfs_shares table does not exist, returning empty list")
|
||||
return []*Share{}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to list shares: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var shares []*Share
|
||||
for rows.Next() {
|
||||
var share Share
|
||||
var mountPoint sql.NullString
|
||||
var nfsOptions sql.NullString
|
||||
var smbShareName sql.NullString
|
||||
var smbPath sql.NullString
|
||||
var smbComment sql.NullString
|
||||
var nfsClients []string
|
||||
|
||||
err := rows.Scan(
|
||||
&share.ID, &share.DatasetID, &share.DatasetName, &mountPoint,
|
||||
&share.ShareType, &share.NFSEnabled, &nfsOptions, pq.Array(&nfsClients),
|
||||
&share.SMBEnabled, &smbShareName, &smbPath, &smbComment,
|
||||
&share.SMBGuestOK, &share.SMBReadOnly, &share.SMBBrowseable,
|
||||
&share.IsActive, &share.CreatedAt, &share.UpdatedAt, &share.CreatedBy,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to scan share row", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
share.NFSClients = nfsClients
|
||||
|
||||
if mountPoint.Valid {
|
||||
share.MountPoint = mountPoint.String
|
||||
}
|
||||
if nfsOptions.Valid {
|
||||
share.NFSOptions = nfsOptions.String
|
||||
}
|
||||
if smbShareName.Valid {
|
||||
share.SMBShareName = smbShareName.String
|
||||
}
|
||||
if smbPath.Valid {
|
||||
share.SMBPath = smbPath.String
|
||||
}
|
||||
if smbComment.Valid {
|
||||
share.SMBComment = smbComment.String
|
||||
}
|
||||
|
||||
shares = append(shares, &share)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating share rows: %w", err)
|
||||
}
|
||||
|
||||
return shares, nil
|
||||
}
|
||||
|
||||
// GetShare retrieves a share by ID
|
||||
func (s *Service) GetShare(ctx context.Context, shareID string) (*Share, error) {
|
||||
query := `
|
||||
SELECT
|
||||
zs.id, zs.dataset_id, zd.name as dataset_name, zd.mount_point,
|
||||
zs.share_type, zs.nfs_enabled, zs.nfs_options, zs.nfs_clients,
|
||||
zs.smb_enabled, zs.smb_share_name, zs.smb_path, zs.smb_comment,
|
||||
zs.smb_guest_ok, zs.smb_read_only, zs.smb_browseable,
|
||||
zs.is_active, zs.created_at, zs.updated_at, zs.created_by
|
||||
FROM zfs_shares zs
|
||||
JOIN zfs_datasets zd ON zs.dataset_id = zd.id
|
||||
WHERE zs.id = $1
|
||||
`
|
||||
|
||||
var share Share
|
||||
var mountPoint sql.NullString
|
||||
var nfsOptions sql.NullString
|
||||
var smbShareName sql.NullString
|
||||
var smbPath sql.NullString
|
||||
var smbComment sql.NullString
|
||||
var nfsClients []string
|
||||
|
||||
err := s.db.QueryRowContext(ctx, query, shareID).Scan(
|
||||
&share.ID, &share.DatasetID, &share.DatasetName, &mountPoint,
|
||||
&share.ShareType, &share.NFSEnabled, &nfsOptions, pq.Array(&nfsClients),
|
||||
&share.SMBEnabled, &smbShareName, &smbPath, &smbComment,
|
||||
&share.SMBGuestOK, &share.SMBReadOnly, &share.SMBBrowseable,
|
||||
&share.IsActive, &share.CreatedAt, &share.UpdatedAt, &share.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("share not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get share: %w", err)
|
||||
}
|
||||
|
||||
share.NFSClients = nfsClients
|
||||
|
||||
if mountPoint.Valid {
|
||||
share.MountPoint = mountPoint.String
|
||||
}
|
||||
if nfsOptions.Valid {
|
||||
share.NFSOptions = nfsOptions.String
|
||||
}
|
||||
if smbShareName.Valid {
|
||||
share.SMBShareName = smbShareName.String
|
||||
}
|
||||
if smbPath.Valid {
|
||||
share.SMBPath = smbPath.String
|
||||
}
|
||||
if smbComment.Valid {
|
||||
share.SMBComment = smbComment.String
|
||||
}
|
||||
|
||||
return &share, nil
|
||||
}
|
||||
|
||||
// CreateShareRequest represents a share creation request
|
||||
type CreateShareRequest struct {
|
||||
DatasetID string `json:"dataset_id" binding:"required"`
|
||||
NFSEnabled bool `json:"nfs_enabled"`
|
||||
NFSOptions string `json:"nfs_options"`
|
||||
NFSClients []string `json:"nfs_clients"`
|
||||
SMBEnabled bool `json:"smb_enabled"`
|
||||
SMBShareName string `json:"smb_share_name"`
|
||||
SMBPath string `json:"smb_path"`
|
||||
SMBComment string `json:"smb_comment"`
|
||||
SMBGuestOK bool `json:"smb_guest_ok"`
|
||||
SMBReadOnly bool `json:"smb_read_only"`
|
||||
SMBBrowseable bool `json:"smb_browseable"`
|
||||
}
|
||||
|
||||
// CreateShare creates a new share
|
||||
func (s *Service) CreateShare(ctx context.Context, req *CreateShareRequest, userID string) (*Share, error) {
|
||||
// Validate dataset exists and is a filesystem (not volume)
|
||||
// req.DatasetID can be either UUID or dataset name
|
||||
var datasetID, datasetType, datasetName, mountPoint string
|
||||
var mountPointNull sql.NullString
|
||||
|
||||
// Try to find by ID first (UUID)
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
"SELECT id, type, name, mount_point FROM zfs_datasets WHERE id = $1",
|
||||
req.DatasetID,
|
||||
).Scan(&datasetID, &datasetType, &datasetName, &mountPointNull)
|
||||
|
||||
// If not found by ID, try by name
|
||||
if err == sql.ErrNoRows {
|
||||
err = s.db.QueryRowContext(ctx,
|
||||
"SELECT id, type, name, mount_point FROM zfs_datasets WHERE name = $1",
|
||||
req.DatasetID,
|
||||
).Scan(&datasetID, &datasetType, &datasetName, &mountPointNull)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("dataset not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to validate dataset: %w", err)
|
||||
}
|
||||
|
||||
if mountPointNull.Valid {
|
||||
mountPoint = mountPointNull.String
|
||||
} else {
|
||||
mountPoint = "none"
|
||||
}
|
||||
|
||||
if datasetType != "filesystem" {
|
||||
return nil, fmt.Errorf("only filesystem datasets can be shared (not volumes)")
|
||||
}
|
||||
|
||||
// Determine share type
|
||||
shareType := "none"
|
||||
if req.NFSEnabled && req.SMBEnabled {
|
||||
shareType = "both"
|
||||
} else if req.NFSEnabled {
|
||||
shareType = "nfs"
|
||||
} else if req.SMBEnabled {
|
||||
shareType = "smb"
|
||||
} else {
|
||||
return nil, fmt.Errorf("at least one protocol (NFS or SMB) must be enabled")
|
||||
}
|
||||
|
||||
// Set default NFS options if not provided
|
||||
nfsOptions := req.NFSOptions
|
||||
if nfsOptions == "" {
|
||||
nfsOptions = "rw,sync,no_subtree_check"
|
||||
}
|
||||
|
||||
// Set default SMB share name if not provided
|
||||
smbShareName := req.SMBShareName
|
||||
if smbShareName == "" {
|
||||
// Extract dataset name from full path (e.g., "pool/dataset" -> "dataset")
|
||||
parts := strings.Split(datasetName, "/")
|
||||
smbShareName = parts[len(parts)-1]
|
||||
}
|
||||
|
||||
// Set SMB path (use mount_point if available, otherwise use dataset name)
|
||||
smbPath := req.SMBPath
|
||||
if smbPath == "" {
|
||||
if mountPoint != "" && mountPoint != "none" {
|
||||
smbPath = mountPoint
|
||||
} else {
|
||||
smbPath = fmt.Sprintf("/mnt/%s", strings.ReplaceAll(datasetName, "/", "_"))
|
||||
}
|
||||
}
|
||||
|
||||
// Insert into database
|
||||
query := `
|
||||
INSERT INTO zfs_shares (
|
||||
dataset_id, share_type, nfs_enabled, nfs_options, nfs_clients,
|
||||
smb_enabled, smb_share_name, smb_path, smb_comment,
|
||||
smb_guest_ok, smb_read_only, smb_browseable, is_active, created_by
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
var shareID string
|
||||
var createdAt, updatedAt time.Time
|
||||
|
||||
// Handle nfs_clients array - use empty array if nil
|
||||
nfsClients := req.NFSClients
|
||||
if nfsClients == nil {
|
||||
nfsClients = []string{}
|
||||
}
|
||||
|
||||
err = s.db.QueryRowContext(ctx, query,
|
||||
datasetID, shareType, req.NFSEnabled, nfsOptions, pq.Array(nfsClients),
|
||||
req.SMBEnabled, smbShareName, smbPath, req.SMBComment,
|
||||
req.SMBGuestOK, req.SMBReadOnly, req.SMBBrowseable, true, userID,
|
||||
).Scan(&shareID, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create share: %w", err)
|
||||
}
|
||||
|
||||
// Apply NFS export if enabled
|
||||
if req.NFSEnabled {
|
||||
if err := s.applyNFSExport(ctx, mountPoint, nfsOptions, req.NFSClients); err != nil {
|
||||
s.logger.Error("Failed to apply NFS export", "error", err, "share_id", shareID)
|
||||
// Don't fail the creation, but log the error
|
||||
}
|
||||
}
|
||||
|
||||
// Apply SMB share if enabled
|
||||
if req.SMBEnabled {
|
||||
if err := s.applySMBShare(ctx, smbShareName, smbPath, req.SMBComment, req.SMBGuestOK, req.SMBReadOnly, req.SMBBrowseable); err != nil {
|
||||
s.logger.Error("Failed to apply SMB share", "error", err, "share_id", shareID)
|
||||
// Don't fail the creation, but log the error
|
||||
}
|
||||
}
|
||||
|
||||
// Return the created share
|
||||
return s.GetShare(ctx, shareID)
|
||||
}
|
||||
|
||||
// UpdateShareRequest represents a share update request
|
||||
type UpdateShareRequest struct {
|
||||
NFSEnabled *bool `json:"nfs_enabled"`
|
||||
NFSOptions *string `json:"nfs_options"`
|
||||
NFSClients *[]string `json:"nfs_clients"`
|
||||
SMBEnabled *bool `json:"smb_enabled"`
|
||||
SMBShareName *string `json:"smb_share_name"`
|
||||
SMBComment *string `json:"smb_comment"`
|
||||
SMBGuestOK *bool `json:"smb_guest_ok"`
|
||||
SMBReadOnly *bool `json:"smb_read_only"`
|
||||
SMBBrowseable *bool `json:"smb_browseable"`
|
||||
IsActive *bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// UpdateShare updates an existing share
|
||||
func (s *Service) UpdateShare(ctx context.Context, shareID string, req *UpdateShareRequest) (*Share, error) {
|
||||
// Get current share
|
||||
share, err := s.GetShare(ctx, shareID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build update query dynamically
|
||||
updates := []string{}
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if req.NFSEnabled != nil {
|
||||
updates = append(updates, fmt.Sprintf("nfs_enabled = $%d", argIndex))
|
||||
args = append(args, *req.NFSEnabled)
|
||||
argIndex++
|
||||
}
|
||||
if req.NFSOptions != nil {
|
||||
updates = append(updates, fmt.Sprintf("nfs_options = $%d", argIndex))
|
||||
args = append(args, *req.NFSOptions)
|
||||
argIndex++
|
||||
}
|
||||
if req.NFSClients != nil {
|
||||
updates = append(updates, fmt.Sprintf("nfs_clients = $%d", argIndex))
|
||||
args = append(args, pq.Array(*req.NFSClients))
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBEnabled != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_enabled = $%d", argIndex))
|
||||
args = append(args, *req.SMBEnabled)
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBShareName != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_share_name = $%d", argIndex))
|
||||
args = append(args, *req.SMBShareName)
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBComment != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_comment = $%d", argIndex))
|
||||
args = append(args, *req.SMBComment)
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBGuestOK != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_guest_ok = $%d", argIndex))
|
||||
args = append(args, *req.SMBGuestOK)
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBReadOnly != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_read_only = $%d", argIndex))
|
||||
args = append(args, *req.SMBReadOnly)
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBBrowseable != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_browseable = $%d", argIndex))
|
||||
args = append(args, *req.SMBBrowseable)
|
||||
argIndex++
|
||||
}
|
||||
if req.IsActive != nil {
|
||||
updates = append(updates, fmt.Sprintf("is_active = $%d", argIndex))
|
||||
args = append(args, *req.IsActive)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if len(updates) == 0 {
|
||||
return share, nil // No changes
|
||||
}
|
||||
|
||||
// Update share_type based on enabled protocols
|
||||
nfsEnabled := share.NFSEnabled
|
||||
smbEnabled := share.SMBEnabled
|
||||
if req.NFSEnabled != nil {
|
||||
nfsEnabled = *req.NFSEnabled
|
||||
}
|
||||
if req.SMBEnabled != nil {
|
||||
smbEnabled = *req.SMBEnabled
|
||||
}
|
||||
|
||||
shareType := "none"
|
||||
if nfsEnabled && smbEnabled {
|
||||
shareType = "both"
|
||||
} else if nfsEnabled {
|
||||
shareType = "nfs"
|
||||
} else if smbEnabled {
|
||||
shareType = "smb"
|
||||
}
|
||||
|
||||
updates = append(updates, fmt.Sprintf("share_type = $%d", argIndex))
|
||||
args = append(args, shareType)
|
||||
argIndex++
|
||||
|
||||
updates = append(updates, fmt.Sprintf("updated_at = NOW()"))
|
||||
args = append(args, shareID)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE zfs_shares
|
||||
SET %s
|
||||
WHERE id = $%d
|
||||
`, strings.Join(updates, ", "), argIndex)
|
||||
|
||||
_, err = s.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update share: %w", err)
|
||||
}
|
||||
|
||||
// Re-apply NFS export if NFS is enabled
|
||||
if nfsEnabled {
|
||||
nfsOptions := share.NFSOptions
|
||||
if req.NFSOptions != nil {
|
||||
nfsOptions = *req.NFSOptions
|
||||
}
|
||||
nfsClients := share.NFSClients
|
||||
if req.NFSClients != nil {
|
||||
nfsClients = *req.NFSClients
|
||||
}
|
||||
if err := s.applyNFSExport(ctx, share.MountPoint, nfsOptions, nfsClients); err != nil {
|
||||
s.logger.Error("Failed to apply NFS export", "error", err, "share_id", shareID)
|
||||
}
|
||||
} else {
|
||||
// Remove NFS export if disabled
|
||||
if err := s.removeNFSExport(ctx, share.MountPoint); err != nil {
|
||||
s.logger.Error("Failed to remove NFS export", "error", err, "share_id", shareID)
|
||||
}
|
||||
}
|
||||
|
||||
// Re-apply SMB share if SMB is enabled
|
||||
if smbEnabled {
|
||||
smbShareName := share.SMBShareName
|
||||
if req.SMBShareName != nil {
|
||||
smbShareName = *req.SMBShareName
|
||||
}
|
||||
smbPath := share.SMBPath
|
||||
smbComment := share.SMBComment
|
||||
if req.SMBComment != nil {
|
||||
smbComment = *req.SMBComment
|
||||
}
|
||||
smbGuestOK := share.SMBGuestOK
|
||||
if req.SMBGuestOK != nil {
|
||||
smbGuestOK = *req.SMBGuestOK
|
||||
}
|
||||
smbReadOnly := share.SMBReadOnly
|
||||
if req.SMBReadOnly != nil {
|
||||
smbReadOnly = *req.SMBReadOnly
|
||||
}
|
||||
smbBrowseable := share.SMBBrowseable
|
||||
if req.SMBBrowseable != nil {
|
||||
smbBrowseable = *req.SMBBrowseable
|
||||
}
|
||||
if err := s.applySMBShare(ctx, smbShareName, smbPath, smbComment, smbGuestOK, smbReadOnly, smbBrowseable); err != nil {
|
||||
s.logger.Error("Failed to apply SMB share", "error", err, "share_id", shareID)
|
||||
}
|
||||
} else {
|
||||
// Remove SMB share if disabled
|
||||
if err := s.removeSMBShare(ctx, share.SMBShareName); err != nil {
|
||||
s.logger.Error("Failed to remove SMB share", "error", err, "share_id", shareID)
|
||||
}
|
||||
}
|
||||
|
||||
return s.GetShare(ctx, shareID)
|
||||
}
|
||||
|
||||
// DeleteShare deletes a share
|
||||
func (s *Service) DeleteShare(ctx context.Context, shareID string) error {
|
||||
// Get share to get mount point and share name
|
||||
share, err := s.GetShare(ctx, shareID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove NFS export
|
||||
if share.NFSEnabled {
|
||||
if err := s.removeNFSExport(ctx, share.MountPoint); err != nil {
|
||||
s.logger.Error("Failed to remove NFS export", "error", err, "share_id", shareID)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove SMB share
|
||||
if share.SMBEnabled {
|
||||
if err := s.removeSMBShare(ctx, share.SMBShareName); err != nil {
|
||||
s.logger.Error("Failed to remove SMB share", "error", err, "share_id", shareID)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete from database
|
||||
_, err = s.db.ExecContext(ctx, "DELETE FROM zfs_shares WHERE id = $1", shareID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete share: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNFSExport adds or updates an NFS export in /etc/exports
|
||||
func (s *Service) applyNFSExport(ctx context.Context, mountPoint, options string, clients []string) error {
|
||||
if mountPoint == "" || mountPoint == "none" {
|
||||
return fmt.Errorf("mount point is required for NFS export")
|
||||
}
|
||||
|
||||
// Build client list (default to * if empty)
|
||||
clientList := "*"
|
||||
if len(clients) > 0 {
|
||||
clientList = strings.Join(clients, " ")
|
||||
}
|
||||
|
||||
// Build export line
|
||||
exportLine := fmt.Sprintf("%s %s(%s)", mountPoint, clientList, options)
|
||||
|
||||
// Read current /etc/exports
|
||||
exportsPath := "/etc/exports"
|
||||
exportsContent, err := os.ReadFile(exportsPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to read exports file: %w", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(exportsContent), "\n")
|
||||
var newLines []string
|
||||
found := false
|
||||
|
||||
// Check if this mount point already exists
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
newLines = append(newLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this line is for our mount point
|
||||
if strings.HasPrefix(line, mountPoint+" ") {
|
||||
newLines = append(newLines, exportLine)
|
||||
found = true
|
||||
} else {
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// Add if not found
|
||||
if !found {
|
||||
newLines = append(newLines, exportLine)
|
||||
}
|
||||
|
||||
// Write back to file
|
||||
newContent := strings.Join(newLines, "\n") + "\n"
|
||||
if err := os.WriteFile(exportsPath, []byte(newContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write exports file: %w", err)
|
||||
}
|
||||
|
||||
// Apply exports
|
||||
cmd := exec.CommandContext(ctx, "sudo", "exportfs", "-ra")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to apply exports: %s: %w", string(output), err)
|
||||
}
|
||||
|
||||
s.logger.Info("NFS export applied", "mount_point", mountPoint, "clients", clientList)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeNFSExport removes an NFS export from /etc/exports
|
||||
func (s *Service) removeNFSExport(ctx context.Context, mountPoint string) error {
|
||||
if mountPoint == "" || mountPoint == "none" {
|
||||
return nil // Nothing to remove
|
||||
}
|
||||
|
||||
exportsPath := "/etc/exports"
|
||||
exportsContent, err := os.ReadFile(exportsPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // File doesn't exist, nothing to remove
|
||||
}
|
||||
return fmt.Errorf("failed to read exports file: %w", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(exportsContent), "\n")
|
||||
var newLines []string
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
newLines = append(newLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip lines for this mount point
|
||||
if strings.HasPrefix(line, mountPoint+" ") {
|
||||
continue
|
||||
}
|
||||
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
|
||||
// Write back to file
|
||||
newContent := strings.Join(newLines, "\n")
|
||||
if newContent != "" && !strings.HasSuffix(newContent, "\n") {
|
||||
newContent += "\n"
|
||||
}
|
||||
if err := os.WriteFile(exportsPath, []byte(newContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write exports file: %w", err)
|
||||
}
|
||||
|
||||
// Apply exports
|
||||
cmd := exec.CommandContext(ctx, "sudo", "exportfs", "-ra")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to apply exports: %s: %w", string(output), err)
|
||||
}
|
||||
|
||||
s.logger.Info("NFS export removed", "mount_point", mountPoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
// applySMBShare adds or updates an SMB share in /etc/samba/smb.conf
|
||||
func (s *Service) applySMBShare(ctx context.Context, shareName, path, comment string, guestOK, readOnly, browseable bool) error {
|
||||
if shareName == "" {
|
||||
return fmt.Errorf("SMB share name is required")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("SMB path is required")
|
||||
}
|
||||
|
||||
smbConfPath := "/etc/samba/smb.conf"
|
||||
smbContent, err := os.ReadFile(smbConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read smb.conf: %w", err)
|
||||
}
|
||||
|
||||
// Parse and update smb.conf
|
||||
lines := strings.Split(string(smbContent), "\n")
|
||||
var newLines []string
|
||||
inShare := false
|
||||
shareStart := -1
|
||||
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// Check if we're entering our share section
|
||||
if strings.HasPrefix(trimmed, "[") && strings.HasSuffix(trimmed, "]") {
|
||||
sectionName := trimmed[1 : len(trimmed)-1]
|
||||
if sectionName == shareName {
|
||||
inShare = true
|
||||
shareStart = i
|
||||
continue
|
||||
} else if inShare {
|
||||
// We've left our share section, insert the share config here
|
||||
newLines = append(newLines, s.buildSMBShareConfig(shareName, path, comment, guestOK, readOnly, browseable))
|
||||
inShare = false
|
||||
}
|
||||
}
|
||||
|
||||
if inShare {
|
||||
// Skip lines until we find the next section or end of file
|
||||
continue
|
||||
}
|
||||
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
|
||||
// If we were still in the share at the end, add it
|
||||
if inShare {
|
||||
newLines = append(newLines, s.buildSMBShareConfig(shareName, path, comment, guestOK, readOnly, browseable))
|
||||
} else if shareStart == -1 {
|
||||
// Share doesn't exist, add it at the end
|
||||
newLines = append(newLines, "")
|
||||
newLines = append(newLines, s.buildSMBShareConfig(shareName, path, comment, guestOK, readOnly, browseable))
|
||||
}
|
||||
|
||||
// Write back to file
|
||||
newContent := strings.Join(newLines, "\n")
|
||||
if err := os.WriteFile(smbConfPath, []byte(newContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write smb.conf: %w", err)
|
||||
}
|
||||
|
||||
// Reload Samba
|
||||
cmd := exec.CommandContext(ctx, "sudo", "systemctl", "reload", "smbd")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
// Try restart if reload fails
|
||||
cmd = exec.CommandContext(ctx, "sudo", "systemctl", "restart", "smbd")
|
||||
if output2, err2 := cmd.CombinedOutput(); err2 != nil {
|
||||
return fmt.Errorf("failed to reload/restart smbd: %s / %s: %w", string(output), string(output2), err2)
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info("SMB share applied", "share_name", shareName, "path", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildSMBShareConfig builds the SMB share configuration block
|
||||
func (s *Service) buildSMBShareConfig(shareName, path, comment string, guestOK, readOnly, browseable bool) string {
|
||||
var config []string
|
||||
config = append(config, fmt.Sprintf("[%s]", shareName))
|
||||
if comment != "" {
|
||||
config = append(config, fmt.Sprintf(" comment = %s", comment))
|
||||
}
|
||||
config = append(config, fmt.Sprintf(" path = %s", path))
|
||||
if guestOK {
|
||||
config = append(config, " guest ok = yes")
|
||||
} else {
|
||||
config = append(config, " guest ok = no")
|
||||
}
|
||||
if readOnly {
|
||||
config = append(config, " read only = yes")
|
||||
} else {
|
||||
config = append(config, " read only = no")
|
||||
}
|
||||
if browseable {
|
||||
config = append(config, " browseable = yes")
|
||||
} else {
|
||||
config = append(config, " browseable = no")
|
||||
}
|
||||
return strings.Join(config, "\n")
|
||||
}
|
||||
|
||||
// removeSMBShare removes an SMB share from /etc/samba/smb.conf
|
||||
func (s *Service) removeSMBShare(ctx context.Context, shareName string) error {
|
||||
if shareName == "" {
|
||||
return nil // Nothing to remove
|
||||
}
|
||||
|
||||
smbConfPath := "/etc/samba/smb.conf"
|
||||
smbContent, err := os.ReadFile(smbConfPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // File doesn't exist, nothing to remove
|
||||
}
|
||||
return fmt.Errorf("failed to read smb.conf: %w", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(smbContent), "\n")
|
||||
var newLines []string
|
||||
inShare := false
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// Check if we're entering our share section
|
||||
if strings.HasPrefix(trimmed, "[") && strings.HasSuffix(trimmed, "]") {
|
||||
sectionName := trimmed[1 : len(trimmed)-1]
|
||||
if sectionName == shareName {
|
||||
inShare = true
|
||||
continue
|
||||
} else if inShare {
|
||||
// We've left our share section
|
||||
inShare = false
|
||||
}
|
||||
}
|
||||
|
||||
if inShare {
|
||||
// Skip lines in this share section
|
||||
continue
|
||||
}
|
||||
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
|
||||
// Write back to file
|
||||
newContent := strings.Join(newLines, "\n")
|
||||
if err := os.WriteFile(smbConfPath, []byte(newContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write smb.conf: %w", err)
|
||||
}
|
||||
|
||||
// Reload Samba
|
||||
cmd := exec.CommandContext(ctx, "sudo", "systemctl", "reload", "smbd")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
// Try restart if reload fails
|
||||
cmd = exec.CommandContext(ctx, "sudo", "systemctl", "restart", "smbd")
|
||||
if output2, err2 := cmd.CombinedOutput(); err2 != nil {
|
||||
return fmt.Errorf("failed to reload/restart smbd: %s / %s: %w", string(output), string(output2), err2)
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info("SMB share removed", "share_name", shareName)
|
||||
return nil
|
||||
}
|
||||
@@ -610,6 +610,7 @@ func (s *ZFSService) AddSpareDisk(ctx context.Context, poolID string, diskPaths
|
||||
|
||||
// ZFSDataset represents a ZFS dataset
|
||||
type ZFSDataset struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Pool string `json:"pool"`
|
||||
Type string `json:"type"` // filesystem, volume, snapshot
|
||||
@@ -628,7 +629,7 @@ type ZFSDataset struct {
|
||||
func (s *ZFSService) ListDatasets(ctx context.Context, poolName string) ([]*ZFSDataset, error) {
|
||||
// Get datasets from database
|
||||
query := `
|
||||
SELECT name, pool_name, type, mount_point,
|
||||
SELECT id, name, pool_name, type, mount_point,
|
||||
used_bytes, available_bytes, referenced_bytes,
|
||||
compression, deduplication, quota, reservation,
|
||||
created_at
|
||||
@@ -654,7 +655,7 @@ func (s *ZFSService) ListDatasets(ctx context.Context, poolName string) ([]*ZFSD
|
||||
var mountPoint sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&ds.Name, &ds.Pool, &ds.Type, &mountPoint,
|
||||
&ds.ID, &ds.Name, &ds.Pool, &ds.Type, &mountPoint,
|
||||
&ds.UsedBytes, &ds.AvailableBytes, &ds.ReferencedBytes,
|
||||
&ds.Compression, &ds.Deduplication, &ds.Quota, &ds.Reservation,
|
||||
&ds.CreatedAt,
|
||||
|
||||
@@ -253,3 +253,30 @@ func (h *Handler) GetNetworkThroughput(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": data})
|
||||
}
|
||||
|
||||
// ExecuteCommand executes a shell command
|
||||
func (h *Handler) ExecuteCommand(c *gin.Context) {
|
||||
var req struct {
|
||||
Command string `json:"command" binding:"required"`
|
||||
Service string `json:"service,omitempty"` // Optional: system, scst, storage, backup, tape
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Error("Invalid request body", "error", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "command is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Execute command based on service context
|
||||
output, err := h.service.ExecuteCommand(c.Request.Context(), req.Command, req.Service)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to execute command", "error", err, "command", req.Command, "service", req.Service)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": err.Error(),
|
||||
"output": output, // Include output even on error
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"output": output})
|
||||
}
|
||||
|
||||
@@ -871,3 +871,143 @@ func (s *Service) GetNTPSettings(ctx context.Context) (*NTPSettings, error) {
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// ExecuteCommand executes a shell command and returns the output
|
||||
// service parameter is optional and can be: system, scst, storage, backup, tape
|
||||
func (s *Service) ExecuteCommand(ctx context.Context, command string, service string) (string, error) {
|
||||
// Sanitize command - basic security check
|
||||
command = strings.TrimSpace(command)
|
||||
if command == "" {
|
||||
return "", fmt.Errorf("command cannot be empty")
|
||||
}
|
||||
|
||||
// Block dangerous commands that could harm the system
|
||||
dangerousCommands := []string{
|
||||
"rm -rf /",
|
||||
"dd if=",
|
||||
":(){ :|:& };:",
|
||||
"mkfs",
|
||||
"fdisk",
|
||||
"parted",
|
||||
"format",
|
||||
"> /dev/sd",
|
||||
"mkfs.ext",
|
||||
"mkfs.xfs",
|
||||
"mkfs.btrfs",
|
||||
"wipefs",
|
||||
}
|
||||
|
||||
commandLower := strings.ToLower(command)
|
||||
for _, dangerous := range dangerousCommands {
|
||||
if strings.Contains(commandLower, dangerous) {
|
||||
return "", fmt.Errorf("command blocked for security reasons")
|
||||
}
|
||||
}
|
||||
|
||||
// Service-specific command handling
|
||||
switch service {
|
||||
case "scst":
|
||||
// Allow SCST admin commands
|
||||
if strings.HasPrefix(command, "scstadmin") {
|
||||
// SCST commands are safe
|
||||
break
|
||||
}
|
||||
case "backup":
|
||||
// Allow bconsole commands
|
||||
if strings.HasPrefix(command, "bconsole") {
|
||||
// Backup console commands are safe
|
||||
break
|
||||
}
|
||||
case "storage":
|
||||
// Allow ZFS and storage commands
|
||||
if strings.HasPrefix(command, "zfs") || strings.HasPrefix(command, "zpool") || strings.HasPrefix(command, "lsblk") {
|
||||
// Storage commands are safe
|
||||
break
|
||||
}
|
||||
case "tape":
|
||||
// Allow tape library commands
|
||||
if strings.HasPrefix(command, "mtx") || strings.HasPrefix(command, "lsscsi") || strings.HasPrefix(command, "sg_") {
|
||||
// Tape commands are safe
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Execute command with timeout (30 seconds)
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Check if command already has sudo (reuse commandLower from above)
|
||||
hasSudo := strings.HasPrefix(commandLower, "sudo ")
|
||||
|
||||
// Determine if command needs sudo based on service and command type
|
||||
needsSudo := false
|
||||
|
||||
if !hasSudo {
|
||||
// Commands that typically need sudo
|
||||
sudoCommands := []string{
|
||||
"scstadmin",
|
||||
"systemctl",
|
||||
"zfs",
|
||||
"zpool",
|
||||
"mount",
|
||||
"umount",
|
||||
"ip link",
|
||||
"ip addr",
|
||||
"iptables",
|
||||
"journalctl",
|
||||
}
|
||||
|
||||
for _, sudoCmd := range sudoCommands {
|
||||
if strings.HasPrefix(commandLower, sudoCmd) {
|
||||
needsSudo = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Service-specific sudo requirements
|
||||
switch service {
|
||||
case "scst":
|
||||
// All SCST admin commands need sudo
|
||||
if strings.HasPrefix(commandLower, "scstadmin") {
|
||||
needsSudo = true
|
||||
}
|
||||
case "storage":
|
||||
// ZFS commands typically need sudo
|
||||
if strings.HasPrefix(commandLower, "zfs") || strings.HasPrefix(commandLower, "zpool") {
|
||||
needsSudo = true
|
||||
}
|
||||
case "system":
|
||||
// System commands like systemctl need sudo
|
||||
if strings.HasPrefix(commandLower, "systemctl") || strings.HasPrefix(commandLower, "journalctl") {
|
||||
needsSudo = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build command with or without sudo
|
||||
var cmd *exec.Cmd
|
||||
if needsSudo && !hasSudo {
|
||||
// Use sudo for privileged commands (if not already present)
|
||||
cmd = exec.CommandContext(ctx, "sudo", "sh", "-c", command)
|
||||
} else {
|
||||
// Regular command (or already has sudo)
|
||||
cmd = exec.CommandContext(ctx, "sh", "-c", command)
|
||||
}
|
||||
|
||||
cmd.Env = append(os.Environ(), "TERM=xterm-256color")
|
||||
|
||||
cmd.Env = append(os.Environ(), "TERM=xterm-256color")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
// Return output even if there's an error (some commands return non-zero exit codes)
|
||||
outputStr := string(output)
|
||||
if len(outputStr) > 0 {
|
||||
return outputStr, nil
|
||||
}
|
||||
return "", fmt.Errorf("command execution failed: %w", err)
|
||||
}
|
||||
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
328
backend/internal/system/terminal.go
Normal file
328
backend/internal/system/terminal.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/atlasos/calypso/internal/common/logger"
|
||||
"github.com/creack/pty"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
// WebSocket timeouts
|
||||
writeWait = 10 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
// Allow all origins - in production, validate against allowed domains
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
// TerminalSession manages a single terminal session
|
||||
type TerminalSession struct {
|
||||
conn *websocket.Conn
|
||||
pty *os.File
|
||||
cmd *exec.Cmd
|
||||
logger *logger.Logger
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
username string
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// HandleTerminalWebSocket handles WebSocket connection for terminal
|
||||
func HandleTerminalWebSocket(c *gin.Context, log *logger.Logger) {
|
||||
// Verify authentication
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
log.Warn("Terminal WebSocket: unauthorized access", "ip", c.ClientIP())
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
username, _ := c.Get("username")
|
||||
if username == nil {
|
||||
username = userID
|
||||
}
|
||||
|
||||
log.Info("Terminal WebSocket: connection attempt", "username", username, "ip", c.ClientIP())
|
||||
|
||||
// Upgrade connection
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Error("Terminal WebSocket: upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("Terminal WebSocket: connection upgraded", "username", username)
|
||||
|
||||
// Create session
|
||||
session := &TerminalSession{
|
||||
conn: conn,
|
||||
logger: log,
|
||||
username: username.(string),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start terminal
|
||||
if err := session.startPTY(); err != nil {
|
||||
log.Error("Terminal WebSocket: failed to start PTY", "error", err, "username", username)
|
||||
session.sendError(err.Error())
|
||||
session.close()
|
||||
return
|
||||
}
|
||||
|
||||
// Handle messages and PTY output
|
||||
go session.handleRead()
|
||||
go session.handleWrite()
|
||||
}
|
||||
|
||||
// startPTY starts the PTY session
|
||||
func (s *TerminalSession) startPTY() error {
|
||||
// Get user info
|
||||
currentUser, err := user.Lookup(s.username)
|
||||
if err != nil {
|
||||
// Fallback to current user
|
||||
currentUser, err = user.Current()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Determine shell
|
||||
shell := os.Getenv("SHELL")
|
||||
if shell == "" {
|
||||
shell = "/bin/bash"
|
||||
}
|
||||
|
||||
// Create command
|
||||
s.cmd = exec.Command(shell)
|
||||
s.cmd.Env = append(os.Environ(),
|
||||
"TERM=xterm-256color",
|
||||
"HOME="+currentUser.HomeDir,
|
||||
"USER="+currentUser.Username,
|
||||
"USERNAME="+currentUser.Username,
|
||||
)
|
||||
s.cmd.Dir = currentUser.HomeDir
|
||||
|
||||
// Start PTY
|
||||
ptyFile, err := pty.Start(s.cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.pty = ptyFile
|
||||
|
||||
// Set initial size
|
||||
pty.Setsize(ptyFile, &pty.Winsize{
|
||||
Rows: 24,
|
||||
Cols: 80,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleRead handles incoming WebSocket messages
|
||||
func (s *TerminalSession) handleRead() {
|
||||
defer s.close()
|
||||
|
||||
// Set read deadline and pong handler
|
||||
s.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
s.conn.SetPongHandler(func(string) error {
|
||||
s.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
default:
|
||||
messageType, data, err := s.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
s.logger.Error("Terminal WebSocket: read error", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handle binary messages (raw input)
|
||||
if messageType == websocket.BinaryMessage {
|
||||
s.writeToPTY(data)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle text messages (JSON commands)
|
||||
if messageType == websocket.TextMessage {
|
||||
var msg map[string]interface{}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch msg["type"] {
|
||||
case "input":
|
||||
if data, ok := msg["data"].(string); ok {
|
||||
s.writeToPTY([]byte(data))
|
||||
}
|
||||
|
||||
case "resize":
|
||||
if cols, ok1 := msg["cols"].(float64); ok1 {
|
||||
if rows, ok2 := msg["rows"].(float64); ok2 {
|
||||
s.resizePTY(uint16(cols), uint16(rows))
|
||||
}
|
||||
}
|
||||
|
||||
case "ping":
|
||||
s.writeWS(websocket.TextMessage, []byte(`{"type":"pong"}`))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleWrite handles PTY output to WebSocket
|
||||
func (s *TerminalSession) handleWrite() {
|
||||
defer s.close()
|
||||
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Read from PTY and write to WebSocket
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Send ping
|
||||
if err := s.writeWS(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
default:
|
||||
// Read from PTY
|
||||
if s.pty != nil {
|
||||
n, err := s.pty.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
s.logger.Error("Terminal WebSocket: PTY read error", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 {
|
||||
// Write binary data to WebSocket
|
||||
if err := s.writeWS(websocket.BinaryMessage, buffer[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeToPTY writes data to PTY
|
||||
func (s *TerminalSession) writeToPTY(data []byte) {
|
||||
s.mu.RLock()
|
||||
closed := s.closed
|
||||
pty := s.pty
|
||||
s.mu.RUnlock()
|
||||
|
||||
if closed || pty == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := pty.Write(data); err != nil {
|
||||
s.logger.Error("Terminal WebSocket: PTY write error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// resizePTY resizes the PTY
|
||||
func (s *TerminalSession) resizePTY(cols, rows uint16) {
|
||||
s.mu.RLock()
|
||||
closed := s.closed
|
||||
ptyFile := s.pty
|
||||
s.mu.RUnlock()
|
||||
|
||||
if closed || ptyFile == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Use pty.Setsize from package, not method from variable
|
||||
pty.Setsize(ptyFile, &pty.Winsize{
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
})
|
||||
}
|
||||
|
||||
// writeWS writes message to WebSocket
|
||||
func (s *TerminalSession) writeWS(messageType int, data []byte) error {
|
||||
s.mu.RLock()
|
||||
closed := s.closed
|
||||
conn := s.conn
|
||||
s.mu.RUnlock()
|
||||
|
||||
if closed || conn == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
return conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
// sendError sends error message
|
||||
func (s *TerminalSession) sendError(errMsg string) {
|
||||
msg := map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": errMsg,
|
||||
}
|
||||
data, _ := json.Marshal(msg)
|
||||
s.writeWS(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
// close closes the terminal session
|
||||
func (s *TerminalSession) close() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
|
||||
s.closed = true
|
||||
close(s.done)
|
||||
|
||||
// Close PTY
|
||||
if s.pty != nil {
|
||||
s.pty.Close()
|
||||
}
|
||||
|
||||
// Kill process
|
||||
if s.cmd != nil && s.cmd.Process != nil {
|
||||
s.cmd.Process.Signal(syscall.SIGTERM)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if s.cmd.ProcessState == nil || !s.cmd.ProcessState.Exited() {
|
||||
s.cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
|
||||
// Close WebSocket
|
||||
if s.conn != nil {
|
||||
s.conn.Close()
|
||||
}
|
||||
|
||||
s.logger.Info("Terminal WebSocket: session closed", "username", s.username)
|
||||
}
|
||||
Reference in New Issue
Block a user