package main import ( "crypto/sha256" "fmt" "io" "net/http" "os" "os/exec" "path/filepath" "strings" "time" ) func isURL(path string) bool { return strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") } func downloadOrCopyImage(config *Config) error { // Check if it's a local file if !isURL(config.ImageURL) { return handleLocalImage(config) } // It's a URL, download it return downloadImage(config) } func handleLocalImage(config *Config) error { localPath := config.ImageURL // Check if file exists if _, err := os.Stat(localPath); os.IsNotExist(err) { return fmt.Errorf("local image file not found: %s", localPath) } fmt.Printf("Using local image: %s\n", localPath) // Verify image if err := verifyImage(localPath); err != nil { return fmt.Errorf("local image verification failed: %w", err) } fmt.Println("Local image verification passed!") // Check if it's already in /tmp, if not copy it filename := filepath.Base(localPath) tmpPath := filepath.Join("/tmp", filename) absLocal, _ := filepath.Abs(localPath) absTmp, _ := filepath.Abs(tmpPath) if absLocal != absTmp { // Check if already exists in /tmp if _, err := os.Stat(tmpPath); err == nil { fmt.Printf("Image already exists in /tmp, verifying...\n") if err := verifyImage(tmpPath); err != nil { fmt.Println("Cached image corrupted, copying fresh...") if err := copyFile(localPath, tmpPath); err != nil { return fmt.Errorf("failed to copy image to /tmp: %w", err) } } else { fmt.Println("Using cached image from /tmp") } } else { fmt.Printf("Copying image to /tmp...\n") if err := copyFile(localPath, tmpPath); err != nil { return fmt.Errorf("failed to copy image to /tmp: %w", err) } fmt.Println("Image copied successfully!") } config.ImageURL = tmpPath } else { // Already in /tmp, ensure config path is normalized config.ImageURL = tmpPath } return nil } func copyFile(src, dst string) error { sourceFile, err := os.Open(src) if err != nil { return err } defer sourceFile.Close() destFile, err := os.Create(dst) if err != nil { return err } defer destFile.Close() // Get file size for progress fileInfo, err := sourceFile.Stat() if err != nil { return err } fileSize := fileInfo.Size() fmt.Printf("Copying %s (%.2f MB)...\n", filepath.Base(src), float64(fileSize)/(1024*1024)) // Copy with progress buf := make([]byte, 1024*1024) // 1MB buffer var written int64 lastProgress := 0 for { nr, err := sourceFile.Read(buf) if nr > 0 { nw, err := destFile.Write(buf[0:nr]) if err != nil { return err } if nr != nw { return fmt.Errorf("short write") } written += int64(nw) // Show progress every 10% if fileSize > 0 { progress := int(float64(written) / float64(fileSize) * 100) if progress >= lastProgress+10 { fmt.Printf("Progress: %d%%\n", progress) lastProgress = progress } } } if err == io.EOF { break } if err != nil { return err } } fmt.Println("Progress: 100%") return nil } func downloadImage(config *Config) error { filename := getFilenameFromURL(config.ImageURL) filepath := filepath.Join("/tmp", filename) if _, err := os.Stat(filepath); err == nil { fmt.Printf("Image already exists at %s\n", filepath) if err := verifyImage(filepath); err != nil { fmt.Printf("Image verification failed: %v\n", err) fmt.Println("Removing corrupted image and re-downloading...") os.Remove(filepath) } else { fmt.Println("Image verification passed, skipping download") config.ImageURL = filepath return nil } } maxRetries := 3 for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { fmt.Printf("\nRetry attempt %d/%d...\n", attempt, maxRetries) time.Sleep(time.Second * 2) } fmt.Printf("Downloading image from %s...\n", config.ImageURL) resp, err := http.Get(config.ImageURL) if err != nil { if attempt == maxRetries { return fmt.Errorf("failed to download image after %d attempts: %w", maxRetries, err) } fmt.Printf("Download failed: %v\n", err) continue } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { if attempt == maxRetries { return fmt.Errorf("bad status: %s", resp.Status) } fmt.Printf("Bad status: %s\n", resp.Status) continue } out, err := os.Create(filepath) if err != nil { return fmt.Errorf("failed to create file: %w", err) } defer out.Close() counter := &WriteCounter{Total: resp.ContentLength} _, err = io.Copy(out, io.TeeReader(resp.Body, counter)) if err != nil { out.Close() os.Remove(filepath) if attempt == maxRetries { return fmt.Errorf("failed to save image after %d attempts: %w", maxRetries, err) } fmt.Printf("\nDownload failed: %v\n", err) continue } out.Close() fmt.Println("\nDownload completed! Verifying image integrity...") if err := verifyImage(filepath); err != nil { os.Remove(filepath) if attempt == maxRetries { return fmt.Errorf("image verification failed after %d attempts: %w", maxRetries, err) } fmt.Printf("Verification failed: %v\n", err) continue } fmt.Println("Image verification passed!") config.ImageURL = filepath return nil } return fmt.Errorf("failed to download and verify image after %d attempts", maxRetries) } func verifyImage(filepath string) error { fmt.Printf("Checking image with qemu-img...\n") cmd := exec.Command("qemu-img", "check", filepath) output, err := cmd.CombinedOutput() if err != nil { return fmt.Errorf("qemu-img check failed: %w\nOutput: %s", err, string(output)) } fmt.Printf("Computing SHA256 checksum...\n") file, err := os.Open(filepath) if err != nil { return fmt.Errorf("failed to open file for checksum: %w", err) } defer file.Close() hash := sha256.New() if _, err := io.Copy(hash, file); err != nil { return fmt.Errorf("failed to compute checksum: %w", err) } checksum := fmt.Sprintf("%x", hash.Sum(nil)) fmt.Printf("SHA256: %s\n", checksum) return nil } func getFilenameFromURL(url string) string { parts := strings.Split(url, "/") return parts[len(parts)-1] } type WriteCounter struct { Total int64 Downloaded int64 } func (wc *WriteCounter) Write(p []byte) (int, error) { n := len(p) wc.Downloaded += int64(n) wc.printProgress() return n, nil } func (wc *WriteCounter) printProgress() { fmt.Printf("\r") if wc.Total > 0 { percent := float64(wc.Downloaded) / float64(wc.Total) * 100 fmt.Printf("Downloading... %.0f%% (%d/%d MB)", percent, wc.Downloaded/1024/1024, wc.Total/1024/1024) } else { fmt.Printf("Downloading... %d MB", wc.Downloaded/1024/1024) } }