add feature license management
This commit is contained in:
147
backend/internal/shares/handler.go
Normal file
147
backend/internal/shares/handler.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package shares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/atlasos/calypso/internal/common/database"
|
||||
"github.com/atlasos/calypso/internal/common/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// Handler handles Shares-related API requests
|
||||
type Handler struct {
|
||||
service *Service
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// NewHandler creates a new Shares handler
|
||||
func NewHandler(db *database.DB, log *logger.Logger) *Handler {
|
||||
return &Handler{
|
||||
service: NewService(db, log),
|
||||
logger: log,
|
||||
}
|
||||
}
|
||||
|
||||
// ListShares lists all shares
|
||||
func (h *Handler) ListShares(c *gin.Context) {
|
||||
shares, err := h.service.ListShares(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to list shares", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list shares"})
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return an empty array instead of null
|
||||
if shares == nil {
|
||||
shares = []*Share{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"shares": shares})
|
||||
}
|
||||
|
||||
// GetShare retrieves a share by ID
|
||||
func (h *Handler) GetShare(c *gin.Context) {
|
||||
shareID := c.Param("id")
|
||||
|
||||
share, err := h.service.GetShare(c.Request.Context(), shareID)
|
||||
if err != nil {
|
||||
if err.Error() == "share not found" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "share not found"})
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to get share", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get share"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, share)
|
||||
}
|
||||
|
||||
// CreateShare creates a new share
|
||||
func (h *Handler) CreateShare(c *gin.Context) {
|
||||
var req CreateShareRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Error("Invalid create share request", "error", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate request
|
||||
validate := validator.New()
|
||||
if err := validate.Struct(req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "validation failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user ID from context (set by auth middleware)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
share, err := h.service.CreateShare(c.Request.Context(), &req, userID.(string))
|
||||
if err != nil {
|
||||
if err.Error() == "dataset not found" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "dataset not found"})
|
||||
return
|
||||
}
|
||||
if err.Error() == "only filesystem datasets can be shared (not volumes)" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err.Error() == "at least one protocol (NFS or SMB) must be enabled" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to create share", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, share)
|
||||
}
|
||||
|
||||
// UpdateShare updates an existing share
|
||||
func (h *Handler) UpdateShare(c *gin.Context) {
|
||||
shareID := c.Param("id")
|
||||
|
||||
var req UpdateShareRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Error("Invalid update share request", "error", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
share, err := h.service.UpdateShare(c.Request.Context(), shareID, &req)
|
||||
if err != nil {
|
||||
if err.Error() == "share not found" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "share not found"})
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to update share", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, share)
|
||||
}
|
||||
|
||||
// DeleteShare deletes a share
|
||||
func (h *Handler) DeleteShare(c *gin.Context) {
|
||||
shareID := c.Param("id")
|
||||
|
||||
err := h.service.DeleteShare(c.Request.Context(), shareID)
|
||||
if err != nil {
|
||||
if err.Error() == "share not found" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "share not found"})
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to delete share", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "share deleted successfully"})
|
||||
}
|
||||
806
backend/internal/shares/service.go
Normal file
806
backend/internal/shares/service.go
Normal file
@@ -0,0 +1,806 @@
|
||||
package shares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/atlasos/calypso/internal/common/database"
|
||||
"github.com/atlasos/calypso/internal/common/logger"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Service handles Shares (CIFS/NFS) operations
|
||||
type Service struct {
|
||||
db *database.DB
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// NewService creates a new Shares service
|
||||
func NewService(db *database.DB, log *logger.Logger) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
logger: log,
|
||||
}
|
||||
}
|
||||
|
||||
// Share represents a filesystem share (NFS/SMB)
|
||||
type Share struct {
|
||||
ID string `json:"id"`
|
||||
DatasetID string `json:"dataset_id"`
|
||||
DatasetName string `json:"dataset_name"`
|
||||
MountPoint string `json:"mount_point"`
|
||||
ShareType string `json:"share_type"` // 'nfs', 'smb', 'both'
|
||||
NFSEnabled bool `json:"nfs_enabled"`
|
||||
NFSOptions string `json:"nfs_options,omitempty"`
|
||||
NFSClients []string `json:"nfs_clients,omitempty"`
|
||||
SMBEnabled bool `json:"smb_enabled"`
|
||||
SMBShareName string `json:"smb_share_name,omitempty"`
|
||||
SMBPath string `json:"smb_path,omitempty"`
|
||||
SMBComment string `json:"smb_comment,omitempty"`
|
||||
SMBGuestOK bool `json:"smb_guest_ok"`
|
||||
SMBReadOnly bool `json:"smb_read_only"`
|
||||
SMBBrowseable bool `json:"smb_browseable"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
}
|
||||
|
||||
// ListShares lists all shares
|
||||
func (s *Service) ListShares(ctx context.Context) ([]*Share, error) {
|
||||
query := `
|
||||
SELECT
|
||||
zs.id, zs.dataset_id, zd.name as dataset_name, zd.mount_point,
|
||||
zs.share_type, zs.nfs_enabled, zs.nfs_options, zs.nfs_clients,
|
||||
zs.smb_enabled, zs.smb_share_name, zs.smb_path, zs.smb_comment,
|
||||
zs.smb_guest_ok, zs.smb_read_only, zs.smb_browseable,
|
||||
zs.is_active, zs.created_at, zs.updated_at, zs.created_by
|
||||
FROM zfs_shares zs
|
||||
JOIN zfs_datasets zd ON zs.dataset_id = zd.id
|
||||
ORDER BY zd.name
|
||||
`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "does not exist") {
|
||||
s.logger.Warn("zfs_shares table does not exist, returning empty list")
|
||||
return []*Share{}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to list shares: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var shares []*Share
|
||||
for rows.Next() {
|
||||
var share Share
|
||||
var mountPoint sql.NullString
|
||||
var nfsOptions sql.NullString
|
||||
var smbShareName sql.NullString
|
||||
var smbPath sql.NullString
|
||||
var smbComment sql.NullString
|
||||
var nfsClients []string
|
||||
|
||||
err := rows.Scan(
|
||||
&share.ID, &share.DatasetID, &share.DatasetName, &mountPoint,
|
||||
&share.ShareType, &share.NFSEnabled, &nfsOptions, pq.Array(&nfsClients),
|
||||
&share.SMBEnabled, &smbShareName, &smbPath, &smbComment,
|
||||
&share.SMBGuestOK, &share.SMBReadOnly, &share.SMBBrowseable,
|
||||
&share.IsActive, &share.CreatedAt, &share.UpdatedAt, &share.CreatedBy,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to scan share row", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
share.NFSClients = nfsClients
|
||||
|
||||
if mountPoint.Valid {
|
||||
share.MountPoint = mountPoint.String
|
||||
}
|
||||
if nfsOptions.Valid {
|
||||
share.NFSOptions = nfsOptions.String
|
||||
}
|
||||
if smbShareName.Valid {
|
||||
share.SMBShareName = smbShareName.String
|
||||
}
|
||||
if smbPath.Valid {
|
||||
share.SMBPath = smbPath.String
|
||||
}
|
||||
if smbComment.Valid {
|
||||
share.SMBComment = smbComment.String
|
||||
}
|
||||
|
||||
shares = append(shares, &share)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating share rows: %w", err)
|
||||
}
|
||||
|
||||
return shares, nil
|
||||
}
|
||||
|
||||
// GetShare retrieves a share by ID
|
||||
func (s *Service) GetShare(ctx context.Context, shareID string) (*Share, error) {
|
||||
query := `
|
||||
SELECT
|
||||
zs.id, zs.dataset_id, zd.name as dataset_name, zd.mount_point,
|
||||
zs.share_type, zs.nfs_enabled, zs.nfs_options, zs.nfs_clients,
|
||||
zs.smb_enabled, zs.smb_share_name, zs.smb_path, zs.smb_comment,
|
||||
zs.smb_guest_ok, zs.smb_read_only, zs.smb_browseable,
|
||||
zs.is_active, zs.created_at, zs.updated_at, zs.created_by
|
||||
FROM zfs_shares zs
|
||||
JOIN zfs_datasets zd ON zs.dataset_id = zd.id
|
||||
WHERE zs.id = $1
|
||||
`
|
||||
|
||||
var share Share
|
||||
var mountPoint sql.NullString
|
||||
var nfsOptions sql.NullString
|
||||
var smbShareName sql.NullString
|
||||
var smbPath sql.NullString
|
||||
var smbComment sql.NullString
|
||||
var nfsClients []string
|
||||
|
||||
err := s.db.QueryRowContext(ctx, query, shareID).Scan(
|
||||
&share.ID, &share.DatasetID, &share.DatasetName, &mountPoint,
|
||||
&share.ShareType, &share.NFSEnabled, &nfsOptions, pq.Array(&nfsClients),
|
||||
&share.SMBEnabled, &smbShareName, &smbPath, &smbComment,
|
||||
&share.SMBGuestOK, &share.SMBReadOnly, &share.SMBBrowseable,
|
||||
&share.IsActive, &share.CreatedAt, &share.UpdatedAt, &share.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("share not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get share: %w", err)
|
||||
}
|
||||
|
||||
share.NFSClients = nfsClients
|
||||
|
||||
if mountPoint.Valid {
|
||||
share.MountPoint = mountPoint.String
|
||||
}
|
||||
if nfsOptions.Valid {
|
||||
share.NFSOptions = nfsOptions.String
|
||||
}
|
||||
if smbShareName.Valid {
|
||||
share.SMBShareName = smbShareName.String
|
||||
}
|
||||
if smbPath.Valid {
|
||||
share.SMBPath = smbPath.String
|
||||
}
|
||||
if smbComment.Valid {
|
||||
share.SMBComment = smbComment.String
|
||||
}
|
||||
|
||||
return &share, nil
|
||||
}
|
||||
|
||||
// CreateShareRequest represents a share creation request
|
||||
type CreateShareRequest struct {
|
||||
DatasetID string `json:"dataset_id" binding:"required"`
|
||||
NFSEnabled bool `json:"nfs_enabled"`
|
||||
NFSOptions string `json:"nfs_options"`
|
||||
NFSClients []string `json:"nfs_clients"`
|
||||
SMBEnabled bool `json:"smb_enabled"`
|
||||
SMBShareName string `json:"smb_share_name"`
|
||||
SMBPath string `json:"smb_path"`
|
||||
SMBComment string `json:"smb_comment"`
|
||||
SMBGuestOK bool `json:"smb_guest_ok"`
|
||||
SMBReadOnly bool `json:"smb_read_only"`
|
||||
SMBBrowseable bool `json:"smb_browseable"`
|
||||
}
|
||||
|
||||
// CreateShare creates a new share
|
||||
func (s *Service) CreateShare(ctx context.Context, req *CreateShareRequest, userID string) (*Share, error) {
|
||||
// Validate dataset exists and is a filesystem (not volume)
|
||||
// req.DatasetID can be either UUID or dataset name
|
||||
var datasetID, datasetType, datasetName, mountPoint string
|
||||
var mountPointNull sql.NullString
|
||||
|
||||
// Try to find by ID first (UUID)
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
"SELECT id, type, name, mount_point FROM zfs_datasets WHERE id = $1",
|
||||
req.DatasetID,
|
||||
).Scan(&datasetID, &datasetType, &datasetName, &mountPointNull)
|
||||
|
||||
// If not found by ID, try by name
|
||||
if err == sql.ErrNoRows {
|
||||
err = s.db.QueryRowContext(ctx,
|
||||
"SELECT id, type, name, mount_point FROM zfs_datasets WHERE name = $1",
|
||||
req.DatasetID,
|
||||
).Scan(&datasetID, &datasetType, &datasetName, &mountPointNull)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("dataset not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to validate dataset: %w", err)
|
||||
}
|
||||
|
||||
if mountPointNull.Valid {
|
||||
mountPoint = mountPointNull.String
|
||||
} else {
|
||||
mountPoint = "none"
|
||||
}
|
||||
|
||||
if datasetType != "filesystem" {
|
||||
return nil, fmt.Errorf("only filesystem datasets can be shared (not volumes)")
|
||||
}
|
||||
|
||||
// Determine share type
|
||||
shareType := "none"
|
||||
if req.NFSEnabled && req.SMBEnabled {
|
||||
shareType = "both"
|
||||
} else if req.NFSEnabled {
|
||||
shareType = "nfs"
|
||||
} else if req.SMBEnabled {
|
||||
shareType = "smb"
|
||||
} else {
|
||||
return nil, fmt.Errorf("at least one protocol (NFS or SMB) must be enabled")
|
||||
}
|
||||
|
||||
// Set default NFS options if not provided
|
||||
nfsOptions := req.NFSOptions
|
||||
if nfsOptions == "" {
|
||||
nfsOptions = "rw,sync,no_subtree_check"
|
||||
}
|
||||
|
||||
// Set default SMB share name if not provided
|
||||
smbShareName := req.SMBShareName
|
||||
if smbShareName == "" {
|
||||
// Extract dataset name from full path (e.g., "pool/dataset" -> "dataset")
|
||||
parts := strings.Split(datasetName, "/")
|
||||
smbShareName = parts[len(parts)-1]
|
||||
}
|
||||
|
||||
// Set SMB path (use mount_point if available, otherwise use dataset name)
|
||||
smbPath := req.SMBPath
|
||||
if smbPath == "" {
|
||||
if mountPoint != "" && mountPoint != "none" {
|
||||
smbPath = mountPoint
|
||||
} else {
|
||||
smbPath = fmt.Sprintf("/mnt/%s", strings.ReplaceAll(datasetName, "/", "_"))
|
||||
}
|
||||
}
|
||||
|
||||
// Insert into database
|
||||
query := `
|
||||
INSERT INTO zfs_shares (
|
||||
dataset_id, share_type, nfs_enabled, nfs_options, nfs_clients,
|
||||
smb_enabled, smb_share_name, smb_path, smb_comment,
|
||||
smb_guest_ok, smb_read_only, smb_browseable, is_active, created_by
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
var shareID string
|
||||
var createdAt, updatedAt time.Time
|
||||
|
||||
// Handle nfs_clients array - use empty array if nil
|
||||
nfsClients := req.NFSClients
|
||||
if nfsClients == nil {
|
||||
nfsClients = []string{}
|
||||
}
|
||||
|
||||
err = s.db.QueryRowContext(ctx, query,
|
||||
datasetID, shareType, req.NFSEnabled, nfsOptions, pq.Array(nfsClients),
|
||||
req.SMBEnabled, smbShareName, smbPath, req.SMBComment,
|
||||
req.SMBGuestOK, req.SMBReadOnly, req.SMBBrowseable, true, userID,
|
||||
).Scan(&shareID, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create share: %w", err)
|
||||
}
|
||||
|
||||
// Apply NFS export if enabled
|
||||
if req.NFSEnabled {
|
||||
if err := s.applyNFSExport(ctx, mountPoint, nfsOptions, req.NFSClients); err != nil {
|
||||
s.logger.Error("Failed to apply NFS export", "error", err, "share_id", shareID)
|
||||
// Don't fail the creation, but log the error
|
||||
}
|
||||
}
|
||||
|
||||
// Apply SMB share if enabled
|
||||
if req.SMBEnabled {
|
||||
if err := s.applySMBShare(ctx, smbShareName, smbPath, req.SMBComment, req.SMBGuestOK, req.SMBReadOnly, req.SMBBrowseable); err != nil {
|
||||
s.logger.Error("Failed to apply SMB share", "error", err, "share_id", shareID)
|
||||
// Don't fail the creation, but log the error
|
||||
}
|
||||
}
|
||||
|
||||
// Return the created share
|
||||
return s.GetShare(ctx, shareID)
|
||||
}
|
||||
|
||||
// UpdateShareRequest represents a share update request
|
||||
type UpdateShareRequest struct {
|
||||
NFSEnabled *bool `json:"nfs_enabled"`
|
||||
NFSOptions *string `json:"nfs_options"`
|
||||
NFSClients *[]string `json:"nfs_clients"`
|
||||
SMBEnabled *bool `json:"smb_enabled"`
|
||||
SMBShareName *string `json:"smb_share_name"`
|
||||
SMBComment *string `json:"smb_comment"`
|
||||
SMBGuestOK *bool `json:"smb_guest_ok"`
|
||||
SMBReadOnly *bool `json:"smb_read_only"`
|
||||
SMBBrowseable *bool `json:"smb_browseable"`
|
||||
IsActive *bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// UpdateShare updates an existing share
|
||||
func (s *Service) UpdateShare(ctx context.Context, shareID string, req *UpdateShareRequest) (*Share, error) {
|
||||
// Get current share
|
||||
share, err := s.GetShare(ctx, shareID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build update query dynamically
|
||||
updates := []string{}
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if req.NFSEnabled != nil {
|
||||
updates = append(updates, fmt.Sprintf("nfs_enabled = $%d", argIndex))
|
||||
args = append(args, *req.NFSEnabled)
|
||||
argIndex++
|
||||
}
|
||||
if req.NFSOptions != nil {
|
||||
updates = append(updates, fmt.Sprintf("nfs_options = $%d", argIndex))
|
||||
args = append(args, *req.NFSOptions)
|
||||
argIndex++
|
||||
}
|
||||
if req.NFSClients != nil {
|
||||
updates = append(updates, fmt.Sprintf("nfs_clients = $%d", argIndex))
|
||||
args = append(args, pq.Array(*req.NFSClients))
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBEnabled != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_enabled = $%d", argIndex))
|
||||
args = append(args, *req.SMBEnabled)
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBShareName != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_share_name = $%d", argIndex))
|
||||
args = append(args, *req.SMBShareName)
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBComment != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_comment = $%d", argIndex))
|
||||
args = append(args, *req.SMBComment)
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBGuestOK != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_guest_ok = $%d", argIndex))
|
||||
args = append(args, *req.SMBGuestOK)
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBReadOnly != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_read_only = $%d", argIndex))
|
||||
args = append(args, *req.SMBReadOnly)
|
||||
argIndex++
|
||||
}
|
||||
if req.SMBBrowseable != nil {
|
||||
updates = append(updates, fmt.Sprintf("smb_browseable = $%d", argIndex))
|
||||
args = append(args, *req.SMBBrowseable)
|
||||
argIndex++
|
||||
}
|
||||
if req.IsActive != nil {
|
||||
updates = append(updates, fmt.Sprintf("is_active = $%d", argIndex))
|
||||
args = append(args, *req.IsActive)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if len(updates) == 0 {
|
||||
return share, nil // No changes
|
||||
}
|
||||
|
||||
// Update share_type based on enabled protocols
|
||||
nfsEnabled := share.NFSEnabled
|
||||
smbEnabled := share.SMBEnabled
|
||||
if req.NFSEnabled != nil {
|
||||
nfsEnabled = *req.NFSEnabled
|
||||
}
|
||||
if req.SMBEnabled != nil {
|
||||
smbEnabled = *req.SMBEnabled
|
||||
}
|
||||
|
||||
shareType := "none"
|
||||
if nfsEnabled && smbEnabled {
|
||||
shareType = "both"
|
||||
} else if nfsEnabled {
|
||||
shareType = "nfs"
|
||||
} else if smbEnabled {
|
||||
shareType = "smb"
|
||||
}
|
||||
|
||||
updates = append(updates, fmt.Sprintf("share_type = $%d", argIndex))
|
||||
args = append(args, shareType)
|
||||
argIndex++
|
||||
|
||||
updates = append(updates, fmt.Sprintf("updated_at = NOW()"))
|
||||
args = append(args, shareID)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE zfs_shares
|
||||
SET %s
|
||||
WHERE id = $%d
|
||||
`, strings.Join(updates, ", "), argIndex)
|
||||
|
||||
_, err = s.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update share: %w", err)
|
||||
}
|
||||
|
||||
// Re-apply NFS export if NFS is enabled
|
||||
if nfsEnabled {
|
||||
nfsOptions := share.NFSOptions
|
||||
if req.NFSOptions != nil {
|
||||
nfsOptions = *req.NFSOptions
|
||||
}
|
||||
nfsClients := share.NFSClients
|
||||
if req.NFSClients != nil {
|
||||
nfsClients = *req.NFSClients
|
||||
}
|
||||
if err := s.applyNFSExport(ctx, share.MountPoint, nfsOptions, nfsClients); err != nil {
|
||||
s.logger.Error("Failed to apply NFS export", "error", err, "share_id", shareID)
|
||||
}
|
||||
} else {
|
||||
// Remove NFS export if disabled
|
||||
if err := s.removeNFSExport(ctx, share.MountPoint); err != nil {
|
||||
s.logger.Error("Failed to remove NFS export", "error", err, "share_id", shareID)
|
||||
}
|
||||
}
|
||||
|
||||
// Re-apply SMB share if SMB is enabled
|
||||
if smbEnabled {
|
||||
smbShareName := share.SMBShareName
|
||||
if req.SMBShareName != nil {
|
||||
smbShareName = *req.SMBShareName
|
||||
}
|
||||
smbPath := share.SMBPath
|
||||
smbComment := share.SMBComment
|
||||
if req.SMBComment != nil {
|
||||
smbComment = *req.SMBComment
|
||||
}
|
||||
smbGuestOK := share.SMBGuestOK
|
||||
if req.SMBGuestOK != nil {
|
||||
smbGuestOK = *req.SMBGuestOK
|
||||
}
|
||||
smbReadOnly := share.SMBReadOnly
|
||||
if req.SMBReadOnly != nil {
|
||||
smbReadOnly = *req.SMBReadOnly
|
||||
}
|
||||
smbBrowseable := share.SMBBrowseable
|
||||
if req.SMBBrowseable != nil {
|
||||
smbBrowseable = *req.SMBBrowseable
|
||||
}
|
||||
if err := s.applySMBShare(ctx, smbShareName, smbPath, smbComment, smbGuestOK, smbReadOnly, smbBrowseable); err != nil {
|
||||
s.logger.Error("Failed to apply SMB share", "error", err, "share_id", shareID)
|
||||
}
|
||||
} else {
|
||||
// Remove SMB share if disabled
|
||||
if err := s.removeSMBShare(ctx, share.SMBShareName); err != nil {
|
||||
s.logger.Error("Failed to remove SMB share", "error", err, "share_id", shareID)
|
||||
}
|
||||
}
|
||||
|
||||
return s.GetShare(ctx, shareID)
|
||||
}
|
||||
|
||||
// DeleteShare deletes a share
|
||||
func (s *Service) DeleteShare(ctx context.Context, shareID string) error {
|
||||
// Get share to get mount point and share name
|
||||
share, err := s.GetShare(ctx, shareID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove NFS export
|
||||
if share.NFSEnabled {
|
||||
if err := s.removeNFSExport(ctx, share.MountPoint); err != nil {
|
||||
s.logger.Error("Failed to remove NFS export", "error", err, "share_id", shareID)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove SMB share
|
||||
if share.SMBEnabled {
|
||||
if err := s.removeSMBShare(ctx, share.SMBShareName); err != nil {
|
||||
s.logger.Error("Failed to remove SMB share", "error", err, "share_id", shareID)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete from database
|
||||
_, err = s.db.ExecContext(ctx, "DELETE FROM zfs_shares WHERE id = $1", shareID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete share: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNFSExport adds or updates an NFS export in /etc/exports
|
||||
func (s *Service) applyNFSExport(ctx context.Context, mountPoint, options string, clients []string) error {
|
||||
if mountPoint == "" || mountPoint == "none" {
|
||||
return fmt.Errorf("mount point is required for NFS export")
|
||||
}
|
||||
|
||||
// Build client list (default to * if empty)
|
||||
clientList := "*"
|
||||
if len(clients) > 0 {
|
||||
clientList = strings.Join(clients, " ")
|
||||
}
|
||||
|
||||
// Build export line
|
||||
exportLine := fmt.Sprintf("%s %s(%s)", mountPoint, clientList, options)
|
||||
|
||||
// Read current /etc/exports
|
||||
exportsPath := "/etc/exports"
|
||||
exportsContent, err := os.ReadFile(exportsPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to read exports file: %w", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(exportsContent), "\n")
|
||||
var newLines []string
|
||||
found := false
|
||||
|
||||
// Check if this mount point already exists
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
newLines = append(newLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this line is for our mount point
|
||||
if strings.HasPrefix(line, mountPoint+" ") {
|
||||
newLines = append(newLines, exportLine)
|
||||
found = true
|
||||
} else {
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// Add if not found
|
||||
if !found {
|
||||
newLines = append(newLines, exportLine)
|
||||
}
|
||||
|
||||
// Write back to file
|
||||
newContent := strings.Join(newLines, "\n") + "\n"
|
||||
if err := os.WriteFile(exportsPath, []byte(newContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write exports file: %w", err)
|
||||
}
|
||||
|
||||
// Apply exports
|
||||
cmd := exec.CommandContext(ctx, "sudo", "exportfs", "-ra")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to apply exports: %s: %w", string(output), err)
|
||||
}
|
||||
|
||||
s.logger.Info("NFS export applied", "mount_point", mountPoint, "clients", clientList)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeNFSExport removes an NFS export from /etc/exports
|
||||
func (s *Service) removeNFSExport(ctx context.Context, mountPoint string) error {
|
||||
if mountPoint == "" || mountPoint == "none" {
|
||||
return nil // Nothing to remove
|
||||
}
|
||||
|
||||
exportsPath := "/etc/exports"
|
||||
exportsContent, err := os.ReadFile(exportsPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // File doesn't exist, nothing to remove
|
||||
}
|
||||
return fmt.Errorf("failed to read exports file: %w", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(exportsContent), "\n")
|
||||
var newLines []string
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
newLines = append(newLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip lines for this mount point
|
||||
if strings.HasPrefix(line, mountPoint+" ") {
|
||||
continue
|
||||
}
|
||||
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
|
||||
// Write back to file
|
||||
newContent := strings.Join(newLines, "\n")
|
||||
if newContent != "" && !strings.HasSuffix(newContent, "\n") {
|
||||
newContent += "\n"
|
||||
}
|
||||
if err := os.WriteFile(exportsPath, []byte(newContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write exports file: %w", err)
|
||||
}
|
||||
|
||||
// Apply exports
|
||||
cmd := exec.CommandContext(ctx, "sudo", "exportfs", "-ra")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to apply exports: %s: %w", string(output), err)
|
||||
}
|
||||
|
||||
s.logger.Info("NFS export removed", "mount_point", mountPoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
// applySMBShare adds or updates an SMB share in /etc/samba/smb.conf
|
||||
func (s *Service) applySMBShare(ctx context.Context, shareName, path, comment string, guestOK, readOnly, browseable bool) error {
|
||||
if shareName == "" {
|
||||
return fmt.Errorf("SMB share name is required")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("SMB path is required")
|
||||
}
|
||||
|
||||
smbConfPath := "/etc/samba/smb.conf"
|
||||
smbContent, err := os.ReadFile(smbConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read smb.conf: %w", err)
|
||||
}
|
||||
|
||||
// Parse and update smb.conf
|
||||
lines := strings.Split(string(smbContent), "\n")
|
||||
var newLines []string
|
||||
inShare := false
|
||||
shareStart := -1
|
||||
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// Check if we're entering our share section
|
||||
if strings.HasPrefix(trimmed, "[") && strings.HasSuffix(trimmed, "]") {
|
||||
sectionName := trimmed[1 : len(trimmed)-1]
|
||||
if sectionName == shareName {
|
||||
inShare = true
|
||||
shareStart = i
|
||||
continue
|
||||
} else if inShare {
|
||||
// We've left our share section, insert the share config here
|
||||
newLines = append(newLines, s.buildSMBShareConfig(shareName, path, comment, guestOK, readOnly, browseable))
|
||||
inShare = false
|
||||
}
|
||||
}
|
||||
|
||||
if inShare {
|
||||
// Skip lines until we find the next section or end of file
|
||||
continue
|
||||
}
|
||||
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
|
||||
// If we were still in the share at the end, add it
|
||||
if inShare {
|
||||
newLines = append(newLines, s.buildSMBShareConfig(shareName, path, comment, guestOK, readOnly, browseable))
|
||||
} else if shareStart == -1 {
|
||||
// Share doesn't exist, add it at the end
|
||||
newLines = append(newLines, "")
|
||||
newLines = append(newLines, s.buildSMBShareConfig(shareName, path, comment, guestOK, readOnly, browseable))
|
||||
}
|
||||
|
||||
// Write back to file
|
||||
newContent := strings.Join(newLines, "\n")
|
||||
if err := os.WriteFile(smbConfPath, []byte(newContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write smb.conf: %w", err)
|
||||
}
|
||||
|
||||
// Reload Samba
|
||||
cmd := exec.CommandContext(ctx, "sudo", "systemctl", "reload", "smbd")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
// Try restart if reload fails
|
||||
cmd = exec.CommandContext(ctx, "sudo", "systemctl", "restart", "smbd")
|
||||
if output2, err2 := cmd.CombinedOutput(); err2 != nil {
|
||||
return fmt.Errorf("failed to reload/restart smbd: %s / %s: %w", string(output), string(output2), err2)
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info("SMB share applied", "share_name", shareName, "path", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildSMBShareConfig builds the SMB share configuration block
|
||||
func (s *Service) buildSMBShareConfig(shareName, path, comment string, guestOK, readOnly, browseable bool) string {
|
||||
var config []string
|
||||
config = append(config, fmt.Sprintf("[%s]", shareName))
|
||||
if comment != "" {
|
||||
config = append(config, fmt.Sprintf(" comment = %s", comment))
|
||||
}
|
||||
config = append(config, fmt.Sprintf(" path = %s", path))
|
||||
if guestOK {
|
||||
config = append(config, " guest ok = yes")
|
||||
} else {
|
||||
config = append(config, " guest ok = no")
|
||||
}
|
||||
if readOnly {
|
||||
config = append(config, " read only = yes")
|
||||
} else {
|
||||
config = append(config, " read only = no")
|
||||
}
|
||||
if browseable {
|
||||
config = append(config, " browseable = yes")
|
||||
} else {
|
||||
config = append(config, " browseable = no")
|
||||
}
|
||||
return strings.Join(config, "\n")
|
||||
}
|
||||
|
||||
// removeSMBShare removes an SMB share from /etc/samba/smb.conf
|
||||
func (s *Service) removeSMBShare(ctx context.Context, shareName string) error {
|
||||
if shareName == "" {
|
||||
return nil // Nothing to remove
|
||||
}
|
||||
|
||||
smbConfPath := "/etc/samba/smb.conf"
|
||||
smbContent, err := os.ReadFile(smbConfPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // File doesn't exist, nothing to remove
|
||||
}
|
||||
return fmt.Errorf("failed to read smb.conf: %w", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(smbContent), "\n")
|
||||
var newLines []string
|
||||
inShare := false
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// Check if we're entering our share section
|
||||
if strings.HasPrefix(trimmed, "[") && strings.HasSuffix(trimmed, "]") {
|
||||
sectionName := trimmed[1 : len(trimmed)-1]
|
||||
if sectionName == shareName {
|
||||
inShare = true
|
||||
continue
|
||||
} else if inShare {
|
||||
// We've left our share section
|
||||
inShare = false
|
||||
}
|
||||
}
|
||||
|
||||
if inShare {
|
||||
// Skip lines in this share section
|
||||
continue
|
||||
}
|
||||
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
|
||||
// Write back to file
|
||||
newContent := strings.Join(newLines, "\n")
|
||||
if err := os.WriteFile(smbConfPath, []byte(newContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write smb.conf: %w", err)
|
||||
}
|
||||
|
||||
// Reload Samba
|
||||
cmd := exec.CommandContext(ctx, "sudo", "systemctl", "reload", "smbd")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
// Try restart if reload fails
|
||||
cmd = exec.CommandContext(ctx, "sudo", "systemctl", "restart", "smbd")
|
||||
if output2, err2 := cmd.CombinedOutput(); err2 != nil {
|
||||
return fmt.Errorf("failed to reload/restart smbd: %s / %s: %w", string(output), string(output2), err2)
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info("SMB share removed", "share_name", shareName)
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user