package storage import ( "context" "database/sql" "encoding/json" "fmt" "time" "github.com/atlasos/calypso/internal/common/database" "github.com/atlasos/calypso/internal/common/logger" "github.com/google/uuid" ) // ReplicationService handles ZFS replication task management type ReplicationService struct { db *database.DB logger *logger.Logger } // NewReplicationService creates a new replication service func NewReplicationService(db *database.DB, log *logger.Logger) *ReplicationService { return &ReplicationService{ db: db, logger: log, } } // ReplicationTask represents a ZFS replication task type ReplicationTask struct { ID string `json:"id"` Name string `json:"name"` Direction string `json:"direction"` // "outbound" or "inbound" SourceDataset *string `json:"source_dataset,omitempty"` TargetHost *string `json:"target_host,omitempty"` TargetPort *int `json:"target_port,omitempty"` TargetUser *string `json:"target_user,omitempty"` TargetDataset *string `json:"target_dataset,omitempty"` TargetSSHKeyPath *string `json:"target_ssh_key_path,omitempty"` SourceHost *string `json:"source_host,omitempty"` SourcePort *int `json:"source_port,omitempty"` SourceUser *string `json:"source_user,omitempty"` LocalDataset *string `json:"local_dataset,omitempty"` ScheduleType *string `json:"schedule_type,omitempty"` ScheduleConfig map[string]interface{} `json:"schedule_config,omitempty"` Compression string `json:"compression"` Encryption bool `json:"encryption"` Recursive bool `json:"recursive"` Incremental bool `json:"incremental"` AutoSnapshot bool `json:"auto_snapshot"` Enabled bool `json:"enabled"` Status string `json:"status"` LastRunAt *time.Time `json:"last_run_at,omitempty"` LastRunStatus *string `json:"last_run_status,omitempty"` LastRunError *string `json:"last_run_error,omitempty"` NextRunAt *time.Time `json:"next_run_at,omitempty"` LastSnapshotSent *string `json:"last_snapshot_sent,omitempty"` LastSnapshotReceived *string `json:"last_snapshot_received,omitempty"` TotalRuns int `json:"total_runs"` SuccessfulRuns int `json:"successful_runs"` FailedRuns int `json:"failed_runs"` BytesSent int64 `json:"bytes_sent"` BytesReceived int64 `json:"bytes_received"` CreatedBy string `json:"created_by,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // CreateReplicationRequest represents a request to create a replication task type CreateReplicationRequest struct { Name string `json:"name" binding:"required"` Direction string `json:"direction" binding:"required"` // "outbound" or "inbound" SourceDataset *string `json:"source_dataset"` TargetHost *string `json:"target_host"` TargetPort *int `json:"target_port"` TargetUser *string `json:"target_user"` TargetDataset *string `json:"target_dataset"` TargetSSHKeyPath *string `json:"target_ssh_key_path"` SourceHost *string `json:"source_host"` SourcePort *int `json:"source_port"` SourceUser *string `json:"source_user"` LocalDataset *string `json:"local_dataset"` ScheduleType *string `json:"schedule_type"` ScheduleConfig map[string]interface{} `json:"schedule_config"` Compression string `json:"compression"` Encryption bool `json:"encryption"` Recursive bool `json:"recursive"` Incremental bool `json:"incremental"` AutoSnapshot bool `json:"auto_snapshot"` Enabled bool `json:"enabled"` } // ListReplicationTasks lists all replication tasks, optionally filtered by direction func (s *ReplicationService) ListReplicationTasks(ctx context.Context, directionFilter string) ([]*ReplicationTask, error) { var query string var args []interface{} if directionFilter != "" { query = ` SELECT id, name, direction, source_dataset, target_host, target_port, target_user, target_dataset, target_ssh_key_path, source_host, source_port, source_user, local_dataset, schedule_type, schedule_config, compression, encryption, recursive, incremental, auto_snapshot, enabled, status, last_run_at, last_run_status, last_run_error, next_run_at, last_snapshot_sent, last_snapshot_received, total_runs, successful_runs, failed_runs, bytes_sent, bytes_received, created_by, created_at, updated_at FROM replication_tasks WHERE direction = $1 ORDER BY created_at DESC ` args = []interface{}{directionFilter} } else { query = ` SELECT id, name, direction, source_dataset, target_host, target_port, target_user, target_dataset, target_ssh_key_path, source_host, source_port, source_user, local_dataset, schedule_type, schedule_config, compression, encryption, recursive, incremental, auto_snapshot, enabled, status, last_run_at, last_run_status, last_run_error, next_run_at, last_snapshot_sent, last_snapshot_received, total_runs, successful_runs, failed_runs, bytes_sent, bytes_received, created_by, created_at, updated_at FROM replication_tasks ORDER BY direction, created_at DESC ` args = []interface{}{} } rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to query replication tasks: %w", err) } defer rows.Close() var tasks []*ReplicationTask for rows.Next() { task, err := s.scanReplicationTask(rows) if err != nil { s.logger.Error("Failed to scan replication task", "error", err) continue } tasks = append(tasks, task) } return tasks, rows.Err() } // GetReplicationTask retrieves a replication task by ID func (s *ReplicationService) GetReplicationTask(ctx context.Context, id string) (*ReplicationTask, error) { query := ` SELECT id, name, direction, source_dataset, target_host, target_port, target_user, target_dataset, target_ssh_key_path, source_host, source_port, source_user, local_dataset, schedule_type, schedule_config, compression, encryption, recursive, incremental, auto_snapshot, enabled, status, last_run_at, last_run_status, last_run_error, next_run_at, last_snapshot_sent, last_snapshot_received, total_runs, successful_runs, failed_runs, bytes_sent, bytes_received, created_by, created_at, updated_at FROM replication_tasks WHERE id = $1 ` row := s.db.QueryRowContext(ctx, query, id) task, err := s.scanReplicationTaskRow(row) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("replication task not found") } return nil, fmt.Errorf("failed to get replication task: %w", err) } return task, nil } // CreateReplicationTask creates a new replication task func (s *ReplicationService) CreateReplicationTask(ctx context.Context, req *CreateReplicationRequest, createdBy string) (*ReplicationTask, error) { id := uuid.New().String() // Validate direction-specific fields if req.Direction == "outbound" { if req.SourceDataset == nil || req.TargetHost == nil || req.TargetDataset == nil { return nil, fmt.Errorf("outbound replication requires source_dataset, target_host, and target_dataset") } } else if req.Direction == "inbound" { if req.SourceHost == nil || req.SourceDataset == nil || req.LocalDataset == nil { return nil, fmt.Errorf("inbound replication requires source_host, source_dataset, and local_dataset") } } else { return nil, fmt.Errorf("invalid direction: must be 'outbound' or 'inbound'") } // Set defaults if req.Compression == "" { req.Compression = "lz4" } if req.TargetPort == nil { defaultPort := 22 req.TargetPort = &defaultPort } if req.SourcePort == nil { defaultPort := 22 req.SourcePort = &defaultPort } if req.TargetUser == nil { defaultUser := "root" req.TargetUser = &defaultUser } if req.SourceUser == nil { defaultUser := "root" req.SourceUser = &defaultUser } // Marshal schedule config to JSON var scheduleConfigJSON sql.NullString if req.ScheduleConfig != nil { configJSON, err := json.Marshal(req.ScheduleConfig) if err != nil { return nil, fmt.Errorf("failed to marshal schedule config: %w", err) } scheduleConfigJSON = sql.NullString{String: string(configJSON), Valid: true} } query := ` INSERT INTO replication_tasks ( id, name, direction, source_dataset, target_host, target_port, target_user, target_dataset, target_ssh_key_path, source_host, source_port, source_user, local_dataset, schedule_type, schedule_config, compression, encryption, recursive, incremental, auto_snapshot, enabled, status, created_by, created_at, updated_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, NOW(), NOW() ) RETURNING id, name, direction, source_dataset, target_host, target_port, target_user, target_dataset, target_ssh_key_path, source_host, source_port, source_user, local_dataset, schedule_type, schedule_config, compression, encryption, recursive, incremental, auto_snapshot, enabled, status, last_run_at, last_run_status, last_run_error, next_run_at, last_snapshot_sent, last_snapshot_received, total_runs, successful_runs, failed_runs, bytes_sent, bytes_received, created_by, created_at, updated_at ` var scheduleTypeStr sql.NullString if req.ScheduleType != nil { scheduleTypeStr = sql.NullString{String: *req.ScheduleType, Valid: true} } row := s.db.QueryRowContext(ctx, query, id, req.Name, req.Direction, req.SourceDataset, req.TargetHost, req.TargetPort, req.TargetUser, req.TargetDataset, req.TargetSSHKeyPath, req.SourceHost, req.SourcePort, req.SourceUser, req.LocalDataset, scheduleTypeStr, scheduleConfigJSON, req.Compression, req.Encryption, req.Recursive, req.Incremental, req.AutoSnapshot, req.Enabled, "idle", createdBy, ) task, err := s.scanReplicationTaskRow(row) if err != nil { return nil, fmt.Errorf("failed to create replication task: %w", err) } s.logger.Info("Replication task created", "id", id, "name", req.Name, "direction", req.Direction) return task, nil } // UpdateReplicationTask updates an existing replication task func (s *ReplicationService) UpdateReplicationTask(ctx context.Context, id string, req *CreateReplicationRequest) (*ReplicationTask, error) { // Validate direction-specific fields if req.Direction == "outbound" { if req.SourceDataset == nil || req.TargetHost == nil || req.TargetDataset == nil { return nil, fmt.Errorf("outbound replication requires source_dataset, target_host, and target_dataset") } } else if req.Direction == "inbound" { if req.SourceHost == nil || req.SourceDataset == nil || req.LocalDataset == nil { return nil, fmt.Errorf("inbound replication requires source_host, source_dataset, and local_dataset") } } else { return nil, fmt.Errorf("invalid direction: must be 'outbound' or 'inbound'") } // Set defaults if req.Compression == "" { req.Compression = "lz4" } if req.TargetPort == nil { defaultPort := 22 req.TargetPort = &defaultPort } if req.SourcePort == nil { defaultPort := 22 req.SourcePort = &defaultPort } if req.TargetUser == nil { defaultUser := "root" req.TargetUser = &defaultUser } if req.SourceUser == nil { defaultUser := "root" req.SourceUser = &defaultUser } // Marshal schedule config to JSON var scheduleConfigJSON sql.NullString if req.ScheduleConfig != nil { configJSON, err := json.Marshal(req.ScheduleConfig) if err != nil { return nil, fmt.Errorf("failed to marshal schedule config: %w", err) } scheduleConfigJSON = sql.NullString{String: string(configJSON), Valid: true} } var scheduleTypeStr sql.NullString if req.ScheduleType != nil { scheduleTypeStr = sql.NullString{String: *req.ScheduleType, Valid: true} } query := ` UPDATE replication_tasks SET name = $1, direction = $2, source_dataset = $3, target_host = $4, target_port = $5, target_user = $6, target_dataset = $7, target_ssh_key_path = $8, source_host = $9, source_port = $10, source_user = $11, local_dataset = $12, schedule_type = $13, schedule_config = $14, compression = $15, encryption = $16, recursive = $17, incremental = $18, auto_snapshot = $19, enabled = $20, updated_at = NOW() WHERE id = $21 RETURNING id, name, direction, source_dataset, target_host, target_port, target_user, target_dataset, target_ssh_key_path, source_host, source_port, source_user, local_dataset, schedule_type, schedule_config, compression, encryption, recursive, incremental, auto_snapshot, enabled, status, last_run_at, last_run_status, last_run_error, next_run_at, last_snapshot_sent, last_snapshot_received, total_runs, successful_runs, failed_runs, bytes_sent, bytes_received, created_by, created_at, updated_at ` row := s.db.QueryRowContext(ctx, query, req.Name, req.Direction, req.SourceDataset, req.TargetHost, req.TargetPort, req.TargetUser, req.TargetDataset, req.TargetSSHKeyPath, req.SourceHost, req.SourcePort, req.SourceUser, req.LocalDataset, scheduleTypeStr, scheduleConfigJSON, req.Compression, req.Encryption, req.Recursive, req.Incremental, req.AutoSnapshot, req.Enabled, id, ) task, err := s.scanReplicationTaskRow(row) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("replication task not found") } return nil, fmt.Errorf("failed to update replication task: %w", err) } s.logger.Info("Replication task updated", "id", id) return task, nil } // DeleteReplicationTask deletes a replication task func (s *ReplicationService) DeleteReplicationTask(ctx context.Context, id string) error { query := `DELETE FROM replication_tasks WHERE id = $1` result, err := s.db.ExecContext(ctx, query, id) if err != nil { return fmt.Errorf("failed to delete replication task: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rowsAffected == 0 { return fmt.Errorf("replication task not found") } s.logger.Info("Replication task deleted", "id", id) return nil } // scanReplicationTaskRow scans a single replication task row func (s *ReplicationService) scanReplicationTaskRow(row *sql.Row) (*ReplicationTask, error) { var task ReplicationTask var sourceDataset, targetHost, targetUser, targetDataset, targetSSHKeyPath sql.NullString var sourceHost, sourceUser, localDataset sql.NullString var targetPort, sourcePort sql.NullInt64 var scheduleType, scheduleConfigJSON sql.NullString var lastRunAt, nextRunAt sql.NullTime var lastRunStatus, lastRunError, lastSnapshotSent, lastSnapshotReceived sql.NullString var createdBy sql.NullString err := row.Scan( &task.ID, &task.Name, &task.Direction, &sourceDataset, &targetHost, &targetPort, &targetUser, &targetDataset, &targetSSHKeyPath, &sourceHost, &sourcePort, &sourceUser, &localDataset, &scheduleType, &scheduleConfigJSON, &task.Compression, &task.Encryption, &task.Recursive, &task.Incremental, &task.AutoSnapshot, &task.Enabled, &task.Status, &lastRunAt, &lastRunStatus, &lastRunError, &nextRunAt, &lastSnapshotSent, &lastSnapshotReceived, &task.TotalRuns, &task.SuccessfulRuns, &task.FailedRuns, &task.BytesSent, &task.BytesReceived, &createdBy, &task.CreatedAt, &task.UpdatedAt, ) if err != nil { return nil, err } // Handle nullable fields if sourceDataset.Valid { task.SourceDataset = &sourceDataset.String } if targetHost.Valid { task.TargetHost = &targetHost.String } if targetPort.Valid { port := int(targetPort.Int64) task.TargetPort = &port } if targetUser.Valid { task.TargetUser = &targetUser.String } if targetDataset.Valid { task.TargetDataset = &targetDataset.String } if targetSSHKeyPath.Valid { task.TargetSSHKeyPath = &targetSSHKeyPath.String } if sourceHost.Valid { task.SourceHost = &sourceHost.String } if sourcePort.Valid { port := int(sourcePort.Int64) task.SourcePort = &port } if sourceUser.Valid { task.SourceUser = &sourceUser.String } if localDataset.Valid { task.LocalDataset = &localDataset.String } if scheduleType.Valid { task.ScheduleType = &scheduleType.String } if scheduleConfigJSON.Valid { if err := json.Unmarshal([]byte(scheduleConfigJSON.String), &task.ScheduleConfig); err != nil { return nil, fmt.Errorf("failed to unmarshal schedule config: %w", err) } } if lastRunAt.Valid { task.LastRunAt = &lastRunAt.Time } if lastRunStatus.Valid { task.LastRunStatus = &lastRunStatus.String } if lastRunError.Valid { task.LastRunError = &lastRunError.String } if nextRunAt.Valid { task.NextRunAt = &nextRunAt.Time } if lastSnapshotSent.Valid { task.LastSnapshotSent = &lastSnapshotSent.String } if lastSnapshotReceived.Valid { task.LastSnapshotReceived = &lastSnapshotReceived.String } if createdBy.Valid { task.CreatedBy = createdBy.String } return &task, nil } // scanReplicationTask scans a replication task from rows func (s *ReplicationService) scanReplicationTask(rows *sql.Rows) (*ReplicationTask, error) { var task ReplicationTask var sourceDataset, targetHost, targetUser, targetDataset, targetSSHKeyPath sql.NullString var sourceHost, sourceUser, localDataset sql.NullString var targetPort, sourcePort sql.NullInt64 var scheduleType, scheduleConfigJSON sql.NullString var lastRunAt, nextRunAt sql.NullTime var lastRunStatus, lastRunError, lastSnapshotSent, lastSnapshotReceived sql.NullString var createdBy sql.NullString err := rows.Scan( &task.ID, &task.Name, &task.Direction, &sourceDataset, &targetHost, &targetPort, &targetUser, &targetDataset, &targetSSHKeyPath, &sourceHost, &sourcePort, &sourceUser, &localDataset, &scheduleType, &scheduleConfigJSON, &task.Compression, &task.Encryption, &task.Recursive, &task.Incremental, &task.AutoSnapshot, &task.Enabled, &task.Status, &lastRunAt, &lastRunStatus, &lastRunError, &nextRunAt, &lastSnapshotSent, &lastSnapshotReceived, &task.TotalRuns, &task.SuccessfulRuns, &task.FailedRuns, &task.BytesSent, &task.BytesReceived, &createdBy, &task.CreatedAt, &task.UpdatedAt, ) if err != nil { return nil, err } // Handle nullable fields (same as scanReplicationTaskRow) if sourceDataset.Valid { task.SourceDataset = &sourceDataset.String } if targetHost.Valid { task.TargetHost = &targetHost.String } if targetPort.Valid { port := int(targetPort.Int64) task.TargetPort = &port } if targetUser.Valid { task.TargetUser = &targetUser.String } if targetDataset.Valid { task.TargetDataset = &targetDataset.String } if targetSSHKeyPath.Valid { task.TargetSSHKeyPath = &targetSSHKeyPath.String } if sourceHost.Valid { task.SourceHost = &sourceHost.String } if sourcePort.Valid { port := int(sourcePort.Int64) task.SourcePort = &port } if sourceUser.Valid { task.SourceUser = &sourceUser.String } if localDataset.Valid { task.LocalDataset = &localDataset.String } if scheduleType.Valid { task.ScheduleType = &scheduleType.String } if scheduleConfigJSON.Valid { if err := json.Unmarshal([]byte(scheduleConfigJSON.String), &task.ScheduleConfig); err != nil { return nil, fmt.Errorf("failed to unmarshal schedule config: %w", err) } } if lastRunAt.Valid { task.LastRunAt = &lastRunAt.Time } if lastRunStatus.Valid { task.LastRunStatus = &lastRunStatus.String } if lastRunError.Valid { task.LastRunError = &lastRunError.String } if nextRunAt.Valid { task.NextRunAt = &nextRunAt.Time } if lastSnapshotSent.Valid { task.LastSnapshotSent = &lastSnapshotSent.String } if lastSnapshotReceived.Valid { task.LastSnapshotReceived = &lastSnapshotReceived.String } if createdBy.Valid { task.CreatedBy = createdBy.String } return &task, nil }