277 lines
6.5 KiB
Go
277 lines
6.5 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
func isURL(path string) bool {
|
|
return strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://")
|
|
}
|
|
|
|
func downloadOrCopyImage(config *Config) error {
|
|
// Check if it's a local file
|
|
if !isURL(config.ImageURL) {
|
|
return handleLocalImage(config)
|
|
}
|
|
|
|
// It's a URL, download it
|
|
return downloadImage(config)
|
|
}
|
|
|
|
func handleLocalImage(config *Config) error {
|
|
localPath := config.ImageURL
|
|
|
|
// Check if file exists
|
|
if _, err := os.Stat(localPath); os.IsNotExist(err) {
|
|
return fmt.Errorf("local image file not found: %s", localPath)
|
|
}
|
|
|
|
fmt.Printf("Using local image: %s\n", localPath)
|
|
|
|
// Verify image
|
|
if err := verifyImage(localPath); err != nil {
|
|
return fmt.Errorf("local image verification failed: %w", err)
|
|
}
|
|
|
|
fmt.Println("Local image verification passed!")
|
|
|
|
// Check if it's already in /tmp, if not copy it
|
|
filename := filepath.Base(localPath)
|
|
tmpPath := filepath.Join("/tmp", filename)
|
|
|
|
absLocal, _ := filepath.Abs(localPath)
|
|
absTmp, _ := filepath.Abs(tmpPath)
|
|
|
|
if absLocal != absTmp {
|
|
// Check if already exists in /tmp
|
|
if _, err := os.Stat(tmpPath); err == nil {
|
|
fmt.Printf("Image already exists in /tmp, verifying...\n")
|
|
if err := verifyImage(tmpPath); err != nil {
|
|
fmt.Println("Cached image corrupted, copying fresh...")
|
|
if err := copyFile(localPath, tmpPath); err != nil {
|
|
return fmt.Errorf("failed to copy image to /tmp: %w", err)
|
|
}
|
|
} else {
|
|
fmt.Println("Using cached image from /tmp")
|
|
}
|
|
} else {
|
|
fmt.Printf("Copying image to /tmp...\n")
|
|
if err := copyFile(localPath, tmpPath); err != nil {
|
|
return fmt.Errorf("failed to copy image to /tmp: %w", err)
|
|
}
|
|
fmt.Println("Image copied successfully!")
|
|
}
|
|
config.ImageURL = tmpPath
|
|
} else {
|
|
// Already in /tmp, ensure config path is normalized
|
|
config.ImageURL = tmpPath
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func copyFile(src, dst string) error {
|
|
sourceFile, err := os.Open(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer sourceFile.Close()
|
|
|
|
destFile, err := os.Create(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer destFile.Close()
|
|
|
|
// Get file size for progress
|
|
fileInfo, err := sourceFile.Stat()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fileSize := fileInfo.Size()
|
|
|
|
fmt.Printf("Copying %s (%.2f MB)...\n", filepath.Base(src), float64(fileSize)/(1024*1024))
|
|
|
|
// Copy with progress
|
|
buf := make([]byte, 1024*1024) // 1MB buffer
|
|
var written int64
|
|
lastProgress := 0
|
|
|
|
for {
|
|
nr, err := sourceFile.Read(buf)
|
|
if nr > 0 {
|
|
nw, err := destFile.Write(buf[0:nr])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if nr != nw {
|
|
return fmt.Errorf("short write")
|
|
}
|
|
written += int64(nw)
|
|
|
|
// Show progress every 10%
|
|
if fileSize > 0 {
|
|
progress := int(float64(written) / float64(fileSize) * 100)
|
|
if progress >= lastProgress+10 {
|
|
fmt.Printf("Progress: %d%%\n", progress)
|
|
lastProgress = progress
|
|
}
|
|
}
|
|
}
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
fmt.Println("Progress: 100%")
|
|
return nil
|
|
}
|
|
|
|
func downloadImage(config *Config) error {
|
|
filename := getFilenameFromURL(config.ImageURL)
|
|
filepath := filepath.Join("/tmp", filename)
|
|
|
|
if _, err := os.Stat(filepath); err == nil {
|
|
fmt.Printf("Image already exists at %s\n", filepath)
|
|
|
|
if err := verifyImage(filepath); err != nil {
|
|
fmt.Printf("Image verification failed: %v\n", err)
|
|
fmt.Println("Removing corrupted image and re-downloading...")
|
|
os.Remove(filepath)
|
|
} else {
|
|
fmt.Println("Image verification passed, skipping download")
|
|
config.ImageURL = filepath
|
|
return nil
|
|
}
|
|
}
|
|
|
|
maxRetries := 3
|
|
for attempt := 1; attempt <= maxRetries; attempt++ {
|
|
if attempt > 1 {
|
|
fmt.Printf("\nRetry attempt %d/%d...\n", attempt, maxRetries)
|
|
time.Sleep(time.Second * 2)
|
|
}
|
|
|
|
fmt.Printf("Downloading image from %s...\n", config.ImageURL)
|
|
|
|
resp, err := http.Get(config.ImageURL)
|
|
if err != nil {
|
|
if attempt == maxRetries {
|
|
return fmt.Errorf("failed to download image after %d attempts: %w", maxRetries, err)
|
|
}
|
|
fmt.Printf("Download failed: %v\n", err)
|
|
continue
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
if attempt == maxRetries {
|
|
return fmt.Errorf("bad status: %s", resp.Status)
|
|
}
|
|
fmt.Printf("Bad status: %s\n", resp.Status)
|
|
continue
|
|
}
|
|
|
|
out, err := os.Create(filepath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create file: %w", err)
|
|
}
|
|
defer out.Close()
|
|
|
|
counter := &WriteCounter{Total: resp.ContentLength}
|
|
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
|
|
if err != nil {
|
|
out.Close()
|
|
os.Remove(filepath)
|
|
if attempt == maxRetries {
|
|
return fmt.Errorf("failed to save image after %d attempts: %w", maxRetries, err)
|
|
}
|
|
fmt.Printf("\nDownload failed: %v\n", err)
|
|
continue
|
|
}
|
|
out.Close()
|
|
|
|
fmt.Println("\nDownload completed! Verifying image integrity...")
|
|
|
|
if err := verifyImage(filepath); err != nil {
|
|
os.Remove(filepath)
|
|
if attempt == maxRetries {
|
|
return fmt.Errorf("image verification failed after %d attempts: %w", maxRetries, err)
|
|
}
|
|
fmt.Printf("Verification failed: %v\n", err)
|
|
continue
|
|
}
|
|
|
|
fmt.Println("Image verification passed!")
|
|
config.ImageURL = filepath
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("failed to download and verify image after %d attempts", maxRetries)
|
|
}
|
|
|
|
func verifyImage(filepath string) error {
|
|
fmt.Printf("Checking image with qemu-img...\n")
|
|
|
|
cmd := exec.Command("qemu-img", "check", filepath)
|
|
output, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("qemu-img check failed: %w\nOutput: %s", err, string(output))
|
|
}
|
|
|
|
fmt.Printf("Computing SHA256 checksum...\n")
|
|
file, err := os.Open(filepath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open file for checksum: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
hash := sha256.New()
|
|
if _, err := io.Copy(hash, file); err != nil {
|
|
return fmt.Errorf("failed to compute checksum: %w", err)
|
|
}
|
|
|
|
checksum := fmt.Sprintf("%x", hash.Sum(nil))
|
|
fmt.Printf("SHA256: %s\n", checksum)
|
|
|
|
return nil
|
|
}
|
|
|
|
func getFilenameFromURL(url string) string {
|
|
parts := strings.Split(url, "/")
|
|
return parts[len(parts)-1]
|
|
}
|
|
|
|
type WriteCounter struct {
|
|
Total int64
|
|
Downloaded int64
|
|
}
|
|
|
|
func (wc *WriteCounter) Write(p []byte) (int, error) {
|
|
n := len(p)
|
|
wc.Downloaded += int64(n)
|
|
wc.printProgress()
|
|
return n, nil
|
|
}
|
|
|
|
func (wc *WriteCounter) printProgress() {
|
|
fmt.Printf("\r")
|
|
if wc.Total > 0 {
|
|
percent := float64(wc.Downloaded) / float64(wc.Total) * 100
|
|
fmt.Printf("Downloading... %.0f%% (%d/%d MB)", percent, wc.Downloaded/1024/1024, wc.Total/1024/1024)
|
|
} else {
|
|
fmt.Printf("Downloading... %d MB", wc.Downloaded/1024/1024)
|
|
}
|
|
}
|