+# WEBHOOK_PUSHOVER_PRIORITY=0
+
# ----------------------------------------------------------------------
# Metriche / Prometheus
# ----------------------------------------------------------------------
diff --git a/internal/environment/detect.go b/internal/environment/detect.go
index d734702b..a6fd841d 100644
--- a/internal/environment/detect.go
+++ b/internal/environment/detect.go
@@ -11,6 +11,7 @@ import (
"strings"
"time"
+ "github.com/tis24dev/proxsave/internal/safeexec"
"github.com/tis24dev/proxsave/internal/types"
)
@@ -46,8 +47,7 @@ var (
"/etc/apt/sources.list.d/proxmox.list",
}
- lookPathFunc = exec.LookPath
- commandContextFunc = exec.CommandContext
+ lookPathFunc = exec.LookPath
readFileFunc = os.ReadFile
statFunc = os.Stat
@@ -341,7 +341,10 @@ func runCommand(command string, args ...string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), commandTimeout)
defer cancel()
- cmd := commandContextFunc(ctx, command, args...)
+ cmd, cmdErr := safeexec.TrustedCommandContext(ctx, command, args...)
+ if cmdErr != nil {
+ return "", cmdErr
+ }
output, err := cmd.Output()
if ctx.Err() == context.DeadlineExceeded {
return "", fmt.Errorf("command %s timed out", command)
diff --git a/internal/environment/detect_additional_test.go b/internal/environment/detect_additional_test.go
index 11e5c583..d8d347c9 100644
--- a/internal/environment/detect_additional_test.go
+++ b/internal/environment/detect_additional_test.go
@@ -3,6 +3,7 @@ package environment
import (
"context"
"os"
+ "os/exec"
"path/filepath"
"strings"
"testing"
@@ -151,7 +152,11 @@ func TestContainsAny(t *testing.T) {
// TestRunCommand tests command execution with timeout
func TestRunCommand(t *testing.T) {
// Test successful command
- output, err := runCommand("echo", "test")
+ echoPath, err := exec.LookPath("echo")
+ if err != nil {
+ t.Fatalf("LookPath(echo) failed: %v", err)
+ }
+ output, err := runCommand(echoPath, "test")
if err != nil {
t.Errorf("runCommand() error = %v", err)
}
diff --git a/internal/identity/identity.go b/internal/identity/identity.go
index f8f0c7e3..755dc24b 100644
--- a/internal/identity/identity.go
+++ b/internal/identity/identity.go
@@ -19,6 +19,7 @@ import (
"time"
"github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/safeexec"
)
const (
@@ -969,7 +970,11 @@ func setImmutableAttributeWithContext(ctx context.Context, path string, enable b
flag = "-i"
}
- cmd := exec.CommandContext(ctx, chattrPath, flag, path)
+ cmd, err := safeexec.TrustedCommandContext(ctx, chattrPath, flag, path)
+ if err != nil {
+ logDebug(logger, "Identity: immutable: chattr path rejected for %s: %v", path, err)
+ return nil
+ }
if err := cmd.Run(); err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
logDebug(logger, "Identity: immutable: chattr canceled for %s: %v", path, ctxErr)
diff --git a/internal/notify/email.go b/internal/notify/email.go
index 8dfa0082..da44380f 100644
--- a/internal/notify/email.go
+++ b/internal/notify/email.go
@@ -1,11 +1,13 @@
package notify
import (
+ "bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
+ "mime/quotedprintable"
"os"
"os/exec"
"path/filepath"
@@ -15,6 +17,7 @@ import (
"time"
"github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/safeexec"
"github.com/tis24dev/proxsave/internal/types"
)
@@ -656,7 +659,10 @@ func (e *EmailNotifier) detectRecipientViaUserCfg(cfgPath string, targetUserID s
}
func runCombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error) {
- cmd := exec.CommandContext(ctx, name, args...)
+ cmd, err := safeexec.CommandContext(ctx, name, args...)
+ if err != nil {
+ return nil, err
+ }
out, err := cmd.CombinedOutput()
if err != nil {
return out, err
@@ -671,6 +677,37 @@ func truncateForLog(s string, maxBytes int) string {
return s[:maxBytes] + "...(truncated)"
}
+func commandForMailTool(ctx context.Context, pathOrName string, args ...string) (*exec.Cmd, error) {
+ if filepath.IsAbs(pathOrName) {
+ return safeexec.TrustedCommandContext(ctx, pathOrName, args...)
+ }
+ return safeexec.CommandContext(ctx, pathOrName, args...)
+}
+
+func lookupAbsolutePath(name string) (string, error) {
+ execPath, err := exec.LookPath(name)
+ if err != nil {
+ return "", err
+ }
+ if filepath.IsAbs(execPath) {
+ return execPath, nil
+ }
+ return "", fmt.Errorf("exec.LookPath returned non-absolute path %q", execPath)
+}
+
+func findMailqPath() (string, error) {
+ candidates := []string{"mailq", "/usr/bin/mailq"}
+ errs := make([]error, 0, len(candidates))
+ for _, candidate := range candidates {
+ path, err := lookupAbsolutePath(candidate)
+ if err == nil {
+ return path, nil
+ }
+ errs = append(errs, fmt.Errorf("%s: %w", candidate, err))
+ }
+ return "", fmt.Errorf("mailq command not found: %w", errors.Join(errs...))
+}
+
// sendViaRelay sends email via cloud relay
func (e *EmailNotifier) sendViaRelay(ctx context.Context, recipient, subject, htmlBody, textBody string, data *NotificationData) error {
// Build payload
@@ -692,12 +729,16 @@ func (e *EmailNotifier) sendViaRelay(ctx context.Context, recipient, subject, ht
func (e *EmailNotifier) isMTAServiceActive(ctx context.Context) (bool, string) {
services := []string{"postfix", "sendmail", "exim4"}
- if _, err := exec.LookPath("systemctl"); err != nil {
+ systemctlPath, err := lookupAbsolutePath("systemctl")
+ if err != nil {
return false, "systemctl not available"
}
for _, service := range services {
- cmd := exec.CommandContext(ctx, "systemctl", "is-active", service)
+ cmd, err := safeexec.TrustedCommandContext(ctx, systemctlPath, "is-active", service)
+ if err != nil {
+ return false, err.Error()
+ }
if err := cmd.Run(); err == nil {
e.logger.Debug("MTA service %s is active", service)
return true, service
@@ -759,16 +800,15 @@ func (e *EmailNotifier) checkRelayHostConfigured(ctx context.Context) (bool, str
// checkMailQueue checks the mail queue status
func (e *EmailNotifier) checkMailQueue(ctx context.Context) (int, error) {
// Try mailq command (works for both Postfix and Sendmail)
- mailqPath := "/usr/bin/mailq"
- if _, err := exec.LookPath("mailq"); err != nil {
- if _, err := exec.LookPath(mailqPath); err != nil {
- return 0, fmt.Errorf("mailq command not found")
- }
- } else {
- mailqPath = "mailq"
+ mailqPath, err := findMailqPath()
+ if err != nil {
+ return 0, err
}
- cmd := exec.CommandContext(ctx, mailqPath)
+ cmd, err := commandForMailTool(ctx, mailqPath)
+ if err != nil {
+ return 0, err
+ }
output, err := cmd.Output()
if err != nil {
return 0, fmt.Errorf("mailq failed: %w", err)
@@ -803,14 +843,15 @@ func (e *EmailNotifier) checkMailQueue(ctx context.Context) (int, error) {
// detectQueueEntry scans the mail queue for a recipient and returns the latest queue ID.
func (e *EmailNotifier) detectQueueEntry(ctx context.Context, recipient string) (string, string, error) {
- mailqPath := "/usr/bin/mailq"
- if _, err := exec.LookPath("mailq"); err == nil {
- mailqPath = "mailq"
- } else if _, err := exec.LookPath(mailqPath); err != nil {
- return "", "", fmt.Errorf("mailq command not found")
+ mailqPath, err := findMailqPath()
+ if err != nil {
+ return "", "", err
}
- cmd := exec.CommandContext(ctx, mailqPath)
+ cmd, err := commandForMailTool(ctx, mailqPath)
+ if err != nil {
+ return "", "", err
+ }
output, err := cmd.Output()
if err != nil {
return "", "", fmt.Errorf("mailq failed: %w", err)
@@ -851,7 +892,10 @@ func (e *EmailNotifier) tailMailLog(ctx context.Context, maxLines int) ([]string
continue
}
- cmd := exec.CommandContext(ctx, "tail", "-n", strconv.Itoa(maxLines), logFile)
+ cmd, err := safeexec.CommandContext(ctx, "tail", "-n", strconv.Itoa(maxLines), logFile)
+ if err != nil {
+ continue
+ }
output, err := cmd.Output()
if err != nil {
if ctx.Err() != nil {
@@ -874,7 +918,10 @@ func (e *EmailNotifier) tailMailLog(ctx context.Context, maxLines int) ([]string
args = append(args, "-u", unit)
}
- cmd := exec.CommandContext(ctx, "journalctl", args...)
+ cmd, err := safeexec.CommandContext(ctx, "journalctl", args...)
+ if err != nil {
+ return nil, ""
+ }
output, err := cmd.Output()
if err == nil && len(output) > 0 {
lines := strings.Split(strings.TrimRight(string(output), "\n"), "\n")
@@ -1084,6 +1131,14 @@ func summarizeSendmailTranscript(transcript string) (highlights []string, remote
return highlights, remoteID, localQueueID
}
+func encodeQuotedPrintableBody(body string) string {
+ var encoded bytes.Buffer
+ writer := quotedprintable.NewWriter(&encoded)
+ _, _ = writer.Write([]byte(body))
+ _ = writer.Close()
+ return encoded.String()
+}
+
func (e *EmailNotifier) buildEmailMessage(recipient, subject, htmlBody, textBody string, data *NotificationData) (emailMessage, toHeader string) {
e.logger.Debug("=== Building email message ===")
@@ -1126,17 +1181,17 @@ func (e *EmailNotifier) buildEmailMessage(recipient, subject, htmlBody, textBody
// Plain text part
email.WriteString(fmt.Sprintf("--%s\n", altBoundary))
email.WriteString("Content-Type: text/plain; charset=UTF-8\n")
- email.WriteString("Content-Transfer-Encoding: 8bit\n")
+ email.WriteString("Content-Transfer-Encoding: quoted-printable\n")
email.WriteString("\n")
- email.WriteString(textBody)
+ email.WriteString(encodeQuotedPrintableBody(textBody))
email.WriteString("\n\n")
// HTML part
email.WriteString(fmt.Sprintf("--%s\n", altBoundary))
email.WriteString("Content-Type: text/html; charset=UTF-8\n")
- email.WriteString("Content-Transfer-Encoding: 8bit\n")
+ email.WriteString("Content-Transfer-Encoding: quoted-printable\n")
email.WriteString("\n")
- email.WriteString(htmlBody)
+ email.WriteString(encodeQuotedPrintableBody(htmlBody))
email.WriteString("\n\n")
email.WriteString(fmt.Sprintf("--%s--\n", altBoundary))
@@ -1178,17 +1233,17 @@ func (e *EmailNotifier) buildEmailMessage(recipient, subject, htmlBody, textBody
// Plain text part
email.WriteString(fmt.Sprintf("--%s\n", altBoundary))
email.WriteString("Content-Type: text/plain; charset=UTF-8\n")
- email.WriteString("Content-Transfer-Encoding: 8bit\n")
+ email.WriteString("Content-Transfer-Encoding: quoted-printable\n")
email.WriteString("\n")
- email.WriteString(textBody)
+ email.WriteString(encodeQuotedPrintableBody(textBody))
email.WriteString("\n\n")
// HTML part
email.WriteString(fmt.Sprintf("--%s\n", altBoundary))
email.WriteString("Content-Type: text/html; charset=UTF-8\n")
- email.WriteString("Content-Transfer-Encoding: 8bit\n")
+ email.WriteString("Content-Transfer-Encoding: quoted-printable\n")
email.WriteString("\n")
- email.WriteString(htmlBody)
+ email.WriteString(encodeQuotedPrintableBody(htmlBody))
email.WriteString("\n\n")
email.WriteString(fmt.Sprintf("--%s--\n", altBoundary))
@@ -1218,7 +1273,10 @@ func (e *EmailNotifier) sendViaPMF(ctx context.Context, recipient, subject, html
e.logger.Debug("=== Sending email via proxmox-mail-forward ===")
e.logger.Debug("proxmox-mail-forward routing is handled by Proxmox Notifications; To=%q is only a mail header", toHeader)
- cmd := exec.CommandContext(ctx, pmfPath)
+ cmd, err := commandForMailTool(ctx, pmfPath)
+ if err != nil {
+ return "", "", err
+ }
cmd.Stdin = strings.NewReader(emailMessage)
var stdoutBuf, stderrBuf strings.Builder
@@ -1329,7 +1387,10 @@ func (e *EmailNotifier) sendViaSendmail(ctx context.Context, recipient, subject,
}
// Create sendmail command
- cmd := exec.CommandContext(ctx, sendmailPath, args...)
+ cmd, err := commandForMailTool(ctx, sendmailPath, args...)
+ if err != nil {
+ return "", "", "", err
+ }
cmd.Stdin = strings.NewReader(emailMessage)
// Capture stdout and stderr separately
diff --git a/internal/notify/email_delivery_methods_test.go b/internal/notify/email_delivery_methods_test.go
index 41c42765..10a73fea 100644
--- a/internal/notify/email_delivery_methods_test.go
+++ b/internal/notify/email_delivery_methods_test.go
@@ -377,6 +377,43 @@ func TestEmailNotifierBuildEmailMessage_FallsBackWhenLogUnreadable(t *testing.T)
}
}
+func TestEmailNotifierBuildEmailMessage_EncodesUTF8BodiesAsSevenBitSafe(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ logger.SetOutput(io.Discard)
+
+ notifier, err := NewEmailNotifier(EmailConfig{
+ Enabled: true,
+ DeliveryMethod: EmailDeliveryPMF,
+ From: "no-reply@proxmox.example.com",
+ }, types.ProxmoxBS, logger)
+ if err != nil {
+ t.Fatalf("NewEmailNotifier() error = %v", err)
+ }
+
+ emailMessage, _ := notifier.buildEmailMessage(
+ "admin@example.com",
+ "✅ PVE Backup à",
+ "Backup completato ✅ con avvisi: è pieno
",
+ "Backup completato ✅ con avvisi: è pieno",
+ createTestNotificationData(),
+ )
+
+ if strings.Contains(emailMessage, "Content-Transfer-Encoding: 8bit") {
+ t.Fatalf("email message must not use 8bit transfer encoding:\n%s", emailMessage)
+ }
+ if count := strings.Count(emailMessage, "Content-Transfer-Encoding: quoted-printable"); count != 2 {
+ t.Fatalf("expected two quoted-printable body parts, got %d:\n%s", count, emailMessage)
+ }
+ if strings.Contains(emailMessage, "✅") || strings.Contains(emailMessage, "è") || strings.Contains(emailMessage, "à") {
+ t.Fatalf("email message contains raw non-ASCII body/subject characters:\n%s", emailMessage)
+ }
+ for i, b := range []byte(emailMessage) {
+ if b > 0x7f {
+ t.Fatalf("email message contains non-ASCII byte 0x%x at offset %d", b, i)
+ }
+ }
+}
+
func TestEmailNotifierIsMTAServiceActive_SystemctlMissing(t *testing.T) {
logger := logging.New(types.LogLevelDebug, false)
logger.SetOutput(io.Discard)
diff --git a/internal/notify/webhook.go b/internal/notify/webhook.go
index 87e0f1c1..9a020bc3 100644
--- a/internal/notify/webhook.go
+++ b/internal/notify/webhook.go
@@ -24,6 +24,25 @@ type WebhookNotifier struct {
client *http.Client
}
+func resolveWebhookFormat(format, defaultFormat string) string {
+ format = strings.TrimSpace(format)
+ if format == "" {
+ format = strings.TrimSpace(defaultFormat)
+ }
+ if format == "" {
+ return "generic"
+ }
+ return format
+}
+
+func resolveWebhookMethod(method string) string {
+ method = strings.ToUpper(strings.TrimSpace(method))
+ if method == "" {
+ return http.MethodPost
+ }
+ return method
+}
+
// NewWebhookNotifier creates a new webhook notifier
func NewWebhookNotifier(webhookConfig *config.WebhookConfig, logger *logging.Logger) (*WebhookNotifier, error) {
logger.Debug("WebhookNotifier initialization starting...")
@@ -44,6 +63,11 @@ func NewWebhookNotifier(webhookConfig *config.WebhookConfig, logger *logging.Log
return nil, fmt.Errorf("webhook notifications enabled but no endpoints configured")
}
+ notifier := &WebhookNotifier{
+ config: webhookConfig,
+ logger: logger,
+ }
+
// Log each endpoint configuration (with masked sensitive data)
for i, ep := range webhookConfig.Endpoints {
logger.Debug("Endpoint #%d configuration:", i+1)
@@ -58,6 +82,9 @@ func NewWebhookNotifier(webhookConfig *config.WebhookConfig, logger *logging.Log
logger.Debug(" Header: %s (value masked)", k)
}
}
+ if err := notifier.validateEndpoint(ep); err != nil {
+ return nil, err
+ }
}
// Create HTTP client with timeout
@@ -74,11 +101,34 @@ func NewWebhookNotifier(webhookConfig *config.WebhookConfig, logger *logging.Log
logger.Info("✅ WebhookNotifier initialized successfully with %d endpoint(s)", len(webhookConfig.Endpoints))
- return &WebhookNotifier{
- config: webhookConfig,
- logger: logger,
- client: client,
- }, nil
+ notifier.client = client
+ return notifier, nil
+}
+
+func (w *WebhookNotifier) validateEndpoint(ep config.WebhookEndpoint) error {
+ format := resolveWebhookFormat(ep.Format, w.config.DefaultFormat)
+ method := resolveWebhookMethod(ep.Method)
+ if !strings.EqualFold(format, "pushover") {
+ return nil
+ }
+
+ missing := []string{}
+ if ep.Auth.Token == "" {
+ missing = append(missing, "token")
+ }
+ if ep.Auth.User == "" {
+ missing = append(missing, "user")
+ }
+ if len(missing) > 0 {
+ return fmt.Errorf("webhook endpoint %q: Pushover requires Auth.Token and Auth.User; missing %s", ep.Name, strings.Join(missing, "/"))
+ }
+ if ep.Priority < -2 || ep.Priority > 1 {
+ return fmt.Errorf("webhook endpoint %q: PRIORITY must be in range -2..1 (got %d); priority 2 (emergency) is not supported", ep.Name, ep.Priority)
+ }
+ if method != http.MethodPost {
+ return fmt.Errorf("webhook endpoint %q: METHOD must be POST for pushover (got %s)", ep.Name, method)
+ }
+ return nil
}
// Name returns the notifier name
@@ -164,21 +214,29 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We
w.logger.Debug("Endpoint format: %s, URL: %s", endpoint.Format, maskURL(endpoint.URL))
// Determine format to use
- format := endpoint.Format
- if format == "" {
- format = w.config.DefaultFormat
- w.logger.Debug("Using default format: %s", format)
+ format := resolveWebhookFormat(endpoint.Format, w.config.DefaultFormat)
+ if strings.TrimSpace(endpoint.Format) == "" {
+ if strings.TrimSpace(w.config.DefaultFormat) != "" {
+ w.logger.Debug("Using default format: %s", format)
+ } else {
+ w.logger.Debug("No format specified, using generic")
+ }
}
- if format == "" {
- format = "generic"
- w.logger.Debug("No format specified, using generic")
+
+ method := resolveWebhookMethod(endpoint.Method)
+ if strings.TrimSpace(endpoint.Method) == "" {
+ w.logger.Debug("No method specified, using POST")
+ }
+ if strings.EqualFold(format, "pushover") && method != http.MethodPost {
+ return fmt.Errorf("webhook endpoint %q: METHOD must be POST for pushover (got %s)", endpoint.Name, method)
}
// Build payload based on format
w.logger.Debug("Building %s payload...", format)
payloadStart := time.Now()
- payload, err := w.buildPayload(format, data)
+ endpoint.Format = format
+ payload, err := w.buildPayload(endpoint, data)
if err != nil {
w.logger.Error("Failed to build %s payload: %v", format, err)
return fmt.Errorf("failed to build payload: %w", err)
@@ -197,7 +255,9 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We
w.logger.Debug("Payload marshaled: %d bytes", len(payloadBytes))
if w.logger.GetLevel() <= types.LogLevelDebug {
- if len(payloadBytes) > 200 {
+ if strings.EqualFold(format, "pushover") {
+ w.logger.Debug("Payload preview omitted: pushover payload contains credentials")
+ } else if len(payloadBytes) > 200 {
w.logger.Debug("Payload preview (first 200 chars): %s...", string(payloadBytes[:200]))
} else {
w.logger.Debug("Payload content: %s", string(payloadBytes))
@@ -229,12 +289,6 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We
}
}
- // Determine HTTP method
- method := strings.ToUpper(strings.TrimSpace(endpoint.Method))
- if method == "" {
- method = "POST"
- }
-
parsedURL, parseErr := url.Parse(endpoint.URL)
if parseErr != nil {
lastErr = fmt.Errorf("invalid webhook URL: %w", parseErr)
@@ -415,16 +469,19 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We
}
// buildPayload builds the webhook payload based on format
-func (w *WebhookNotifier) buildPayload(format string, data *NotificationData) (interface{}, error) {
+func (w *WebhookNotifier) buildPayload(endpoint config.WebhookEndpoint, data *NotificationData) (interface{}, error) {
+ format := strings.ToLower(endpoint.Format)
w.logger.Debug("buildPayload() called with format=%s", format)
- switch strings.ToLower(format) {
+ switch format {
case "discord":
return buildDiscordPayload(data, w.logger)
case "slack":
return buildSlackPayload(data, w.logger)
case "teams":
return buildTeamsPayload(data, w.logger)
+ case "pushover":
+ return buildPushoverPayload(endpoint, data, w.logger)
case "generic":
return buildGenericPayload(data, w.logger)
default:
diff --git a/internal/notify/webhook_payloads.go b/internal/notify/webhook_payloads.go
index 1c4530ff..7436d04c 100644
--- a/internal/notify/webhook_payloads.go
+++ b/internal/notify/webhook_payloads.go
@@ -4,6 +4,7 @@ import (
"fmt"
"strings"
+ "github.com/tis24dev/proxsave/internal/config"
"github.com/tis24dev/proxsave/internal/logging"
)
@@ -575,3 +576,53 @@ func buildGenericPayload(data *NotificationData, logger *logging.Logger) (map[st
logger.Debug("Generic payload built successfully with %d top-level keys", len(payload))
return payload, nil
}
+
+// buildPushoverPayload builds a Pushover-formatted webhook payload.
+// Pushover requires the application token and user/group key in the JSON body
+// (not in headers); this builder reads them from endpoint.Auth.Token and
+// endpoint.Auth.User and rejects requests where either is missing.
+func buildPushoverPayload(endpoint config.WebhookEndpoint, data *NotificationData, logger *logging.Logger) (map[string]interface{}, error) {
+ logger.Debug("buildPushoverPayload() starting...")
+
+ if endpoint.Auth.Token == "" {
+ return nil, fmt.Errorf("pushover: AUTH_TOKEN (Pushover application token) is required")
+ }
+ if endpoint.Auth.User == "" {
+ return nil, fmt.Errorf("pushover: AUTH_USER (Pushover user/group key) is required")
+ }
+
+ title := truncateRunes(fmt.Sprintf("%s Proxmox Backup — %s", GetStatusEmoji(data.Status), data.Hostname), 250)
+
+ message := truncateRunes(fmt.Sprintf(
+ "Status: %s\nDuration: %s\nSize: %s\nErrors: %d | Warnings: %d",
+ data.StatusMessage,
+ FormatDuration(data.BackupDuration),
+ data.BackupSizeHR,
+ data.ErrorCount,
+ data.WarningCount,
+ ), 1024)
+
+ payload := map[string]interface{}{
+ "token": endpoint.Auth.Token,
+ "user": endpoint.Auth.User,
+ "title": title,
+ "message": message,
+ "priority": endpoint.Priority,
+ }
+
+ logger.Debug("Pushover payload built (priority=%d, title_len=%d, message_len=%d)", endpoint.Priority, len([]rune(title)), len([]rune(message)))
+ return payload, nil
+}
+
+// truncateRunes shortens s to at most max runes, suffixing with "…" when cut.
+// Operates on runes (not bytes) so multibyte characters like emoji are not split.
+func truncateRunes(s string, max int) string {
+ if max <= 0 {
+ return ""
+ }
+ r := []rune(s)
+ if len(r) <= max {
+ return s
+ }
+ return string(r[:max-1]) + "…"
+}
diff --git a/internal/notify/webhook_test.go b/internal/notify/webhook_test.go
index 85503332..da918043 100644
--- a/internal/notify/webhook_test.go
+++ b/internal/notify/webhook_test.go
@@ -658,7 +658,8 @@ func TestWebhookNotifier_buildPayload_CoversFormats(t *testing.T) {
for _, format := range formats {
format := format
t.Run(format, func(t *testing.T) {
- payload, err := notifier.buildPayload(format, data)
+ ep := config.WebhookEndpoint{Name: "x", URL: "https://example.com", Format: format}
+ payload, err := notifier.buildPayload(ep, data)
if err != nil {
t.Fatalf("buildPayload(%q) error = %v", format, err)
}
@@ -1070,3 +1071,290 @@ func TestMaskHeaderValue(t *testing.T) {
})
}
}
+
+func pushoverTestEndpoint(priority int) config.WebhookEndpoint {
+ return config.WebhookEndpoint{
+ Name: "pushover",
+ URL: "https://api.pushover.net/1/messages.json",
+ Format: "pushover",
+ Method: "POST",
+ Auth: config.WebhookAuth{Type: "none", Token: "app-token-abc", User: "user-key-xyz"},
+ Priority: priority,
+ }
+}
+
+func TestBuildPushoverPayload_Success(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ data := createTestNotificationData()
+
+ payload, err := buildPushoverPayload(pushoverTestEndpoint(0), data, logger)
+ if err != nil {
+ t.Fatalf("buildPushoverPayload() error: %v", err)
+ }
+
+ if got := payload["token"]; got != "app-token-abc" {
+ t.Errorf("token = %v, want app-token-abc", got)
+ }
+ if got := payload["user"]; got != "user-key-xyz" {
+ t.Errorf("user = %v, want user-key-xyz", got)
+ }
+ if got := payload["priority"]; got != 0 {
+ t.Errorf("priority = %v, want 0", got)
+ }
+
+ title, ok := payload["title"].(string)
+ if !ok {
+ t.Fatalf("title is not a string: %T", payload["title"])
+ }
+ if !strings.Contains(title, data.Hostname) {
+ t.Errorf("title %q does not contain hostname %q", title, data.Hostname)
+ }
+ if !strings.Contains(title, GetStatusEmoji(data.Status)) {
+ t.Errorf("title %q does not contain status emoji", title)
+ }
+
+ message, ok := payload["message"].(string)
+ if !ok {
+ t.Fatalf("message is not a string: %T", payload["message"])
+ }
+ for _, want := range []string{"Status:", "Duration:", "Size:", "Errors:", "Warnings:"} {
+ if !strings.Contains(message, want) {
+ t.Errorf("message missing %q; got %q", want, message)
+ }
+ }
+}
+
+func TestBuildPushoverPayload_MissingToken(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ data := createTestNotificationData()
+ ep := pushoverTestEndpoint(0)
+ ep.Auth.Token = ""
+
+ _, err := buildPushoverPayload(ep, data, logger)
+ if err == nil {
+ t.Fatal("expected error for missing token, got nil")
+ }
+ if !strings.Contains(err.Error(), "AUTH_TOKEN") {
+ t.Errorf("error %q does not mention AUTH_TOKEN", err.Error())
+ }
+}
+
+func TestBuildPushoverPayload_MissingUser(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ data := createTestNotificationData()
+ ep := pushoverTestEndpoint(0)
+ ep.Auth.User = ""
+
+ _, err := buildPushoverPayload(ep, data, logger)
+ if err == nil {
+ t.Fatal("expected error for missing user, got nil")
+ }
+ if !strings.Contains(err.Error(), "AUTH_USER") {
+ t.Errorf("error %q does not mention AUTH_USER", err.Error())
+ }
+}
+
+func TestBuildPushoverPayload_TitleTruncated(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ data := createTestNotificationData()
+ data.Hostname = strings.Repeat("h", 300)
+
+ payload, err := buildPushoverPayload(pushoverTestEndpoint(0), data, logger)
+ if err != nil {
+ t.Fatalf("buildPushoverPayload() error: %v", err)
+ }
+
+ title := payload["title"].(string)
+ if got := len([]rune(title)); got > 250 {
+ t.Errorf("title rune length = %d, want <= 250", got)
+ }
+ if !strings.HasSuffix(title, "…") {
+ t.Errorf("truncated title should end with ellipsis; got %q", title)
+ }
+}
+
+func TestBuildPushoverPayload_MessageTruncated(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ data := createTestNotificationData()
+ data.StatusMessage = strings.Repeat("x", 1100)
+
+ payload, err := buildPushoverPayload(pushoverTestEndpoint(0), data, logger)
+ if err != nil {
+ t.Fatalf("buildPushoverPayload() error: %v", err)
+ }
+
+ message := payload["message"].(string)
+ if got := len([]rune(message)); got > 1024 {
+ t.Errorf("message rune length = %d, want <= 1024", got)
+ }
+ if !strings.HasSuffix(message, "…") {
+ t.Errorf("truncated message should end with ellipsis; got %q", message)
+ }
+}
+
+func TestBuildPushoverPayload_PriorityPassthrough(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ data := createTestNotificationData()
+
+ for _, p := range []int{-2, -1, 0, 1} {
+ payload, err := buildPushoverPayload(pushoverTestEndpoint(p), data, logger)
+ if err != nil {
+ t.Fatalf("priority=%d: buildPushoverPayload() error: %v", p, err)
+ }
+ if got := payload["priority"]; got != p {
+ t.Errorf("priority=%d: got %v", p, got)
+ }
+ }
+}
+
+func TestNewWebhookNotifier_PushoverPriority(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+
+ tests := []struct {
+ name string
+ priority int
+ expectError bool
+ }{
+ {"min valid", -2, false},
+ {"zero", 0, false},
+ {"max valid", 1, false},
+ {"too low", -3, true},
+ {"emergency rejected", 2, true},
+ {"too high", 3, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &config.WebhookConfig{
+ Enabled: true,
+ DefaultFormat: "pushover",
+ Timeout: 30,
+ Endpoints: []config.WebhookEndpoint{pushoverTestEndpoint(tt.priority)},
+ }
+ _, err := NewWebhookNotifier(cfg, logger)
+ if tt.expectError {
+ if err == nil {
+ t.Fatalf("priority=%d: expected error, got nil", tt.priority)
+ }
+ if !strings.Contains(err.Error(), "PRIORITY") {
+ t.Errorf("error %q does not mention PRIORITY", err.Error())
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("priority=%d: unexpected error: %v", tt.priority, err)
+ }
+ })
+ }
+}
+
+func TestNewWebhookNotifier_PushoverPriority_UsesDefaultFormat(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+
+ ep := pushoverTestEndpoint(2)
+ ep.Format = ""
+
+ cfg := &config.WebhookConfig{
+ Enabled: true,
+ DefaultFormat: "pushover",
+ Timeout: 30,
+ Endpoints: []config.WebhookEndpoint{ep},
+ }
+
+ _, err := NewWebhookNotifier(cfg, logger)
+ if err == nil {
+ t.Fatal("expected error for invalid pushover priority resolved from default format, got nil")
+ }
+ if !strings.Contains(err.Error(), "PRIORITY") {
+ t.Fatalf("error %q does not mention PRIORITY", err.Error())
+ }
+}
+
+func TestNewWebhookNotifier_PushoverMethod(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+
+ tests := []struct {
+ name string
+ method string
+ format string
+ defaultFormat string
+ expectError bool
+ }{
+ {name: "explicit post", method: "POST", format: "pushover", expectError: false},
+ {name: "implicit post", method: "", format: "pushover", expectError: false},
+ {name: "default format post", method: "", format: "", defaultFormat: "pushover", expectError: false},
+ {name: "get rejected", method: "GET", format: "pushover", expectError: true},
+ {name: "head rejected", method: "HEAD", format: "pushover", expectError: true},
+ {name: "put rejected", method: "PUT", format: "pushover", expectError: true},
+ {name: "default format get rejected", method: "GET", format: "", defaultFormat: "pushover", expectError: true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ep := pushoverTestEndpoint(0)
+ ep.Method = tt.method
+ ep.Format = tt.format
+
+ cfg := &config.WebhookConfig{
+ Enabled: true,
+ DefaultFormat: tt.defaultFormat,
+ Timeout: 30,
+ Endpoints: []config.WebhookEndpoint{ep},
+ }
+
+ _, err := NewWebhookNotifier(cfg, logger)
+ if tt.expectError {
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+ if !strings.Contains(err.Error(), "METHOD must be POST") {
+ t.Fatalf("error %q does not mention POST method requirement", err.Error())
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ })
+ }
+}
+
+func TestNewWebhookNotifier_PushoverAuthRequired(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+
+ tests := []struct {
+ name string
+ token string
+ user string
+ missing string
+ }{
+ {name: "missing token", token: "", user: "user-key-xyz", missing: "missing token"},
+ {name: "missing user", token: "app-token-abc", user: "", missing: "missing user"},
+ {name: "missing both", token: "", user: "", missing: "missing token/user"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ep := pushoverTestEndpoint(0)
+ ep.Auth.Token = tt.token
+ ep.Auth.User = tt.user
+
+ cfg := &config.WebhookConfig{
+ Enabled: true,
+ Timeout: 30,
+ Endpoints: []config.WebhookEndpoint{ep},
+ }
+
+ _, err := NewWebhookNotifier(cfg, logger)
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+ if !strings.Contains(err.Error(), "Pushover requires Auth.Token and Auth.User") {
+ t.Fatalf("error %q does not mention Pushover auth requirement", err.Error())
+ }
+ if !strings.Contains(err.Error(), tt.missing) {
+ t.Fatalf("error %q does not mention %q", err.Error(), tt.missing)
+ }
+ })
+ }
+}
diff --git a/internal/orchestrator/additional_helpers_test.go b/internal/orchestrator/additional_helpers_test.go
index 94844c43..39286f81 100644
--- a/internal/orchestrator/additional_helpers_test.go
+++ b/internal/orchestrator/additional_helpers_test.go
@@ -828,8 +828,8 @@ storage: backup
if blocks[0].ID != "local" || blocks[1].ID != "backup" {
t.Fatalf("unexpected IDs: %+v", blocks)
}
- if len(blocks[0].data) == 0 || len(blocks[1].data) == 0 {
- t.Fatalf("expected data in blocks")
+ if len(blocks[0].entries) == 0 || len(blocks[1].entries) == 0 {
+ t.Fatalf("expected entries in blocks")
}
// Empty file -> zero blocks
@@ -913,7 +913,12 @@ func TestExtractArchiveNativeSymlinkAndHardlink(t *testing.T) {
}
dest := filepath.Join(tmpDir, "dest")
- if err := extractArchiveNative(context.Background(), tarPath, dest, logger, nil, RestoreModeFull, nil, "", nil); err != nil {
+ if err := extractArchiveNative(context.Background(), restoreArchiveOptions{
+ archivePath: tarPath,
+ destRoot: dest,
+ logger: logger,
+ mode: RestoreModeFull,
+ }); err != nil {
t.Fatalf("extractArchiveNative error: %v", err)
}
@@ -1292,7 +1297,12 @@ func TestExtractArchiveNativeBlocksTraversal(t *testing.T) {
_ = f.Close()
dest := filepath.Join(tmpDir, "dest")
- if err := extractArchiveNative(context.Background(), tarPath, dest, logger, nil, RestoreModeFull, nil, "", nil); err != nil {
+ if err := extractArchiveNative(context.Background(), restoreArchiveOptions{
+ archivePath: tarPath,
+ destRoot: dest,
+ logger: logger,
+ mode: RestoreModeFull,
+ }); err != nil {
t.Fatalf("extractArchiveNative error: %v", err)
}
if _, err := os.Stat(filepath.Join(dest, "../etc/passwd")); err == nil {
diff --git a/internal/orchestrator/backup_run_helpers.go b/internal/orchestrator/backup_run_helpers.go
new file mode 100644
index 00000000..32103742
--- /dev/null
+++ b/internal/orchestrator/backup_run_helpers.go
@@ -0,0 +1,321 @@
+// Package orchestrator coordinates backup, restore, decrypt, and notification workflows.
+package orchestrator
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "path/filepath"
+
+ "filippo.io/age"
+ "github.com/tis24dev/proxsave/internal/backup"
+ "github.com/tis24dev/proxsave/internal/metrics"
+ "github.com/tis24dev/proxsave/internal/types"
+)
+
+func (o *Orchestrator) shouldExportBackupMetrics(stats *BackupStats) bool {
+ return stats != nil && o.cfg != nil && o.cfg.MetricsEnabled && !o.dryRun
+}
+
+func (o *Orchestrator) ensureBackupStatsTiming(stats *BackupStats) {
+ if stats.EndTime.IsZero() {
+ stats.EndTime = o.now()
+ }
+ if stats.Duration == 0 && !stats.StartTime.IsZero() {
+ stats.Duration = stats.EndTime.Sub(stats.StartTime)
+ }
+}
+
+func backupMetricsExitCode(stats *BackupStats, runErr error) int {
+ if runErr == nil {
+ if stats.ExitCode == 0 {
+ return types.ExitSuccess.Int()
+ }
+ return stats.ExitCode
+ }
+
+ var backupErr *BackupError
+ if errors.As(runErr, &backupErr) {
+ return backupErr.Code.Int()
+ }
+ return types.ExitGenericError.Int()
+}
+
+func (o *Orchestrator) exportPrometheusBackupMetrics(stats *BackupStats) {
+ m := stats.toPrometheusMetrics()
+ if m == nil {
+ return
+ }
+
+ exporter := metrics.NewPrometheusExporter(o.cfg.MetricsPath, o.logger)
+ if err := exporter.Export(m); err != nil {
+ o.logger.Warning("Failed to export Prometheus metrics: %v", err)
+ }
+}
+
+func (o *Orchestrator) parseFailedBackupLogCounts(stats *BackupStats) {
+ if stats.LogFilePath == "" {
+ o.logger.Debug("No log file path specified, error/warning counts will be 0 (failure path)")
+ return
+ }
+
+ o.logger.Debug("Parsing log file for error/warning counts after failure: %s", stats.LogFilePath)
+ _, errorCount, warningCount := ParseLogCounts(stats.LogFilePath, 0)
+ stats.ErrorCount = errorCount
+ stats.WarningCount = warningCount
+ if errorCount > 0 || warningCount > 0 {
+ o.logger.Debug("Found %d errors and %d warnings in log file (failure path)", errorCount, warningCount)
+ }
+}
+
+func backupFailureExitCode(runErr error) int {
+ var backupErr *BackupError
+ if errors.As(runErr, &backupErr) {
+ return backupErr.Code.Int()
+ }
+ return types.ExitBackupError.Int()
+}
+
+func (o *Orchestrator) buildBackupCollectorConfig() *backup.CollectorConfig {
+ collectorConfig := backup.GetDefaultCollectorConfig()
+ collectorConfig.ExcludePatterns = append([]string(nil), o.excludePatterns...)
+ if o.cfg == nil {
+ return collectorConfig
+ }
+
+ applyCollectorOverrides(collectorConfig, o.cfg)
+ if len(o.cfg.BackupBlacklist) > 0 {
+ collectorConfig.ExcludePatterns = append(collectorConfig.ExcludePatterns, o.cfg.BackupBlacklist...)
+ }
+ return collectorConfig
+}
+
+func (o *Orchestrator) runBackupCollector(run *backupRunContext, workspace *backupWorkspace, collectorConfig *backup.CollectorConfig) (*backup.Collector, error) {
+ collector := backup.NewCollectorWithDeps(o.logger, collectorConfig, workspace.tempDir, run.proxmoxType, o.dryRun, o.collectorDeps())
+ o.logger.Debug("Starting collector run (type=%s)", run.proxmoxType)
+ if err := collector.CollectAll(run.ctx); err != nil {
+ return nil, err
+ }
+ return collector, nil
+}
+
+func (o *Orchestrator) applyBackupCollectionStats(stats *BackupStats, collStats *backup.CollectionStats, collector *backup.Collector) {
+ stats.FilesCollected = int(collStats.FilesProcessed)
+ stats.FilesFailed = int(collStats.FilesFailed)
+ stats.FilesNotFound = int(collStats.FilesNotFound)
+ stats.DirsCreated = int(collStats.DirsCreated)
+ stats.BytesCollected = collStats.BytesCollected
+ stats.FilesIncluded = int(collStats.FilesProcessed)
+ stats.FilesMissing = int(collStats.FilesNotFound)
+ stats.UncompressedSize = collStats.BytesCollected
+ if stats.ProxmoxType.SupportsPVE() {
+ stats.ClusterMode = standaloneClusterMode(collector)
+ }
+}
+
+func standaloneClusterMode(collector *backup.Collector) string {
+ if collector.IsClusteredPVE() {
+ return "cluster"
+ }
+ return "standalone"
+}
+
+func (o *Orchestrator) writeBackupCollectionMetadata(tempDir, hostname string, stats *BackupStats, collector *backup.Collector) {
+ if err := o.writeBackupMetadata(tempDir, stats); err != nil {
+ o.logger.Debug("Failed to write backup metadata: %v", err)
+ }
+ if err := collector.WriteManifest(hostname); err != nil {
+ o.logger.Debug("Failed to write backup manifest: %v", err)
+ }
+}
+
+func (o *Orchestrator) logBackupCollectionSummary(collStats *backup.CollectionStats) {
+ o.logger.Info("Collection completed: %d files (%s), %d failed, %d dirs created",
+ collStats.FilesProcessed,
+ backup.FormatBytes(collStats.BytesCollected),
+ collStats.FilesFailed,
+ collStats.DirsCreated)
+}
+
+func (o *Orchestrator) applyBackupOptimizations(ctx context.Context, tempDir string) error {
+ if !o.optimizationCfg.Enabled() {
+ o.logger.Debug("Skipping optimization step (all features disabled)")
+ return nil
+ }
+
+ fmt.Println()
+ o.logger.Step("Backup optimizations on collected data")
+ if err := backup.ApplyOptimizations(ctx, o.logger, tempDir, o.optimizationCfg); err != nil {
+ o.logger.Warning("Backup optimizations completed with warnings: %v", err)
+ }
+ return nil
+}
+
+func estimatedBackupSizeGB(bytesCollected int64) float64 {
+ estimatedSizeGB := float64(bytesCollected) / (1024.0 * 1024.0 * 1024.0)
+ if estimatedSizeGB < 0.001 {
+ return 0.001
+ }
+ return estimatedSizeGB
+}
+
+func backupDiskValidationError(message string, diskErr error) error {
+ errMsg := message
+ if errMsg == "" && diskErr != nil {
+ errMsg = diskErr.Error()
+ }
+ if errMsg == "" {
+ errMsg = "insufficient disk space"
+ }
+ if diskErr == nil {
+ diskErr = errors.New(errMsg)
+ }
+ return &BackupError{
+ Phase: "disk",
+ Err: fmt.Errorf("disk space validation failed: %w", diskErr),
+ Code: types.ExitDiskSpaceError,
+ }
+}
+
+func (o *Orchestrator) buildBackupArchiverConfig(run *backupRunContext, ageRecipients []age.Recipient) *backup.ArchiverConfig {
+ return BuildArchiverConfig(
+ o.compressionType,
+ run.normalizedLevel,
+ o.compressionThreads,
+ o.compressionMode,
+ o.dryRun,
+ o.cfg != nil && o.cfg.EncryptArchive,
+ ageRecipients,
+ run.collectorConfig.ExcludePatterns,
+ )
+}
+
+func (o *Orchestrator) applyBackupArchiverStats(stats *BackupStats, archiver *backup.Archiver) {
+ stats.Compression = archiver.ResolveCompression()
+ stats.CompressionLevel = archiver.CompressionLevel()
+ stats.CompressionMode = archiver.CompressionMode()
+ stats.CompressionThreads = archiver.CompressionThreads()
+}
+
+func (o *Orchestrator) backupArchivePath(run *backupRunContext, archiver *backup.Archiver) string {
+ archiveBasename := fmt.Sprintf("%s-backup-%s", run.hostname, run.timestamp)
+ return filepath.Join(o.backupPath, archiveBasename+archiver.GetArchiveExtension())
+}
+
+func (o *Orchestrator) logResolvedBackupCompression(stats *BackupStats) {
+ if stats.RequestedCompression != stats.Compression {
+ o.logger.Info("Using %s compression (requested %s)", stats.Compression, stats.RequestedCompression)
+ }
+}
+
+func createBackupArchiveFile(ctx context.Context, archiver *backup.Archiver, tempDir, archivePath string) error {
+ if err := archiver.CreateArchive(ctx, tempDir, archivePath); err != nil {
+ return backupArchiveCreationError(err)
+ }
+ return nil
+}
+
+func backupArchiveCreationError(err error) error {
+ phase := "archive"
+ code := types.ExitArchiveError
+ var compressionErr *backup.CompressionError
+ if errors.As(err, &compressionErr) {
+ phase = "compression"
+ code = types.ExitCompressionError
+ }
+ return &BackupError{Phase: phase, Err: err, Code: code}
+}
+
+func (o *Orchestrator) skipDryRunArtifactVerification(stats *BackupStats, artifacts *backupArtifacts) error {
+ fmt.Println()
+ o.logStep(4, "Verification skipped (dry run mode)")
+ o.logger.Info("[DRY RUN] Would create archive: %s", artifacts.archivePath)
+ stats.EndTime = o.now()
+ return nil
+}
+
+func (o *Orchestrator) recordArchiveSize(stats *BackupStats, artifacts *backupArtifacts) {
+ size, err := artifacts.archiver.GetArchiveSize(artifacts.archivePath)
+ if err != nil {
+ o.logger.Warning("Failed to get archive size: %v", err)
+ return
+ }
+
+ stats.ArchiveSize = size
+ stats.CompressedSize = size
+ stats.updateCompressionMetrics()
+ o.logger.Debug("Archive created: %s (%s)", artifacts.archivePath, backup.FormatBytes(size))
+}
+
+func (o *Orchestrator) generateArchiveChecksum(ctx context.Context, archivePath string) (string, error) {
+ checksum, err := backup.GenerateChecksum(ctx, o.logger, archivePath)
+ if err != nil {
+ return "", &BackupError{
+ Phase: "verification",
+ Err: fmt.Errorf("checksum generation failed: %w", err),
+ Code: types.ExitVerificationError,
+ }
+ }
+ return checksum, nil
+}
+
+func (o *Orchestrator) writeArchiveChecksum(workspace *backupWorkspace, artifacts *backupArtifacts, checksum string) error {
+ checksumContent := fmt.Sprintf("%s %s\n", checksum, filepath.Base(artifacts.archivePath))
+ if err := workspace.fs.WriteFile(artifacts.checksumPath, []byte(checksumContent), 0o640); err != nil {
+ return fmt.Errorf("write checksum file %s: %w", artifacts.checksumPath, err)
+ }
+ o.logger.Debug("Checksum file written to %s", artifacts.checksumPath)
+ return nil
+}
+
+func (o *Orchestrator) writeArchiveManifest(run *backupRunContext, artifacts *backupArtifacts, checksum string) error {
+ manifestPath := artifacts.archivePath + ".manifest.json"
+ manifest := o.newArchiveManifest(run.stats, artifacts.archivePath, checksum)
+ if err := backup.CreateManifest(run.ctx, o.logger, manifest, manifestPath); err != nil {
+ return &BackupError{
+ Phase: "verification",
+ Err: fmt.Errorf("manifest creation failed: %w", err),
+ Code: types.ExitVerificationError,
+ }
+ }
+ run.stats.ManifestPath = manifestPath
+ artifacts.manifestPath = manifestPath
+ return nil
+}
+
+func (o *Orchestrator) newArchiveManifest(stats *BackupStats, archivePath, checksum string) *backup.Manifest {
+ return &backup.Manifest{
+ ArchivePath: archivePath,
+ ArchiveSize: stats.ArchiveSize,
+ SHA256: checksum,
+ CreatedAt: stats.Timestamp,
+ CompressionType: string(stats.Compression),
+ CompressionLevel: stats.CompressionLevel,
+ CompressionMode: stats.CompressionMode,
+ ProxmoxType: string(stats.ProxmoxType),
+ ProxmoxTargets: append([]string(nil), stats.ProxmoxTargets...),
+ ProxmoxVersion: stats.ProxmoxVersion,
+ PVEVersion: stats.PVEVersion,
+ PBSVersion: stats.PBSVersion,
+ Hostname: stats.Hostname,
+ ScriptVersion: stats.ScriptVersion,
+ EncryptionMode: o.archiveEncryptionMode(),
+ ClusterMode: stats.ClusterMode,
+ }
+}
+
+func (o *Orchestrator) archiveEncryptionMode() string {
+ if o.cfg != nil && o.cfg.EncryptArchive {
+ return "age"
+ }
+ return "none"
+}
+
+func (o *Orchestrator) writeLegacyMetadataAlias(workspace *backupWorkspace, artifacts *backupArtifacts) {
+ metadataAlias := artifacts.archivePath + ".metadata"
+ if err := copyFile(workspace.fs, artifacts.manifestPath, metadataAlias); err != nil {
+ o.logger.Warning("Failed to write legacy metadata file %s: %v", metadataAlias, err)
+ } else {
+ o.logger.Debug("Legacy metadata file written to %s", metadataAlias)
+ }
+}
diff --git a/internal/orchestrator/backup_run_helpers_additional_test.go b/internal/orchestrator/backup_run_helpers_additional_test.go
new file mode 100644
index 00000000..49c33b93
--- /dev/null
+++ b/internal/orchestrator/backup_run_helpers_additional_test.go
@@ -0,0 +1,130 @@
+package orchestrator
+
+import (
+ "errors"
+ "math"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/tis24dev/proxsave/internal/config"
+ "github.com/tis24dev/proxsave/internal/types"
+)
+
+func TestEstimatedBackupSizeGBMinimumAndScaling(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes int64
+ want float64
+ }{
+ {name: "zero uses minimum", bytes: 0, want: 0.001},
+ {name: "below minimum uses minimum", bytes: 512, want: 0.001},
+ {name: "one gibibyte", bytes: 1024 * 1024 * 1024, want: 1},
+ {name: "two and a half gibibytes", bytes: 5 * 1024 * 1024 * 1024 / 2, want: 2.5},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := estimatedBackupSizeGB(tt.bytes); math.Abs(got-tt.want) > 0.0000001 {
+ t.Fatalf("estimatedBackupSizeGB(%d)=%f want %f", tt.bytes, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestBackupDiskValidationErrorWrapsDiskError(t *testing.T) {
+ diskErr := errors.New("need 3.5 GB free")
+ err := backupDiskValidationError("", diskErr)
+
+ var backupErr *BackupError
+ if !errors.As(err, &backupErr) {
+ t.Fatalf("expected BackupError, got %T", err)
+ }
+ if backupErr.Phase != "disk" || backupErr.Code != types.ExitDiskSpaceError {
+ t.Fatalf("unexpected BackupError fields: phase=%q code=%v", backupErr.Phase, backupErr.Code)
+ }
+ if !errors.Is(err, diskErr) {
+ t.Fatalf("expected disk error to be wrapped, got %v", err)
+ }
+}
+
+func TestBackupDiskValidationErrorUsesDefaultMessage(t *testing.T) {
+ err := backupDiskValidationError("", nil)
+
+ var backupErr *BackupError
+ if !errors.As(err, &backupErr) {
+ t.Fatalf("expected BackupError, got %T", err)
+ }
+ if !strings.Contains(err.Error(), "insufficient disk space") {
+ t.Fatalf("expected default disk space message, got %q", err.Error())
+ }
+}
+
+func TestBackupMetricsExitCode(t *testing.T) {
+ if got := backupMetricsExitCode(&BackupStats{}, nil); got != types.ExitSuccess.Int() {
+ t.Fatalf("success exit code=%d want %d", got, types.ExitSuccess.Int())
+ }
+ if got := backupMetricsExitCode(&BackupStats{ExitCode: 77}, nil); got != 77 {
+ t.Fatalf("stats exit code=%d want 77", got)
+ }
+
+ runErr := &BackupError{Phase: "disk", Err: errors.New("full"), Code: types.ExitDiskSpaceError}
+ if got := backupMetricsExitCode(&BackupStats{}, runErr); got != types.ExitDiskSpaceError.Int() {
+ t.Fatalf("backup error exit code=%d want %d", got, types.ExitDiskSpaceError.Int())
+ }
+ if got := backupMetricsExitCode(&BackupStats{}, errors.New("boom")); got != types.ExitGenericError.Int() {
+ t.Fatalf("generic error exit code=%d want %d", got, types.ExitGenericError.Int())
+ }
+}
+
+func TestEnsureBackupStatsTimingFillsEndAndDuration(t *testing.T) {
+ now := time.Date(2026, 5, 5, 10, 30, 0, 0, time.UTC)
+ orch := New(newTestLogger(), false)
+ orch.clock = &FakeTime{Current: now}
+
+ stats := &BackupStats{StartTime: now.Add(-90 * time.Second)}
+ orch.ensureBackupStatsTiming(stats)
+
+ if !stats.EndTime.Equal(now) {
+ t.Fatalf("EndTime=%v want %v", stats.EndTime, now)
+ }
+ if stats.Duration != 90*time.Second {
+ t.Fatalf("Duration=%v want %v", stats.Duration, 90*time.Second)
+ }
+}
+
+func TestBuildBackupCollectorConfigMergesRuntimeExcludesAndBlacklist(t *testing.T) {
+ orch := New(newTestLogger(), false)
+ orch.SetBackupConfig("/backup", "/logs", types.CompressionZstd, 3, 2, "fast", []string{"runtime/**"})
+ orch.SetConfig(&config.Config{
+ BackupBlacklist: []string{"/secret", "/tmp/cache"},
+ CustomBackupPaths: []string{"/srv/app"},
+ BaseDir: "/opt/proxsave",
+ ConfigPath: "/etc/proxsave/backup.env",
+ })
+
+ cfg := orch.buildBackupCollectorConfig()
+ for _, want := range []string{"runtime/**", "/secret", "/tmp/cache"} {
+ if !containsString(cfg.ExcludePatterns, want) {
+ t.Fatalf("ExcludePatterns missing %q: %#v", want, cfg.ExcludePatterns)
+ }
+ }
+ if len(cfg.BackupBlacklist) != 2 || cfg.BackupBlacklist[0] != "/secret" || cfg.BackupBlacklist[1] != "/tmp/cache" {
+ t.Fatalf("BackupBlacklist not copied: %#v", cfg.BackupBlacklist)
+ }
+ if len(cfg.CustomBackupPaths) != 1 || cfg.CustomBackupPaths[0] != "/srv/app" {
+ t.Fatalf("CustomBackupPaths not copied: %#v", cfg.CustomBackupPaths)
+ }
+ if cfg.ScriptRepositoryPath != "/opt/proxsave" || cfg.ConfigFilePath != "/etc/proxsave/backup.env" {
+ t.Fatalf("paths not copied: script=%q config=%q", cfg.ScriptRepositoryPath, cfg.ConfigFilePath)
+ }
+}
+
+func containsString(values []string, want string) bool {
+ for _, value := range values {
+ if value == want {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/orchestrator/backup_run_phases.go b/internal/orchestrator/backup_run_phases.go
new file mode 100644
index 00000000..0e681583
--- /dev/null
+++ b/internal/orchestrator/backup_run_phases.go
@@ -0,0 +1,384 @@
+// Package orchestrator coordinates backup, restore, decrypt, and notification workflows.
+package orchestrator
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/tis24dev/proxsave/internal/backup"
+ "github.com/tis24dev/proxsave/internal/environment"
+ "github.com/tis24dev/proxsave/internal/types"
+)
+
+type backupRunContext struct {
+ ctx context.Context
+ envInfo *environment.EnvironmentInfo
+ hostname string
+ proxmoxType types.ProxmoxType
+ startTime time.Time
+ timestamp string
+ normalizedLevel int
+ collectorConfig *backup.CollectorConfig
+ stats *BackupStats
+}
+
+type backupWorkspace struct {
+ registry *TempDirRegistry
+ fs FS
+ tempRoot string
+ tempDir string
+}
+
+type backupArtifacts struct {
+ archiver *backup.Archiver
+ archivePath string
+ checksumPath string
+ manifestPath string
+ bundlePath string
+}
+
+func (o *Orchestrator) newBackupRunContext(ctx context.Context, envInfo *environment.EnvironmentInfo, hostname string) *backupRunContext {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ if envInfo == nil {
+ envInfo = o.envInfo
+ } else {
+ o.SetEnvironmentInfo(envInfo)
+ }
+
+ pType := types.ProxmoxUnknown
+ if envInfo != nil {
+ pType = envInfo.Type
+ }
+
+ startTime := o.startTime
+ if startTime.IsZero() {
+ startTime = o.now()
+ o.startTime = startTime
+ }
+
+ return &backupRunContext{
+ ctx: ctx,
+ envInfo: envInfo,
+ hostname: hostname,
+ proxmoxType: pType,
+ startTime: startTime,
+ timestamp: startTime.Format("20060102-150405"),
+ normalizedLevel: normalizeCompressionLevel(o.compressionType, o.compressionLevel),
+ }
+}
+
+func (o *Orchestrator) initBackupRun(run *backupRunContext) *BackupStats {
+ fmt.Println()
+ o.logStep(1, "Initializing backup statistics and temporary workspace")
+ run.stats = InitializeBackupStats(
+ run.hostname,
+ run.envInfo,
+ o.version,
+ run.startTime,
+ o.cfg,
+ o.compressionType,
+ o.compressionMode,
+ run.normalizedLevel,
+ o.compressionThreads,
+ o.backupPath,
+ o.serverID,
+ o.serverMAC,
+ )
+ if logFile := o.logger.GetLogFilePath(); logFile != "" {
+ run.stats.LogFilePath = logFile
+ }
+ if o.versionUpdateAvailable || o.updateCurrentVersion != "" || o.updateLatestVersion != "" {
+ run.stats.NewVersionAvailable = o.versionUpdateAvailable
+ run.stats.CurrentVersion = o.updateCurrentVersion
+ run.stats.LatestVersion = o.updateLatestVersion
+ }
+ return run.stats
+}
+
+func (o *Orchestrator) exportBackupMetrics(run *backupRunContext, runErr error) {
+ stats := run.stats
+ if !o.shouldExportBackupMetrics(stats) {
+ return
+ }
+
+ o.ensureBackupStatsTiming(stats)
+ stats.ExitCode = backupMetricsExitCode(stats, runErr)
+ o.exportPrometheusBackupMetrics(stats)
+}
+
+func (o *Orchestrator) finalizeFailedBackupStats(run *backupRunContext, runErr error) {
+ stats := run.stats
+ if runErr == nil || stats == nil {
+ return
+ }
+
+ o.ensureBackupStatsTiming(stats)
+ o.parseFailedBackupLogCounts(stats)
+ stats.ExitCode = backupFailureExitCode(runErr)
+}
+
+func (o *Orchestrator) prepareBackupWorkspace(run *backupRunContext, workspace *backupWorkspace) error {
+ o.logger.Debug("Creating temporary directory for collection output")
+ workspace.tempRoot = filepath.Join("/tmp", "proxsave")
+ if err := workspace.fs.MkdirAll(workspace.tempRoot, 0o755); err != nil {
+ return fmt.Errorf("Temp directory creation failed - path: %s: %w", workspace.tempRoot, err)
+ }
+
+ tempDir, err := workspace.fs.MkdirTemp(workspace.tempRoot, fmt.Sprintf("proxsave-%s-%s-", run.hostname, run.timestamp))
+ if err != nil {
+ return fmt.Errorf("failed to create temporary directory: %w", err)
+ }
+ workspace.tempDir = tempDir
+
+ if o.dryRun {
+ o.logger.Info("[DRY RUN] Temporary directory would be: %s", workspace.tempDir)
+ } else {
+ o.logger.Debug("Using temporary directory: %s", workspace.tempDir)
+ }
+ return nil
+}
+
+func (o *Orchestrator) cleanupBackupWorkspace(workspace *backupWorkspace) {
+ if workspace.registry == nil {
+ if cleanupErr := workspace.fs.RemoveAll(workspace.tempDir); cleanupErr != nil {
+ o.logger.Warning("Failed to remove temp directory %s: %v", workspace.tempDir, cleanupErr)
+ }
+ return
+ }
+ o.logger.Debug("Temporary workspace preserved at %s (will be removed at the next startup)", workspace.tempDir)
+}
+
+func (o *Orchestrator) markBackupWorkspace(workspace *backupWorkspace) error {
+ markerPath := filepath.Join(workspace.tempDir, ".proxsave-marker")
+ markerContent := fmt.Sprintf(
+ "Created by PID %d on %s UTC\n",
+ os.Getpid(),
+ o.now().UTC().Format("2006-01-02 15:04:05"),
+ )
+ return workspace.fs.WriteFile(markerPath, []byte(markerContent), 0600)
+}
+
+func (o *Orchestrator) registerBackupWorkspace(workspace *backupWorkspace) {
+ if workspace.registry == nil {
+ return
+ }
+ if err := workspace.registry.Register(workspace.tempDir); err != nil {
+ o.logger.Debug("Failed to register temp directory %s: %v", workspace.tempDir, err)
+ }
+}
+
+func (o *Orchestrator) collectBackupData(run *backupRunContext, workspace *backupWorkspace) error {
+ fmt.Println()
+ o.logStep(2, "Collection of configuration files and optimizations")
+ o.logger.Info("Collecting configuration files...")
+ o.logger.Debug("Collector dry-run=%v excludePatterns=%d", o.dryRun, len(o.excludePatterns))
+
+ collectorConfig := o.buildBackupCollectorConfig()
+ run.collectorConfig = collectorConfig
+
+ if err := collectorConfig.Validate(); err != nil {
+ return &BackupError{Phase: "config", Err: err, Code: types.ExitConfigError}
+ }
+
+ collector, err := o.runBackupCollector(run, workspace, collectorConfig)
+ if err != nil {
+ return &BackupError{Phase: "collection", Err: err, Code: types.ExitCollectionError}
+ }
+
+ collStats := collector.GetStats()
+ o.applyBackupCollectionStats(run.stats, collStats, collector)
+ o.writeBackupCollectionMetadata(workspace.tempDir, run.hostname, run.stats, collector)
+ o.logBackupCollectionSummary(collStats)
+
+ if err := o.validateCollectedBackupSize(run.stats); err != nil {
+ return err
+ }
+
+ return o.applyBackupOptimizations(run.ctx, workspace.tempDir)
+}
+
+func (o *Orchestrator) validateCollectedBackupSize(stats *BackupStats) error {
+ if o.checker == nil || stats.BytesCollected <= 0 {
+ return nil
+ }
+
+ o.logger.Debug("Running disk-space validation for estimated data size")
+ result := o.checker.CheckDiskSpaceForEstimate(estimatedBackupSizeGB(stats.BytesCollected))
+ if result.Passed {
+ o.logger.Debug("Disk check passed: %s", result.Message)
+ return nil
+ }
+
+ return backupDiskValidationError(result.Message, result.Error)
+}
+
+func (o *Orchestrator) createBackupArchive(run *backupRunContext, workspace *backupWorkspace) (*backupArtifacts, error) {
+ fmt.Println()
+ o.logStep(3, "Creation of compressed archive")
+ o.logger.Info("Creating compressed archive...")
+ o.logger.Debug("Archiver configuration: type=%s level=%d mode=%s threads=%d",
+ o.compressionType, run.normalizedLevel, o.compressionMode, o.compressionThreads)
+
+ ageRecipients, err := o.prepareAgeRecipients(run.ctx)
+ if err != nil {
+ return nil, &BackupError{Phase: "encryption", Err: err, Code: types.ExitEncryptionError}
+ }
+
+ archiverConfig := o.buildBackupArchiverConfig(run, ageRecipients)
+ if err := archiverConfig.Validate(); err != nil {
+ return nil, &BackupError{Phase: "config", Err: err, Code: types.ExitConfigError}
+ }
+
+ archiver := backup.NewArchiver(o.logger, archiverConfig)
+ o.applyBackupArchiverStats(run.stats, archiver)
+ archivePath := o.backupArchivePath(run, archiver)
+ o.logResolvedBackupCompression(run.stats)
+
+ if err := createBackupArchiveFile(run.ctx, archiver, workspace.tempDir, archivePath); err != nil {
+ return nil, err
+ }
+
+ run.stats.ArchivePath = archivePath
+ return &backupArtifacts{
+ archiver: archiver,
+ archivePath: archivePath,
+ checksumPath: archivePath + ".sha256",
+ }, nil
+}
+
+func (o *Orchestrator) verifyAndWriteBackupArtifacts(run *backupRunContext, workspace *backupWorkspace, artifacts *backupArtifacts) error {
+ stats := run.stats
+ if o.dryRun {
+ return o.skipDryRunArtifactVerification(stats, artifacts)
+ }
+
+ fmt.Println()
+ o.logStep(4, "Verification of archive and metadata generation")
+ o.recordArchiveSize(stats, artifacts)
+
+ if err := artifacts.archiver.VerifyArchive(run.ctx, artifacts.archivePath); err != nil {
+ return &BackupError{Phase: "verification", Err: err, Code: types.ExitVerificationError}
+ }
+
+ checksum, err := o.generateArchiveChecksum(run.ctx, artifacts.archivePath)
+ if err != nil {
+ return err
+ }
+ stats.Checksum = checksum
+
+ if err := o.writeArchiveChecksum(workspace, artifacts, checksum); err != nil {
+ return &BackupError{
+ Phase: "verification",
+ Err: err,
+ Code: types.ExitVerificationError,
+ }
+ }
+ if err := o.writeArchiveManifest(run, artifacts, checksum); err != nil {
+ return err
+ }
+ o.writeLegacyMetadataAlias(workspace, artifacts)
+ return nil
+}
+
+func (o *Orchestrator) bundleBackupArtifacts(run *backupRunContext, workspace *backupWorkspace, artifacts *backupArtifacts) error {
+ if o.dryRun {
+ return nil
+ }
+
+ bundleEnabled := o.cfg != nil && o.cfg.BundleAssociatedFiles
+ if !bundleEnabled {
+ fmt.Println()
+ o.logger.Skip("Bundling disabled")
+ run.stats.EndTime = o.now()
+ o.logger.Info("✓ Archive created and verified")
+ return nil
+ }
+
+ fmt.Println()
+ o.logStep(5, "Bundling of archive, checksum and metadata")
+ o.logger.Debug("Bundling enabled: creating bundle from %s", filepath.Base(artifacts.archivePath))
+ bundlePath, err := o.createBundle(run.ctx, artifacts.archivePath)
+ if err != nil {
+ return &BackupError{
+ Phase: "archive",
+ Err: fmt.Errorf("bundle creation failed: %w", err),
+ Code: types.ExitArchiveError,
+ }
+ }
+
+ if err := o.removeAssociatedFiles(artifacts.archivePath); err != nil {
+ o.logger.Warning("Failed to remove raw files after bundling: %v", err)
+ } else {
+ o.logger.Debug("Removed raw tar/checksum/metadata after bundling")
+ }
+
+ stats := run.stats
+ if info, err := workspace.fs.Stat(bundlePath); err == nil {
+ stats.ArchiveSize = info.Size()
+ stats.CompressedSize = info.Size()
+ stats.updateCompressionMetrics()
+ }
+ stats.ArchivePath = bundlePath
+ stats.ManifestPath = ""
+ stats.BundleCreated = true
+ artifacts.bundlePath = bundlePath
+ artifacts.archivePath = bundlePath
+ o.logger.Debug("Bundle ready: %s", filepath.Base(bundlePath))
+
+ stats.EndTime = o.now()
+ o.logger.Info("✓ Archive created and verified")
+ return nil
+}
+
+func (o *Orchestrator) finalizeBackupStats(run *backupRunContext) {
+ stats := run.stats
+ stats.Duration = stats.EndTime.Sub(stats.StartTime)
+
+ if stats.LogFilePath != "" {
+ o.logger.Debug("Parsing log file for error/warning counts: %s", stats.LogFilePath)
+ _, errorCount, warningCount := ParseLogCounts(stats.LogFilePath, 0)
+ stats.ErrorCount = errorCount
+ stats.WarningCount = warningCount
+ if errorCount > 0 || warningCount > 0 {
+ o.logger.Debug("Found %d errors and %d warnings in log file", errorCount, warningCount)
+ }
+ } else {
+ o.logger.Debug("No log file path specified, error/warning counts will be 0")
+ }
+
+ switch {
+ case stats.ErrorCount > 0:
+ stats.ExitCode = types.ExitBackupError.Int()
+ case stats.WarningCount > 0:
+ stats.ExitCode = types.ExitGenericError.Int()
+ default:
+ stats.ExitCode = types.ExitSuccess.Int()
+ }
+ o.logger.Debug("Aggregated exit code based on log analysis: %d", stats.ExitCode)
+}
+
+func (o *Orchestrator) dispatchBackupArtifacts(run *backupRunContext) error {
+ if len(o.storageTargets) == 0 {
+ fmt.Println()
+ o.logStep(6, "No storage targets registered - skipping")
+ } else if o.dryRun {
+ fmt.Println()
+ o.logStep(6, "Storage dispatch skipped (dry run mode)")
+ } else {
+ fmt.Println()
+ o.logStep(6, "Dispatching archive to %d storage target(s)", len(o.storageTargets))
+ o.logGlobalRetentionPolicy()
+ }
+
+ if o.dryRun {
+ return nil
+ }
+
+ o.logger.Debug("Dispatching archive to %d storage targets", len(o.storageTargets))
+ return o.dispatchPostBackup(run.ctx, run.stats)
+}
diff --git a/internal/orchestrator/backup_run_phases_test.go b/internal/orchestrator/backup_run_phases_test.go
new file mode 100644
index 00000000..dcfc270c
--- /dev/null
+++ b/internal/orchestrator/backup_run_phases_test.go
@@ -0,0 +1,60 @@
+package orchestrator
+
+import (
+ "context"
+ "errors"
+ "strings"
+ "testing"
+
+ "github.com/tis24dev/proxsave/internal/config"
+ "github.com/tis24dev/proxsave/internal/types"
+)
+
+func TestCreateBackupArchiveClassifiesAgeRecipientFailureAsEncryption(t *testing.T) {
+ orch := New(newTestLogger(), false)
+ orch.SetConfig(&config.Config{
+ EncryptArchive: true,
+ BaseDir: t.TempDir(),
+ })
+ orch.SetBackupConfig(t.TempDir(), t.TempDir(), types.CompressionNone, 0, 0, "standard", nil)
+
+ run := orch.newBackupRunContext(context.Background(), nil, "test-host")
+ _, err := orch.createBackupArchive(run, &backupWorkspace{tempDir: t.TempDir()})
+ if err == nil {
+ t.Fatal("expected createBackupArchive error")
+ }
+
+ var backupErr *BackupError
+ if !errors.As(err, &backupErr) {
+ t.Fatalf("expected BackupError, got %T: %v", err, err)
+ }
+ if backupErr.Phase != "encryption" {
+ t.Fatalf("Phase=%q; want encryption", backupErr.Phase)
+ }
+ if backupErr.Code != types.ExitEncryptionError {
+ t.Fatalf("Code=%v; want %v", backupErr.Code, types.ExitEncryptionError)
+ }
+}
+
+func TestWriteArchiveChecksumPropagatesWriteError(t *testing.T) {
+ orch := New(newTestLogger(), false)
+ checksumPath := "/backups/test.tar.sha256"
+ writeErr := errors.New("disk full")
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = fakeFS.Cleanup() })
+
+ err := orch.writeArchiveChecksum(
+ &backupWorkspace{fs: writeFileFailFS{FS: fakeFS, failPath: checksumPath, err: writeErr}},
+ &backupArtifacts{archivePath: "/backups/test.tar", checksumPath: checksumPath},
+ "abc123",
+ )
+ if err == nil {
+ t.Fatal("expected writeArchiveChecksum error")
+ }
+ if !errors.Is(err, writeErr) {
+ t.Fatalf("expected wrapped write error, got %v", err)
+ }
+ if !strings.Contains(err.Error(), checksumPath) {
+ t.Fatalf("expected checksum path in error, got %q", err.Error())
+ }
+}
diff --git a/internal/orchestrator/backup_sources.go b/internal/orchestrator/backup_sources.go
index 86b5f2b6..05225a8b 100644
--- a/internal/orchestrator/backup_sources.go
+++ b/internal/orchestrator/backup_sources.go
@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
- "os/exec"
"path"
"path/filepath"
"sort"
@@ -17,6 +16,7 @@ import (
"github.com/tis24dev/proxsave/internal/backup"
"github.com/tis24dev/proxsave/internal/config"
"github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/safeexec"
)
// decryptPathOption describes a logical backup source (local, secondary, cloud)
@@ -117,7 +117,10 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s
// Use rclone lsf to list files inside the backup directory
lsfCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
- cmd := exec.CommandContext(lsfCtx, "rclone", "lsf", fullPath)
+ cmd, err := safeexec.CommandContext(lsfCtx, "rclone", "lsf", fullPath)
+ if err != nil {
+ return nil, err
+ }
lsfStart := time.Now()
output, err := cmd.CombinedOutput()
if err != nil {
@@ -545,7 +548,10 @@ func inspectRcloneChecksumFile(ctx context.Context, remotePath string, logger *l
defer func() { done(err) }()
logging.DebugStep(logger, "inspect rclone checksum", "executing: rclone cat %s", remotePath)
- cmd := exec.CommandContext(ctx, "rclone", "cat", remotePath)
+ cmd, err := safeexec.CommandContext(ctx, "rclone", "cat", remotePath)
+ if err != nil {
+ return "", err
+ }
stdout, err := cmd.StdoutPipe()
if err != nil {
return "", fmt.Errorf("start rclone cat %s: %w", remotePath, err)
diff --git a/internal/orchestrator/decompress_reader_test.go b/internal/orchestrator/decompress_reader_test.go
index 542c7bc8..ddf86fef 100644
--- a/internal/orchestrator/decompress_reader_test.go
+++ b/internal/orchestrator/decompress_reader_test.go
@@ -1,9 +1,12 @@
package orchestrator
import (
+ "bytes"
"context"
+ "errors"
"io"
"os"
+ "path/filepath"
"strings"
"testing"
)
@@ -36,6 +39,7 @@ func TestCreateDecompressionReaderTar(t *testing.T) {
if reader == nil {
t.Fatalf("reader should not be nil for tar")
}
+ _ = reader.Close()
}
type fakeStreamCommandRunner struct {
@@ -59,6 +63,28 @@ func (f *fakeStreamCommandRunner) RunStream(ctx context.Context, name string, st
return io.NopCloser(strings.NewReader("")), nil
}
+type extractionCloseErrorReadCloser struct {
+ *bytes.Reader
+ err error
+}
+
+func (r *extractionCloseErrorReadCloser) Close() error {
+ return r.err
+}
+
+type closeErrorStreamCommandRunner struct {
+ data []byte
+ closeErr error
+}
+
+func (f *closeErrorStreamCommandRunner) Run(context.Context, string, ...string) ([]byte, error) {
+ return nil, nil
+}
+
+func (f *closeErrorStreamCommandRunner) RunStream(context.Context, string, io.Reader, ...string) (io.ReadCloser, error) {
+ return &extractionCloseErrorReadCloser{Reader: bytes.NewReader(f.data), err: f.closeErr}, nil
+}
+
func TestCreateDecompressionReaderUsesStreamingRunnerForCompressedFormats(t *testing.T) {
orig := restoreCmd
t.Cleanup(func() { restoreCmd = orig })
@@ -98,14 +124,9 @@ func TestCreateDecompressionReaderUsesStreamingRunnerForCompressedFormats(t *tes
if err != nil {
t.Fatalf("createDecompressionReader(%s) error: %v", tt.ext, err)
}
+ defer reader.Close()
- rc, ok := reader.(io.ReadCloser)
- if !ok {
- t.Fatalf("expected io.ReadCloser, got %T", reader)
- }
- defer rc.Close()
-
- out, err := io.ReadAll(rc)
+ out, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
@@ -118,3 +139,40 @@ func TestCreateDecompressionReaderUsesStreamingRunnerForCompressedFormats(t *tes
})
}
}
+
+func TestExtractArchiveNativeReturnsDecompressionCloseError(t *testing.T) {
+ origCmd := restoreCmd
+ origFS := restoreFS
+ t.Cleanup(func() {
+ restoreCmd = origCmd
+ restoreFS = origFS
+ })
+
+ dir := t.TempDir()
+ tarPath := filepath.Join(dir, "source.tar")
+ if err := writeTarFile(tarPath, map[string]string{"etc/example.conf": "ok\n"}); err != nil {
+ t.Fatalf("writeTarFile: %v", err)
+ }
+ tarData, err := os.ReadFile(tarPath)
+ if err != nil {
+ t.Fatalf("ReadFile: %v", err)
+ }
+
+ closeErr := errors.New("decompressor exited 2")
+ restoreCmd = &closeErrorStreamCommandRunner{data: tarData, closeErr: closeErr}
+ restoreFS = osFS{}
+
+ archivePath := filepath.Join(dir, "archive.tar.zst")
+ if err := os.WriteFile(archivePath, []byte("compressed"), 0o640); err != nil {
+ t.Fatalf("WriteFile: %v", err)
+ }
+
+ err = extractArchiveNative(context.Background(), restoreArchiveOptions{
+ archivePath: archivePath,
+ destRoot: filepath.Join(dir, "dest"),
+ logger: newTestLogger(),
+ })
+ if !errors.Is(err, closeErr) {
+ t.Fatalf("extractArchiveNative error = %v, want close error %v", err, closeErr)
+ }
+}
diff --git a/internal/orchestrator/decrypt.go b/internal/orchestrator/decrypt.go
index 9ce590d5..f55cde81 100644
--- a/internal/orchestrator/decrypt.go
+++ b/internal/orchestrator/decrypt.go
@@ -10,7 +10,6 @@ import (
"fmt"
"io"
"os"
- "os/exec"
"path"
"path/filepath"
"strconv"
@@ -22,6 +21,7 @@ import (
"github.com/tis24dev/proxsave/internal/config"
"github.com/tis24dev/proxsave/internal/input"
"github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/safeexec"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
@@ -186,7 +186,10 @@ func inspectRcloneBundleManifest(ctx context.Context, remotePath string, logger
defer cancel()
logging.DebugStep(logger, "inspect rclone bundle manifest", "executing: rclone cat %s", remotePath)
- cmd := exec.CommandContext(cmdCtx, "rclone", "cat", remotePath)
+ cmd, err := safeexec.CommandContext(cmdCtx, "rclone", "cat", remotePath)
+ if err != nil {
+ return nil, fmt.Errorf("prepare rclone cat: %w", err)
+ }
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("open rclone stream: %w", err)
@@ -270,7 +273,10 @@ func inspectRcloneMetadataManifest(ctx context.Context, remoteMetadataPath, remo
defer func() { done(err) }()
logging.DebugStep(logger, "inspect rclone metadata manifest", "executing: rclone cat %s", remoteMetadataPath)
- cmd := exec.CommandContext(ctx, "rclone", "cat", remoteMetadataPath)
+ cmd, err := safeexec.CommandContext(ctx, "rclone", "cat", remoteMetadataPath)
+ if err != nil {
+ return nil, fmt.Errorf("prepare rclone metadata cat: %w", err)
+ }
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("rclone cat %s failed: %w (output: %s)", remoteMetadataPath, err, strings.TrimSpace(string(output)))
@@ -418,7 +424,10 @@ func downloadRcloneBackup(ctx context.Context, remotePath string, logger *loggin
logging.DebugStep(logger, "download rclone backup", "local temp file=%s", tmpPath)
// Use rclone copyto to download with progress
- cmd := exec.CommandContext(ctx, "rclone", "copyto", remotePath, tmpPath, "--progress")
+ cmd, err := safeexec.CommandContext(ctx, "rclone", "copyto", remotePath, tmpPath, "--progress")
+ if err != nil {
+ return "", nil, err
+ }
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -573,7 +582,10 @@ func rcloneCopyTo(ctx context.Context, remotePath, localPath string, showProgres
if showProgress {
args = append(args, "--progress")
}
- cmd := exec.CommandContext(ctx, "rclone", args...)
+ cmd, err := safeexec.CommandContext(ctx, "rclone", args...)
+ if err != nil {
+ return err
+ }
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
diff --git a/internal/orchestrator/decrypt_test.go b/internal/orchestrator/decrypt_test.go
index 4932b1bf..a952e3b0 100644
--- a/internal/orchestrator/decrypt_test.go
+++ b/internal/orchestrator/decrypt_test.go
@@ -190,7 +190,6 @@ func TestBuildDecryptPathOptions(t *testing.T) {
}
func TestBaseNameFromRemoteRef(t *testing.T) {
- t.Parallel()
tests := []struct {
in string
want string
@@ -464,7 +463,6 @@ func TestParseIdentityInput(t *testing.T) {
}
func TestSanitizeBundleEntryName(t *testing.T) {
- t.Parallel()
tests := []struct {
name string
input string
@@ -489,7 +487,6 @@ func TestSanitizeBundleEntryName(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
got, err := sanitizeBundleEntryName(tt.input)
if tt.expectErr {
if err == nil {
@@ -567,7 +564,12 @@ func createTestBundle(t *testing.T, entries []bundleEntry) string {
t.Helper()
dir := t.TempDir()
bundlePath := filepath.Join(dir, "bundle.tar")
+ createTestBundleAt(t, bundlePath, entries)
+ return bundlePath
+}
+func createTestBundleAt(t *testing.T, bundlePath string, entries []bundleEntry) {
+ t.Helper()
f, err := os.Create(bundlePath)
if err != nil {
t.Fatalf("create bundle: %v", err)
@@ -591,7 +593,88 @@ func createTestBundle(t *testing.T, entries []bundleEntry) string {
if err := tw.Close(); err != nil {
t.Fatalf("close tar writer: %v", err)
}
- return bundlePath
+}
+
+func createPlainBackupBundle(t *testing.T, bundlePath string, archiveData []byte, manifest backup.Manifest, includeChecksum bool) {
+ t.Helper()
+ metaJSON, err := json.Marshal(manifest)
+ if err != nil {
+ t.Fatalf("marshal manifest: %v", err)
+ }
+
+ entries := []bundleEntry{
+ {name: "backup.tar.xz", data: archiveData},
+ {name: "backup.metadata", data: metaJSON},
+ }
+ if includeChecksum {
+ entries = append(entries, bundleEntry{
+ name: "backup.sha256",
+ data: []byte(checksumLineForBytes("backup.tar.xz", archiveData)),
+ })
+ }
+ createTestBundleAt(t, bundlePath, entries)
+}
+
+func useRestoreFS(t *testing.T, fs FS) {
+ t.Helper()
+ orig := restoreFS
+ restoreFS = fs
+ t.Cleanup(func() { restoreFS = orig })
+}
+
+type rawArtifactFixture struct {
+ candidate *backupCandidate
+ workDir string
+}
+
+func createRawArtifactFixture(t *testing.T, includeMetadata bool, checksumContent string) rawArtifactFixture {
+ t.Helper()
+ srcDir := t.TempDir()
+
+ archivePath := filepath.Join(srcDir, "backup.tar.xz")
+ if err := os.WriteFile(archivePath, []byte("archive data"), 0o644); err != nil {
+ t.Fatalf("write archive: %v", err)
+ }
+
+ metadataPath := "/nonexistent/backup.metadata"
+ if includeMetadata {
+ metadataPath = filepath.Join(srcDir, "backup.metadata")
+ if err := os.WriteFile(metadataPath, []byte("{}"), 0o644); err != nil {
+ t.Fatalf("write metadata: %v", err)
+ }
+ }
+
+ checksumPath := ""
+ if checksumContent != "" {
+ checksumPath = filepath.Join(srcDir, "backup.sha256")
+ if err := os.WriteFile(checksumPath, []byte(checksumContent), 0o644); err != nil {
+ t.Fatalf("write checksum: %v", err)
+ }
+ }
+
+ return rawArtifactFixture{
+ candidate: &backupCandidate{
+ RawArchivePath: archivePath,
+ RawMetadataPath: metadataPath,
+ RawChecksumPath: checksumPath,
+ },
+ workDir: t.TempDir(),
+ }
+}
+
+func plainBundleCandidate(path string, manifest *backup.Manifest) *backupCandidate {
+ return &backupCandidate{
+ Source: sourceBundle,
+ BundlePath: path,
+ Manifest: manifest,
+ }
+}
+
+func preparePlainBundleTestInput(path string, manifest *backup.Manifest) (*backupCandidate, context.Context, *bufio.Reader, *logging.Logger) {
+ return plainBundleCandidate(path, manifest),
+ context.Background(),
+ bufio.NewReader(strings.NewReader("")),
+ logging.New(types.LogLevelError, false)
}
func TestEnsureWritablePath(t *testing.T) {
@@ -1916,34 +1999,10 @@ func TestMoveFileSafe_SameDevice(t *testing.T) {
// =====================================
func TestCopyRawArtifactsToWorkdir_Success(t *testing.T) {
- origFS := restoreFS
- restoreFS = osFS{}
- t.Cleanup(func() { restoreFS = origFS })
+ useRestoreFS(t, osFS{})
- srcDir := t.TempDir()
- workDir := t.TempDir()
-
- // Create source files
- archivePath := filepath.Join(srcDir, "backup.tar.xz")
- if err := os.WriteFile(archivePath, []byte("archive data"), 0o644); err != nil {
- t.Fatalf("write archive: %v", err)
- }
- metadataPath := filepath.Join(srcDir, "backup.metadata")
- if err := os.WriteFile(metadataPath, []byte("{}"), 0o644); err != nil {
- t.Fatalf("write metadata: %v", err)
- }
- checksumPath := filepath.Join(srcDir, "backup.sha256")
- if err := os.WriteFile(checksumPath, []byte("checksum"), 0o644); err != nil {
- t.Fatalf("write checksum: %v", err)
- }
-
- cand := &backupCandidate{
- RawArchivePath: archivePath,
- RawMetadataPath: metadataPath,
- RawChecksumPath: checksumPath,
- }
-
- staged, err := copyRawArtifactsToWorkdir(context.Background(), cand, workDir)
+ fixture := createRawArtifactFixture(t, true, "checksum")
+ staged, err := copyRawArtifactsToWorkdir(context.Background(), fixture.candidate, fixture.workDir)
if err != nil {
t.Fatalf("copyRawArtifactsToWorkdir error: %v", err)
}
@@ -1953,9 +2012,7 @@ func TestCopyRawArtifactsToWorkdir_Success(t *testing.T) {
}
func TestCopyRawArtifactsToWorkdir_ArchiveError(t *testing.T) {
- origFS := restoreFS
- restoreFS = osFS{}
- t.Cleanup(func() { restoreFS = origFS })
+ useRestoreFS(t, osFS{})
cand := &backupCandidate{
RawArchivePath: "/nonexistent/archive.tar.xz",
@@ -1973,26 +2030,11 @@ func TestCopyRawArtifactsToWorkdir_ArchiveError(t *testing.T) {
}
func TestCopyRawArtifactsToWorkdir_MetadataError(t *testing.T) {
- origFS := restoreFS
- restoreFS = osFS{}
- t.Cleanup(func() { restoreFS = origFS })
-
- srcDir := t.TempDir()
- workDir := t.TempDir()
-
- // Create only archive, no metadata
- archivePath := filepath.Join(srcDir, "backup.tar.xz")
- if err := os.WriteFile(archivePath, []byte("archive data"), 0o644); err != nil {
- t.Fatalf("write archive: %v", err)
- }
+ useRestoreFS(t, osFS{})
- cand := &backupCandidate{
- RawArchivePath: archivePath,
- RawMetadataPath: "/nonexistent/backup.metadata",
- RawChecksumPath: "/nonexistent/backup.sha256",
- }
-
- _, err := copyRawArtifactsToWorkdir(context.Background(), cand, workDir)
+ fixture := createRawArtifactFixture(t, false, "")
+ fixture.candidate.RawChecksumPath = "/nonexistent/backup.sha256"
+ _, err := copyRawArtifactsToWorkdir(context.Background(), fixture.candidate, fixture.workDir)
if err == nil {
t.Fatal("expected error for nonexistent metadata")
}
@@ -2002,30 +2044,11 @@ func TestCopyRawArtifactsToWorkdir_MetadataError(t *testing.T) {
}
func TestCopyRawArtifactsToWorkdir_ChecksumError(t *testing.T) {
- origFS := restoreFS
- restoreFS = osFS{}
- t.Cleanup(func() { restoreFS = origFS })
-
- srcDir := t.TempDir()
- workDir := t.TempDir()
-
- // Create archive and metadata, no checksum
- archivePath := filepath.Join(srcDir, "backup.tar.xz")
- if err := os.WriteFile(archivePath, []byte("archive data"), 0o644); err != nil {
- t.Fatalf("write archive: %v", err)
- }
- metadataPath := filepath.Join(srcDir, "backup.metadata")
- if err := os.WriteFile(metadataPath, []byte("{}"), 0o644); err != nil {
- t.Fatalf("write metadata: %v", err)
- }
-
- cand := &backupCandidate{
- RawArchivePath: archivePath,
- RawMetadataPath: metadataPath,
- RawChecksumPath: "/nonexistent/backup.sha256",
- }
+ useRestoreFS(t, osFS{})
- staged, err := copyRawArtifactsToWorkdir(context.Background(), cand, workDir)
+ fixture := createRawArtifactFixture(t, true, "")
+ fixture.candidate.RawChecksumPath = "/nonexistent/backup.sha256"
+ staged, err := copyRawArtifactsToWorkdir(context.Background(), fixture.candidate, fixture.workDir)
if err != nil {
t.Fatalf("expected checksum to be optional, got error: %v", err)
}
@@ -2535,30 +2558,10 @@ func TestInspectRcloneMetadataManifest_RcloneFails(t *testing.T) {
// =====================================
func TestCopyRawArtifactsToWorkdir_ContextWorks(t *testing.T) {
- origFS := restoreFS
- restoreFS = osFS{}
- t.Cleanup(func() { restoreFS = origFS })
-
- srcDir := t.TempDir()
- workDir := t.TempDir()
-
- // Create source files
- archivePath := filepath.Join(srcDir, "backup.tar.xz")
- if err := os.WriteFile(archivePath, []byte("archive data"), 0o644); err != nil {
- t.Fatalf("write archive: %v", err)
- }
- metadataPath := filepath.Join(srcDir, "backup.metadata")
- if err := os.WriteFile(metadataPath, []byte("{}"), 0o644); err != nil {
- t.Fatalf("write metadata: %v", err)
- }
+ useRestoreFS(t, osFS{})
- cand := &backupCandidate{
- RawArchivePath: archivePath,
- RawMetadataPath: metadataPath,
- RawChecksumPath: "",
- }
-
- staged, err := copyRawArtifactsToWorkdirWithLogger(context.TODO(), cand, workDir, nil)
+ fixture := createRawArtifactFixture(t, true, "")
+ staged, err := copyRawArtifactsToWorkdirWithLogger(context.TODO(), fixture.candidate, fixture.workDir, nil)
if err != nil {
t.Fatalf("copyRawArtifactsToWorkdirWithLogger error: %v", err)
}
@@ -3305,34 +3308,10 @@ func TestInspectRcloneBundleManifest_StartError(t *testing.T) {
}
func TestCopyRawArtifactsToWorkdir_WithChecksum(t *testing.T) {
- origFS := restoreFS
- restoreFS = osFS{}
- t.Cleanup(func() { restoreFS = origFS })
-
- srcDir := t.TempDir()
- workDir := t.TempDir()
-
- // Create source files including checksum
- archivePath := filepath.Join(srcDir, "backup.tar.xz")
- if err := os.WriteFile(archivePath, []byte("archive data"), 0o644); err != nil {
- t.Fatalf("write archive: %v", err)
- }
- metadataPath := filepath.Join(srcDir, "backup.metadata")
- if err := os.WriteFile(metadataPath, []byte("{}"), 0o644); err != nil {
- t.Fatalf("write metadata: %v", err)
- }
- checksumPath := filepath.Join(srcDir, "backup.sha256")
- if err := os.WriteFile(checksumPath, []byte("checksum backup.tar.xz"), 0o644); err != nil {
- t.Fatalf("write checksum: %v", err)
- }
+ useRestoreFS(t, osFS{})
- cand := &backupCandidate{
- RawArchivePath: archivePath,
- RawMetadataPath: metadataPath,
- RawChecksumPath: checksumPath,
- }
-
- staged, err := copyRawArtifactsToWorkdirWithLogger(context.Background(), cand, workDir, nil)
+ fixture := createRawArtifactFixture(t, true, "checksum backup.tar.xz")
+ staged, err := copyRawArtifactsToWorkdirWithLogger(context.Background(), fixture.candidate, fixture.workDir, nil)
if err != nil {
t.Fatalf("copyRawArtifactsToWorkdirWithLogger error: %v", err)
}
@@ -3713,26 +3692,9 @@ func TestSelectDecryptCandidate_RequireEncryptedAllPlain(t *testing.T) {
// Create a plain backup bundle (must have .bundle.tar suffix)
bundlePath := filepath.Join(backupDir, "backup-2024-01-01.bundle.tar")
- bundleFile, _ := os.Create(bundlePath)
- tw := tar.NewWriter(bundleFile)
-
- // Add archive (plain, no .age extension)
archiveData := []byte("archive content")
- tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o640})
- tw.Write(archiveData)
-
- // Add metadata with encryption=none
manifest := backup.Manifest{EncryptionMode: "none", Hostname: "test", CreatedAt: time.Now(), ArchivePath: "backup.tar.xz"}
- metaJSON, _ := json.Marshal(manifest)
- tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
- tw.Write(metaJSON)
-
- // Add checksum
- checksum := checksumLineForBytes("backup.tar.xz", archiveData)
- tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
- tw.Write(checksum)
- tw.Close()
- bundleFile.Close()
+ createPlainBackupBundle(t, bundlePath, archiveData, manifest, true)
cfg := &config.Config{
BackupPath: backupDir,
@@ -3824,23 +3786,9 @@ exit 1
// Bundle must have .bundle.tar suffix to be discovered
bundlePath := filepath.Join(backupDir, "backup.bundle.tar")
- bundleFile, _ := os.Create(bundlePath)
- tw := tar.NewWriter(bundleFile)
-
archiveData := []byte("archive content")
- tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o640})
- tw.Write(archiveData)
-
manifest := backup.Manifest{EncryptionMode: "age", Hostname: "test", CreatedAt: time.Now(), ArchivePath: "backup.tar.xz"}
- metaJSON, _ := json.Marshal(manifest)
- tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
- tw.Write(metaJSON)
-
- checksum := checksumLineForBytes("backup.tar.xz", archiveData)
- tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
- tw.Write(checksum)
- tw.Close()
- bundleFile.Close()
+ createPlainBackupBundle(t, bundlePath, archiveData, manifest, true)
cfg := &config.Config{
BackupPath: backupDir,
@@ -3875,23 +3823,9 @@ func TestPreparePlainBundle_StatErrorAfterExtract(t *testing.T) {
// Create a valid bundle
bundlePath := filepath.Join(tmp, "bundle.tar")
- bundleFile, _ := os.Create(bundlePath)
- tw := tar.NewWriter(bundleFile)
-
archiveData := []byte("archive content")
- tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o640})
- tw.Write(archiveData)
-
manifest := backup.Manifest{EncryptionMode: "none", Hostname: "test", CreatedAt: time.Now()}
- metaJSON, _ := json.Marshal(manifest)
- tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
- tw.Write(metaJSON)
-
- checksum := checksumLineForBytes("backup.tar.xz", archiveData)
- tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
- tw.Write(checksum)
- tw.Close()
- bundleFile.Close()
+ createPlainBackupBundle(t, bundlePath, archiveData, manifest, true)
// Create FakeFS that will fail on stat for the extracted archive
fake := NewFakeFS()
@@ -3908,18 +3842,11 @@ func TestPreparePlainBundle_StatErrorAfterExtract(t *testing.T) {
fake.StatErr["/tmp/proxsave"] = nil // Allow this stat
// After extraction, stat will be called on the plain archive - we set error later
- orig := restoreFS
- restoreFS = fake
- defer func() { restoreFS = orig }()
-
- cand := &backupCandidate{
- Source: sourceBundle,
- BundlePath: bundlePath,
- Manifest: &backup.Manifest{EncryptionMode: "none", Hostname: "test"},
- }
- ctx := context.Background()
- reader := bufio.NewReader(strings.NewReader(""))
- logger := logging.New(types.LogLevelError, false)
+ useRestoreFS(t, fake)
+ cand, ctx, reader, logger := preparePlainBundleTestInput(
+ bundlePath,
+ &backup.Manifest{EncryptionMode: "none", Hostname: "test"},
+ )
// The test shows that with proper setup, stat error would be triggered
// For now, run with FakeFS to cover the MkdirAll/MkdirTemp paths
@@ -3975,19 +3902,9 @@ func TestPreparePlainBundle_MkdirTempErrorWithRcloneCleanup(t *testing.T) {
// Create a fake downloaded bundle file
bundlePath := filepath.Join(tmp, "downloaded.bundle.tar")
- bundleFile, _ := os.Create(bundlePath)
- tw := tar.NewWriter(bundleFile)
archiveData := []byte("data")
- tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o640})
- tw.Write(archiveData)
- metaJSON, _ := json.Marshal(backup.Manifest{EncryptionMode: "none", ArchivePath: "backup.tar.xz"})
- tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
- tw.Write(metaJSON)
- checksum := checksumLineForBytes("backup.tar.xz", archiveData)
- tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
- tw.Write(checksum)
- tw.Close()
- bundleFile.Close()
+ manifest := backup.Manifest{EncryptionMode: "none", ArchivePath: "backup.tar.xz"}
+ createPlainBackupBundle(t, bundlePath, archiveData, manifest, true)
// Track if cleanup was called
cleanupCalled := false
@@ -4161,23 +4078,9 @@ func TestPreparePlainBundle_CopyFileError(t *testing.T) {
// Create a valid bundle
bundlePath := filepath.Join(tmp, "bundle.tar")
- bundleFile, _ := os.Create(bundlePath)
- tw := tar.NewWriter(bundleFile)
-
archiveData := []byte("archive content")
- tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o640})
- tw.Write(archiveData)
-
manifest := backup.Manifest{EncryptionMode: "none", Hostname: "test", CreatedAt: time.Now(), ArchivePath: "backup.tar.xz"}
- metaJSON, _ := json.Marshal(manifest)
- tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
- tw.Write(metaJSON)
-
- checksum := checksumLineForBytes("backup.tar.xz", archiveData)
- tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
- tw.Write(checksum)
- tw.Close()
- bundleFile.Close()
+ createPlainBackupBundle(t, bundlePath, archiveData, manifest, true)
// Use FakeFS
fake := NewFakeFS()
@@ -4190,18 +4093,11 @@ func TestPreparePlainBundle_CopyFileError(t *testing.T) {
// After extraction, set OpenFile error for the archive copy destination
// The copyFile function will try to create the destination file
- orig := restoreFS
- restoreFS = fake
- defer func() { restoreFS = orig }()
-
- cand := &backupCandidate{
- Source: sourceBundle,
- BundlePath: bundlePath,
- Manifest: &backup.Manifest{EncryptionMode: "none", Hostname: "test"},
- }
- ctx := context.Background()
- reader := bufio.NewReader(strings.NewReader(""))
- logger := logging.New(types.LogLevelError, false)
+ useRestoreFS(t, fake)
+ cand, ctx, reader, logger := preparePlainBundleTestInput(
+ bundlePath,
+ &backup.Manifest{EncryptionMode: "none", Hostname: "test"},
+ )
bundle, err := preparePlainBundle(ctx, reader, cand, "1.0.0", logger)
// This test verifies that the path goes through successfully for plain archives
@@ -4270,39 +4166,18 @@ func TestPreparePlainBundle_StatErrorOnPlainArchive(t *testing.T) {
// Create a valid bundle with plain (non-encrypted) archive
bundlePath := filepath.Join(tmp, "bundle.tar")
- bundleFile, _ := os.Create(bundlePath)
- tw := tar.NewWriter(bundleFile)
-
archiveData := []byte("archive content for stat test")
- tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o640})
- tw.Write(archiveData)
-
manifest := backup.Manifest{EncryptionMode: "none", Hostname: "test", CreatedAt: time.Now(), ArchivePath: "backup.tar.xz"}
- metaJSON, _ := json.Marshal(manifest)
- tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
- tw.Write(metaJSON)
-
- checksum := checksumLineForBytes("backup.tar.xz", archiveData)
- tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
- tw.Write(checksum)
- tw.Close()
- bundleFile.Close()
+ createPlainBackupBundle(t, bundlePath, archiveData, manifest, true)
// Use wrapped osFS that fails stat on plain archive after several calls
fake := &fakeStatFailOnPlainArchive{}
- orig := restoreFS
- restoreFS = fake
- defer func() { restoreFS = orig }()
-
- cand := &backupCandidate{
- Source: sourceBundle,
- BundlePath: bundlePath,
- Manifest: &backup.Manifest{EncryptionMode: "none", Hostname: "test"},
- }
- ctx := context.Background()
- reader := bufio.NewReader(strings.NewReader(""))
- logger := logging.New(types.LogLevelError, false)
+ useRestoreFS(t, fake)
+ cand, ctx, reader, logger := preparePlainBundleTestInput(
+ bundlePath,
+ &backup.Manifest{EncryptionMode: "none", Hostname: "test"},
+ )
bundle, err := preparePlainBundle(ctx, reader, cand, "1.0.0", logger)
if err == nil {
@@ -4328,17 +4203,9 @@ func TestPreparePlainBundle_MkdirAllErrorWithRcloneDownloadCleanup(t *testing.T)
// Create a valid bundle that rclone will "download"
bundlePath := filepath.Join(downloadDir, "backup.bundle.tar")
- bundleFile, _ := os.Create(bundlePath)
- tw := tar.NewWriter(bundleFile)
archiveData := []byte("archive content")
- tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o640})
- tw.Write(archiveData)
manifest := backup.Manifest{EncryptionMode: "none", Hostname: "test", CreatedAt: time.Now()}
- metaJSON, _ := json.Marshal(manifest)
- tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
- tw.Write(metaJSON)
- tw.Close()
- bundleFile.Close()
+ createPlainBackupBundle(t, bundlePath, archiveData, manifest, false)
// Script that copies the pre-made bundle to the destination
script := fmt.Sprintf(`#!/bin/bash
@@ -4404,39 +4271,18 @@ func TestPreparePlainBundle_GenerateChecksumErrorPath(t *testing.T) {
// Create a valid bundle
bundlePath := filepath.Join(tmp, "bundle.tar")
- bundleFile, _ := os.Create(bundlePath)
- tw := tar.NewWriter(bundleFile)
-
archiveData := []byte("archive content for checksum error test")
- tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o640})
- tw.Write(archiveData)
-
manifest := backup.Manifest{EncryptionMode: "none", Hostname: "test", CreatedAt: time.Now(), ArchivePath: "backup.tar.xz"}
- metaJSON, _ := json.Marshal(manifest)
- tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
- tw.Write(metaJSON)
-
- checksum := checksumLineForBytes("backup.tar.xz", archiveData)
- tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
- tw.Write(checksum)
- tw.Close()
- bundleFile.Close()
+ createPlainBackupBundle(t, bundlePath, archiveData, manifest, true)
// Use FS that removes file after stat
fake := &fakeStatThenRemoveFS{}
- orig := restoreFS
- restoreFS = fake
- defer func() { restoreFS = orig }()
-
- cand := &backupCandidate{
- Source: sourceBundle,
- BundlePath: bundlePath,
- Manifest: &backup.Manifest{EncryptionMode: "none", Hostname: "test"},
- }
- ctx := context.Background()
- reader := bufio.NewReader(strings.NewReader(""))
- logger := logging.New(types.LogLevelError, false)
+ useRestoreFS(t, fake)
+ cand, ctx, reader, logger := preparePlainBundleTestInput(
+ bundlePath,
+ &backup.Manifest{EncryptionMode: "none", Hostname: "test"},
+ )
bundle, err := preparePlainBundle(ctx, reader, cand, "1.0.0", logger)
if err == nil {
@@ -4476,17 +4322,9 @@ func TestPreparePlainBundle_MkdirAllErrorAfterRcloneDownload(t *testing.T) {
// Create the bundle that will be "downloaded"
sourceBundlePath := filepath.Join(bundleDir, "backup.bundle.tar")
- bundleFile, _ := os.Create(sourceBundlePath)
- tw := tar.NewWriter(bundleFile)
archiveData := []byte("archive")
- tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o640})
- tw.Write(archiveData)
manifest := backup.Manifest{EncryptionMode: "none", Hostname: "test", CreatedAt: time.Now()}
- metaJSON, _ := json.Marshal(manifest)
- tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
- tw.Write(metaJSON)
- tw.Close()
- bundleFile.Close()
+ createPlainBackupBundle(t, sourceBundlePath, archiveData, manifest, false)
// Script that copies the bundle to destination
script := fmt.Sprintf(`#!/bin/bash
diff --git a/internal/orchestrator/decrypt_tui_e2e_helpers_test.go b/internal/orchestrator/decrypt_tui_e2e_helpers_test.go
index 84cf85de..9dd6dc32 100644
--- a/internal/orchestrator/decrypt_tui_e2e_helpers_test.go
+++ b/internal/orchestrator/decrypt_tui_e2e_helpers_test.go
@@ -27,31 +27,130 @@ import (
var decryptTUIE2EMu sync.Mutex
+const (
+ timedSimScreenWaitTimeout = 10 * time.Second
+ timedSimCompletionTimeout = 15 * time.Second
+ timedSimDefaultSettle = 15 * time.Millisecond
+ timedSimKeyDelay = 15 * time.Millisecond
+)
+
type notifyingSimulationScreen struct {
tcell.SimulationScreen
- notify func()
+ mu sync.Mutex
+ snapshot timedSimScreenSnapshot
+ notify func()
+}
+
+type timedSimScreenSnapshot struct {
+ cells []tcell.SimCell
+ width int
+ height int
+ cursorX int
+ cursorY int
+ cursorVisible bool
+ ready bool
}
func (s *notifyingSimulationScreen) Show() {
+ s.mu.Lock()
s.SimulationScreen.Show()
- if s.notify != nil {
- s.notify()
- }
+ s.captureLocked()
+ s.mu.Unlock()
+ s.notifyChange()
}
func (s *notifyingSimulationScreen) Sync() {
+ s.mu.Lock()
s.SimulationScreen.Sync()
+ s.captureLocked()
+ s.mu.Unlock()
+ s.notifyChange()
+}
+
+func (s *notifyingSimulationScreen) snapshotState() timedSimScreenSnapshot {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return cloneTimedSimScreenSnapshot(s.snapshot)
+}
+
+func (s *notifyingSimulationScreen) captureLocked() {
+ cells, width, height := s.SimulationScreen.GetContents()
+ cursorX, cursorY, cursorVisible := s.SimulationScreen.GetCursor()
+ s.snapshot = timedSimScreenSnapshot{
+ cells: cloneSimCells(cells),
+ width: width,
+ height: height,
+ cursorX: cursorX,
+ cursorY: cursorY,
+ cursorVisible: cursorVisible,
+ ready: true,
+ }
+}
+
+func (s *notifyingSimulationScreen) notifyChange() {
if s.notify != nil {
s.notify()
}
}
+func cloneTimedSimScreenSnapshot(snapshot timedSimScreenSnapshot) timedSimScreenSnapshot {
+ snapshot.cells = cloneSimCells(snapshot.cells)
+ return snapshot
+}
+
+func cloneSimCells(cells []tcell.SimCell) []tcell.SimCell {
+ if len(cells) == 0 {
+ return nil
+ }
+ cloned := make([]tcell.SimCell, len(cells))
+ for i, cell := range cells {
+ cloned[i] = cell
+ if cell.Bytes != nil {
+ cloned[i].Bytes = append([]byte(nil), cell.Bytes...)
+ }
+ if cell.Runes != nil {
+ cloned[i].Runes = append([]rune(nil), cell.Runes...)
+ }
+ }
+ return cloned
+}
+
type timedSimKey struct {
- Key tcell.Key
- R rune
- Mod tcell.ModMask
- Wait time.Duration
- WaitForText string
+ Key tcell.Key
+ R rune
+ Mod tcell.ModMask
+ WaitForText string
+ Wait time.Duration
+ RequireNewApp bool
+ SettleAfterMatch time.Duration
+}
+
+type timedSimHarness struct {
+ t *testing.T
+ done chan struct{}
+ closeDoneOnce sync.Once
+ injectWG sync.WaitGroup
+ screenStateCh chan struct{}
+ runCompleted chan struct{}
+ closeRunOnce sync.Once
+
+ appMu sync.RWMutex
+ apps []*tui.App
+ current *timedSimAppState
+}
+
+type timedSimAppState struct {
+ generation int
+ app *tui.App
+ screen *notifyingSimulationScreen
+}
+
+type timedSimScreenState struct {
+ generation int
+ text string
+ focusType string
+ ready bool
+ screen *notifyingSimulationScreen
}
type decryptTUIFixture struct {
@@ -68,143 +167,240 @@ type decryptTUIFixture struct {
ExpectedChecksum string
}
-func withTimedSimAppSequence(t *testing.T, keys []timedSimKey) {
+func withTimedSimAppSequence(t *testing.T, keys []timedSimKey) *timedSimHarness {
t.Helper()
decryptTUIE2EMu.Lock()
orig := newTUIApp
- done := make(chan struct{})
- var injectWG sync.WaitGroup
+ h := &timedSimHarness{
+ t: t,
+ done: make(chan struct{}),
+ screenStateCh: make(chan struct{}, 1),
+ runCompleted: make(chan struct{}),
+ }
+
t.Cleanup(func() {
- close(done)
- injectWG.Wait()
+ h.stop()
newTUIApp = orig
decryptTUIE2EMu.Unlock()
})
- baseScreen := tcell.NewSimulationScreen("UTF-8")
- if err := baseScreen.Init(); err != nil {
- t.Fatalf("screen.Init: %v", err)
+ newTUIApp = func() *tui.App {
+ app := tui.NewApp()
+
+ baseScreen := tcell.NewSimulationScreen("UTF-8")
+ if err := baseScreen.Init(); err != nil {
+ t.Fatalf("screen.Init: %v", err)
+ }
+ baseScreen.SetSize(120, 40)
+
+ screen := ¬ifyingSimulationScreen{
+ SimulationScreen: baseScreen,
+ notify: h.notifyScreenStateChanged,
+ }
+
+ h.appMu.Lock()
+ state := &timedSimAppState{
+ generation: len(h.apps) + 1,
+ app: app,
+ screen: screen,
+ }
+ h.apps = append(h.apps, app)
+ h.current = state
+ h.appMu.Unlock()
+
+ app.SetScreen(screen)
+ h.notifyScreenStateChanged()
+ return app
+ }
+
+ h.injectWG.Add(1)
+ go h.run(keys)
+
+ return h
+}
+
+func (h *timedSimHarness) notifyScreenStateChanged() {
+ select {
+ case h.screenStateCh <- struct{}{}:
+ default:
+ }
+}
+
+func (h *timedSimHarness) markRunCompleted() {
+ if h == nil {
+ return
+ }
+ if h.runCompleted == nil {
+ return
+ }
+ h.closeRunOnce.Do(func() {
+ close(h.runCompleted)
+ })
+}
+
+func (h *timedSimHarness) stop() {
+ if h == nil {
+ return
}
- baseScreen.SetSize(120, 40)
+ h.closeDoneOnce.Do(func() {
+ close(h.done)
+ })
+ h.StopAll()
+ h.injectWG.Wait()
+}
- type timedSimScreenState struct {
- signature string
- text string
+func (h *timedSimHarness) StopAll() {
+ if h == nil {
+ return
}
+ h.appMu.RLock()
+ apps := append([]*tui.App(nil), h.apps...)
+ h.appMu.RUnlock()
+ for i := len(apps) - 1; i >= 0; i-- {
+ apps[i].Stop()
+ }
+}
+
+func (h *timedSimHarness) run(keys []timedSimKey) {
+ defer h.injectWG.Done()
- screenStateCh := make(chan struct{}, 1)
- var appMu sync.RWMutex
- var currentApp *tui.App
- screen := ¬ifyingSimulationScreen{
- SimulationScreen: baseScreen,
- notify: func() {
- select {
- case screenStateCh <- struct{}{}:
- default:
+ generation := 0
+ for idx, key := range keys {
+ minGeneration := generation
+ if minGeneration == 0 || key.RequireNewApp {
+ minGeneration++
+ }
+
+ state, ok := h.waitForScreenText(idx, key, minGeneration)
+ if !ok {
+ return
+ }
+ generation = state.generation
+ if key.Wait > 0 && strings.TrimSpace(key.WaitForText) == "" {
+ if !h.sleepOrDone(key.Wait) {
+ return
}
- },
+ }
+
+ settle := key.SettleAfterMatch
+ if settle <= 0 {
+ settle = timedSimDefaultSettle
+ }
+ if !h.sleepOrDone(settle) {
+ return
+ }
+
+ mod := key.Mod
+ if mod == 0 {
+ mod = tcell.ModNone
+ }
+ state.screen.InjectKey(key.Key, key.R, mod)
+ if !h.sleepOrDone(timedSimKeyDelay) {
+ return
+ }
}
- var once sync.Once
- newTUIApp = func() *tui.App {
- app := tui.NewApp()
- appMu.Lock()
- currentApp = app
- appMu.Unlock()
- app.SetScreen(screen)
+ timer := time.NewTimer(timedSimCompletionTimeout)
+ defer timer.Stop()
+ select {
+ case <-h.runCompleted:
+ case <-h.done:
+ case <-timer.C:
+ h.t.Errorf("TUI simulation did not finish within %s after injecting %d key(s)\n%s", timedSimCompletionTimeout, len(keys), h.describeCurrentState())
+ h.StopAll()
+ }
+}
- once.Do(func() {
- injectWG.Add(1)
- go func() {
- defer injectWG.Done()
- var lastInjectedState string
-
- currentScreenState := func() timedSimScreenState {
- appMu.RLock()
- app := currentApp
- appMu.RUnlock()
-
- var focus any
- if app != nil {
- focus = app.GetFocus()
- }
-
- return timedSimScreenState{
- signature: timedSimScreenStateSignature(screen, focus),
- text: timedSimScreenText(screen),
- }
- }
-
- waitForScreenText := func(expected string) bool {
- expected = strings.TrimSpace(expected)
- for {
- current := currentScreenState()
- if current.signature != "" {
- if (expected == "" || strings.Contains(current.text, expected)) &&
- (lastInjectedState == "" || current.signature != lastInjectedState) {
- return true
- }
- }
-
- select {
- case <-done:
- return false
- case <-screenStateCh:
- }
- }
- }
-
- for _, k := range keys {
- if k.Wait > 0 {
- if !waitForScreenText(k.WaitForText) {
- return
- }
- }
- current := currentScreenState()
- mod := k.Mod
- if mod == 0 {
- mod = tcell.ModNone
- }
- select {
- case <-done:
- return
- default:
- }
- screen.InjectKey(k.Key, k.R, mod)
- lastInjectedState = current.signature
- }
- }()
- })
+func (h *timedSimHarness) waitForScreenText(index int, key timedSimKey, minGeneration int) (timedSimScreenState, bool) {
+ expected := strings.TrimSpace(key.WaitForText)
+ timeout := timedSimScreenWaitTimeout
+ if key.Wait > 0 {
+ timeout = key.Wait
+ }
+ timer := time.NewTimer(timeout)
+ defer timer.Stop()
- return app
+ for {
+ state := h.currentScreenState()
+ if state.ready && state.generation >= minGeneration && (expected == "" || strings.Contains(state.text, expected)) {
+ return state, true
+ }
+
+ select {
+ case <-h.done:
+ return timedSimScreenState{}, false
+ case <-h.screenStateCh:
+ case <-timer.C:
+ h.t.Errorf(
+ "TUI simulation timed out at action %d waiting for text %q within %s (min generation=%d, current generation=%d, focus=%s)\nCurrent screen:\n%s",
+ index,
+ expected,
+ timeout,
+ minGeneration,
+ state.generation,
+ state.focusType,
+ state.text,
+ )
+ h.StopAll()
+ return state, false
+ }
}
}
-func timedSimScreenStateSignature(screen tcell.SimulationScreen, focus any) string {
- cells, width, height := screen.GetContents()
- cursorX, cursorY, cursorVisible := screen.GetCursor()
+func (h *timedSimHarness) currentScreenState() timedSimScreenState {
+ h.appMu.RLock()
+ current := h.current
+ h.appMu.RUnlock()
+ if current == nil || current.screen == nil {
+ return timedSimScreenState{}
+ }
- sum := sha256.New()
- fmt.Fprintf(sum, "size:%d:%d cursor:%d:%d:%t focus:%T:%p\n", width, height, cursorX, cursorY, cursorVisible, focus, focus)
- for _, cell := range cells {
- fg, bg, attr := cell.Style.Decompose()
- fmt.Fprintf(sum, "%x/%d/%d/%d;", cell.Bytes, fg, bg, attr)
+ focusType := ""
+ if current.app != nil {
+ if focus := current.app.GetFocus(); focus != nil {
+ focusType = fmt.Sprintf("%T", focus)
+ }
+ }
+ snapshot := current.screen.snapshotState()
+ return timedSimScreenState{
+ generation: current.generation,
+ text: timedSimScreenText(snapshot),
+ focusType: focusType,
+ ready: snapshot.ready,
+ screen: current.screen,
}
- return hex.EncodeToString(sum.Sum(nil))
}
-func timedSimScreenText(screen tcell.SimulationScreen) string {
- cells, width, height := screen.GetContents()
- if width <= 0 || height <= 0 || len(cells) < width*height {
+func (h *timedSimHarness) describeCurrentState() string {
+ state := h.currentScreenState()
+ return fmt.Sprintf("current generation=%d focus=%s ready=%t\nCurrent screen:\n%s", state.generation, state.focusType, state.ready, state.text)
+}
+
+func (h *timedSimHarness) sleepOrDone(d time.Duration) bool {
+ if d <= 0 {
+ return true
+ }
+ timer := time.NewTimer(d)
+ defer timer.Stop()
+ select {
+ case <-h.done:
+ return false
+ case <-timer.C:
+ return true
+ }
+}
+
+func timedSimScreenText(snapshot timedSimScreenSnapshot) string {
+ if !snapshot.ready || snapshot.width <= 0 || snapshot.height <= 0 || len(snapshot.cells) < snapshot.width*snapshot.height {
return ""
}
var b strings.Builder
- for y := 0; y < height; y++ {
- row := make([]byte, 0, width)
- for x := 0; x < width; x++ {
- cell := cells[y*width+x]
+ for y := 0; y < snapshot.height; y++ {
+ row := make([]byte, 0, snapshot.width)
+ for x := 0; x < snapshot.width; x++ {
+ cell := snapshot.cells[y*snapshot.width+x]
if len(cell.Bytes) == 0 {
row = append(row, ' ')
continue
@@ -312,24 +508,25 @@ func createDecryptTUIEncryptedFixture(t *testing.T) *decryptTUIFixture {
func successDecryptTUISequence(secret string) []timedSimKey {
keys := []timedSimKey{
- {Key: tcell.KeyEnter, Wait: 1 * time.Second, WaitForText: "Select backup source"},
- {Key: tcell.KeyEnter, Wait: 750 * time.Millisecond, WaitForText: "Select backup"},
+ {Key: tcell.KeyEnter, WaitForText: "Select backup source", RequireNewApp: true},
+ {Key: tcell.KeyEnter, WaitForText: "Select backup", RequireNewApp: true},
}
- for _, r := range secret {
+ for idx, r := range secret {
keys = append(keys, timedSimKey{
- Key: tcell.KeyRune,
- R: r,
- Wait: 35 * time.Millisecond,
- WaitForText: "Decrypt key",
+ Key: tcell.KeyRune,
+ R: r,
+ WaitForText: "Decrypt key",
+ RequireNewApp: idx == 0,
+ SettleAfterMatch: 5 * time.Millisecond,
})
}
keys = append(keys,
- timedSimKey{Key: tcell.KeyTab, Wait: 150 * time.Millisecond, WaitForText: "Decrypt key"},
- timedSimKey{Key: tcell.KeyEnter, Wait: 100 * time.Millisecond, WaitForText: "Decrypt key"},
- timedSimKey{Key: tcell.KeyTab, Wait: 500 * time.Millisecond, WaitForText: "Destination directory"},
- timedSimKey{Key: tcell.KeyEnter, Wait: 100 * time.Millisecond, WaitForText: "Destination directory"},
+ timedSimKey{Key: tcell.KeyTab, WaitForText: "Decrypt key"},
+ timedSimKey{Key: tcell.KeyEnter, WaitForText: "Decrypt key"},
+ timedSimKey{Key: tcell.KeyTab, WaitForText: "Destination directory", RequireNewApp: true},
+ timedSimKey{Key: tcell.KeyEnter, WaitForText: "Destination directory"},
)
return keys
@@ -337,23 +534,30 @@ func successDecryptTUISequence(secret string) []timedSimKey {
func abortDecryptTUISequence() []timedSimKey {
return []timedSimKey{
- {Key: tcell.KeyEnter, Wait: 1 * time.Second, WaitForText: "Select backup source"},
- {Key: tcell.KeyEnter, Wait: 750 * time.Millisecond, WaitForText: "Select backup"},
- {Key: tcell.KeyRune, R: '0', Wait: 500 * time.Millisecond, WaitForText: "Decrypt key"},
- {Key: tcell.KeyTab, Wait: 150 * time.Millisecond, WaitForText: "Decrypt key"},
- {Key: tcell.KeyEnter, Wait: 100 * time.Millisecond, WaitForText: "Decrypt key"},
+ {Key: tcell.KeyEnter, WaitForText: "Select backup source", RequireNewApp: true},
+ {Key: tcell.KeyEnter, WaitForText: "Select backup", RequireNewApp: true},
+ {Key: tcell.KeyRune, R: '0', WaitForText: "Decrypt key", RequireNewApp: true},
+ {Key: tcell.KeyTab, WaitForText: "Decrypt key"},
+ {Key: tcell.KeyEnter, WaitForText: "Decrypt key"},
}
}
-func runDecryptWorkflowTUIForTest(t *testing.T, ctx context.Context, cfg *config.Config, configPath string) error {
+func runDecryptWorkflowTUIForTest(t *testing.T, sim *timedSimHarness, ctx context.Context, cfg *config.Config, configPath string) error {
t.Helper()
+ runCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
logger := logging.New(types.LogLevelError, false)
logger.SetOutput(io.Discard)
errCh := make(chan error, 1)
go func() {
- errCh <- RunDecryptWorkflowTUI(ctx, cfg, logger, "1.0.0", configPath, "test-build")
+ err := RunDecryptWorkflowTUI(runCtx, cfg, logger, "1.0.0", configPath, "test-build")
+ if sim != nil {
+ sim.markRunCompleted()
+ }
+ errCh <- err
}()
waitTimeout := 30 * time.Second
@@ -370,10 +574,29 @@ func runDecryptWorkflowTUIForTest(t *testing.T, ctx context.Context, cfg *config
case err := <-errCh:
return err
case <-timer.C:
- if err := ctx.Err(); err != nil {
+ cancel()
+ if sim != nil {
+ sim.StopAll()
+ }
+
+ shutdownTimer := time.NewTimer(2 * time.Second)
+ defer shutdownTimer.Stop()
+ select {
+ case err := <-errCh:
+ return err
+ case <-shutdownTimer.C:
+ }
+
+ if err := runCtx.Err(); err != nil {
+ if sim != nil {
+ t.Fatalf("RunDecryptWorkflowTUI did not return within %s (context state: %v)\n%s", waitTimeout, err, sim.describeCurrentState())
+ }
t.Fatalf("RunDecryptWorkflowTUI did not return within %s (context state: %v)", waitTimeout, err)
return nil
}
+ if sim != nil {
+ t.Fatalf("RunDecryptWorkflowTUI did not return within %s\n%s", waitTimeout, sim.describeCurrentState())
+ }
t.Fatalf("RunDecryptWorkflowTUI did not return within %s", waitTimeout)
return nil
}
diff --git a/internal/orchestrator/decrypt_tui_e2e_test.go b/internal/orchestrator/decrypt_tui_e2e_test.go
index 925b81d0..bea6a525 100644
--- a/internal/orchestrator/decrypt_tui_e2e_test.go
+++ b/internal/orchestrator/decrypt_tui_e2e_test.go
@@ -18,12 +18,12 @@ func TestRunDecryptWorkflowTUI_SuccessLocalEncrypted(t *testing.T) {
t.Cleanup(func() { restoreFS = origFS })
fixture := createDecryptTUIEncryptedFixture(t)
- withTimedSimAppSequence(t, successDecryptTUISequence(fixture.Secret))
+ sim := withTimedSimAppSequence(t, successDecryptTUISequence(fixture.Secret))
ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second)
defer cancel()
- if err := runDecryptWorkflowTUIForTest(t, ctx, fixture.Config, fixture.ConfigPath); err != nil {
+ if err := runDecryptWorkflowTUIForTest(t, sim, ctx, fixture.Config, fixture.ConfigPath); err != nil {
t.Fatalf("RunDecryptWorkflowTUI error: %v", err)
}
@@ -79,12 +79,12 @@ func TestRunDecryptWorkflowTUI_AbortAtSecretPrompt(t *testing.T) {
t.Cleanup(func() { restoreFS = origFS })
fixture := createDecryptTUIEncryptedFixture(t)
- withTimedSimAppSequence(t, abortDecryptTUISequence())
+ sim := withTimedSimAppSequence(t, abortDecryptTUISequence())
ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second)
defer cancel()
- err := runDecryptWorkflowTUIForTest(t, ctx, fixture.Config, fixture.ConfigPath)
+ err := runDecryptWorkflowTUIForTest(t, sim, ctx, fixture.Config, fixture.ConfigPath)
if !errors.Is(err, ErrDecryptAborted) {
t.Fatalf("RunDecryptWorkflowTUI error=%v; want %v", err, ErrDecryptAborted)
}
diff --git a/internal/orchestrator/deps.go b/internal/orchestrator/deps.go
index 6530c30b..9a5099a4 100644
--- a/internal/orchestrator/deps.go
+++ b/internal/orchestrator/deps.go
@@ -7,10 +7,12 @@ import (
"io/fs"
"os"
"os/exec"
+ "syscall"
"time"
"github.com/tis24dev/proxsave/internal/config"
"github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/safeexec"
)
// FS abstracts filesystem operations to simplify testing.
@@ -32,6 +34,8 @@ type FS interface {
CreateTemp(dir, pattern string) (*os.File, error)
MkdirTemp(dir, pattern string) (string, error)
Rename(oldpath, newpath string) error
+ Lchown(path string, uid, gid int) error
+ UtimesNano(path string, times []syscall.Timespec) error
}
// Prompter encapsulates interactive prompts.
@@ -93,6 +97,10 @@ func (osFS) CreateTemp(dir, pattern string) (*os.File, error) {
}
func (osFS) MkdirTemp(dir, pattern string) (string, error) { return os.MkdirTemp(dir, pattern) }
func (osFS) Rename(oldpath, newpath string) error { return os.Rename(oldpath, newpath) }
+func (osFS) Lchown(path string, uid, gid int) error { return os.Lchown(path, uid, gid) }
+func (osFS) UtimesNano(path string, times []syscall.Timespec) error {
+ return syscall.UtimesNano(path, times)
+}
type consolePrompter struct{}
@@ -123,7 +131,10 @@ type osCommandRunner struct{}
const defaultCommandWaitDelay = 3 * time.Second
func (osCommandRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) {
- cmd := exec.CommandContext(ctx, name, args...)
+ cmd, err := safeexec.CommandContext(ctx, name, args...)
+ if err != nil {
+ return nil, err
+ }
cmd.WaitDelay = defaultCommandWaitDelay
out, err := cmd.CombinedOutput()
if err != nil && errors.Is(err, exec.ErrWaitDelay) {
@@ -134,7 +145,10 @@ func (osCommandRunner) Run(ctx context.Context, name string, args ...string) ([]
// RunStream returns a stdout pipe for streaming commands that read from stdin.
func (osCommandRunner) RunStream(ctx context.Context, name string, stdin io.Reader, args ...string) (io.ReadCloser, error) {
- cmd := exec.CommandContext(ctx, name, args...)
+ cmd, err := safeexec.CommandContext(ctx, name, args...)
+ if err != nil {
+ return nil, err
+ }
cmd.Stdin = stdin
stdout, err := cmd.StdoutPipe()
if err != nil {
diff --git a/internal/orchestrator/deps_test.go b/internal/orchestrator/deps_test.go
index b2fbfb3e..6154e075 100644
--- a/internal/orchestrator/deps_test.go
+++ b/internal/orchestrator/deps_test.go
@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"strings"
+ "syscall"
"testing"
"time"
@@ -23,6 +24,12 @@ type FakeFS struct {
MkdirAllErr error
MkdirTempErr error
OpenFileErr map[string]error
+ Ownership map[string]FakeOwnership
+}
+
+type FakeOwnership struct {
+ UID int
+ GID int
}
func NewFakeFS() *FakeFS {
@@ -32,6 +39,7 @@ func NewFakeFS() *FakeFS {
StatErr: make(map[string]error),
StatErrors: make(map[string]error),
OpenFileErr: make(map[string]error),
+ Ownership: make(map[string]FakeOwnership),
}
}
@@ -186,6 +194,22 @@ func (f *FakeFS) Rename(oldpath, newpath string) error {
return os.Rename(f.onDisk(oldpath), f.onDisk(newpath))
}
+func (f *FakeFS) Lchown(path string, uid, gid int) error {
+ diskPath := f.onDisk(path)
+ if _, err := os.Lstat(diskPath); err != nil {
+ return err
+ }
+ if f.Ownership == nil {
+ f.Ownership = make(map[string]FakeOwnership)
+ }
+ f.Ownership[diskPath] = FakeOwnership{UID: uid, GID: gid}
+ return nil
+}
+
+func (f *FakeFS) UtimesNano(path string, times []syscall.Timespec) error {
+ return syscall.UtimesNano(f.onDisk(path), times)
+}
+
// FakeTime provides deterministic time.
type FakeTime struct {
Current time.Time
@@ -201,14 +225,16 @@ func (f *FakeTime) Advance(d time.Duration) {
// FakeCommandRunner records invocations and returns predefined outputs/errors.
type FakeCommandRunner struct {
- Outputs map[string][]byte
- Errors map[string]error
- Calls []string
+ Outputs map[string][]byte
+ Errors map[string]error
+ Calls []string
+ Contexts []context.Context
}
func (f *FakeCommandRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) {
key := commandKey(name, args)
f.Calls = append(f.Calls, key)
+ f.Contexts = append(f.Contexts, ctx)
var out []byte
if f.Outputs != nil {
out = f.Outputs[key]
@@ -237,6 +263,16 @@ func commandKey(name string, args []string) string {
return fmt.Sprintf("%s %s", name, strings.Join(args, " "))
}
+func backgroundRollbackCallKey(timeoutSeconds int, scriptPath string) string {
+ return commandKey("sh", []string{
+ "-c",
+ backgroundRollbackCommand,
+ "proxsave-rollback",
+ fmt.Sprintf("%d", timeoutSeconds),
+ scriptPath,
+ })
+}
+
// FakePrompter simulates user choices.
type FakePrompter struct {
Mode RestoreMode
diff --git a/internal/orchestrator/guards_cleanup_test.go b/internal/orchestrator/guards_cleanup_test.go
index 7e4e9684..c8b0eb8b 100644
--- a/internal/orchestrator/guards_cleanup_test.go
+++ b/internal/orchestrator/guards_cleanup_test.go
@@ -11,8 +11,6 @@ import (
)
func TestGuardMountpointsFromMountinfo_VisibleAndHidden(t *testing.T) {
- t.Parallel()
-
mountinfo := strings.Join([]string{
"10 1 0:1 " + mountGuardBaseDir + "/g1 /mnt/visible rw - ext4 /dev/sda1 rw",
"20 1 0:1 " + mountGuardBaseDir + "/g2 /mnt/hidden rw - ext4 /dev/sda1 rw",
@@ -32,8 +30,6 @@ func TestGuardMountpointsFromMountinfo_VisibleAndHidden(t *testing.T) {
}
func TestGuardMountpointsFromMountinfo_UnescapesMountpoint(t *testing.T) {
- t.Parallel()
-
mountinfo := "10 1 0:1 " + mountGuardBaseDir + "/g1 /mnt/with\\040space rw - ext4 /dev/sda1 rw\n"
visible, hidden, mounts := guardMountpointsFromMountinfo(mountinfo)
if mounts != 1 {
diff --git a/internal/orchestrator/mount_guard_more_test.go b/internal/orchestrator/mount_guard_more_test.go
index 41109084..67719ac7 100644
--- a/internal/orchestrator/mount_guard_more_test.go
+++ b/internal/orchestrator/mount_guard_more_test.go
@@ -14,8 +14,6 @@ import (
)
func TestGuardDirForTarget(t *testing.T) {
- t.Parallel()
-
target := "/mnt/datastore"
sum := sha256.Sum256([]byte(target))
id := fmt.Sprintf("%x", sum[:8])
@@ -34,8 +32,6 @@ func TestGuardDirForTarget(t *testing.T) {
}
func TestIsMountedFromMountinfo(t *testing.T) {
- t.Parallel()
-
mountinfo := strings.Join([]string{
"36 25 0:32 / / rw,relatime - ext4 /dev/sda1 rw",
`37 36 0:33 / /mnt/pbs\040datastore rw,relatime - ext4 /dev/sdb1 rw`,
@@ -98,8 +94,6 @@ func TestFstabMountpointsSet_Error(t *testing.T) {
}
func TestSplitPathAndMountRootWithPrefix(t *testing.T) {
- t.Parallel()
-
if got := splitPath("a//b/ /c/"); strings.Join(got, ",") != "a,b,c" {
t.Fatalf("splitPath unexpected: %#v", got)
}
@@ -112,8 +106,6 @@ func TestSplitPathAndMountRootWithPrefix(t *testing.T) {
}
func TestSortByLengthDesc(t *testing.T) {
- t.Parallel()
-
items := []string{"a", "abc", "ab"}
sortByLengthDesc(items)
if len(items) != 3 {
@@ -125,8 +117,6 @@ func TestSortByLengthDesc(t *testing.T) {
}
func TestFirstFstabMountpointMatch(t *testing.T) {
- t.Parallel()
-
mountpoints := []string{"/mnt/storage/pbs", "/mnt/storage", "/"}
if got := firstFstabMountpointMatch("/mnt/storage/pbs/ds1/data", mountpoints); got != "/mnt/storage/pbs" {
t.Fatalf("firstFstabMountpointMatch got %q want %q", got, "/mnt/storage/pbs")
diff --git a/internal/orchestrator/network_apply.go b/internal/orchestrator/network_apply.go
index faa76b18..ca5a6f7d 100644
--- a/internal/orchestrator/network_apply.go
+++ b/internal/orchestrator/network_apply.go
@@ -295,8 +295,8 @@ func armNetworkRollback(ctx context.Context, logger *logging.Logger, backupPath
if handle.unitName == "" {
logging.DebugStep(logger, "arm network rollback", "Arm timer via background sleep (%ds)", timeoutSeconds)
- cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", timeoutSeconds, handle.scriptPath)
- if output, err := restoreCmd.Run(ctx, "sh", "-c", cmd); err != nil {
+ output, err := runBackgroundRollbackTimer(ctx, timeoutSeconds, handle.scriptPath)
+ if err != nil {
logger.Debug("Background rollback output: %s", strings.TrimSpace(string(output)))
return nil, fmt.Errorf("failed to arm rollback timer: %w", err)
}
@@ -794,6 +794,15 @@ func shellQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'"
}
+const backgroundRollbackCommand = `nohup sh -c 'sleep "$1"; /bin/sh "$2"' proxsave-rollback-worker "$1" "$2" >/dev/null 2>&1 &`
+
+func runBackgroundRollbackTimer(ctx context.Context, timeoutSeconds int, scriptPath string) ([]byte, error) {
+ if timeoutSeconds < 1 {
+ timeoutSeconds = 1
+ }
+ return restoreCmd.Run(ctx, "sh", "-c", backgroundRollbackCommand, "proxsave-rollback", fmt.Sprintf("%d", timeoutSeconds), scriptPath)
+}
+
func commandAvailable(name string) bool {
_, err := exec.LookPath(name)
return err == nil
diff --git a/internal/orchestrator/network_apply_additional_test.go b/internal/orchestrator/network_apply_additional_test.go
index b3e27803..2d939b5f 100644
--- a/internal/orchestrator/network_apply_additional_test.go
+++ b/internal/orchestrator/network_apply_additional_test.go
@@ -606,11 +606,12 @@ func TestArmNetworkRollback_SystemdRunFailureFallsBackToNohup(t *testing.T) {
foundSystemdRun := false
foundFallback := false
+ wantFallback := backgroundRollbackCallKey(30, handle.scriptPath)
for _, call := range fakeCmd.CallsList() {
if strings.HasPrefix(call, "systemd-run ") {
foundSystemdRun = true
}
- if strings.HasPrefix(call, "sh -c nohup sh -c 'sleep ") {
+ if call == wantFallback {
foundFallback = true
}
}
@@ -653,8 +654,9 @@ func TestArmNetworkRollback_WithoutSystemdRunUsesNohup(t *testing.T) {
}
foundFallback := false
+ wantFallback := backgroundRollbackCallKey(1, handle.scriptPath)
for _, call := range fakeCmd.CallsList() {
- if strings.HasPrefix(call, "sh -c nohup sh -c 'sleep ") {
+ if call == wantFallback {
foundFallback = true
}
}
@@ -693,8 +695,9 @@ func TestArmNetworkRollback_SubSecondTimeoutArmsAtLeastOneSecond(t *testing.T) {
}
foundSleep1 := false
+ wantFallback := backgroundRollbackCallKey(1, handle.scriptPath)
for _, call := range fakeCmd.CallsList() {
- if strings.Contains(call, "sleep 1;") {
+ if call == wantFallback {
foundSleep1 = true
}
}
@@ -723,7 +726,7 @@ func TestArmNetworkRollback_FallbackCommandFailureReturnsError(t *testing.T) {
restoreCmd = &FakeCommandRunner{
Errors: map[string]error{
- "sh -c nohup sh -c 'sleep 1; /bin/sh /tmp/proxsave/network_rollback_20260201_123456.sh' >/dev/null 2>&1 &": errors.New("boom"),
+ backgroundRollbackCallKey(1, "/tmp/proxsave/network_rollback_20260201_123456.sh"): errors.New("boom"),
},
}
@@ -733,6 +736,28 @@ func TestArmNetworkRollback_FallbackCommandFailureReturnsError(t *testing.T) {
}
}
+func TestRunBackgroundRollbackTimer_UsesPositionalArgsForScriptPath(t *testing.T) {
+ origCmd := restoreCmd
+ t.Cleanup(func() { restoreCmd = origCmd })
+
+ fakeCmd := &FakeCommandRunner{}
+ restoreCmd = fakeCmd
+
+ scriptPath := "/tmp/proxsave dir/rollback's ; touch /tmp/proxsave-injected.sh"
+ if _, err := runBackgroundRollbackTimer(context.Background(), 2, scriptPath); err != nil {
+ t.Fatalf("runBackgroundRollbackTimer error: %v", err)
+ }
+
+ want := backgroundRollbackCallKey(2, scriptPath)
+ calls := fakeCmd.CallsList()
+ if len(calls) != 1 || calls[0] != want {
+ t.Fatalf("unexpected calls: %#v", calls)
+ }
+ if strings.Contains(backgroundRollbackCommand, scriptPath) {
+ t.Fatalf("rollback script path must not be interpolated into shell command")
+ }
+}
+
func TestDisarmNetworkRollback_RemovesMarkerAndStopsTimer(t *testing.T) {
origFS := restoreFS
origCmd := restoreCmd
diff --git a/internal/orchestrator/nic_mapping.go b/internal/orchestrator/nic_mapping.go
index f2a0372c..b77dc273 100644
--- a/internal/orchestrator/nic_mapping.go
+++ b/internal/orchestrator/nic_mapping.go
@@ -312,7 +312,7 @@ func loadBackupNetworkInventoryFromArchive(ctx context.Context, archivePath stri
return &inv, used, nil
}
-func readArchiveEntry(ctx context.Context, archivePath string, candidates []string, maxBytes int64) ([]byte, string, error) {
+func readArchiveEntry(ctx context.Context, archivePath string, candidates []string, maxBytes int64) (data []byte, used string, err error) {
file, err := restoreFS.Open(archivePath)
if err != nil {
return nil, "", err
@@ -323,9 +323,7 @@ func readArchiveEntry(ctx context.Context, archivePath string, candidates []stri
if err != nil {
return nil, "", err
}
- if closer, ok := reader.(io.Closer); ok {
- defer closer.Close()
- }
+ defer closeDecompressionReader(reader, &err, "close decompression reader")
tr := tar.NewReader(reader)
diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go
index 2107ddef..de9a9548 100644
--- a/internal/orchestrator/orchestrator.go
+++ b/internal/orchestrator/orchestrator.go
@@ -4,7 +4,6 @@ import (
"archive/tar"
"context"
"encoding/json"
- "errors"
"fmt"
"io"
"os"
@@ -504,528 +503,54 @@ func (o *Orchestrator) ensureTempRegistry() *TempDirRegistry {
// RunGoBackup performs the entire backup using Go components (collector + archiver)
func (o *Orchestrator) RunGoBackup(ctx context.Context, envInfo *environment.EnvironmentInfo, hostname string) (stats *BackupStats, err error) {
- if envInfo == nil {
- envInfo = o.envInfo
- } else {
- o.SetEnvironmentInfo(envInfo)
- }
- pType := types.ProxmoxUnknown
- if envInfo != nil {
- pType = envInfo.Type
- }
- done := logging.DebugStart(o.logger, "backup run", "type=%s hostname=%s", pType, hostname)
+ run := o.newBackupRunContext(ctx, envInfo, hostname)
+ done := logging.DebugStart(o.logger, "backup run", "type=%s hostname=%s", run.proxmoxType, hostname)
defer func() { done(err) }()
- o.logger.Info("Starting Go-based backup orchestration for %s", pType)
-
- // Unified cleanup of previous execution artifacts
- registry := o.cleanupPreviousExecutionArtifacts()
- fs := o.filesystem()
+ o.logger.Info("Starting Go-based backup orchestration for %s", run.proxmoxType)
- startTime := o.startTime
- if startTime.IsZero() {
- startTime = o.now()
- o.startTime = startTime
+ workspace := &backupWorkspace{
+ registry: o.cleanupPreviousExecutionArtifacts(),
+ fs: o.filesystem(),
}
- normalizedLevel := normalizeCompressionLevel(o.compressionType, o.compressionLevel)
-
- fmt.Println()
- o.logStep(1, "Initializing backup statistics and temporary workspace")
- stats = InitializeBackupStats(
- hostname,
- envInfo,
- o.version,
- startTime,
- o.cfg,
- o.compressionType,
- o.compressionMode,
- normalizedLevel,
- o.compressionThreads,
- o.backupPath,
- o.serverID,
- o.serverMAC,
- )
- // Get log file path from logger (more reliable than env var)
- if logFile := o.logger.GetLogFilePath(); logFile != "" {
- stats.LogFilePath = logFile
- }
-
- // Propagate version update information (if any) into stats so that
- // downstream notification adapters can include it in their payloads.
- if o.versionUpdateAvailable || o.updateCurrentVersion != "" || o.updateLatestVersion != "" {
- stats.NewVersionAvailable = o.versionUpdateAvailable
- stats.CurrentVersion = o.updateCurrentVersion
- stats.LatestVersion = o.updateLatestVersion
- }
-
- metricsStats := stats
+ stats = o.initBackupRun(run)
defer func() {
- if metricsStats == nil || o.cfg == nil || !o.cfg.MetricsEnabled || o.dryRun {
- return
- }
-
- if metricsStats.EndTime.IsZero() {
- metricsStats.EndTime = o.now()
- }
- if metricsStats.Duration == 0 && !metricsStats.StartTime.IsZero() {
- metricsStats.Duration = metricsStats.EndTime.Sub(metricsStats.StartTime)
- }
-
- if err != nil {
- var backupErr *BackupError
- if errors.As(err, &backupErr) {
- metricsStats.ExitCode = backupErr.Code.Int()
- } else {
- metricsStats.ExitCode = types.ExitGenericError.Int()
- }
- } else if metricsStats.ExitCode == 0 {
- metricsStats.ExitCode = types.ExitSuccess.Int()
- }
-
- if m := metricsStats.toPrometheusMetrics(); m != nil {
- exporter := metrics.NewPrometheusExporter(o.cfg.MetricsPath, o.logger)
- if exportErr := exporter.Export(m); exportErr != nil {
- o.logger.Warning("Failed to export Prometheus metrics: %v", exportErr)
- }
- }
+ o.exportBackupMetrics(run, err)
}()
-
- // Ensure that, in case of failure, we still perform log parsing,
- // derive an exit code and dispatch notifications/log rotation.
defer func() {
- if err == nil || stats == nil {
- return
- }
-
- // Ensure end time and duration are set
- if stats.EndTime.IsZero() {
- stats.EndTime = o.now()
- }
- if stats.Duration == 0 && !stats.StartTime.IsZero() {
- stats.Duration = stats.EndTime.Sub(stats.StartTime)
- }
-
- // Parse log file to populate error/warning counts
- if stats.LogFilePath != "" {
- o.logger.Debug("Parsing log file for error/warning counts after failure: %s", stats.LogFilePath)
- _, errorCount, warningCount := ParseLogCounts(stats.LogFilePath, 0)
- stats.ErrorCount = errorCount
- stats.WarningCount = warningCount
- if errorCount > 0 || warningCount > 0 {
- o.logger.Debug("Found %d errors and %d warnings in log file (failure path)", errorCount, warningCount)
- }
- } else {
- o.logger.Debug("No log file path specified, error/warning counts will be 0 (failure path)")
- }
-
- // Derive exit code from the error when possible
- var backupErr *BackupError
- if errors.As(err, &backupErr) {
- stats.ExitCode = backupErr.Code.Int()
- } else {
- stats.ExitCode = types.ExitBackupError.Int()
- }
-
+ o.finalizeFailedBackupStats(run, err)
}()
- o.logger.Debug("Creating temporary directory for collection output")
- // Create temporary directory for collection (outside backup path)
- // Note: /tmp/proxsave is validated in pre-backup checks (CheckTempDirectory)
- // This MkdirAll is a fallback for cases where pre-checks don't run
- timestampStr := startTime.Format("20060102-150405")
- tempRoot := filepath.Join("/tmp", "proxsave")
- if err := fs.MkdirAll(tempRoot, 0o755); err != nil {
- return nil, fmt.Errorf("Temp directory creation failed - path: %s: %w", tempRoot, err)
- }
- tempDir, err := fs.MkdirTemp(tempRoot, fmt.Sprintf("proxsave-%s-%s-", hostname, timestampStr))
- if err != nil {
- return nil, fmt.Errorf("failed to create temporary directory: %w", err)
- }
- if o.dryRun {
- o.logger.Info("[DRY RUN] Temporary directory would be: %s", tempDir)
- } else {
- o.logger.Debug("Using temporary directory: %s", tempDir)
+ if err := o.prepareBackupWorkspace(run, workspace); err != nil {
+ return stats, err
}
defer func() {
- if registry == nil {
- if cleanupErr := fs.RemoveAll(tempDir); cleanupErr != nil {
- o.logger.Warning("Failed to remove temp directory %s: %v", tempDir, cleanupErr)
- }
- return
- }
- o.logger.Debug("Temporary workspace preserved at %s (will be removed at the next startup)", tempDir)
+ o.cleanupBackupWorkspace(workspace)
}()
-
- // Create marker file for parity with Bash cleanup guarantees
- markerPath := filepath.Join(tempDir, ".proxsave-marker")
- markerContent := fmt.Sprintf(
- "Created by PID %d on %s UTC\n",
- os.Getpid(),
- o.now().UTC().Format("2006-01-02 15:04:05"),
- )
- if err := fs.WriteFile(markerPath, []byte(markerContent), 0600); err != nil {
+ if err := o.markBackupWorkspace(workspace); err != nil {
return stats, fmt.Errorf("failed to create temp marker file: %w", err)
}
+ o.registerBackupWorkspace(workspace)
- if registry != nil {
- if err := registry.Register(tempDir); err != nil {
- o.logger.Debug("Failed to register temp directory %s: %v", tempDir, err)
- }
- }
-
- // Step 1: Collect configuration files
- fmt.Println()
- o.logStep(2, "Collection of configuration files and optimizations")
- o.logger.Info("Collecting configuration files...")
- o.logger.Debug("Collector dry-run=%v excludePatterns=%d", o.dryRun, len(o.excludePatterns))
- collectorConfig := backup.GetDefaultCollectorConfig()
- collectorConfig.ExcludePatterns = append([]string(nil), o.excludePatterns...)
- if o.cfg != nil {
- applyCollectorOverrides(collectorConfig, o.cfg)
- if len(o.cfg.BackupBlacklist) > 0 {
- collectorConfig.ExcludePatterns = append(collectorConfig.ExcludePatterns, o.cfg.BackupBlacklist...)
- }
- }
-
- if err := collectorConfig.Validate(); err != nil {
- return stats, &BackupError{
- Phase: "config",
- Err: err,
- Code: types.ExitConfigError,
- }
- }
-
- collector := backup.NewCollectorWithDeps(o.logger, collectorConfig, tempDir, pType, o.dryRun, o.collectorDeps())
-
- o.logger.Debug("Starting collector run (type=%s)", pType)
- if err := collector.CollectAll(ctx); err != nil {
- // Return collection-specific error
- return stats, &BackupError{
- Phase: "collection",
- Err: err,
- Code: types.ExitCollectionError,
- }
- }
-
- // Get collection statistics
- collStats := collector.GetStats()
- stats.FilesCollected = int(collStats.FilesProcessed)
- stats.FilesFailed = int(collStats.FilesFailed)
- stats.FilesNotFound = int(collStats.FilesNotFound)
- stats.DirsCreated = int(collStats.DirsCreated)
- stats.BytesCollected = collStats.BytesCollected
- stats.FilesIncluded = int(collStats.FilesProcessed)
- stats.FilesMissing = int(collStats.FilesNotFound)
- stats.UncompressedSize = collStats.BytesCollected
- if stats.ProxmoxType.SupportsPVE() {
- if collector.IsClusteredPVE() {
- stats.ClusterMode = "cluster"
- } else {
- stats.ClusterMode = "standalone"
- }
- }
-
- if err := o.writeBackupMetadata(tempDir, stats); err != nil {
- o.logger.Debug("Failed to write backup metadata: %v", err)
- }
-
- // Write backup manifest with file status details
- if err := collector.WriteManifest(hostname); err != nil {
- o.logger.Debug("Failed to write backup manifest: %v", err)
- }
-
- o.logger.Info("Collection completed: %d files (%s), %d failed, %d dirs created",
- collStats.FilesProcessed,
- backup.FormatBytes(collStats.BytesCollected),
- collStats.FilesFailed,
- collStats.DirsCreated)
-
- // Additional disk space check using estimated size and safety factor
- if o.checker != nil && stats.BytesCollected > 0 {
- o.logger.Debug("Running disk-space validation for estimated data size")
- estimatedSizeGB := float64(stats.BytesCollected) / (1024.0 * 1024.0 * 1024.0)
- // Ensure we always reserve at least a small amount
- if estimatedSizeGB < 0.001 {
- estimatedSizeGB = 0.001
- }
- result := o.checker.CheckDiskSpaceForEstimate(estimatedSizeGB)
- if result.Passed {
- o.logger.Debug("Disk check passed: %s", result.Message)
- } else {
- errMsg := result.Message
- diskErr := result.Error
- if errMsg == "" && diskErr != nil {
- errMsg = diskErr.Error()
- }
- if errMsg == "" {
- errMsg = "insufficient disk space"
- }
- if diskErr == nil {
- diskErr = errors.New(errMsg)
- }
- return stats, &BackupError{
- Phase: "disk",
- Err: fmt.Errorf("disk space validation failed: %w", diskErr),
- Code: types.ExitDiskSpaceError,
- }
- }
- }
-
- if o.optimizationCfg.Enabled() {
- fmt.Println()
- o.logger.Step("Backup optimizations on collected data")
- if err := backup.ApplyOptimizations(ctx, o.logger, tempDir, o.optimizationCfg); err != nil {
- o.logger.Warning("Backup optimizations completed with warnings: %v", err)
- }
- } else {
- o.logger.Debug("Skipping optimization step (all features disabled)")
+ if err := o.collectBackupData(run, workspace); err != nil {
+ return stats, err
}
-
- // Step 2: Create archive
- fmt.Println()
- o.logStep(3, "Creation of compressed archive")
- o.logger.Info("Creating compressed archive...")
- o.logger.Debug("Archiver configuration: type=%s level=%d mode=%s threads=%d",
- o.compressionType, normalizedLevel, o.compressionMode, o.compressionThreads)
-
- // Generate archive filename
- archiveBasename := fmt.Sprintf("%s-backup-%s", hostname, timestampStr)
-
- ageRecipients, err := o.prepareAgeRecipients(ctx)
+ artifacts, err := o.createBackupArchive(run, workspace)
if err != nil {
- return stats, &BackupError{
- Phase: "config",
- Err: err,
- Code: types.ExitConfigError,
- }
+ return stats, err
}
-
- archiverConfig := BuildArchiverConfig(
- o.compressionType,
- normalizedLevel,
- o.compressionThreads,
- o.compressionMode,
- o.dryRun,
- o.cfg != nil && o.cfg.EncryptArchive,
- ageRecipients,
- collectorConfig.ExcludePatterns,
- )
-
- if err := archiverConfig.Validate(); err != nil {
- return stats, &BackupError{
- Phase: "config",
- Err: err,
- Code: types.ExitConfigError,
- }
+ if err := o.verifyAndWriteBackupArtifacts(run, workspace, artifacts); err != nil {
+ return stats, err
}
-
- archiver := backup.NewArchiver(o.logger, archiverConfig)
- effectiveCompression := archiver.ResolveCompression()
- stats.Compression = effectiveCompression
- stats.CompressionLevel = archiver.CompressionLevel()
- stats.CompressionMode = archiver.CompressionMode()
- stats.CompressionThreads = archiver.CompressionThreads()
- archiveExt := archiver.GetArchiveExtension()
- archivePath := filepath.Join(o.backupPath, archiveBasename+archiveExt)
- if stats.RequestedCompression != stats.Compression {
- o.logger.Info("Using %s compression (requested %s)", stats.Compression, stats.RequestedCompression)
- }
-
- if err := archiver.CreateArchive(ctx, tempDir, archivePath); err != nil {
- phase := "archive"
- code := types.ExitArchiveError
- var compressionErr *backup.CompressionError
- if errors.As(err, &compressionErr) {
- phase = "compression"
- code = types.ExitCompressionError
- }
-
- return stats, &BackupError{
- Phase: phase,
- Err: err,
- Code: code,
- }
+ if err := o.bundleBackupArtifacts(run, workspace, artifacts); err != nil {
+ return stats, err
}
-
- stats.ArchivePath = archivePath
- checksumPath := archivePath + ".sha256"
-
- // Get archive size
- if !o.dryRun {
- fmt.Println()
- o.logStep(4, "Verification of archive and metadata generation")
- if size, err := archiver.GetArchiveSize(archivePath); err == nil {
- stats.ArchiveSize = size
- stats.CompressedSize = size
- stats.updateCompressionMetrics()
- o.logger.Debug("Archive created: %s (%s)", archivePath, backup.FormatBytes(size))
- } else {
- o.logger.Warning("Failed to get archive size: %v", err)
- }
-
- // Verify archive (skipped internally when encryption is enabled)
- if err := archiver.VerifyArchive(ctx, archivePath); err != nil {
- // Return verification-specific error
- return stats, &BackupError{
- Phase: "verification",
- Err: err,
- Code: types.ExitVerificationError,
- }
- }
-
- // Generate checksum and manifest for the archive
- checksum, err := backup.GenerateChecksum(ctx, o.logger, archivePath)
- if err != nil {
- return stats, &BackupError{
- Phase: "verification",
- Err: fmt.Errorf("checksum generation failed: %w", err),
- Code: types.ExitVerificationError,
- }
- }
- stats.Checksum = checksum
-
- checksumContent := fmt.Sprintf("%s %s\n", checksum, filepath.Base(archivePath))
- if err := fs.WriteFile(checksumPath, []byte(checksumContent), 0640); err != nil {
- o.logger.Warning("Failed to write checksum file %s: %v", checksumPath, err)
- } else {
- o.logger.Debug("Checksum file written to %s", checksumPath)
- }
-
- manifestPath := archivePath + ".manifest.json"
- manifestCreatedAt := stats.Timestamp
- encryptionMode := "none"
- if o.cfg != nil && o.cfg.EncryptArchive {
- encryptionMode = "age"
- }
- targets := append([]string(nil), stats.ProxmoxTargets...)
- manifest := &backup.Manifest{
- ArchivePath: archivePath,
- ArchiveSize: stats.ArchiveSize,
- SHA256: checksum,
- CreatedAt: manifestCreatedAt,
- CompressionType: string(stats.Compression),
- CompressionLevel: stats.CompressionLevel,
- CompressionMode: stats.CompressionMode,
- ProxmoxType: string(stats.ProxmoxType),
- ProxmoxTargets: targets,
- ProxmoxVersion: stats.ProxmoxVersion,
- PVEVersion: stats.PVEVersion,
- PBSVersion: stats.PBSVersion,
- Hostname: stats.Hostname,
- ScriptVersion: stats.ScriptVersion,
- EncryptionMode: encryptionMode,
- ClusterMode: stats.ClusterMode,
- }
-
- if err := backup.CreateManifest(ctx, o.logger, manifest, manifestPath); err != nil {
- return stats, &BackupError{
- Phase: "verification",
- Err: fmt.Errorf("manifest creation failed: %w", err),
- Code: types.ExitVerificationError,
- }
- }
- stats.ManifestPath = manifestPath
-
- // Maintain Bash-compatible metadata filename for downstream tooling
- metadataAlias := archivePath + ".metadata"
- if err := copyFile(fs, manifestPath, metadataAlias); err != nil {
- o.logger.Warning("Failed to write legacy metadata file %s: %v", metadataAlias, err)
- } else {
- o.logger.Debug("Legacy metadata file written to %s", metadataAlias)
- }
-
- // Create bundle (if requested) before dispatching to other storage targets
- bundleEnabled := o.cfg != nil && o.cfg.BundleAssociatedFiles
- if bundleEnabled {
- fmt.Println()
- o.logStep(5, "Bundling of archive, checksum and metadata")
- o.logger.Debug("Bundling enabled: creating bundle from %s", filepath.Base(archivePath))
- bundlePath, err := o.createBundle(ctx, archivePath)
- if err != nil {
- return stats, &BackupError{
- Phase: "archive",
- Err: fmt.Errorf("bundle creation failed: %w", err),
- Code: types.ExitArchiveError,
- }
- }
-
- if err := o.removeAssociatedFiles(archivePath); err != nil {
- o.logger.Warning("Failed to remove raw files after bundling: %v", err)
- } else {
- o.logger.Debug("Removed raw tar/checksum/metadata after bundling")
- }
-
- if info, err := fs.Stat(bundlePath); err == nil {
- stats.ArchiveSize = info.Size()
- stats.CompressedSize = info.Size()
- stats.updateCompressionMetrics()
- }
- stats.ArchivePath = bundlePath
- stats.ManifestPath = ""
- stats.BundleCreated = true
- archivePath = bundlePath
- o.logger.Debug("Bundle ready: %s", filepath.Base(bundlePath))
- } else {
- fmt.Println()
- o.logger.Skip("Bundling disabled")
- }
-
- stats.EndTime = o.now()
-
- o.logger.Info("✓ Archive created and verified")
- } else {
- fmt.Println()
- o.logStep(4, "Verification skipped (dry run mode)")
- o.logger.Info("[DRY RUN] Would create archive: %s", archivePath)
- stats.EndTime = o.now()
- }
-
- stats.Duration = stats.EndTime.Sub(stats.StartTime)
-
- // Parse log file to populate error/warning counts before dispatch
- if stats.LogFilePath != "" {
- o.logger.Debug("Parsing log file for error/warning counts: %s", stats.LogFilePath)
- _, errorCount, warningCount := ParseLogCounts(stats.LogFilePath, 0)
- stats.ErrorCount = errorCount
- stats.WarningCount = warningCount
- if errorCount > 0 || warningCount > 0 {
- o.logger.Debug("Found %d errors and %d warnings in log file", errorCount, warningCount)
- }
- } else {
- o.logger.Debug("No log file path specified, error/warning counts will be 0")
- }
-
- // Determine aggregated exit code (similar to legacy Bash logic)
- switch {
- case stats.ErrorCount > 0:
- stats.ExitCode = types.ExitBackupError.Int()
- case stats.WarningCount > 0:
- stats.ExitCode = types.ExitGenericError.Int()
- default:
- stats.ExitCode = types.ExitSuccess.Int()
- }
- o.logger.Debug("Aggregated exit code based on log analysis: %d", stats.ExitCode)
-
- if len(o.storageTargets) == 0 {
- fmt.Println()
- o.logStep(6, "No storage targets registered - skipping")
- } else if o.dryRun {
- fmt.Println()
- o.logStep(6, "Storage dispatch skipped (dry run mode)")
- } else {
- fmt.Println()
- o.logStep(6, "Dispatching archive to %d storage target(s)", len(o.storageTargets))
- o.logGlobalRetentionPolicy()
- }
-
- if !o.dryRun {
- o.logger.Debug("Dispatching archive to %d storage targets", len(o.storageTargets))
- if err := o.dispatchPostBackup(ctx, stats); err != nil {
- return stats, err
- }
+ o.finalizeBackupStats(run)
+ if err := o.dispatchBackupArtifacts(run); err != nil {
+ return stats, err
}
fmt.Println()
- o.logger.Debug("Go backup completed in %s", backup.FormatDuration(stats.Duration))
+ o.logger.Debug("Go backup completed in %s", backup.FormatDuration(run.stats.Duration))
return stats, nil
}
diff --git a/internal/orchestrator/pbs_mount_guard_test.go b/internal/orchestrator/pbs_mount_guard_test.go
index a9efbc17..75a4d385 100644
--- a/internal/orchestrator/pbs_mount_guard_test.go
+++ b/internal/orchestrator/pbs_mount_guard_test.go
@@ -3,8 +3,6 @@ package orchestrator
import "testing"
func TestPBSMountGuardRootForDatastorePath(t *testing.T) {
- t.Parallel()
-
tests := []struct {
name string
in string
@@ -25,7 +23,6 @@ func TestPBSMountGuardRootForDatastorePath(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
if got := pbsMountGuardRootForDatastorePath(tt.in); got != tt.want {
t.Fatalf("pbsMountGuardRootForDatastorePath(%q)=%q want %q", tt.in, got, tt.want)
}
diff --git a/internal/orchestrator/pbs_staged_apply_additional_test.go b/internal/orchestrator/pbs_staged_apply_additional_test.go
index 5c0d1d23..6d360c4b 100644
--- a/internal/orchestrator/pbs_staged_apply_additional_test.go
+++ b/internal/orchestrator/pbs_staged_apply_additional_test.go
@@ -12,8 +12,6 @@ import (
)
func TestPBSConfigHasHeader_AcceptsAndRejectsExpectedForms(t *testing.T) {
- t.Parallel()
-
tests := []struct {
name string
content string
@@ -348,8 +346,6 @@ func TestLoadPBSDatastoreCfgFromInventory_PropagatesErrors(t *testing.T) {
}
func TestDetectPBSDatastoreCfgDuplicateKeys_DetectsDuplicateKeys(t *testing.T) {
- t.Parallel()
-
blocks := []pbsDatastoreBlock{{
Name: "DS1",
Lines: []string{
@@ -366,8 +362,6 @@ func TestDetectPBSDatastoreCfgDuplicateKeys_DetectsDuplicateKeys(t *testing.T) {
}
func TestDetectPBSDatastoreCfgDuplicateKeys_AllowsUniqueKeys(t *testing.T) {
- t.Parallel()
-
blocks := []pbsDatastoreBlock{{
Name: "DS1",
Lines: []string{
@@ -382,8 +376,6 @@ func TestDetectPBSDatastoreCfgDuplicateKeys_AllowsUniqueKeys(t *testing.T) {
}
func TestParsePBSDatastoreCfgBlocks_IgnoresGarbageAndHandlesMissingNames(t *testing.T) {
- t.Parallel()
-
content := strings.Join([]string{
"path /should/be/ignored",
"datastore:",
@@ -416,8 +408,6 @@ func TestParsePBSDatastoreCfgBlocks_IgnoresGarbageAndHandlesMissingNames(t *test
}
func TestParsePBSDatastoreCfgBlocks_DropsEmptyNamedBlocks(t *testing.T) {
- t.Parallel()
-
content := strings.Join([]string{
"datastore: :",
" path /mnt/ignored",
@@ -440,8 +430,6 @@ func TestParsePBSDatastoreCfgBlocks_DropsEmptyNamedBlocks(t *testing.T) {
}
func TestShouldApplyPBSDatastoreBlock_CoversCommonBranches(t *testing.T) {
- t.Parallel()
-
if ok, reason := shouldApplyPBSDatastoreBlock(pbsDatastoreBlock{Name: "ds", Path: "/"}, newTestLogger()); ok || !strings.Contains(reason, "invalid") {
t.Fatalf("expected invalid path rejection, got ok=%v reason=%q", ok, reason)
}
@@ -464,8 +452,6 @@ func TestShouldApplyPBSDatastoreBlock_CoversCommonBranches(t *testing.T) {
}
func TestWriteDeferredPBSDatastoreCfg_EmptyInputIsNoop(t *testing.T) {
- t.Parallel()
-
if path, err := writeDeferredPBSDatastoreCfg(nil); err != nil {
t.Fatalf("err=%v", err)
} else if path != "" {
diff --git a/internal/orchestrator/pbs_staged_apply_test.go b/internal/orchestrator/pbs_staged_apply_test.go
index 0ee7ab72..d881eff2 100644
--- a/internal/orchestrator/pbs_staged_apply_test.go
+++ b/internal/orchestrator/pbs_staged_apply_test.go
@@ -63,8 +63,6 @@ func TestApplyPBSRemoteCfgFromStage_RemovesWhenEmpty(t *testing.T) {
}
func TestShouldApplyPBSDatastoreBlock_AllowsMountLikePathsOnRootFS(t *testing.T) {
- t.Parallel()
-
dir, err := os.MkdirTemp("/mnt", "proxsave-test-ds-")
if err != nil {
t.Skipf("cannot create temp dir under /mnt: %v", err)
diff --git a/internal/orchestrator/pve_staged_apply_test.go b/internal/orchestrator/pve_staged_apply_test.go
index a5561beb..fcc80117 100644
--- a/internal/orchestrator/pve_staged_apply_test.go
+++ b/internal/orchestrator/pve_staged_apply_test.go
@@ -7,8 +7,6 @@ import (
)
func TestPVEStorageMountGuardItems_BuildsExpectedTargets(t *testing.T) {
- t.Parallel()
-
candidates := []pveStorageMountGuardCandidate{
{StorageID: "Data1", StorageType: "dir", Path: "/mnt/datastore/Data1"},
{StorageID: "Synology-Archive", StorageType: "dir", Path: "/mnt/Synology_NFS/PBS_Backup"},
@@ -40,8 +38,6 @@ func TestPVEStorageMountGuardItems_BuildsExpectedTargets(t *testing.T) {
}
func TestApplyPVEBackupJobsFromStage_CreatesJobsViaPvesh(t *testing.T) {
- t.Parallel()
-
origFS := restoreFS
origCmd := restoreCmd
t.Cleanup(func() {
diff --git a/internal/orchestrator/resolv_conf_repair.go b/internal/orchestrator/resolv_conf_repair.go
index 396e6415..bce82238 100644
--- a/internal/orchestrator/resolv_conf_repair.go
+++ b/internal/orchestrator/resolv_conf_repair.go
@@ -148,7 +148,7 @@ func repairResolvConfWithSystemdResolved(logger *logging.Logger) (bool, error) {
return false, nil
}
-func readTarEntry(ctx context.Context, archivePath, name string, maxBytes int64) ([]byte, error) {
+func readTarEntry(ctx context.Context, archivePath, name string, maxBytes int64) (data []byte, err error) {
file, err := restoreFS.Open(archivePath)
if err != nil {
return nil, fmt.Errorf("open archive: %w", err)
@@ -159,9 +159,7 @@ func readTarEntry(ctx context.Context, archivePath, name string, maxBytes int64)
if err != nil {
return nil, fmt.Errorf("create decompression reader: %w", err)
}
- if closer, ok := reader.(io.Closer); ok {
- defer closer.Close()
- }
+ defer closeDecompressionReader(reader, &err, "close decompression reader")
wantA := strings.TrimPrefix(strings.TrimSpace(name), "./")
wantB := "./" + wantA
diff --git a/internal/orchestrator/restore.go b/internal/orchestrator/restore.go
index 03b79316..f2f779fd 100644
--- a/internal/orchestrator/restore.go
+++ b/internal/orchestrator/restore.go
@@ -1,44 +1,21 @@
+// Package orchestrator coordinates backup, restore, decrypt, and related workflows.
package orchestrator
import (
- "archive/tar"
"bufio"
- "compress/gzip"
"context"
"errors"
"fmt"
- "io"
"os"
- "os/exec"
- "path"
- "path/filepath"
- "sort"
- "strings"
- "sync/atomic"
- "syscall"
"time"
"github.com/tis24dev/proxsave/internal/config"
- "github.com/tis24dev/proxsave/internal/input"
"github.com/tis24dev/proxsave/internal/logging"
)
+// ErrRestoreAborted is returned when a restore workflow is intentionally aborted by the user.
var ErrRestoreAborted = errors.New("restore workflow aborted by user")
-var (
- serviceStopTimeout = 45 * time.Second
- serviceStopNoBlockTimeout = 15 * time.Second
- serviceStartTimeout = 30 * time.Second
- serviceVerifyTimeout = 30 * time.Second
- serviceStatusCheckTimeout = 5 * time.Second
- servicePollInterval = 500 * time.Millisecond
- serviceRetryDelay = 500 * time.Millisecond
- restoreLogSequence uint64
- restoreGlob = filepath.Glob
-)
-
-const restoreTempPattern = ".proxsave-tmp-*"
-
// RestoreAbortInfo contains information about an aborted restore with network rollback.
type RestoreAbortInfo struct {
NetworkRollbackArmed bool
@@ -61,6 +38,7 @@ func ClearRestoreAbortInfo() {
lastRestoreAbortInfo = nil
}
+// RunRestoreWorkflow runs the CLI restore workflow using stdin prompts and the provided configuration.
func RunRestoreWorkflow(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string) (err error) {
if cfg == nil {
return fmt.Errorf("configuration not available")
@@ -75,1715 +53,3 @@ func RunRestoreWorkflow(ctx context.Context, cfg *config.Config, logger *logging
ui := newCLIWorkflowUI(bufio.NewReader(os.Stdin), logger)
return runRestoreWorkflowWithUI(ctx, cfg, logger, version, ui)
}
-
-// checkZFSPoolsAfterRestore checks if ZFS pools need to be imported after restore
-func checkZFSPoolsAfterRestore(logger *logging.Logger) error {
- if _, err := restoreCmd.Run(context.Background(), "which", "zpool"); err != nil {
- // zpool utility not available -> no ZFS tooling installed
- return nil
- }
-
- logger.Info("Checking ZFS pool status...")
-
- configuredPools := detectConfiguredZFSPools()
- importablePools, importOutput, importErr := detectImportableZFSPools()
-
- if len(configuredPools) > 0 {
- logger.Warning("Found %d ZFS pool(s) configured for automatic import:", len(configuredPools))
- for _, pool := range configuredPools {
- logger.Warning(" - %s", pool)
- }
- logger.Info("")
- }
-
- if importErr != nil {
- logger.Warning("`zpool import` command returned an error: %v", importErr)
- if strings.TrimSpace(importOutput) != "" {
- logger.Warning("`zpool import` output:\n%s", importOutput)
- }
- } else if len(importablePools) > 0 {
- logger.Warning("`zpool import` reports pools waiting to be imported:")
- for _, pool := range importablePools {
- logger.Warning(" - %s", pool)
- }
- logger.Info("")
- }
-
- if len(importablePools) == 0 {
- logger.Info("`zpool import` did not report pools waiting for import.")
-
- if len(configuredPools) > 0 {
- logger.Info("")
- for _, pool := range configuredPools {
- if _, err := restoreCmd.Run(context.Background(), "zpool", "status", pool); err == nil {
- logger.Info("Pool %s is already imported (no manual action needed)", pool)
- } else {
- logger.Warning("Systemd expects pool %s, but `zpool import` and `zpool status` did not report it. Check disk visibility and pool status.", pool)
- }
- }
- }
- return nil
- }
-
- logger.Info("⚠ IMPORTANT: ZFS pools may need manual import after restore!")
- logger.Info(" Before rebooting, run these commands:")
- logger.Info(" 1. Check available pools: zpool import")
- for _, pool := range importablePools {
- logger.Info(" 2. Import pool manually: zpool import %s", pool)
- }
- logger.Info(" 3. Verify pool status: zpool status")
- logger.Info("")
- logger.Info(" If pools fail to import, check:")
- logger.Info(" - journalctl -u zfs-import@.service oppure import@.service")
- logger.Info(" - zpool import -d /dev/disk/by-id")
- logger.Info("")
-
- return nil
-}
-
-func stopPVEClusterServices(ctx context.Context, logger *logging.Logger) error {
- services := []string{"pve-cluster", "pvedaemon", "pveproxy", "pvestatd"}
- for _, service := range services {
- if err := stopServiceWithRetries(ctx, logger, service); err != nil {
- return fmt.Errorf("failed to stop PVE services (%s): %w", service, err)
- }
- }
- return nil
-}
-
-func startPVEClusterServices(ctx context.Context, logger *logging.Logger) error {
- services := []string{"pve-cluster", "pvedaemon", "pveproxy", "pvestatd"}
- for _, service := range services {
- if err := startServiceWithRetries(ctx, logger, service); err != nil {
- return fmt.Errorf("failed to start PVE services (%s): %w", service, err)
- }
- }
- return nil
-}
-
-func stopPBSServices(ctx context.Context, logger *logging.Logger) error {
- if _, err := restoreCmd.Run(ctx, "which", "systemctl"); err != nil {
- return fmt.Errorf("systemctl not available: %w", err)
- }
- services := []string{"proxmox-backup-proxy", "proxmox-backup"}
- var failures []string
- for _, service := range services {
- if err := stopServiceWithRetries(ctx, logger, service); err != nil {
- failures = append(failures, fmt.Sprintf("%s: %v", service, err))
- }
- }
- if len(failures) > 0 {
- return errors.New(strings.Join(failures, "; "))
- }
- return nil
-}
-
-func startPBSServices(ctx context.Context, logger *logging.Logger) error {
- if _, err := restoreCmd.Run(ctx, "which", "systemctl"); err != nil {
- return fmt.Errorf("systemctl not available: %w", err)
- }
- services := []string{"proxmox-backup", "proxmox-backup-proxy"}
- var failures []string
- for _, service := range services {
- if err := startServiceWithRetries(ctx, logger, service); err != nil {
- failures = append(failures, fmt.Sprintf("%s: %v", service, err))
- }
- }
- if len(failures) > 0 {
- return errors.New(strings.Join(failures, "; "))
- }
- return nil
-}
-
-func unmountEtcPVE(ctx context.Context, logger *logging.Logger) error {
- output, err := restoreCmd.Run(ctx, "umount", "/etc/pve")
- msg := strings.TrimSpace(string(output))
- if err != nil {
- if strings.Contains(msg, "not mounted") {
- logger.Info("Skipping umount /etc/pve (already unmounted)")
- return nil
- }
- if msg != "" {
- return fmt.Errorf("umount /etc/pve failed: %s", msg)
- }
- return fmt.Errorf("umount /etc/pve failed: %w", err)
- }
- if msg != "" {
- logger.Debug("umount /etc/pve output: %s", msg)
- }
- return nil
-}
-
-func runCommandWithTimeout(ctx context.Context, logger *logging.Logger, timeout time.Duration, name string, args ...string) error {
- return execCommand(ctx, logger, timeout, name, args...)
-}
-
-func execCommand(ctx context.Context, logger *logging.Logger, timeout time.Duration, name string, args ...string) error {
- execCtx := ctx
- var cancel context.CancelFunc
- if timeout > 0 {
- execCtx, cancel = context.WithTimeout(ctx, timeout)
- defer cancel()
- }
-
- output, err := restoreCmd.Run(execCtx, name, args...)
- msg := strings.TrimSpace(string(output))
- if err != nil {
- if timeout > 0 && (errors.Is(execCtx.Err(), context.DeadlineExceeded) || errors.Is(err, context.DeadlineExceeded)) {
- return fmt.Errorf("%s %s timed out after %s", name, strings.Join(args, " "), timeout)
- }
- if msg != "" {
- return fmt.Errorf("%s %s failed: %s", name, strings.Join(args, " "), msg)
- }
- return fmt.Errorf("%s %s failed: %w", name, strings.Join(args, " "), err)
- }
- if msg != "" && logger != nil {
- logger.Debug("%s %s: %s", name, strings.Join(args, " "), msg)
- }
- return nil
-}
-
-func stopServiceWithRetries(ctx context.Context, logger *logging.Logger, service string) error {
- attempts := []struct {
- description string
- args []string
- timeout time.Duration
- }{
- {"stop (no-block)", []string{"stop", "--no-block", service}, serviceStopNoBlockTimeout},
- {"stop (blocking)", []string{"stop", service}, serviceStopTimeout},
- {"aggressive stop", []string{"kill", "--signal=SIGTERM", "--kill-who=all", service}, serviceStopTimeout},
- {"force kill", []string{"kill", "--signal=SIGKILL", "--kill-who=all", service}, serviceStopTimeout},
- }
-
- var lastErr error
- for i, attempt := range attempts {
- if i > 0 {
- if err := sleepWithContext(ctx, serviceRetryDelay); err != nil {
- return err
- }
- }
-
- if logger != nil {
- logger.Debug("Attempting %s for %s (%d/%d)", attempt.description, service, i+1, len(attempts))
- }
-
- if err := runCommandWithTimeoutCountdown(ctx, logger, attempt.timeout, service, attempt.description, "systemctl", attempt.args...); err != nil {
- lastErr = err
- continue
- }
- if err := waitForServiceInactive(ctx, logger, service, serviceVerifyTimeout); err != nil {
- lastErr = err
- continue
- }
- resetFailedService(ctx, logger, service)
- return nil
- }
-
- if lastErr == nil {
- lastErr = fmt.Errorf("unable to stop %s", service)
- }
- return lastErr
-}
-
-func startServiceWithRetries(ctx context.Context, logger *logging.Logger, service string) error {
- attempts := []struct {
- description string
- args []string
- }{
- {"start", []string{"start", service}},
- {"retry start", []string{"start", service}},
- {"aggressive restart", []string{"restart", service}},
- }
-
- var lastErr error
- for i, attempt := range attempts {
- if i > 0 {
- if err := sleepWithContext(ctx, serviceRetryDelay); err != nil {
- return err
- }
- }
-
- if logger != nil {
- logger.Debug("Attempting %s for %s (%d/%d)", attempt.description, service, i+1, len(attempts))
- }
-
- if err := runCommandWithTimeout(ctx, logger, serviceStartTimeout, "systemctl", attempt.args...); err != nil {
- lastErr = err
- continue
- }
- return nil
- }
-
- if lastErr == nil {
- lastErr = fmt.Errorf("unable to start %s", service)
- }
- return lastErr
-}
-
-func runCommandWithTimeoutCountdown(ctx context.Context, logger *logging.Logger, timeout time.Duration, service, action, name string, args ...string) error {
- if timeout <= 0 {
- return execCommand(ctx, logger, timeout, name, args...)
- }
-
- execCtx, cancel := context.WithTimeout(ctx, timeout)
- defer cancel()
-
- type result struct {
- out []byte
- err error
- }
-
- resultCh := make(chan result, 1)
- go func() {
- out, err := restoreCmd.Run(execCtx, name, args...)
- resultCh <- result{out: out, err: err}
- }()
-
- progressEnabled := isTerminal(int(os.Stderr.Fd()))
- deadline := time.Now().Add(timeout)
- ticker := time.NewTicker(1 * time.Second)
- defer ticker.Stop()
-
- writeProgress := func(left time.Duration) {
- if !progressEnabled {
- return
- }
- seconds := int(left.Round(time.Second).Seconds())
- if seconds < 0 {
- seconds = 0
- }
- fmt.Fprintf(os.Stderr, "\rStopping %s: %s (attempt timeout in %ds)...", service, action, seconds)
- }
-
- for {
- select {
- case r := <-resultCh:
- if progressEnabled {
- fmt.Fprint(os.Stderr, "\r")
- fmt.Fprintln(os.Stderr, strings.Repeat(" ", 80))
- fmt.Fprint(os.Stderr, "\r")
- }
- msg := strings.TrimSpace(string(r.out))
- if r.err != nil {
- if errors.Is(execCtx.Err(), context.DeadlineExceeded) || errors.Is(r.err, context.DeadlineExceeded) {
- return fmt.Errorf("%s %s timed out after %s", name, strings.Join(args, " "), timeout)
- }
- if msg != "" {
- return fmt.Errorf("%s %s failed: %s", name, strings.Join(args, " "), msg)
- }
- return fmt.Errorf("%s %s failed: %w", name, strings.Join(args, " "), r.err)
- }
- if msg != "" && logger != nil {
- logger.Debug("%s %s: %s", name, strings.Join(args, " "), msg)
- }
- return nil
- case <-ticker.C:
- writeProgress(time.Until(deadline))
- case <-execCtx.Done():
- writeProgress(0)
- if progressEnabled {
- fmt.Fprintln(os.Stderr)
- }
- select {
- case r := <-resultCh:
- msg := strings.TrimSpace(string(r.out))
- if msg != "" && logger != nil {
- logger.Debug("%s %s: %s", name, strings.Join(args, " "), msg)
- }
- default:
- }
- return fmt.Errorf("%s %s timed out after %s", name, strings.Join(args, " "), timeout)
- }
- }
-}
-
-func waitForServiceInactive(ctx context.Context, logger *logging.Logger, service string, timeout time.Duration) error {
- if timeout <= 0 {
- return nil
- }
- deadline := time.Now().Add(timeout)
- progressEnabled := isTerminal(int(os.Stderr.Fd()))
- ticker := time.NewTicker(1 * time.Second)
- defer ticker.Stop()
- for {
- remaining := time.Until(deadline)
- if remaining <= 0 {
- if progressEnabled {
- fmt.Fprintln(os.Stderr)
- }
- return fmt.Errorf("%s still active after %s", service, timeout)
- }
-
- checkTimeout := minDuration(remaining, serviceStatusCheckTimeout)
- active, err := isServiceActive(ctx, service, checkTimeout)
- if err != nil {
- return err
- }
- if !active {
- if logger != nil {
- logger.Debug("%s stopped successfully", service)
- }
- return nil
- }
-
- wait := minDuration(remaining, servicePollInterval)
- timer := time.NewTimer(wait)
- select {
- case <-ctx.Done():
- if !timer.Stop() {
- <-timer.C
- }
- if progressEnabled {
- fmt.Fprintln(os.Stderr)
- }
- return ctx.Err()
- case <-timer.C:
- }
- select {
- case <-ticker.C:
- if progressEnabled {
- seconds := int(remaining.Round(time.Second).Seconds())
- if seconds < 0 {
- seconds = 0
- }
- fmt.Fprintf(os.Stderr, "\rWaiting for %s to stop (%ds remaining)...", service, seconds)
- }
- default:
- }
- }
-}
-
-func resetFailedService(ctx context.Context, logger *logging.Logger, service string) {
- resetCtx, cancel := context.WithTimeout(ctx, serviceStatusCheckTimeout)
- defer cancel()
-
- if _, err := restoreCmd.Run(resetCtx, "systemctl", "reset-failed", service); err != nil {
- if logger != nil {
- logger.Debug("systemctl reset-failed %s ignored: %v", service, err)
- }
- }
-}
-
-func isServiceActive(ctx context.Context, service string, timeout time.Duration) (bool, error) {
- if timeout <= 0 {
- timeout = serviceStatusCheckTimeout
- }
- checkCtx, cancel := context.WithTimeout(ctx, timeout)
- defer cancel()
-
- output, err := restoreCmd.Run(checkCtx, "systemctl", "is-active", service)
- msg := strings.TrimSpace(string(output))
- if err == nil {
- return true, nil
- }
- if errors.Is(checkCtx.Err(), context.DeadlineExceeded) || errors.Is(err, context.DeadlineExceeded) {
- return false, fmt.Errorf("systemctl is-active %s timed out after %s", service, timeout)
- }
- if msg == "" {
- msg = err.Error()
- }
- lower := strings.ToLower(msg)
- if strings.Contains(lower, "deactivating") || strings.Contains(lower, "activating") {
- return true, nil
- }
- if strings.Contains(lower, "inactive") || strings.Contains(lower, "failed") || strings.Contains(lower, "dead") {
- return false, nil
- }
- return false, fmt.Errorf("systemctl is-active %s failed: %s", service, msg)
-}
-
-func minDuration(a, b time.Duration) time.Duration {
- if a < b {
- return a
- }
- return b
-}
-
-func sleepWithContext(ctx context.Context, d time.Duration) error {
- if d <= 0 {
- return nil
- }
- timer := time.NewTimer(d)
- defer timer.Stop()
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-timer.C:
- return nil
- }
-}
-
-func detectConfiguredZFSPools() []string {
- pools := make(map[string]struct{})
-
- directories := []string{
- "/etc/systemd/system/zfs-import.target.wants",
- "/etc/systemd/system/multi-user.target.wants",
- }
-
- for _, dir := range directories {
- entries, err := restoreFS.ReadDir(dir)
- if err != nil {
- continue
- }
-
- for _, entry := range entries {
- if pool := parsePoolNameFromUnit(entry.Name()); pool != "" {
- pools[pool] = struct{}{}
- }
- }
- }
-
- globPatterns := []string{
- "/etc/systemd/system/zfs-import@*.service",
- "/etc/systemd/system/import@*.service",
- }
-
- for _, pattern := range globPatterns {
- matches, err := restoreGlob(pattern)
- if err != nil {
- continue
- }
- for _, match := range matches {
- if pool := parsePoolNameFromUnit(filepath.Base(match)); pool != "" {
- pools[pool] = struct{}{}
- }
- }
- }
-
- var poolNames []string
- for pool := range pools {
- poolNames = append(poolNames, pool)
- }
- sort.Strings(poolNames)
- return poolNames
-}
-
-func parsePoolNameFromUnit(unitName string) string {
- switch {
- case strings.HasPrefix(unitName, "zfs-import@") && strings.HasSuffix(unitName, ".service"):
- pool := strings.TrimPrefix(unitName, "zfs-import@")
- return strings.TrimSuffix(pool, ".service")
- case strings.HasPrefix(unitName, "import@") && strings.HasSuffix(unitName, ".service"):
- pool := strings.TrimPrefix(unitName, "import@")
- return strings.TrimSuffix(pool, ".service")
- default:
- return ""
- }
-}
-
-func detectImportableZFSPools() ([]string, string, error) {
- output, err := restoreCmd.Run(context.Background(), "zpool", "import")
- poolNames := parseZpoolImportOutput(string(output))
- if err != nil {
- return poolNames, string(output), err
- }
- return poolNames, string(output), nil
-}
-
-func parseZpoolImportOutput(output string) []string {
- var pools []string
- scanner := bufio.NewScanner(strings.NewReader(output))
- for scanner.Scan() {
- line := strings.TrimSpace(scanner.Text())
- if strings.HasPrefix(strings.ToLower(line), "pool:") {
- pool := strings.TrimSpace(line[len("pool:"):])
- if pool != "" {
- pools = append(pools, pool)
- }
- }
- }
- return pools
-}
-
-func combinePoolNames(a, b []string) []string {
- merged := make(map[string]struct{})
- for _, pool := range a {
- merged[pool] = struct{}{}
- }
- for _, pool := range b {
- merged[pool] = struct{}{}
- }
-
- if len(merged) == 0 {
- return nil
- }
-
- names := make([]string, 0, len(merged))
- for pool := range merged {
- names = append(names, pool)
- }
- sort.Strings(names)
- return names
-}
-
-func shouldRecreateDirectories(systemType SystemType, categories []Category) bool {
- return (systemType.SupportsPVE() && hasCategoryID(categories, "storage_pve")) ||
- (systemType.SupportsPBS() && hasCategoryID(categories, "datastore_pbs"))
-}
-
-func hasCategoryID(categories []Category, id string) bool {
- for _, cat := range categories {
- if cat.ID == id {
- return true
- }
- }
- return false
-}
-
-// shouldStopPBSServices reports whether any selected categories belong to PBS-specific configuration
-// and therefore require stopping PBS services before restore.
-func shouldStopPBSServices(categories []Category) bool {
- for _, cat := range categories {
- if cat.Type == CategoryTypePBS {
- return true
- }
- // Some common categories (e.g. SSL) include PBS paths that require restarting PBS services.
- for _, p := range cat.Paths {
- p = strings.TrimSpace(p)
- if strings.HasPrefix(p, "./etc/proxmox-backup/") || strings.HasPrefix(p, "./var/lib/proxmox-backup/") {
- return true
- }
- }
- }
- return false
-}
-
-func splitExportCategories(categories []Category) (normal []Category, export []Category) {
- for _, cat := range categories {
- if cat.ExportOnly {
- export = append(export, cat)
- continue
- }
- normal = append(normal, cat)
- }
- return normal, export
-}
-
-// redirectClusterCategoryToExport removes pve_cluster from normal categories and adds it to export-only list.
-func redirectClusterCategoryToExport(normal []Category, export []Category) ([]Category, []Category) {
- filtered := make([]Category, 0, len(normal))
- for _, cat := range normal {
- if cat.ID == "pve_cluster" {
- export = append(export, cat)
- continue
- }
- filtered = append(filtered, cat)
- }
- return filtered, export
-}
-
-func exportDestRoot(baseDir string) string {
- base := strings.TrimSpace(baseDir)
- if base == "" {
- base = "/opt/proxsave"
- }
- return filepath.Join(base, fmt.Sprintf("proxmox-config-export-%s", nowRestore().Format("20060102-150405")))
-}
-
-// runFullRestore performs a full restore without selective options (fallback)
-func runFullRestore(ctx context.Context, reader *bufio.Reader, candidate *backupCandidate, prepared *preparedBundle, destRoot string, logger *logging.Logger, dryRun bool) error {
- if err := confirmRestoreAction(ctx, reader, candidate, destRoot); err != nil {
- return err
- }
-
- safeFstabMerge := destRoot == "/" && isRealRestoreFS(restoreFS)
- skipFn := func(name string) bool {
- if !safeFstabMerge {
- return false
- }
- clean := strings.TrimPrefix(strings.TrimSpace(name), "./")
- clean = strings.TrimPrefix(clean, "/")
- return clean == "etc/fstab"
- }
-
- if safeFstabMerge {
- logger.Warning("Full restore safety: /etc/fstab will not be overwritten; Smart Merge will be applied after extraction.")
- }
-
- if err := extractPlainArchive(ctx, prepared.ArchivePath, destRoot, logger, skipFn); err != nil {
- return err
- }
-
- if safeFstabMerge {
- logger.Info("")
- fsTempDir, err := restoreFS.MkdirTemp("", "proxsave-fstab-")
- if err != nil {
- logger.Warning("Failed to create temp dir for fstab merge: %v", err)
- } else {
- defer restoreFS.RemoveAll(fsTempDir)
- fsCategory := []Category{{
- ID: "filesystem",
- Name: "Filesystem Configuration",
- Paths: []string{
- "./etc/fstab",
- },
- }}
- if err := extractArchiveNative(ctx, prepared.ArchivePath, fsTempDir, logger, fsCategory, RestoreModeCustom, nil, "", nil); err != nil {
- logger.Warning("Failed to extract filesystem config for merge: %v", err)
- } else {
- // Best-effort: extract ProxSave inventory files used for stable fstab device remapping.
- invCategory := []Category{{
- ID: "fstab_inventory",
- Name: "Fstab inventory (device mapping)",
- Paths: []string{
- "./var/lib/proxsave-info/commands/system/blkid.txt",
- "./var/lib/proxsave-info/commands/system/lsblk_json.json",
- "./var/lib/proxsave-info/commands/system/lsblk.txt",
- "./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json",
- },
- }}
- if err := extractArchiveNative(ctx, prepared.ArchivePath, fsTempDir, logger, invCategory, RestoreModeCustom, nil, "", nil); err != nil {
- logger.Debug("Failed to extract fstab inventory data (continuing): %v", err)
- }
-
- currentFstab := filepath.Join(destRoot, "etc", "fstab")
- backupFstab := filepath.Join(fsTempDir, "etc", "fstab")
- if err := SmartMergeFstab(ctx, logger, reader, currentFstab, backupFstab, dryRun); err != nil {
- if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) {
- logger.Info("Restore aborted by user during Smart Filesystem Configuration Merge.")
- return err
- }
- logger.Warning("Smart Fstab Merge failed: %v", err)
- }
- }
- }
- }
-
- logger.Info("Restore completed successfully.")
- return nil
-}
-
-func confirmRestoreAction(ctx context.Context, reader *bufio.Reader, cand *backupCandidate, dest string) error {
- manifest := cand.Manifest
- fmt.Println()
- fmt.Printf("Selected backup: %s (%s)\n", cand.DisplayBase, manifest.CreatedAt.Format("2006-01-02 15:04:05"))
- cleanDest := filepath.Clean(strings.TrimSpace(dest))
- if cleanDest == "" || cleanDest == "." {
- cleanDest = string(os.PathSeparator)
- }
- if cleanDest == string(os.PathSeparator) {
- fmt.Println("Restore destination: / (system root; original paths will be preserved)")
- fmt.Println("WARNING: This operation will overwrite configuration files on this system.")
- } else {
- fmt.Printf("Restore destination: %s (original paths will be preserved under this directory)\n", cleanDest)
- fmt.Printf("WARNING: This operation will overwrite existing files under %s.\n", cleanDest)
- }
- fmt.Println("Type RESTORE to proceed or 0 to cancel.")
-
- for {
- fmt.Print("Confirmation: ")
- line, err := input.ReadLineWithContext(ctx, reader)
- if err != nil {
- return err
- }
- switch strings.TrimSpace(line) {
- case "RESTORE":
- return nil
- case "0":
- return ErrRestoreAborted
- default:
- fmt.Println("Please type RESTORE to confirm or 0 to cancel.")
- }
- }
-}
-
-func extractPlainArchive(ctx context.Context, archivePath, destRoot string, logger *logging.Logger, skipFn func(entryName string) bool) error {
- if err := restoreFS.MkdirAll(destRoot, 0o755); err != nil {
- return fmt.Errorf("create destination directory: %w", err)
- }
-
- // Only enforce root privileges when writing to the real system root.
- if destRoot == "/" && isRealRestoreFS(restoreFS) && os.Geteuid() != 0 {
- return fmt.Errorf("restore to %s requires root privileges", destRoot)
- }
-
- logger.Info("Extracting archive %s into %s", filepath.Base(archivePath), destRoot)
-
- // Use native Go extraction to preserve atime/ctime from PAX headers
- if err := extractArchiveNative(ctx, archivePath, destRoot, logger, nil, RestoreModeFull, nil, "", skipFn); err != nil {
- return fmt.Errorf("archive extraction failed: %w", err)
- }
-
- return nil
-}
-
-// runSafeClusterApply applies selected cluster configs via pvesh without touching config.db.
-// It operates on files extracted to exportRoot (e.g. exportDestRoot).
-func runSafeClusterApply(ctx context.Context, reader *bufio.Reader, exportRoot string, logger *logging.Logger) (err error) {
- if logger == nil {
- logger = logging.GetDefaultLogger()
- }
- ui := newCLIWorkflowUI(reader, logger)
- return runSafeClusterApplyWithUI(ctx, ui, exportRoot, logger, nil)
-}
-
-type vmEntry struct {
- VMID string
- Kind string // qemu | lxc
- Name string
- Path string
-}
-
-func scanVMConfigs(exportRoot, node string) ([]vmEntry, error) {
- var entries []vmEntry
- base := filepath.Join(exportRoot, "etc/pve/nodes", node)
-
- type dirSpec struct {
- kind string
- path string
- }
-
- dirs := []dirSpec{
- {kind: "qemu", path: filepath.Join(base, "qemu-server")},
- {kind: "lxc", path: filepath.Join(base, "lxc")},
- }
-
- for _, spec := range dirs {
- infos, err := restoreFS.ReadDir(spec.path)
- if err != nil {
- continue
- }
- for _, entry := range infos {
- if entry.IsDir() {
- continue
- }
- name := entry.Name()
- if !strings.HasSuffix(name, ".conf") {
- continue
- }
- vmid := strings.TrimSuffix(name, ".conf")
- vmPath := filepath.Join(spec.path, name)
- vmName := readVMName(vmPath)
- entries = append(entries, vmEntry{
- VMID: vmid,
- Kind: spec.kind,
- Name: vmName,
- Path: vmPath,
- })
- }
- }
-
- return entries, nil
-}
-
-func listExportNodeDirs(exportRoot string) ([]string, error) {
- nodesRoot := filepath.Join(exportRoot, "etc/pve/nodes")
- entries, err := restoreFS.ReadDir(nodesRoot)
- if err != nil {
- if errors.Is(err, os.ErrNotExist) || os.IsNotExist(err) {
- return nil, nil
- }
- return nil, err
- }
-
- var nodes []string
- for _, entry := range entries {
- if !entry.IsDir() {
- continue
- }
- name := strings.TrimSpace(entry.Name())
- if name == "" {
- continue
- }
- nodes = append(nodes, name)
- }
- sort.Strings(nodes)
- return nodes, nil
-}
-
-func countVMConfigsForNode(exportRoot, node string) (qemuCount, lxcCount int) {
- base := filepath.Join(exportRoot, "etc/pve/nodes", node)
-
- countInDir := func(dir string) int {
- entries, err := restoreFS.ReadDir(dir)
- if err != nil {
- return 0
- }
- n := 0
- for _, entry := range entries {
- if entry.IsDir() {
- continue
- }
- if strings.HasSuffix(entry.Name(), ".conf") {
- n++
- }
- }
- return n
- }
-
- qemuCount = countInDir(filepath.Join(base, "qemu-server"))
- lxcCount = countInDir(filepath.Join(base, "lxc"))
- return qemuCount, lxcCount
-}
-
-func promptExportNodeSelection(ctx context.Context, reader *bufio.Reader, exportRoot, currentNode string, exportNodes []string) (string, error) {
- for {
- fmt.Println()
- fmt.Printf("WARNING: VM/CT configs in this backup are stored under different node names.\n")
- fmt.Printf("Current node: %s\n", currentNode)
- fmt.Println("Select which exported node to import VM/CT configs from (they will be applied to the current node):")
- for idx, node := range exportNodes {
- qemuCount, lxcCount := countVMConfigsForNode(exportRoot, node)
- fmt.Printf(" [%d] %s (qemu=%d, lxc=%d)\n", idx+1, node, qemuCount, lxcCount)
- }
- fmt.Println(" [0] Skip VM/CT apply")
-
- fmt.Print("Choice: ")
- line, err := input.ReadLineWithContext(ctx, reader)
- if err != nil {
- return "", err
- }
- trimmed := strings.TrimSpace(line)
- if trimmed == "0" {
- return "", nil
- }
- if trimmed == "" {
- continue
- }
- idx, err := parseMenuIndex(trimmed, len(exportNodes))
- if err != nil {
- fmt.Println(err)
- continue
- }
- return exportNodes[idx], nil
- }
-}
-
-func stringSliceContains(items []string, want string) bool {
- for _, item := range items {
- if item == want {
- return true
- }
- }
- return false
-}
-
-func readVMName(confPath string) string {
- data, err := restoreFS.ReadFile(confPath)
- if err != nil {
- return ""
- }
- for _, line := range strings.Split(string(data), "\n") {
- t := strings.TrimSpace(line)
- if strings.HasPrefix(t, "name:") {
- return strings.TrimSpace(strings.TrimPrefix(t, "name:"))
- }
- if strings.HasPrefix(t, "hostname:") {
- return strings.TrimSpace(strings.TrimPrefix(t, "hostname:"))
- }
- }
- return ""
-}
-
-func applyVMConfigs(ctx context.Context, entries []vmEntry, logger *logging.Logger) (applied, failed int) {
- for _, vm := range entries {
- if err := ctx.Err(); err != nil {
- logger.Warning("VM apply aborted: %v", err)
- return applied, failed
- }
- target := fmt.Sprintf("/nodes/%s/%s/%s/config", detectNodeForVM(), vm.Kind, vm.VMID)
- args := []string{"set", target, "--filename", vm.Path}
- if err := runPvesh(ctx, logger, args); err != nil {
- logger.Warning("Failed to apply %s (vmid=%s kind=%s): %v", target, vm.VMID, vm.Kind, err)
- failed++
- } else {
- display := vm.VMID
- if vm.Name != "" {
- display = fmt.Sprintf("%s (%s)", vm.VMID, vm.Name)
- }
- logger.Info("Applied VM/CT config %s", display)
- applied++
- }
- }
- return applied, failed
-}
-
-func detectNodeForVM() string {
- host, _ := os.Hostname()
- host = shortHost(host)
- if host != "" {
- return host
- }
- return "localhost"
-}
-
-type storageBlock struct {
- ID string
- data []string
-}
-
-func applyStorageCfg(ctx context.Context, cfgPath string, logger *logging.Logger) (applied, failed int, err error) {
- blocks, perr := parseStorageBlocks(cfgPath)
- if perr != nil {
- return 0, 0, perr
- }
- if len(blocks) == 0 {
- logger.Info("No storage definitions detected in storage.cfg")
- return 0, 0, nil
- }
-
- for _, blk := range blocks {
- tmp, tmpErr := restoreFS.CreateTemp("", fmt.Sprintf("pve-storage-%s-*.cfg", sanitizeID(blk.ID)))
- if tmpErr != nil {
- failed++
- continue
- }
- tmpName := tmp.Name()
- if _, werr := tmp.WriteString(strings.Join(blk.data, "\n") + "\n"); werr != nil {
- _ = tmp.Close()
- _ = restoreFS.Remove(tmpName)
- failed++
- continue
- }
- _ = tmp.Close()
-
- args := []string{"set", fmt.Sprintf("/cluster/storage/%s", blk.ID), "-conf", tmpName}
- if runErr := runPvesh(ctx, logger, args); runErr != nil {
- logger.Warning("Failed to apply storage %s: %v", blk.ID, runErr)
- failed++
- } else {
- logger.Info("Applied storage definition %s", blk.ID)
- applied++
- }
-
- _ = restoreFS.Remove(tmpName)
-
- if err := ctx.Err(); err != nil {
- return applied, failed, err
- }
- }
-
- return applied, failed, nil
-}
-
-func parseStorageBlocks(cfgPath string) ([]storageBlock, error) {
- data, err := restoreFS.ReadFile(cfgPath)
- if err != nil {
- return nil, err
- }
-
- var blocks []storageBlock
- var current *storageBlock
-
- flush := func() {
- if current != nil {
- blocks = append(blocks, *current)
- current = nil
- }
- }
-
- for _, line := range strings.Split(string(data), "\n") {
- trimmed := strings.TrimSpace(line)
- if trimmed == "" {
- flush()
- continue
- }
-
- // storage.cfg blocks use `: ` (e.g. `dir: local`, `nfs: backup`).
- // Older exports may still use `storage: ` blocks.
- _, name, ok := parseProxmoxNotificationHeader(trimmed)
- if ok {
- flush()
- current = &storageBlock{ID: name, data: []string{line}}
- continue
- }
- if current != nil {
- current.data = append(current.data, line)
- }
- }
- flush()
-
- return blocks, nil
-}
-
-func runPvesh(ctx context.Context, logger *logging.Logger, args []string) error {
- output, err := restoreCmd.Run(ctx, "pvesh", args...)
- if len(output) > 0 {
- logger.Debug("pvesh %v output: %s", args, strings.TrimSpace(string(output)))
- }
- if err != nil {
- return fmt.Errorf("pvesh %v failed: %w", args, err)
- }
- return nil
-}
-
-func shortHost(host string) string {
- if idx := strings.Index(host, "."); idx > 0 {
- return host[:idx]
- }
- return host
-}
-
-func sanitizeID(id string) string {
- var b strings.Builder
- for _, r := range id {
- if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' || r == '_' {
- b.WriteRune(r)
- } else {
- b.WriteRune('_')
- }
- }
- return b.String()
-}
-
-// promptClusterRestoreMode asks how to handle cluster database restore (safe export vs full recovery).
-func promptClusterRestoreMode(ctx context.Context, reader *bufio.Reader) (int, error) {
- fmt.Println()
- fmt.Println("Cluster backup detected. Choose how to restore the cluster database:")
- fmt.Println(" [1] SAFE: Do NOT write /var/lib/pve-cluster/config.db. Export cluster files only (manual/apply via API).")
- fmt.Println(" [2] RECOVERY: Restore full cluster database (/var/lib/pve-cluster). Use only when cluster is offline/isolated.")
- fmt.Println(" [0] Exit")
-
- for {
- fmt.Print("Choice: ")
- choiceLine, err := input.ReadLineWithContext(ctx, reader)
- if err != nil {
- return 0, err
- }
- switch strings.TrimSpace(choiceLine) {
- case "1":
- return 1, nil
- case "2":
- return 2, nil
- case "0":
- return 0, nil
- default:
- fmt.Println("Please enter 1, 2, or 0.")
- }
- }
-}
-
-// extractSelectiveArchive extracts only files matching selected categories
-func extractSelectiveArchive(ctx context.Context, archivePath, destRoot string, categories []Category, mode RestoreMode, logger *logging.Logger) (logPath string, err error) {
- done := logging.DebugStart(logger, "extract selective archive", "archive=%s dest=%s categories=%d mode=%s", archivePath, destRoot, len(categories), mode)
- defer func() { done(err) }()
- if err := restoreFS.MkdirAll(destRoot, 0o755); err != nil {
- return "", fmt.Errorf("create destination directory: %w", err)
- }
-
- // Only enforce root privileges when writing to the real system root.
- if destRoot == "/" && isRealRestoreFS(restoreFS) && os.Geteuid() != 0 {
- return "", fmt.Errorf("restore to %s requires root privileges", destRoot)
- }
-
- // Create detailed log directory
- logDir := "/tmp/proxsave"
- if err := restoreFS.MkdirAll(logDir, 0755); err != nil {
- logger.Warning("Could not create log directory: %v", err)
- }
-
- // Create detailed log file
- timestamp := nowRestore().Format("20060102_150405")
- logSeq := atomic.AddUint64(&restoreLogSequence, 1)
- logPath = filepath.Join(logDir, fmt.Sprintf("restore_%s_%d.log", timestamp, logSeq))
- logFile, err := restoreFS.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640)
- if err != nil {
- logger.Warning("Could not create detailed log file: %v", err)
- logFile = nil
- } else {
- defer logFile.Close()
- logger.Info("Detailed restore log: %s", logPath)
- logging.DebugStep(logger, "extract selective archive", "log file=%s", logPath)
- }
-
- logger.Info("Extracting selected categories from archive %s into %s", filepath.Base(archivePath), destRoot)
-
- // Use native Go extraction with category filter
- if err := extractArchiveNative(ctx, archivePath, destRoot, logger, categories, mode, logFile, logPath, nil); err != nil {
- return logPath, err
- }
-
- return logPath, nil
-}
-
-// extractArchiveNative extracts TAR archives natively in Go, preserving all timestamps
-// If categories is nil, all files are extracted. Otherwise, only files matching the categories are extracted.
-func extractArchiveNative(ctx context.Context, archivePath, destRoot string, logger *logging.Logger, categories []Category, mode RestoreMode, logFile *os.File, logFilePath string, skipFn func(entryName string) bool) error {
- // Open the archive file
- file, err := restoreFS.Open(archivePath)
- if err != nil {
- return fmt.Errorf("open archive: %w", err)
- }
- defer file.Close()
-
- // Create decompression reader based on file extension
- reader, err := createDecompressionReader(ctx, file, archivePath)
- if err != nil {
- return fmt.Errorf("create decompression reader: %w", err)
- }
- if closer, ok := reader.(io.Closer); ok {
- defer closer.Close()
- }
-
- // Create TAR reader
- tarReader := tar.NewReader(reader)
-
- // Write log header if log file is available
- if logFile != nil {
- fmt.Fprintf(logFile, "=== PROXMOX RESTORE LOG ===\n")
- fmt.Fprintf(logFile, "Date: %s\n", nowRestore().Format("2006-01-02 15:04:05"))
- fmt.Fprintf(logFile, "Mode: %s\n", getModeName(mode))
- if len(categories) > 0 {
- fmt.Fprintf(logFile, "Selected categories: %d categories\n", len(categories))
- for _, cat := range categories {
- fmt.Fprintf(logFile, " - %s (%s)\n", cat.Name, cat.ID)
- }
- } else {
- fmt.Fprintf(logFile, "Selected categories: ALL (full restore)\n")
- }
- fmt.Fprintf(logFile, "Archive: %s\n", filepath.Base(archivePath))
- fmt.Fprintf(logFile, "\n")
- }
-
- // Extract files (selective or full)
- filesExtracted := 0
- filesSkipped := 0
- filesFailed := 0
- selectiveMode := len(categories) > 0
-
- var restoredTemp, skippedTemp *os.File
- if logFile != nil {
- if tmp, err := restoreFS.CreateTemp("", "restored_entries_*.log"); err == nil {
- restoredTemp = tmp
- defer func() {
- tmp.Close()
- _ = restoreFS.Remove(tmp.Name())
- }()
- } else {
- logger.Warning("Could not create temporary file for restored entries: %v", err)
- }
-
- if tmp, err := restoreFS.CreateTemp("", "skipped_entries_*.log"); err == nil {
- skippedTemp = tmp
- defer func() {
- tmp.Close()
- _ = restoreFS.Remove(tmp.Name())
- }()
- } else {
- logger.Warning("Could not create temporary file for skipped entries: %v", err)
- }
- }
-
- for {
- select {
- case <-ctx.Done():
- return ctx.Err()
- default:
- }
-
- header, err := tarReader.Next()
- if err == io.EOF {
- break
- }
- if err != nil {
- return fmt.Errorf("read tar header: %w", err)
- }
-
- if skipFn != nil && skipFn(header.Name) {
- filesSkipped++
- if skippedTemp != nil {
- fmt.Fprintf(skippedTemp, "SKIPPED: %s (skipped by restore policy)\n", header.Name)
- }
- continue
- }
-
- // Check if file should be extracted (selective mode)
- if selectiveMode {
- shouldExtract := false
- for _, cat := range categories {
- if PathMatchesCategory(header.Name, cat) {
- shouldExtract = true
- break
- }
- }
-
- if !shouldExtract {
- filesSkipped++
- if skippedTemp != nil {
- fmt.Fprintf(skippedTemp, "SKIPPED: %s (does not match any selected category)\n", header.Name)
- }
- continue
- }
- }
-
- if err := extractTarEntry(tarReader, header, destRoot, logger); err != nil {
- logger.Warning("Failed to extract %s: %v", header.Name, err)
- filesFailed++
- continue
- }
-
- filesExtracted++
- if restoredTemp != nil {
- fmt.Fprintf(restoredTemp, "RESTORED: %s\n", header.Name)
- }
- if filesExtracted%100 == 0 {
- logger.Debug("Extracted %d files...", filesExtracted)
- }
- }
-
- // Write detailed log
- if logFile != nil {
- fmt.Fprintf(logFile, "=== FILES RESTORED ===\n")
- if restoredTemp != nil {
- if _, err := restoredTemp.Seek(0, 0); err == nil {
- if _, err := io.Copy(logFile, restoredTemp); err != nil {
- logger.Warning("Could not write restored entries to log: %v", err)
- }
- }
- }
- fmt.Fprintf(logFile, "\n")
-
- fmt.Fprintf(logFile, "=== FILES SKIPPED ===\n")
- if skippedTemp != nil {
- if _, err := skippedTemp.Seek(0, 0); err == nil {
- if _, err := io.Copy(logFile, skippedTemp); err != nil {
- logger.Warning("Could not write skipped entries to log: %v", err)
- }
- }
- }
- fmt.Fprintf(logFile, "\n")
-
- fmt.Fprintf(logFile, "=== SUMMARY ===\n")
- fmt.Fprintf(logFile, "Total files extracted: %d\n", filesExtracted)
- fmt.Fprintf(logFile, "Total files skipped: %d\n", filesSkipped)
- fmt.Fprintf(logFile, "Total files in archive: %d\n", filesExtracted+filesSkipped)
- }
-
- if filesFailed == 0 {
- if selectiveMode {
- logger.Info("Successfully restored all %d configuration files/directories", filesExtracted)
- } else {
- logger.Info("Successfully restored all %d files/directories", filesExtracted)
- }
- } else {
- logger.Warning("Restored %d files/directories; %d item(s) failed (see detailed log)", filesExtracted, filesFailed)
- }
-
- if filesSkipped > 0 {
- logger.Info("%d additional archive entries (logs, diagnostics, system defaults) were left unchanged on this system; see detailed log for details", filesSkipped)
- }
-
- if logFilePath != "" {
- logger.Info("Detailed restore log: %s", logFilePath)
- }
-
- return nil
-}
-
-func isRealRestoreFS(fs FS) bool {
- switch fs.(type) {
- case osFS, *osFS:
- return true
- default:
- return false
- }
-}
-
-// createDecompressionReader creates appropriate decompression reader based on file extension
-func createDecompressionReader(ctx context.Context, file *os.File, archivePath string) (io.Reader, error) {
- switch {
- case strings.HasSuffix(archivePath, ".tar.gz") || strings.HasSuffix(archivePath, ".tgz"):
- return gzip.NewReader(file)
- case strings.HasSuffix(archivePath, ".tar.xz"):
- return createXZReader(ctx, file)
- case strings.HasSuffix(archivePath, ".tar.zst") || strings.HasSuffix(archivePath, ".tar.zstd"):
- return createZstdReader(ctx, file)
- case strings.HasSuffix(archivePath, ".tar.bz2"):
- return createBzip2Reader(ctx, file)
- case strings.HasSuffix(archivePath, ".tar.lzma"):
- return createLzmaReader(ctx, file)
- case strings.HasSuffix(archivePath, ".tar"):
- return file, nil
- default:
- return nil, fmt.Errorf("unsupported archive format: %s", filepath.Base(archivePath))
- }
-}
-
-// createXZReader creates an XZ decompression reader using injectable command runner
-func createXZReader(ctx context.Context, file *os.File) (io.Reader, error) {
- return runRestoreCommandStream(ctx, "xz", file, "-d", "-c")
-}
-
-// createZstdReader creates a Zstd decompression reader using injectable command runner
-func createZstdReader(ctx context.Context, file *os.File) (io.Reader, error) {
- return runRestoreCommandStream(ctx, "zstd", file, "-d", "-c")
-}
-
-// createBzip2Reader creates a Bzip2 decompression reader using injectable command runner
-func createBzip2Reader(ctx context.Context, file *os.File) (io.Reader, error) {
- return runRestoreCommandStream(ctx, "bzip2", file, "-d", "-c")
-}
-
-// createLzmaReader creates an LZMA decompression reader using injectable command runner
-func createLzmaReader(ctx context.Context, file *os.File) (io.Reader, error) {
- return runRestoreCommandStream(ctx, "lzma", file, "-d", "-c")
-}
-
-// runRestoreCommandStream starts a command that reads from stdin and exposes stdout as a ReadCloser.
-// It prefers an injectable streaming runner when available; otherwise falls back to exec.CommandContext.
-func runRestoreCommandStream(ctx context.Context, name string, stdin io.Reader, args ...string) (io.Reader, error) {
- type streamingRunner interface {
- RunStream(ctx context.Context, name string, stdin io.Reader, args ...string) (io.ReadCloser, error)
- }
- if sr, ok := restoreCmd.(streamingRunner); ok && sr != nil {
- return sr.RunStream(ctx, name, stdin, args...)
- }
-
- cmd := exec.CommandContext(ctx, name, args...)
- cmd.Stdin = stdin
- stdout, err := cmd.StdoutPipe()
- if err != nil {
- return nil, fmt.Errorf("create %s pipe: %w", name, err)
- }
- if err := cmd.Start(); err != nil {
- stdout.Close()
- return nil, fmt.Errorf("start %s: %w", name, err)
- }
- return &waitReadCloser{ReadCloser: stdout, wait: cmd.Wait}, nil
-}
-
-func sanitizeRestoreEntryTarget(destRoot, entryName string) (string, string, error) {
- return sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, entryName)
-}
-
-func sanitizeRestoreEntryTargetWithFS(fsys FS, destRoot, entryName string) (string, string, error) {
- cleanDestRoot := filepath.Clean(destRoot)
- if cleanDestRoot == "" {
- cleanDestRoot = string(os.PathSeparator)
- }
-
- absDestRoot, err := filepath.Abs(cleanDestRoot)
- if err != nil {
- return "", "", fmt.Errorf("resolve destination root: %w", err)
- }
-
- name := strings.TrimSpace(entryName)
- if name == "" {
- return "", "", fmt.Errorf("empty archive entry name")
- }
-
- sanitized := path.Clean(name)
- for strings.HasPrefix(sanitized, string(os.PathSeparator)) {
- sanitized = strings.TrimPrefix(sanitized, string(os.PathSeparator))
- }
-
- if sanitized == "" || sanitized == "." {
- return "", "", fmt.Errorf("invalid archive entry name: %q", entryName)
- }
-
- if sanitized == ".." || strings.HasPrefix(sanitized, "../") || strings.Contains(sanitized, "/../") {
- return "", "", fmt.Errorf("illegal path: %s", entryName)
- }
-
- target := filepath.Join(absDestRoot, filepath.FromSlash(sanitized))
- absTarget, err := filepath.Abs(target)
- if err != nil {
- return "", "", fmt.Errorf("resolve extraction target: %w", err)
- }
-
- rel, err := filepath.Rel(absDestRoot, absTarget)
- if err != nil {
- return "", "", fmt.Errorf("illegal path: %s: %w", entryName, err)
- }
- if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." || filepath.IsAbs(rel) {
- return "", "", fmt.Errorf("illegal path: %s", entryName)
- }
-
- if _, err := resolvePathWithinRootFS(fsys, absDestRoot, absTarget); err != nil {
- if isPathSecurityError(err) {
- return "", "", fmt.Errorf("illegal path: %s: %w", entryName, err)
- }
- if !isPathOperationalError(err) {
- return "", "", fmt.Errorf("resolve extraction target: %w", err)
- }
- }
-
- return absTarget, absDestRoot, nil
-}
-
-func shouldSkipProxmoxSystemRestore(relTarget string) (bool, string) {
- rel := filepath.ToSlash(filepath.Clean(strings.TrimSpace(relTarget)))
- rel = strings.TrimPrefix(rel, "./")
- rel = strings.TrimPrefix(rel, "/")
-
- switch rel {
- case "etc/proxmox-backup/domains.cfg":
- return true, "PBS auth realms must be recreated (domains.cfg is too fragile to restore raw)"
- case "etc/proxmox-backup/user.cfg":
- return true, "PBS users must be recreated (user.cfg should not be restored raw)"
- case "etc/proxmox-backup/acl.cfg":
- return true, "PBS permissions must be recreated (acl.cfg should not be restored raw)"
- case "var/lib/proxmox-backup/.clusterlock":
- return true, "PBS runtime lock files must not be restored"
- }
-
- if strings.HasPrefix(rel, "var/lib/proxmox-backup/lock/") {
- return true, "PBS runtime lock files must not be restored"
- }
-
- return false, ""
-}
-
-// extractTarEntry extracts a single TAR entry, preserving all attributes including atime/ctime
-func extractTarEntry(tarReader *tar.Reader, header *tar.Header, destRoot string, logger *logging.Logger) error {
- target, cleanDestRoot, err := sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, header.Name)
- if err != nil {
- return err
- }
-
- // Hard guard: never write directly into /etc/pve when restoring to system root
- if cleanDestRoot == string(os.PathSeparator) && strings.HasPrefix(target, "/etc/pve") {
- logger.Warning("Skipping restore to %s (writes to /etc/pve are prohibited)", target)
- return nil
- }
-
- if cleanDestRoot == string(os.PathSeparator) {
- relTarget, err := filepath.Rel(cleanDestRoot, target)
- if err != nil {
- return fmt.Errorf("determine restore target for %s: %w", header.Name, err)
- }
- if skip, reason := shouldSkipProxmoxSystemRestore(relTarget); skip {
- logger.Warning("Skipping restore to %s (%s)", target, reason)
- return nil
- }
- }
-
- // Create parent directories
- if err := restoreFS.MkdirAll(filepath.Dir(target), 0755); err != nil {
- return fmt.Errorf("create parent directory: %w", err)
- }
-
- switch header.Typeflag {
- case tar.TypeDir:
- return extractDirectory(target, header, logger)
- case tar.TypeReg:
- return extractRegularFile(tarReader, target, header, logger)
- case tar.TypeSymlink:
- return extractSymlink(target, header, cleanDestRoot, logger)
- case tar.TypeLink:
- return extractHardlink(target, header, cleanDestRoot)
- default:
- logger.Debug("Skipping unsupported file type %d: %s", header.Typeflag, header.Name)
- return nil
- }
-}
-
-// extractDirectory creates a directory with proper permissions and timestamps
-func extractDirectory(target string, header *tar.Header, logger *logging.Logger) (retErr error) {
- // Create with an owner-accessible mode first so the directory can be opened
- // before applying restrictive archive permissions.
- if err := restoreFS.MkdirAll(target, 0o700); err != nil {
- return fmt.Errorf("create directory: %w", err)
- }
-
- dirFile, err := restoreFS.Open(target)
- if err != nil {
- return fmt.Errorf("open directory: %w", err)
- }
- defer func() {
- if dirFile == nil {
- return
- }
- if err := dirFile.Close(); err != nil && retErr == nil {
- retErr = fmt.Errorf("close directory: %w", err)
- }
- }()
-
- // Apply metadata on the opened directory handle so logical FS paths
- // (e.g. FakeFS-backed test roots) do not leak through to host paths.
- // Ownership remains best-effort to match the previous restore behavior on
- // unprivileged runs and filesystems that do not support chown.
- if err := atomicFileChown(dirFile, header.Uid, header.Gid); err != nil {
- logger.Debug("Failed to chown directory %s: %v", target, err)
- }
- if err := atomicFileChmod(dirFile, os.FileMode(header.Mode)); err != nil {
- return fmt.Errorf("chmod directory: %w", err)
- }
-
- // Set timestamps (mtime, atime)
- if err := setTimestamps(target, header); err != nil {
- logger.Debug("Failed to set timestamps on directory %s: %v", target, err)
- }
-
- return nil
-}
-
-// extractRegularFile extracts a regular file with content and timestamps
-func extractRegularFile(tarReader *tar.Reader, target string, header *tar.Header, logger *logging.Logger) (retErr error) {
- tmpPath := ""
- var outFile *os.File
- appendDeferredErr := func(prefix string, err error) {
- if err == nil {
- return
- }
- wrapped := fmt.Errorf("%s: %w", prefix, err)
- if retErr == nil {
- retErr = wrapped
- return
- }
- retErr = errors.Join(retErr, wrapped)
- }
- closeOutFile := func() error {
- if outFile == nil {
- return nil
- }
- err := outFile.Close()
- outFile = nil
- return err
- }
-
- // Write to a sibling temp file first so a truncated archive entry cannot clobber
- // an existing target before the content is fully copied and closed.
- outFile, err := restoreFS.CreateTemp(filepath.Dir(target), restoreTempPattern)
- if err != nil {
- return fmt.Errorf("create file: %w", err)
- }
- tmpPath = outFile.Name()
- defer func() {
- appendDeferredErr("close file", closeOutFile())
- if tmpPath != "" {
- if err := restoreFS.Remove(tmpPath); err != nil && logger != nil {
- logger.Debug("Failed to remove temp file %s: %v", tmpPath, err)
- }
- }
- }()
-
- // Copy content
- if _, err := io.Copy(outFile, tarReader); err != nil {
- return fmt.Errorf("write file content: %w", err)
- }
-
- // Set metadata on the temp file before replacing the target so failures do not
- // leave the final path in a partially restored state.
- // Ownership remains best-effort to match the previous restore behavior on
- // unprivileged runs and filesystems that do not support chown.
- if err := atomicFileChown(outFile, header.Uid, header.Gid); err != nil {
- logger.Debug("Failed to chown file %s: %v", target, err)
- }
- if err := atomicFileChmod(outFile, os.FileMode(header.Mode)); err != nil {
- return fmt.Errorf("chmod file: %w", err)
- }
-
- // Close before renaming into place.
- if err := closeOutFile(); err != nil {
- return fmt.Errorf("close file: %w", err)
- }
-
- if err := restoreFS.Rename(tmpPath, target); err != nil {
- return fmt.Errorf("replace file: %w", err)
- }
- tmpPath = ""
-
- // Set timestamps (mtime, atime, ctime via syscall)
- if err := setTimestamps(target, header); err != nil {
- logger.Debug("Failed to set timestamps on file %s: %v", target, err)
- }
-
- return nil
-}
-
-// extractSymlink creates a symbolic link
-func extractSymlink(target string, header *tar.Header, destRoot string, logger *logging.Logger) error {
- linkTarget := header.Linkname
-
- // Pre-validation: ensure the symlink target resolves within destRoot before creation.
- if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), linkTarget); err != nil {
- return fmt.Errorf("symlink target escapes root before creation: %s -> %s: %w", header.Name, linkTarget, err)
- }
-
- // Remove existing file/link if it exists
- _ = restoreFS.Remove(target)
-
- // Create symlink
- if err := restoreFS.Symlink(linkTarget, target); err != nil {
- return fmt.Errorf("create symlink: %w", err)
- }
-
- // POST-CREATION VALIDATION: Verify the created symlink's target stays within destRoot
- actualTarget, err := restoreFS.Readlink(target)
- if err != nil {
- restoreFS.Remove(target) // Clean up
- return fmt.Errorf("read created symlink %s: %w", target, err)
- }
-
- if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), actualTarget); err != nil {
- restoreFS.Remove(target)
- return fmt.Errorf("symlink target escapes root after creation: %s -> %s: %w", header.Name, actualTarget, err)
- }
-
- // Set ownership (on the symlink itself, not the target)
- if err := os.Lchown(target, header.Uid, header.Gid); err != nil {
- logger.Debug("Failed to lchown symlink %s: %v", target, err)
- }
-
- // Note: timestamps on symlinks are not typically preserved
- return nil
-}
-
-// extractHardlink creates a hard link
-func extractHardlink(target string, header *tar.Header, destRoot string) error {
- // Validate hard link target
- linkName := header.Linkname
-
- // Reject absolute hard link targets immediately
- if filepath.IsAbs(linkName) {
- return fmt.Errorf("absolute hardlink target not allowed: %s", linkName)
- }
-
- // Validate the hard link target stays within extraction root
- if _, err := resolvePathWithinRootFS(restoreFS, destRoot, linkName); err != nil {
- return fmt.Errorf("hardlink target escapes root: %s -> %s: %w", header.Name, linkName, err)
- }
-
- linkTarget := filepath.Join(destRoot, linkName)
-
- // Remove existing file/link if it exists
- _ = restoreFS.Remove(target)
-
- // Create hard link
- if err := restoreFS.Link(linkTarget, target); err != nil {
- return fmt.Errorf("create hardlink: %w", err)
- }
-
- return nil
-}
-
-// setTimestamps sets atime, mtime, and attempts to set ctime via syscall
-func setTimestamps(target string, header *tar.Header) error {
- // Convert times to Unix format
- atime := header.AccessTime
- mtime := header.ModTime
-
- // Use syscall.UtimesNano to set atime and mtime with nanosecond precision
- times := []syscall.Timespec{
- {Sec: atime.Unix(), Nsec: int64(atime.Nanosecond())},
- {Sec: mtime.Unix(), Nsec: int64(mtime.Nanosecond())},
- }
-
- if err := syscall.UtimesNano(target, times); err != nil {
- return fmt.Errorf("set atime/mtime: %w", err)
- }
-
- // Note: ctime (change time) cannot be set directly by user-space programs
- // It is automatically updated by the kernel when file metadata changes
- // The header.ChangeTime is stored in PAX but cannot be restored
-
- return nil
-}
-
-// getModeName returns a human-readable name for the restore mode
-func getModeName(mode RestoreMode) string {
- switch mode {
- case RestoreModeFull:
- return "FULL restore (all files)"
- case RestoreModeStorage:
- return "STORAGE/DATASTORE only"
- case RestoreModeBase:
- return "SYSTEM BASE only"
- case RestoreModeCustom:
- return "CUSTOM selection"
- default:
- return "Unknown mode"
- }
-}
diff --git a/internal/orchestrator/restore_access_control_ui.go b/internal/orchestrator/restore_access_control_ui.go
index db96a8c6..82bf9d2f 100644
--- a/internal/orchestrator/restore_access_control_ui.go
+++ b/internal/orchestrator/restore_access_control_ui.go
@@ -402,8 +402,8 @@ func armAccessControlRollback(ctx context.Context, logger *logging.Logger, backu
}
if handle.unitName == "" {
- cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", timeoutSeconds, handle.scriptPath)
- if output, err := restoreCmd.Run(ctx, "sh", "-c", cmd); err != nil {
+ output, err := runBackgroundRollbackTimer(ctx, timeoutSeconds, handle.scriptPath)
+ if err != nil {
logger.Debug("Background rollback output: %s", strings.TrimSpace(string(output)))
return nil, fmt.Errorf("failed to schedule rollback timer: %w", err)
}
diff --git a/internal/orchestrator/restore_access_control_ui_additional_test.go b/internal/orchestrator/restore_access_control_ui_additional_test.go
index e8427514..8252d1f0 100644
--- a/internal/orchestrator/restore_access_control_ui_additional_test.go
+++ b/internal/orchestrator/restore_access_control_ui_additional_test.go
@@ -224,8 +224,7 @@ func TestArmAccessControlRollback_SystemdAndBackgroundPaths(t *testing.T) {
t.Fatalf("expected unitName cleared after systemd-run failure, got %q", handle.unitName)
}
- cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 2, scriptPath)
- wantBackground := "sh -c " + cmd
+ wantBackground := backgroundRollbackCallKey(2, scriptPath)
calls := fakeCmd.CallsList()
if len(calls) != 2 || calls[1] != wantBackground {
t.Fatalf("unexpected calls: %#v", calls)
@@ -243,8 +242,7 @@ func TestArmAccessControlRollback_SystemdAndBackgroundPaths(t *testing.T) {
timestamp := fakeTime.Current.Format("20060102_150405")
scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("access_control_rollback_%s.sh", timestamp))
- cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 1, scriptPath)
- backgroundKey := "sh -c " + cmd
+ backgroundKey := backgroundRollbackCallKey(1, scriptPath)
fakeCmd.Errors[backgroundKey] = fmt.Errorf("boom")
if _, err := armAccessControlRollback(context.Background(), logger, "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil {
@@ -290,7 +288,7 @@ func TestArmAccessControlRollback_DefaultWorkDirAndMinTimeout(t *testing.T) {
if len(calls) != 1 {
t.Fatalf("unexpected calls: %#v", calls)
}
- if !strings.Contains(calls[0], "sleep 1; /bin/sh") {
+ if calls[0] != backgroundRollbackCallKey(1, handle.scriptPath) {
t.Fatalf("expected timeoutSeconds to clamp to 1, got call=%q", calls[0])
}
}
diff --git a/internal/orchestrator/restore_archive.go b/internal/orchestrator/restore_archive.go
new file mode 100644
index 00000000..363b7387
--- /dev/null
+++ b/internal/orchestrator/restore_archive.go
@@ -0,0 +1,316 @@
+// Package orchestrator coordinates backup, restore, decrypt, and related workflows.
+package orchestrator
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync/atomic"
+
+ "github.com/tis24dev/proxsave/internal/input"
+ "github.com/tis24dev/proxsave/internal/logging"
+)
+
+var restoreLogSequence uint64
+
+func shouldRecreateDirectories(systemType SystemType, categories []Category) bool {
+ return (systemType.SupportsPVE() && hasCategoryID(categories, "storage_pve")) ||
+ (systemType.SupportsPBS() && hasCategoryID(categories, "datastore_pbs"))
+}
+
+func hasCategoryID(categories []Category, id string) bool {
+ for _, cat := range categories {
+ if cat.ID == id {
+ return true
+ }
+ }
+ return false
+}
+
+// shouldStopPBSServices reports whether any selected categories belong to PBS-specific configuration
+// and therefore require stopping PBS services before restore.
+func shouldStopPBSServices(categories []Category) bool {
+ for _, cat := range categories {
+ if cat.Type == CategoryTypePBS {
+ return true
+ }
+ // Some common categories (e.g. SSL) include PBS paths that require restarting PBS services.
+ for _, p := range cat.Paths {
+ p = strings.TrimSpace(p)
+ if strings.HasPrefix(p, "./etc/proxmox-backup/") || strings.HasPrefix(p, "./var/lib/proxmox-backup/") {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+func splitExportCategories(categories []Category) (normal []Category, export []Category) {
+ for _, cat := range categories {
+ if cat.ExportOnly {
+ export = append(export, cat)
+ continue
+ }
+ normal = append(normal, cat)
+ }
+ return normal, export
+}
+
+// redirectClusterCategoryToExport removes pve_cluster from normal categories and adds it to export-only list.
+func redirectClusterCategoryToExport(normal []Category, export []Category) ([]Category, []Category) {
+ filtered := make([]Category, 0, len(normal))
+ for _, cat := range normal {
+ if cat.ID == "pve_cluster" {
+ export = append(export, cat)
+ continue
+ }
+ filtered = append(filtered, cat)
+ }
+ return filtered, export
+}
+
+func exportDestRoot(baseDir string) string {
+ base := strings.TrimSpace(baseDir)
+ if base == "" {
+ base = "/opt/proxsave"
+ }
+ return filepath.Join(base, fmt.Sprintf("proxmox-config-export-%s", nowRestore().Format("20060102-150405")))
+}
+
+// runFullRestore performs a full restore without selective options (fallback)
+func runFullRestore(ctx context.Context, reader *bufio.Reader, candidate *backupCandidate, prepared *preparedBundle, destRoot string, logger *logging.Logger, dryRun bool) error {
+ if err := confirmRestoreAction(ctx, reader, candidate, destRoot); err != nil {
+ return err
+ }
+
+ safeFstabMerge := destRoot == "/" && isRealRestoreFS(restoreFS)
+ if safeFstabMerge {
+ logger.Warning("Full restore safety: /etc/fstab will not be overwritten; Smart Merge will be applied after extraction.")
+ }
+
+ if err := extractPlainArchive(ctx, prepared.ArchivePath, destRoot, logger, fullRestoreSkipFn(safeFstabMerge)); err != nil {
+ return err
+ }
+
+ if safeFstabMerge {
+ if err := runFullRestoreFstabMerge(ctx, reader, prepared.ArchivePath, destRoot, logger, dryRun); err != nil {
+ return err
+ }
+ }
+
+ logger.Info("Restore completed successfully.")
+ return nil
+}
+
+func fullRestoreSkipFn(safeFstabMerge bool) func(name string) bool {
+ return func(name string) bool {
+ if !safeFstabMerge {
+ return false
+ }
+ clean := strings.TrimPrefix(strings.TrimSpace(name), "./")
+ clean = strings.TrimPrefix(clean, "/")
+ return clean == "etc/fstab"
+ }
+}
+
+func runFullRestoreFstabMerge(ctx context.Context, reader *bufio.Reader, archivePath, destRoot string, logger *logging.Logger, dryRun bool) error {
+ logger.Info("")
+ fsTempDir, err := restoreFS.MkdirTemp("", "proxsave-fstab-")
+ if err != nil {
+ logger.Warning("Failed to create temp dir for fstab merge: %v", err)
+ return nil
+ }
+ defer restoreFS.RemoveAll(fsTempDir)
+
+ if err := extractFullRestoreFstab(ctx, archivePath, fsTempDir, logger); err != nil {
+ logger.Warning("Failed to extract filesystem config for merge: %v", err)
+ return nil
+ }
+ extractFullRestoreFstabInventory(ctx, archivePath, fsTempDir, logger)
+ currentFstab := filepath.Join(destRoot, "etc", "fstab")
+ backupFstab := filepath.Join(fsTempDir, "etc", "fstab")
+ if err := SmartMergeFstab(ctx, logger, reader, currentFstab, backupFstab, dryRun); err != nil {
+ if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) {
+ logger.Info("Restore aborted by user during Smart Filesystem Configuration Merge.")
+ return err
+ }
+ logger.Warning("Smart Fstab Merge failed: %v", err)
+ }
+ return nil
+}
+
+func extractFullRestoreFstab(ctx context.Context, archivePath, fsTempDir string, logger *logging.Logger) error {
+ return extractArchiveNative(ctx, restoreArchiveOptions{
+ archivePath: archivePath,
+ destRoot: fsTempDir,
+ logger: logger,
+ categories: []Category{{
+ ID: "filesystem",
+ Name: "Filesystem Configuration",
+ Paths: []string{"./etc/fstab"},
+ }},
+ mode: RestoreModeCustom,
+ })
+}
+
+func extractFullRestoreFstabInventory(ctx context.Context, archivePath, fsTempDir string, logger *logging.Logger) {
+ invCategory := []Category{{
+ ID: "fstab_inventory",
+ Name: "Fstab inventory (device mapping)",
+ Paths: []string{
+ "./var/lib/proxsave-info/commands/system/blkid.txt",
+ "./var/lib/proxsave-info/commands/system/lsblk_json.json",
+ "./var/lib/proxsave-info/commands/system/lsblk.txt",
+ "./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json",
+ },
+ }}
+ if err := extractArchiveNative(ctx, restoreArchiveOptions{
+ archivePath: archivePath,
+ destRoot: fsTempDir,
+ logger: logger,
+ categories: invCategory,
+ mode: RestoreModeCustom,
+ }); err != nil {
+ logger.Debug("Failed to extract fstab inventory data (continuing): %v", err)
+ }
+}
+
+func confirmRestoreAction(ctx context.Context, reader *bufio.Reader, cand *backupCandidate, dest string) error {
+ manifest := cand.Manifest
+ fmt.Println()
+ fmt.Printf("Selected backup: %s (%s)\n", cand.DisplayBase, manifest.CreatedAt.Format("2006-01-02 15:04:05"))
+ cleanDest := filepath.Clean(strings.TrimSpace(dest))
+ if cleanDest == "" || cleanDest == "." {
+ cleanDest = string(os.PathSeparator)
+ }
+ if cleanDest == string(os.PathSeparator) {
+ fmt.Println("Restore destination: / (system root; original paths will be preserved)")
+ fmt.Println("WARNING: This operation will overwrite configuration files on this system.")
+ } else {
+ fmt.Printf("Restore destination: %s (original paths will be preserved under this directory)\n", cleanDest)
+ fmt.Printf("WARNING: This operation will overwrite existing files under %s.\n", cleanDest)
+ }
+ fmt.Println("Type RESTORE to proceed or 0 to cancel.")
+
+ for {
+ fmt.Print("Confirmation: ")
+ line, err := input.ReadLineWithContext(ctx, reader)
+ if err != nil {
+ return err
+ }
+ switch strings.TrimSpace(line) {
+ case "RESTORE":
+ return nil
+ case "0":
+ return ErrRestoreAborted
+ default:
+ fmt.Println("Please type RESTORE to confirm or 0 to cancel.")
+ }
+ }
+}
+
+func extractPlainArchive(ctx context.Context, archivePath, destRoot string, logger *logging.Logger, skipFn func(entryName string) bool) error {
+ if err := restoreFS.MkdirAll(destRoot, 0o755); err != nil {
+ return fmt.Errorf("create destination directory: %w", err)
+ }
+
+ // Only enforce root privileges when writing to the real system root.
+ if destRoot == "/" && isRealRestoreFS(restoreFS) && os.Geteuid() != 0 {
+ return fmt.Errorf("restore to %s requires root privileges", destRoot)
+ }
+
+ logger.Info("Extracting archive %s into %s", filepath.Base(archivePath), destRoot)
+
+ // Use native Go extraction to preserve atime/ctime from PAX headers
+ if err := extractArchiveNative(ctx, restoreArchiveOptions{
+ archivePath: archivePath,
+ destRoot: destRoot,
+ logger: logger,
+ mode: RestoreModeFull,
+ skipFn: skipFn,
+ }); err != nil {
+ return fmt.Errorf("archive extraction failed: %w", err)
+ }
+
+ return nil
+}
+
+// extractSelectiveArchive extracts only files matching selected categories
+func extractSelectiveArchive(ctx context.Context, archivePath, destRoot string, categories []Category, mode RestoreMode, logger *logging.Logger) (logPath string, err error) {
+ done := logging.DebugStart(logger, "extract selective archive", "archive=%s dest=%s categories=%d mode=%s", archivePath, destRoot, len(categories), mode)
+ defer func() { done(err) }()
+ if err := restoreFS.MkdirAll(destRoot, 0o755); err != nil {
+ return "", fmt.Errorf("create destination directory: %w", err)
+ }
+
+ // Only enforce root privileges when writing to the real system root.
+ if destRoot == "/" && isRealRestoreFS(restoreFS) && os.Geteuid() != 0 {
+ return "", fmt.Errorf("restore to %s requires root privileges", destRoot)
+ }
+
+ // Create detailed log directory
+ logDir := "/tmp/proxsave"
+ if err := restoreFS.MkdirAll(logDir, 0o755); err != nil {
+ logger.Warning("Could not create log directory: %v", err)
+ }
+
+ // Create detailed log file
+ timestamp := nowRestore().Format("20060102_150405")
+ logSeq := atomic.AddUint64(&restoreLogSequence, 1)
+ logPath = filepath.Join(logDir, fmt.Sprintf("restore_%s_%d.log", timestamp, logSeq))
+ logFile, err := restoreFS.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640)
+ if err != nil {
+ logger.Warning("Could not create detailed log file: %v", err)
+ logFile = nil
+ } else {
+ defer logFile.Close()
+ logger.Info("Detailed restore log: %s", logPath)
+ logging.DebugStep(logger, "extract selective archive", "log file=%s", logPath)
+ }
+
+ logger.Info("Extracting selected categories from archive %s into %s", filepath.Base(archivePath), destRoot)
+
+ // Use native Go extraction with category filter
+ if err := extractArchiveNative(ctx, restoreArchiveOptions{
+ archivePath: archivePath,
+ destRoot: destRoot,
+ logger: logger,
+ categories: categories,
+ mode: mode,
+ logFile: logFile,
+ logFilePath: logPath,
+ }); err != nil {
+ return logPath, err
+ }
+
+ return logPath, nil
+}
+
+func isRealRestoreFS(fs FS) bool {
+ switch fs.(type) {
+ case osFS, *osFS:
+ return true
+ default:
+ return false
+ }
+}
+
+// getModeName returns a human-readable name for the restore mode
+func getModeName(mode RestoreMode) string {
+ switch mode {
+ case RestoreModeFull:
+ return "FULL restore (all files)"
+ case RestoreModeStorage:
+ return "STORAGE/DATASTORE only"
+ case RestoreModeBase:
+ return "SYSTEM BASE only"
+ case RestoreModeCustom:
+ return "CUSTOM selection"
+ default:
+ return "Unknown mode"
+ }
+}
diff --git a/internal/orchestrator/restore_archive_additional_test.go b/internal/orchestrator/restore_archive_additional_test.go
new file mode 100644
index 00000000..f8fda96d
--- /dev/null
+++ b/internal/orchestrator/restore_archive_additional_test.go
@@ -0,0 +1,40 @@
+package orchestrator
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestExtractPlainArchiveHonorsSkipFn(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+ restoreFS = osFS{}
+
+ tmpDir := t.TempDir()
+ archivePath := filepath.Join(tmpDir, "backup.tar")
+ if err := writeTarFile(archivePath, map[string]string{
+ "etc/fstab": "/dev/sda1 / ext4 defaults 0 1\n",
+ "etc/hosts": "127.0.0.1 localhost\n",
+ }); err != nil {
+ t.Fatalf("write archive: %v", err)
+ }
+
+ destRoot := filepath.Join(tmpDir, "restore")
+ if err := extractPlainArchive(context.Background(), archivePath, destRoot, newTestLogger(), fullRestoreSkipFn(true)); err != nil {
+ t.Fatalf("extractPlainArchive error: %v", err)
+ }
+
+ if _, err := os.Stat(filepath.Join(destRoot, "etc", "fstab")); !os.IsNotExist(err) {
+ t.Fatalf("expected skipped fstab to be absent, stat err=%v", err)
+ }
+
+ hosts, err := os.ReadFile(filepath.Join(destRoot, "etc", "hosts"))
+ if err != nil {
+ t.Fatalf("expected hosts to be extracted: %v", err)
+ }
+ if string(hosts) != "127.0.0.1 localhost\n" {
+ t.Fatalf("hosts content=%q", string(hosts))
+ }
+}
diff --git a/internal/orchestrator/restore_archive_entries.go b/internal/orchestrator/restore_archive_entries.go
new file mode 100644
index 00000000..0411c84b
--- /dev/null
+++ b/internal/orchestrator/restore_archive_entries.go
@@ -0,0 +1,281 @@
+// Package orchestrator coordinates backup, restore, decrypt, and related workflows.
+package orchestrator
+
+import (
+ "archive/tar"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+ "syscall"
+
+ "github.com/tis24dev/proxsave/internal/logging"
+)
+
+const restoreTempPattern = ".proxsave-tmp-*"
+
+// extractTarEntry extracts a single TAR entry, preserving all attributes including atime/ctime
+func extractTarEntry(tarReader *tar.Reader, header *tar.Header, destRoot string, logger *logging.Logger) error {
+ target, cleanDestRoot, err := sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, header.Name)
+ if err != nil {
+ return err
+ }
+
+ skip, err := shouldSkipRestoreEntryTarget(header, target, cleanDestRoot, logger)
+ if err != nil {
+ return err
+ }
+ if skip {
+ return nil
+ }
+
+ // Create parent directories
+ if err := restoreFS.MkdirAll(filepath.Dir(target), 0755); err != nil {
+ return fmt.Errorf("create parent directory: %w", err)
+ }
+
+ return extractTypedTarEntry(tarReader, header, target, cleanDestRoot, logger)
+}
+
+func shouldSkipRestoreEntryTarget(header *tar.Header, target, cleanDestRoot string, logger *logging.Logger) (bool, error) {
+ if cleanDestRoot != string(os.PathSeparator) {
+ return false, nil
+ }
+ // Hard guard: never write directly into /etc/pve when restoring to system root
+ if target == "/etc/pve" || strings.HasPrefix(target, "/etc/pve/") {
+ logger.Warning("Skipping restore to %s (writes to /etc/pve are prohibited)", target)
+ return true, nil
+ }
+ relTarget, err := filepath.Rel(cleanDestRoot, target)
+ if err != nil {
+ return false, fmt.Errorf("determine restore target for %s: %w", header.Name, err)
+ }
+ if skip, reason := shouldSkipProxmoxSystemRestore(relTarget); skip {
+ logger.Warning("Skipping restore to %s (%s)", target, reason)
+ return true, nil
+ }
+ return false, nil
+}
+
+func extractTypedTarEntry(tarReader *tar.Reader, header *tar.Header, target, cleanDestRoot string, logger *logging.Logger) error {
+ switch header.Typeflag {
+ case tar.TypeDir:
+ return extractDirectory(target, header, logger)
+ case tar.TypeReg:
+ return extractRegularFile(tarReader, target, header, logger)
+ case tar.TypeSymlink:
+ return extractSymlink(target, header, cleanDestRoot, logger)
+ case tar.TypeLink:
+ return extractHardlink(target, header, cleanDestRoot)
+ default:
+ logger.Debug("Skipping unsupported file type %d: %s", header.Typeflag, header.Name)
+ return nil
+ }
+}
+
+// extractDirectory creates a directory with proper permissions and timestamps
+func extractDirectory(target string, header *tar.Header, logger *logging.Logger) (retErr error) {
+ // Create with an owner-accessible mode first so the directory can be opened
+ // before applying restrictive archive permissions.
+ if err := restoreFS.MkdirAll(target, 0o700); err != nil {
+ return fmt.Errorf("create directory: %w", err)
+ }
+
+ dirFile, err := restoreFS.Open(target)
+ if err != nil {
+ return fmt.Errorf("open directory: %w", err)
+ }
+ defer func() {
+ if dirFile == nil {
+ return
+ }
+ if err := dirFile.Close(); err != nil && retErr == nil {
+ retErr = fmt.Errorf("close directory: %w", err)
+ }
+ }()
+
+ // Apply metadata on the opened directory handle so logical FS paths
+ // (e.g. FakeFS-backed test roots) do not leak through to host paths.
+ // Ownership remains best-effort to match the previous restore behavior on
+ // unprivileged runs and filesystems that do not support chown.
+ if err := atomicFileChown(dirFile, header.Uid, header.Gid); err != nil {
+ logger.Debug("Failed to chown directory %s: %v", target, err)
+ }
+ if err := atomicFileChmod(dirFile, os.FileMode(header.Mode)); err != nil {
+ return fmt.Errorf("chmod directory: %w", err)
+ }
+
+ // Set timestamps (mtime, atime)
+ if err := setTimestamps(target, header); err != nil {
+ logger.Debug("Failed to set timestamps on directory %s: %v", target, err)
+ }
+
+ return nil
+}
+
+// extractRegularFile extracts a regular file with content and timestamps
+func extractRegularFile(tarReader *tar.Reader, target string, header *tar.Header, logger *logging.Logger) (retErr error) {
+ tmpPath := ""
+ var outFile *os.File
+ appendDeferredErr := func(prefix string, err error) {
+ if err == nil {
+ return
+ }
+ wrapped := fmt.Errorf("%s: %w", prefix, err)
+ if retErr == nil {
+ retErr = wrapped
+ return
+ }
+ retErr = errors.Join(retErr, wrapped)
+ }
+ closeOutFile := func() error {
+ if outFile == nil {
+ return nil
+ }
+ err := outFile.Close()
+ outFile = nil
+ return err
+ }
+
+ // Write to a sibling temp file first so a truncated archive entry cannot clobber
+ // an existing target before the content is fully copied and closed.
+ outFile, err := restoreFS.CreateTemp(filepath.Dir(target), restoreTempPattern)
+ if err != nil {
+ return fmt.Errorf("create file: %w", err)
+ }
+ tmpPath = outFile.Name()
+ defer func() {
+ appendDeferredErr("close file", closeOutFile())
+ if tmpPath != "" {
+ if err := restoreFS.Remove(tmpPath); err != nil && logger != nil {
+ logger.Debug("Failed to remove temp file %s: %v", tmpPath, err)
+ }
+ }
+ }()
+
+ // Copy content
+ if _, err := io.Copy(outFile, tarReader); err != nil {
+ return fmt.Errorf("write file content: %w", err)
+ }
+
+ // Set metadata on the temp file before replacing the target so failures do not
+ // leave the final path in a partially restored state.
+ // Ownership remains best-effort to match the previous restore behavior on
+ // unprivileged runs and filesystems that do not support chown.
+ if err := atomicFileChown(outFile, header.Uid, header.Gid); err != nil {
+ logger.Debug("Failed to chown file %s: %v", target, err)
+ }
+ if err := atomicFileChmod(outFile, os.FileMode(header.Mode)); err != nil {
+ return fmt.Errorf("chmod file: %w", err)
+ }
+
+ // Close before renaming into place.
+ if err := closeOutFile(); err != nil {
+ return fmt.Errorf("close file: %w", err)
+ }
+
+ if err := restoreFS.Rename(tmpPath, target); err != nil {
+ return fmt.Errorf("replace file: %w", err)
+ }
+ tmpPath = ""
+
+ // Set timestamps (mtime, atime, ctime via syscall)
+ if err := setTimestamps(target, header); err != nil {
+ logger.Debug("Failed to set timestamps on file %s: %v", target, err)
+ }
+
+ return nil
+}
+
+// extractSymlink creates a symbolic link
+func extractSymlink(target string, header *tar.Header, destRoot string, logger *logging.Logger) error {
+ linkTarget := header.Linkname
+
+ // Pre-validation: ensure the symlink target resolves within destRoot before creation.
+ if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), linkTarget); err != nil {
+ return fmt.Errorf("symlink target escapes root before creation: %s -> %s: %w", header.Name, linkTarget, err)
+ }
+
+ // Remove existing file/link if it exists
+ _ = restoreFS.Remove(target)
+
+ // Create symlink
+ if err := restoreFS.Symlink(linkTarget, target); err != nil {
+ return fmt.Errorf("create symlink: %w", err)
+ }
+
+ // POST-CREATION VALIDATION: Verify the created symlink's target stays within destRoot
+ actualTarget, err := restoreFS.Readlink(target)
+ if err != nil {
+ restoreFS.Remove(target) // Clean up
+ return fmt.Errorf("read created symlink %s: %w", target, err)
+ }
+
+ if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), actualTarget); err != nil {
+ restoreFS.Remove(target)
+ return fmt.Errorf("symlink target escapes root after creation: %s -> %s: %w", header.Name, actualTarget, err)
+ }
+
+ // Set ownership (on the symlink itself, not the target)
+ if err := restoreFS.Lchown(target, header.Uid, header.Gid); err != nil {
+ logger.Debug("Failed to lchown symlink %s: %v", target, err)
+ }
+
+ // Note: timestamps on symlinks are not typically preserved
+ return nil
+}
+
+// extractHardlink creates a hard link
+func extractHardlink(target string, header *tar.Header, destRoot string) error {
+ // Validate hard link target
+ linkName := filepath.FromSlash(header.Linkname)
+ if linkName == "" || filepath.Clean(linkName) == "." {
+ return fmt.Errorf("empty hardlink target not allowed")
+ }
+
+ // Reject absolute hard link targets immediately
+ if filepath.IsAbs(linkName) {
+ return fmt.Errorf("absolute hardlink target not allowed: %s", linkName)
+ }
+
+ // Resolve and validate the hard link target stays within extraction root.
+ linkTarget, err := resolvePathWithinRootFS(restoreFS, destRoot, linkName)
+ if err != nil {
+ return fmt.Errorf("hardlink target escapes root: %s -> %s: %w", header.Name, linkName, err)
+ }
+
+ // Remove existing file/link if it exists
+ _ = restoreFS.Remove(target)
+
+ // Create hard link
+ if err := restoreFS.Link(linkTarget, target); err != nil {
+ return fmt.Errorf("create hardlink: %w", err)
+ }
+
+ return nil
+}
+
+// setTimestamps sets atime, mtime, and attempts to set ctime via syscall
+func setTimestamps(target string, header *tar.Header) error {
+ // Convert times to Unix format
+ atime := header.AccessTime
+ mtime := header.ModTime
+
+ // Use syscall.UtimesNano to set atime and mtime with nanosecond precision
+ times := []syscall.Timespec{
+ {Sec: atime.Unix(), Nsec: int64(atime.Nanosecond())},
+ {Sec: mtime.Unix(), Nsec: int64(mtime.Nanosecond())},
+ }
+
+ if err := restoreFS.UtimesNano(target, times); err != nil {
+ return fmt.Errorf("set atime/mtime: %w", err)
+ }
+
+ // Note: ctime (change time) cannot be set directly by user-space programs
+ // It is automatically updated by the kernel when file metadata changes
+ // The header.ChangeTime is stored in PAX but cannot be restored
+
+ return nil
+}
diff --git a/internal/orchestrator/restore_archive_extract.go b/internal/orchestrator/restore_archive_extract.go
new file mode 100644
index 00000000..ea16abb4
--- /dev/null
+++ b/internal/orchestrator/restore_archive_extract.go
@@ -0,0 +1,248 @@
+// Package orchestrator coordinates backup, restore, decrypt, and related workflows.
+package orchestrator
+
+import (
+ "archive/tar"
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+
+ "github.com/tis24dev/proxsave/internal/logging"
+)
+
+type restoreArchiveOptions struct {
+ archivePath string
+ destRoot string
+ logger *logging.Logger
+ categories []Category
+ mode RestoreMode
+ logFile *os.File
+ logFilePath string
+ skipFn func(entryName string) bool
+}
+
+type restoreExtractionStats struct {
+ filesExtracted int
+ filesSkipped int
+ filesFailed int
+}
+
+type restoreExtractionLog struct {
+ logger *logging.Logger
+ logFile *os.File
+ logFilePath string
+ restoredTemp *os.File
+ skippedTemp *os.File
+}
+
+// extractArchiveNative extracts TAR archives natively in Go, preserving all timestamps.
+func extractArchiveNative(ctx context.Context, opts restoreArchiveOptions) (err error) {
+ file, err := restoreFS.Open(opts.archivePath)
+ if err != nil {
+ return fmt.Errorf("open archive: %w", err)
+ }
+ defer file.Close()
+
+ reader, err := createDecompressionReader(ctx, file, opts.archivePath)
+ if err != nil {
+ return fmt.Errorf("create decompression reader: %w", err)
+ }
+ defer closeDecompressionReader(reader, &err, "close decompression reader")
+
+ extractionLog := newRestoreExtractionLog(opts)
+ defer extractionLog.close()
+ extractionLog.writeHeader(opts)
+
+ stats, err := processRestoreArchiveEntries(ctx, tar.NewReader(reader), opts, extractionLog)
+ if err != nil {
+ return err
+ }
+
+ extractionLog.writeSummary(stats)
+ logRestoreExtractionSummary(opts, stats)
+ return nil
+}
+
+func newRestoreExtractionLog(opts restoreArchiveOptions) *restoreExtractionLog {
+ extractionLog := &restoreExtractionLog{
+ logger: opts.logger,
+ logFile: opts.logFile,
+ logFilePath: opts.logFilePath,
+ }
+ if opts.logFile == nil {
+ return extractionLog
+ }
+
+ if tmp, err := restoreFS.CreateTemp("", "restored_entries_*.log"); err == nil {
+ extractionLog.restoredTemp = tmp
+ } else {
+ opts.logger.Warning("Could not create temporary file for restored entries: %v", err)
+ }
+ if tmp, err := restoreFS.CreateTemp("", "skipped_entries_*.log"); err == nil {
+ extractionLog.skippedTemp = tmp
+ } else {
+ opts.logger.Warning("Could not create temporary file for skipped entries: %v", err)
+ }
+ return extractionLog
+}
+
+func (log *restoreExtractionLog) close() {
+ closeAndRemoveRestoreTemp(log.restoredTemp)
+ closeAndRemoveRestoreTemp(log.skippedTemp)
+}
+
+func closeAndRemoveRestoreTemp(file *os.File) {
+ if file == nil {
+ return
+ }
+ file.Close()
+ _ = restoreFS.Remove(file.Name())
+}
+
+func (log *restoreExtractionLog) writeHeader(opts restoreArchiveOptions) {
+ if log.logFile == nil {
+ return
+ }
+ fmt.Fprintf(log.logFile, "=== PROXMOX RESTORE LOG ===\n")
+ fmt.Fprintf(log.logFile, "Date: %s\n", nowRestore().Format("2006-01-02 15:04:05"))
+ fmt.Fprintf(log.logFile, "Mode: %s\n", getModeName(opts.mode))
+ if len(opts.categories) > 0 {
+ fmt.Fprintf(log.logFile, "Selected categories: %d categories\n", len(opts.categories))
+ for _, cat := range opts.categories {
+ fmt.Fprintf(log.logFile, " - %s (%s)\n", cat.Name, cat.ID)
+ }
+ } else {
+ fmt.Fprintf(log.logFile, "Selected categories: ALL (full restore)\n")
+ }
+ fmt.Fprintf(log.logFile, "Archive: %s\n", filepath.Base(opts.archivePath))
+ fmt.Fprintf(log.logFile, "\n")
+}
+
+func processRestoreArchiveEntries(ctx context.Context, tarReader *tar.Reader, opts restoreArchiveOptions, extractionLog *restoreExtractionLog) (restoreExtractionStats, error) {
+ var stats restoreExtractionStats
+ selectiveMode := len(opts.categories) > 0
+ for {
+ if err := ctx.Err(); err != nil {
+ return stats, err
+ }
+
+ header, err := tarReader.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return stats, fmt.Errorf("read tar header: %w", err)
+ }
+
+ if skipRestoreArchiveEntry(header, opts, selectiveMode, extractionLog, &stats) {
+ continue
+ }
+ if err := extractTarEntry(tarReader, header, opts.destRoot, opts.logger); err != nil {
+ opts.logger.Warning("Failed to extract %s: %v", header.Name, err)
+ stats.filesFailed++
+ continue
+ }
+
+ stats.filesExtracted++
+ extractionLog.recordRestored(header.Name)
+ if stats.filesExtracted%100 == 0 {
+ opts.logger.Debug("Extracted %d files...", stats.filesExtracted)
+ }
+ }
+ return stats, nil
+}
+
+func skipRestoreArchiveEntry(header *tar.Header, opts restoreArchiveOptions, selectiveMode bool, extractionLog *restoreExtractionLog, stats *restoreExtractionStats) bool {
+ if opts.skipFn != nil && opts.skipFn(header.Name) {
+ stats.filesSkipped++
+ extractionLog.recordSkipped(header.Name, "skipped by restore policy")
+ return true
+ }
+ if !selectiveMode || restoreEntryMatchesCategories(header.Name, opts.categories) {
+ return false
+ }
+ stats.filesSkipped++
+ extractionLog.recordSkipped(header.Name, "does not match any selected category")
+ return true
+}
+
+func restoreEntryMatchesCategories(entryName string, categories []Category) bool {
+ for _, cat := range categories {
+ if PathMatchesCategory(entryName, cat) {
+ return true
+ }
+ }
+ return false
+}
+
+func (log *restoreExtractionLog) recordSkipped(name, reason string) {
+ if log.skippedTemp != nil {
+ fmt.Fprintf(log.skippedTemp, "SKIPPED: %s (%s)\n", name, reason)
+ }
+}
+
+func (log *restoreExtractionLog) recordRestored(name string) {
+ if log.restoredTemp != nil {
+ fmt.Fprintf(log.restoredTemp, "RESTORED: %s\n", name)
+ }
+}
+
+func (log *restoreExtractionLog) writeSummary(stats restoreExtractionStats) {
+ if log.logFile == nil {
+ return
+ }
+ fmt.Fprintf(log.logFile, "=== FILES RESTORED ===\n")
+ log.copyTempEntries(log.restoredTemp, "restored")
+ fmt.Fprintf(log.logFile, "\n")
+
+ fmt.Fprintf(log.logFile, "=== FILES SKIPPED ===\n")
+ log.copyTempEntries(log.skippedTemp, "skipped")
+ fmt.Fprintf(log.logFile, "\n")
+
+ fmt.Fprintf(log.logFile, "=== SUMMARY ===\n")
+ fmt.Fprintf(log.logFile, "Total files extracted: %d\n", stats.filesExtracted)
+ fmt.Fprintf(log.logFile, "Total files skipped: %d\n", stats.filesSkipped)
+ fmt.Fprintf(log.logFile, "Total files failed: %d\n", stats.filesFailed)
+ fmt.Fprintf(log.logFile, "Total files in archive: %d\n", stats.filesExtracted+stats.filesSkipped+stats.filesFailed)
+}
+
+func (log *restoreExtractionLog) copyTempEntries(tempFile *os.File, label string) {
+ if tempFile == nil {
+ return
+ }
+ if _, err := tempFile.Seek(0, 0); err == nil {
+ if _, err := io.Copy(log.logFile, tempFile); err != nil {
+ log.logger.Warning("Could not write %s entries to log: %v", label, err)
+ }
+ }
+}
+
+func logRestoreExtractionSummary(opts restoreArchiveOptions, stats restoreExtractionStats) {
+ if stats.filesFailed == 0 {
+ if len(opts.categories) > 0 {
+ opts.logger.Info("Successfully restored all %d configuration files/directories", stats.filesExtracted)
+ } else {
+ opts.logger.Info("Successfully restored all %d files/directories", stats.filesExtracted)
+ }
+ } else {
+ if opts.logFilePath != "" {
+ opts.logger.Warning("Restored %d files/directories; %d item(s) failed (see detailed log)", stats.filesExtracted, stats.filesFailed)
+ } else {
+ opts.logger.Warning("Restored %d files/directories; %d item(s) failed", stats.filesExtracted, stats.filesFailed)
+ }
+ }
+
+ if stats.filesSkipped > 0 {
+ if opts.logFilePath != "" {
+ opts.logger.Info("%d additional archive entries (logs, diagnostics, system defaults) were left unchanged on this system; see detailed log for details", stats.filesSkipped)
+ } else {
+ opts.logger.Info("%d additional archive entries (logs, diagnostics, system defaults) were left unchanged on this system", stats.filesSkipped)
+ }
+ }
+
+ if opts.logFilePath != "" {
+ opts.logger.Info("Detailed restore log: %s", opts.logFilePath)
+ }
+}
diff --git a/internal/orchestrator/restore_archive_extract_summary_test.go b/internal/orchestrator/restore_archive_extract_summary_test.go
new file mode 100644
index 00000000..5c1caaf6
--- /dev/null
+++ b/internal/orchestrator/restore_archive_extract_summary_test.go
@@ -0,0 +1,65 @@
+package orchestrator
+
+import (
+ "bytes"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/types"
+)
+
+func TestRestoreExtractionLogWriteSummaryIncludesFailedFiles(t *testing.T) {
+ logPath := filepath.Join(t.TempDir(), "restore.log")
+ logFile, err := os.Create(logPath)
+ if err != nil {
+ t.Fatalf("create log file: %v", err)
+ }
+
+ extractionLog := &restoreExtractionLog{
+ logger: newTestLogger(),
+ logFile: logFile,
+ }
+ extractionLog.writeSummary(restoreExtractionStats{
+ filesExtracted: 2,
+ filesSkipped: 3,
+ filesFailed: 4,
+ })
+ if err := logFile.Close(); err != nil {
+ t.Fatalf("close log file: %v", err)
+ }
+
+ content, err := os.ReadFile(logPath)
+ if err != nil {
+ t.Fatalf("read log file: %v", err)
+ }
+ text := string(content)
+ if !strings.Contains(text, "Total files failed: 4") {
+ t.Fatalf("summary missing failed count:\n%s", text)
+ }
+ if !strings.Contains(text, "Total files in archive: 9") {
+ t.Fatalf("summary total should include failed files:\n%s", text)
+ }
+}
+
+func TestLogRestoreExtractionSummaryOmitsDetailedLogHintWithoutLogPath(t *testing.T) {
+ var buf bytes.Buffer
+ logger := logging.New(types.LogLevelInfo, false)
+ logger.SetOutput(&buf)
+
+ logRestoreExtractionSummary(restoreArchiveOptions{logger: logger}, restoreExtractionStats{
+ filesExtracted: 2,
+ filesSkipped: 1,
+ filesFailed: 1,
+ })
+
+ output := buf.String()
+ if strings.Contains(output, "see detailed log") {
+ t.Fatalf("did not expect detailed log hint without log path:\n%s", output)
+ }
+ if !strings.Contains(output, "1 item(s) failed") {
+ t.Fatalf("expected failed count in summary:\n%s", output)
+ }
+}
diff --git a/internal/orchestrator/restore_archive_paths.go b/internal/orchestrator/restore_archive_paths.go
new file mode 100644
index 00000000..8c9840f7
--- /dev/null
+++ b/internal/orchestrator/restore_archive_paths.go
@@ -0,0 +1,115 @@
+// Package orchestrator coordinates backup, restore, decrypt, and related workflows.
+package orchestrator
+
+import (
+ "fmt"
+ "os"
+ "path"
+ "path/filepath"
+ "strings"
+)
+
+func sanitizeRestoreEntryTarget(destRoot, entryName string) (string, string, error) {
+ return sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, entryName)
+}
+
+func sanitizeRestoreEntryTargetWithFS(fsys FS, destRoot, entryName string) (string, string, error) {
+ absDestRoot, err := resolveRestoreDestRoot(destRoot)
+ if err != nil {
+ return "", "", fmt.Errorf("resolve destination root: %w", err)
+ }
+
+ sanitized, err := normalizeRestoreEntryName(entryName)
+ if err != nil {
+ return "", "", err
+ }
+ absTarget, err := resolveRestoreEntryTarget(absDestRoot, sanitized)
+ if err != nil {
+ return "", "", fmt.Errorf("resolve extraction target: %w", err)
+ }
+ if err := ensureRestoreTargetWithinRoot(absDestRoot, absTarget, entryName); err != nil {
+ return "", "", err
+ }
+ if err := ensureRestoreTargetResolverAllows(fsys, absDestRoot, absTarget, entryName); err != nil {
+ return "", "", err
+ }
+
+ return absTarget, absDestRoot, nil
+}
+
+func resolveRestoreDestRoot(destRoot string) (string, error) {
+ cleanDestRoot := filepath.Clean(destRoot)
+ if cleanDestRoot == "" {
+ cleanDestRoot = string(os.PathSeparator)
+ }
+ return filepath.Abs(cleanDestRoot)
+}
+
+func normalizeRestoreEntryName(entryName string) (string, error) {
+ name := strings.TrimSpace(entryName)
+ if name == "" {
+ return "", fmt.Errorf("empty archive entry name")
+ }
+ sanitized := path.Clean(name)
+ for strings.HasPrefix(sanitized, string(os.PathSeparator)) {
+ sanitized = strings.TrimPrefix(sanitized, string(os.PathSeparator))
+ }
+ if sanitized == "" || sanitized == "." {
+ return "", fmt.Errorf("invalid archive entry name: %q", entryName)
+ }
+ if sanitized == ".." || strings.HasPrefix(sanitized, "../") || strings.Contains(sanitized, "/../") {
+ return "", fmt.Errorf("illegal path: %s", entryName)
+ }
+ return sanitized, nil
+}
+
+func resolveRestoreEntryTarget(absDestRoot, sanitized string) (string, error) {
+ target := filepath.Join(absDestRoot, filepath.FromSlash(sanitized))
+ return filepath.Abs(target)
+}
+
+func ensureRestoreTargetWithinRoot(absDestRoot, absTarget, entryName string) error {
+ rel, err := filepath.Rel(absDestRoot, absTarget)
+ if err != nil {
+ return fmt.Errorf("illegal path: %s: %w", entryName, err)
+ }
+ if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." || filepath.IsAbs(rel) {
+ return fmt.Errorf("illegal path: %s", entryName)
+ }
+ return nil
+}
+
+func ensureRestoreTargetResolverAllows(fsys FS, absDestRoot, absTarget, entryName string) error {
+ if _, err := resolvePathWithinRootFS(fsys, absDestRoot, absTarget); err != nil {
+ if isPathSecurityError(err) {
+ return fmt.Errorf("illegal path: %s: %w", entryName, err)
+ }
+ if !isPathOperationalError(err) {
+ return fmt.Errorf("resolve extraction target: %w", err)
+ }
+ }
+ return nil
+}
+
+func shouldSkipProxmoxSystemRestore(relTarget string) (bool, string) {
+ rel := filepath.ToSlash(filepath.Clean(strings.TrimSpace(relTarget)))
+ rel = strings.TrimPrefix(rel, "./")
+ rel = strings.TrimPrefix(rel, "/")
+
+ switch rel {
+ case "etc/proxmox-backup/domains.cfg":
+ return true, "PBS auth realms must be recreated (domains.cfg is too fragile to restore raw)"
+ case "etc/proxmox-backup/user.cfg":
+ return true, "PBS users must be recreated (user.cfg should not be restored raw)"
+ case "etc/proxmox-backup/acl.cfg":
+ return true, "PBS permissions must be recreated (acl.cfg should not be restored raw)"
+ case "var/lib/proxmox-backup/.clusterlock":
+ return true, "PBS runtime lock files must not be restored"
+ }
+
+ if strings.HasPrefix(rel, "var/lib/proxmox-backup/lock/") {
+ return true, "PBS runtime lock files must not be restored"
+ }
+
+ return false, ""
+}
diff --git a/internal/orchestrator/restore_cluster_apply.go b/internal/orchestrator/restore_cluster_apply.go
new file mode 100644
index 00000000..a2fbac64
--- /dev/null
+++ b/internal/orchestrator/restore_cluster_apply.go
@@ -0,0 +1,531 @@
+// Package orchestrator coordinates backup, restore, decrypt, and related workflows.
+package orchestrator
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+
+ "github.com/tis24dev/proxsave/internal/input"
+ "github.com/tis24dev/proxsave/internal/logging"
+)
+
+// runSafeClusterApply applies selected cluster configs via pvesh without touching config.db.
+// It operates on files extracted to exportRoot (e.g. exportDestRoot).
+func runSafeClusterApply(ctx context.Context, reader *bufio.Reader, exportRoot string, logger *logging.Logger) (err error) {
+ if logger == nil {
+ logger = logging.GetDefaultLogger()
+ }
+ ui := newCLIWorkflowUI(reader, logger)
+ return runSafeClusterApplyWithUI(ctx, ui, exportRoot, logger, nil)
+}
+
+type vmEntry struct {
+ VMID string
+ Kind string // qemu | lxc
+ Name string
+ Path string
+}
+
+func scanVMConfigs(exportRoot, node string) ([]vmEntry, error) {
+ var entries []vmEntry
+ base := filepath.Join(exportRoot, "etc/pve/nodes", node)
+
+ type dirSpec struct {
+ kind string
+ path string
+ }
+
+ dirs := []dirSpec{
+ {kind: "qemu", path: filepath.Join(base, "qemu-server")},
+ {kind: "lxc", path: filepath.Join(base, "lxc")},
+ }
+
+ for _, spec := range dirs {
+ infos, err := restoreFS.ReadDir(spec.path)
+ if err != nil {
+ continue
+ }
+ for _, entry := range infos {
+ if entry.IsDir() {
+ continue
+ }
+ name := entry.Name()
+ if !strings.HasSuffix(name, ".conf") {
+ continue
+ }
+ vmid := strings.TrimSuffix(name, ".conf")
+ vmPath := filepath.Join(spec.path, name)
+ vmName := readVMName(vmPath)
+ entries = append(entries, vmEntry{
+ VMID: vmid,
+ Kind: spec.kind,
+ Name: vmName,
+ Path: vmPath,
+ })
+ }
+ }
+
+ return entries, nil
+}
+
+func listExportNodeDirs(exportRoot string) ([]string, error) {
+ nodesRoot := filepath.Join(exportRoot, "etc/pve/nodes")
+ entries, err := restoreFS.ReadDir(nodesRoot)
+ if err != nil {
+ if errors.Is(err, os.ErrNotExist) {
+ return nil, nil
+ }
+ return nil, err
+ }
+
+ var nodes []string
+ for _, entry := range entries {
+ if !entry.IsDir() {
+ continue
+ }
+ name := strings.TrimSpace(entry.Name())
+ if name == "" {
+ continue
+ }
+ nodes = append(nodes, name)
+ }
+ sort.Strings(nodes)
+ return nodes, nil
+}
+
+func countVMConfigsForNode(exportRoot, node string) (qemuCount, lxcCount int) {
+ base := filepath.Join(exportRoot, "etc/pve/nodes", node)
+
+ countInDir := func(dir string) int {
+ entries, err := restoreFS.ReadDir(dir)
+ if err != nil {
+ return 0
+ }
+ n := 0
+ for _, entry := range entries {
+ if entry.IsDir() {
+ continue
+ }
+ if strings.HasSuffix(entry.Name(), ".conf") {
+ n++
+ }
+ }
+ return n
+ }
+
+ qemuCount = countInDir(filepath.Join(base, "qemu-server"))
+ lxcCount = countInDir(filepath.Join(base, "lxc"))
+ return qemuCount, lxcCount
+}
+
+func promptExportNodeSelection(ctx context.Context, reader *bufio.Reader, exportRoot, currentNode string, exportNodes []string) (string, error) {
+ for {
+ fmt.Println()
+ fmt.Printf("WARNING: VM/CT configs in this backup are stored under different node names.\n")
+ fmt.Printf("Current node: %s\n", currentNode)
+ fmt.Println("Select which exported node to import VM/CT configs from (they will be applied to the current node):")
+ for idx, node := range exportNodes {
+ qemuCount, lxcCount := countVMConfigsForNode(exportRoot, node)
+ fmt.Printf(" [%d] %s (qemu=%d, lxc=%d)\n", idx+1, node, qemuCount, lxcCount)
+ }
+ fmt.Println(" [0] Skip VM/CT apply")
+
+ fmt.Print("Choice: ")
+ line, err := input.ReadLineWithContext(ctx, reader)
+ if err != nil {
+ return "", err
+ }
+ trimmed := strings.TrimSpace(line)
+ if trimmed == "0" {
+ return "", nil
+ }
+ if trimmed == "" {
+ continue
+ }
+ idx, err := parseMenuIndex(trimmed, len(exportNodes))
+ if err != nil {
+ fmt.Println(err)
+ continue
+ }
+ return exportNodes[idx], nil
+ }
+}
+
+func stringSliceContains(items []string, want string) bool {
+ for _, item := range items {
+ if item == want {
+ return true
+ }
+ }
+ return false
+}
+
+func readVMName(confPath string) string {
+ data, err := restoreFS.ReadFile(confPath)
+ if err != nil {
+ return ""
+ }
+ for _, line := range strings.Split(string(data), "\n") {
+ t := strings.TrimSpace(line)
+ if strings.HasPrefix(t, "name:") {
+ return strings.TrimSpace(strings.TrimPrefix(t, "name:"))
+ }
+ if strings.HasPrefix(t, "hostname:") {
+ return strings.TrimSpace(strings.TrimPrefix(t, "hostname:"))
+ }
+ }
+ return ""
+}
+
+func applyVMConfigs(ctx context.Context, entries []vmEntry, logger *logging.Logger) (applied, failed int) {
+ node := localNodeName()
+ for _, vm := range entries {
+ if err := ctx.Err(); err != nil {
+ logger.Warning("VM apply aborted: %v", err)
+ return applied, failed
+ }
+ target := fmt.Sprintf("/nodes/%s/%s/%s/config", node, vm.Kind, vm.VMID)
+ configArgs, err := pveshArgsFromColonConfigFile(vm.Path)
+ if err != nil {
+ logger.Warning("Failed to read %s (vmid=%s kind=%s): %v", vm.Path, vm.VMID, vm.Kind, err)
+ failed++
+ continue
+ }
+
+ exists, err := pveshGuestExists(ctx, logger, target)
+ if err != nil {
+ logger.Warning("Failed to check existing VM/CT config %s (vmid=%s kind=%s): %v", target, vm.VMID, vm.Kind, err)
+ failed++
+ continue
+ }
+ if !exists {
+ createArgs, err := pveshCreateGuestArgs(node, vm, configArgs)
+ if err != nil {
+ logger.Warning("Failed to prepare VM/CT create for %s (vmid=%s kind=%s): %v", target, vm.VMID, vm.Kind, err)
+ failed++
+ continue
+ }
+ if err := runPvesh(ctx, logger, createArgs); err != nil {
+ logger.Warning("Failed to create VM/CT config %s (vmid=%s kind=%s): %v", target, vm.VMID, vm.Kind, err)
+ failed++
+ continue
+ }
+ }
+
+ args := append([]string{"set", target}, configArgs...)
+ if err := runPvesh(ctx, logger, args); err != nil {
+ logger.Warning("Failed to apply %s (vmid=%s kind=%s): %v", target, vm.VMID, vm.Kind, err)
+ failed++
+ } else {
+ display := vm.VMID
+ if vm.Name != "" {
+ display = fmt.Sprintf("%s (%s)", vm.VMID, vm.Name)
+ }
+ logger.Info("Applied VM/CT config %s", display)
+ applied++
+ }
+ }
+ return applied, failed
+}
+
+func localNodeName() string {
+ host, _ := os.Hostname()
+ host = shortHost(host)
+ if host != "" {
+ return host
+ }
+ return "localhost"
+}
+
+func pveshGuestExists(ctx context.Context, logger *logging.Logger, target string) (bool, error) {
+ if err := runPvesh(ctx, logger, []string{"get", target}); err != nil {
+ if isPveshNotFoundError(err) {
+ return false, nil
+ }
+ return false, err
+ }
+ return true, nil
+}
+
+func pveshCreateGuestArgs(node string, vm vmEntry, configArgs []string) ([]string, error) {
+ args := []string{
+ "create",
+ fmt.Sprintf("/nodes/%s/%s", node, vm.Kind),
+ fmt.Sprintf("--vmid=%s", vm.VMID),
+ }
+ switch vm.Kind {
+ case "qemu":
+ return args, nil
+ case "lxc":
+ ostemplate, ok := pveshArgValue(configArgs, "ostemplate")
+ if !ok {
+ return nil, fmt.Errorf("missing ostemplate in LXC config")
+ }
+ return append(args, fmt.Sprintf("--ostemplate=%s", ostemplate)), nil
+ default:
+ return nil, fmt.Errorf("unsupported guest kind %q", vm.Kind)
+ }
+}
+
+func pveshArgValue(args []string, key string) (string, bool) {
+ prefix := "--" + key + "="
+ for _, arg := range args {
+ if strings.HasPrefix(arg, prefix) {
+ return strings.TrimPrefix(arg, prefix), true
+ }
+ }
+ return "", false
+}
+
+func isPveshNotFoundError(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := strings.ToLower(err.Error())
+ for _, marker := range []string{"not found", "does not exist", "no such", "unable to find", "404"} {
+ if strings.Contains(msg, marker) {
+ return true
+ }
+ }
+ return false
+}
+
+type storageBlock struct {
+ ID string
+ Type string
+ entries []proxmoxNotificationEntry
+}
+
+func pveshArgsFromColonConfigFile(path string) ([]string, error) {
+ data, err := restoreFS.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+ return pveshArgsFromColonConfigLines(strings.Split(string(data), "\n")), nil
+}
+
+func pveshArgsFromColonConfigLines(lines []string) []string {
+ args := make([]string, 0, len(lines)*2)
+ for _, line := range lines {
+ if strings.HasPrefix(strings.TrimSpace(line), "[") {
+ break
+ }
+ key, value, ok := parseColonConfigLine(line)
+ if !ok {
+ continue
+ }
+ args = append(args, fmt.Sprintf("--%s=%s", key, value))
+ }
+ return args
+}
+
+func pveshArgsFromProxmoxEntries(entries []proxmoxNotificationEntry) []string {
+ args := make([]string, 0, len(entries)*2)
+ for _, entry := range entries {
+ key := strings.TrimSpace(entry.Key)
+ value := strings.TrimSpace(entry.Value)
+ if key == "" || value == "" {
+ continue
+ }
+ args = append(args, fmt.Sprintf("--%s=%s", key, value))
+ }
+ return args
+}
+
+func storageBlockPveshArgs(block storageBlock) ([]string, bool) {
+ storageType := strings.TrimSpace(block.Type)
+ if storageType == "" {
+ storageType = storageEntryValue(block.entries, "type")
+ }
+ if storageType == "" {
+ return nil, false
+ }
+
+ args := []string{
+ fmt.Sprintf("--storage=%s", block.ID),
+ fmt.Sprintf("--type=%s", storageType),
+ }
+ for _, entry := range block.entries {
+ if strings.EqualFold(strings.TrimSpace(entry.Key), "type") {
+ continue
+ }
+ args = append(args, pveshArgsFromProxmoxEntries([]proxmoxNotificationEntry{entry})...)
+ }
+ return args, true
+}
+
+func storageEntryValue(entries []proxmoxNotificationEntry, want string) string {
+ for _, entry := range entries {
+ if strings.EqualFold(strings.TrimSpace(entry.Key), want) {
+ return strings.TrimSpace(entry.Value)
+ }
+ }
+ return ""
+}
+
+func parseColonConfigLine(line string) (key, value string, ok bool) {
+ trimmed := strings.TrimSpace(line)
+ if trimmed == "" || strings.HasPrefix(trimmed, "#") {
+ return "", "", false
+ }
+ idx := strings.Index(trimmed, ":")
+ if idx <= 0 {
+ return "", "", false
+ }
+ key = strings.TrimSpace(trimmed[:idx])
+ value = strings.TrimSpace(trimmed[idx+1:])
+ if key == "" || value == "" {
+ return "", "", false
+ }
+ return key, value, true
+}
+
+func applyStorageCfg(ctx context.Context, cfgPath string, logger *logging.Logger) (applied, failed int, err error) {
+ blocks, perr := parseStorageBlocks(cfgPath)
+ if perr != nil {
+ return 0, 0, perr
+ }
+ if len(blocks) == 0 {
+ logger.Info("No storage definitions detected in storage.cfg")
+ return 0, 0, nil
+ }
+
+ for _, blk := range blocks {
+ createArgs, ok := storageBlockPveshArgs(blk)
+ if !ok {
+ logger.Warning("Skipping storage %s: storage type missing", blk.ID)
+ failed++
+ continue
+ }
+ args := append([]string{"create", "/storage"}, createArgs...)
+
+ if runErr := runPvesh(ctx, logger, args); runErr != nil {
+ logger.Warning("Failed to apply storage %s: %v", blk.ID, runErr)
+ failed++
+ } else {
+ logger.Info("Applied storage definition %s", blk.ID)
+ applied++
+ }
+
+ if err := ctx.Err(); err != nil {
+ return applied, failed, err
+ }
+ }
+
+ return applied, failed, nil
+}
+
+func parseStorageBlocks(cfgPath string) ([]storageBlock, error) {
+ data, err := restoreFS.ReadFile(cfgPath)
+ if err != nil {
+ return nil, err
+ }
+
+ var blocks []storageBlock
+ var current *storageBlock
+
+ flush := func() {
+ if current != nil {
+ blocks = append(blocks, *current)
+ current = nil
+ }
+ }
+
+ for _, line := range strings.Split(string(data), "\n") {
+ trimmed := strings.TrimSpace(line)
+ if trimmed == "" {
+ flush()
+ continue
+ }
+
+ // storage.cfg blocks use `: ` (e.g. `dir: local`, `nfs: backup`).
+ // Older exports may still use `storage: ` blocks.
+ typ, name, ok := parseSectionHeader(trimmed)
+ if ok {
+ flush()
+ storageType := ""
+ if !strings.EqualFold(typ, "storage") {
+ storageType = typ
+ }
+ current = &storageBlock{ID: name, Type: storageType}
+ continue
+ }
+ if current != nil {
+ key, value := parseProxmoxNotificationKV(trimmed)
+ if strings.TrimSpace(key) == "" {
+ continue
+ }
+ current.entries = append(current.entries, proxmoxNotificationEntry{Key: key, Value: value})
+ }
+ }
+ flush()
+
+ return blocks, nil
+}
+
+func runPvesh(ctx context.Context, logger *logging.Logger, args []string) error {
+ output, err := restoreCmd.Run(ctx, "pvesh", args...)
+ if len(output) > 0 {
+ logger.Debug("pvesh %v output: %s", args, strings.TrimSpace(string(output)))
+ }
+ if err != nil {
+ return fmt.Errorf("pvesh %v failed: %w", args, err)
+ }
+ return nil
+}
+
+func shortHost(host string) string {
+ if idx := strings.Index(host, "."); idx > 0 {
+ return host[:idx]
+ }
+ return host
+}
+
+func sanitizeID(id string) string {
+ var b strings.Builder
+ for _, r := range id {
+ if isSafeIDRune(r) {
+ b.WriteRune(r)
+ } else {
+ b.WriteRune('_')
+ }
+ }
+ return b.String()
+}
+
+func isSafeIDRune(r rune) bool {
+ return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' || r == '_'
+}
+
+// promptClusterRestoreMode asks how to handle cluster database restore (safe export vs full recovery).
+func promptClusterRestoreMode(ctx context.Context, reader *bufio.Reader) (int, error) {
+ fmt.Println()
+ fmt.Println("Cluster backup detected. Choose how to restore the cluster database:")
+ fmt.Println(" [1] SAFE: Do NOT write /var/lib/pve-cluster/config.db. Export cluster files only (manual/apply via API).")
+ fmt.Println(" [2] RECOVERY: Restore full cluster database (/var/lib/pve-cluster). Use only when cluster is offline/isolated.")
+ fmt.Println(" [0] Exit")
+
+ for {
+ fmt.Print("Choice: ")
+ choiceLine, err := input.ReadLineWithContext(ctx, reader)
+ if err != nil {
+ return 0, err
+ }
+ switch strings.TrimSpace(choiceLine) {
+ case "1":
+ return 1, nil
+ case "2":
+ return 2, nil
+ case "0":
+ return 0, nil
+ default:
+ fmt.Println("Please enter 1, 2, or 0.")
+ }
+ }
+}
diff --git a/internal/orchestrator/restore_cluster_apply_additional_test.go b/internal/orchestrator/restore_cluster_apply_additional_test.go
new file mode 100644
index 00000000..d9d5ac62
--- /dev/null
+++ b/internal/orchestrator/restore_cluster_apply_additional_test.go
@@ -0,0 +1,60 @@
+package orchestrator
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+func TestRunSafeClusterApplyWithUI_SkipsStorageDatacenterWhenStoragePVEStaged(t *testing.T) {
+ origCmd := restoreCmd
+ origFS := restoreFS
+ t.Cleanup(func() {
+ restoreCmd = origCmd
+ restoreFS = origFS
+ })
+ restoreFS = osFS{}
+
+ pathDir := t.TempDir()
+ pveshPath := filepath.Join(pathDir, "pvesh")
+ if err := os.WriteFile(pveshPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
+ t.Fatalf("write pvesh: %v", err)
+ }
+ t.Setenv("PATH", pathDir+string(os.PathListSeparator)+os.Getenv("PATH"))
+
+ runner := &recordingRunner{}
+ restoreCmd = runner
+
+ exportRoot := t.TempDir()
+ pveDir := filepath.Join(exportRoot, "etc", "pve")
+ if err := os.MkdirAll(pveDir, 0o755); err != nil {
+ t.Fatalf("mkdir pve dir: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(pveDir, "storage.cfg"), []byte("storage: local\n type dir\n"), 0o640); err != nil {
+ t.Fatalf("write storage.cfg: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(pveDir, "datacenter.cfg"), []byte("keyboard: it\n"), 0o640); err != nil {
+ t.Fatalf("write datacenter.cfg: %v", err)
+ }
+
+ plan := &RestorePlan{
+ SystemType: SystemTypePVE,
+ StagedCategories: []Category{{ID: "storage_pve", Type: CategoryTypePVE}},
+ }
+ ui := &fakeRestoreWorkflowUI{
+ applyStorageCfg: true,
+ applyDatacenterCfg: true,
+ }
+
+ if err := runSafeClusterApplyWithUI(context.Background(), ui, exportRoot, newTestLogger(), plan); err != nil {
+ t.Fatalf("runSafeClusterApplyWithUI error: %v", err)
+ }
+
+ for _, call := range runner.calls {
+ if strings.Contains(call, "pvesh create /storage") || strings.Contains(call, "pvesh set /storage") || strings.Contains(call, "/cluster/config") {
+ t.Fatalf("storage/datacenter apply should be skipped for storage_pve staged restore; calls=%#v", runner.calls)
+ }
+ }
+}
diff --git a/internal/orchestrator/restore_coverage_extra_test.go b/internal/orchestrator/restore_coverage_extra_test.go
index 15bdd2b9..e21566cc 100644
--- a/internal/orchestrator/restore_coverage_extra_test.go
+++ b/internal/orchestrator/restore_coverage_extra_test.go
@@ -21,6 +21,8 @@ func (runOnlyRunner) Run(ctx context.Context, name string, args ...string) ([]by
return nil, fmt.Errorf("unexpected command: %s", commandKey(name, args))
}
+type zfsContextTestKey struct{}
+
type recordingRunner struct {
calls []string
}
@@ -63,7 +65,7 @@ func TestDetectImportableZFSPools_ReturnsPoolsAndErrorWhenCommandFails(t *testin
}
restoreCmd = fake
- pools, output, err := detectImportableZFSPools()
+ pools, output, err := detectImportableZFSPools(context.Background())
if err == nil {
t.Fatalf("expected error")
}
@@ -86,7 +88,7 @@ func TestCheckZFSPoolsAfterRestore_ReturnsNilWhenZpoolMissing(t *testing.T) {
}
restoreCmd = fake
- if err := checkZFSPoolsAfterRestore(newTestLogger()); err != nil {
+ if err := checkZFSPoolsAfterRestore(context.Background(), newTestLogger()); err != nil {
t.Fatalf("expected nil error when zpool missing, got %v", err)
}
if len(fake.Calls) != 1 || fake.Calls[0] != "which zpool" {
@@ -94,6 +96,50 @@ func TestCheckZFSPoolsAfterRestore_ReturnsNilWhenZpoolMissing(t *testing.T) {
}
}
+func TestCheckZFSPoolsAfterRestore_UsesProvidedContext(t *testing.T) {
+ orig := restoreCmd
+ t.Cleanup(func() { restoreCmd = orig })
+
+ fake := &FakeCommandRunner{
+ Outputs: map[string][]byte{
+ "which zpool": []byte("/sbin/zpool\n"),
+ "zpool import": []byte(""),
+ },
+ }
+ restoreCmd = fake
+
+ ctx := context.WithValue(context.Background(), zfsContextTestKey{}, "restore")
+ if err := checkZFSPoolsAfterRestore(ctx, newTestLogger()); err != nil {
+ t.Fatalf("checkZFSPoolsAfterRestore error: %v", err)
+ }
+
+ if len(fake.Contexts) == 0 {
+ t.Fatalf("expected command contexts to be recorded")
+ }
+ for i, got := range fake.Contexts {
+ if got.Value(zfsContextTestKey{}) != "restore" {
+ t.Fatalf("command context %d did not use restore context", i)
+ }
+ }
+}
+
+func TestCheckZFSPoolsAfterRestore_ReturnsCanceledContext(t *testing.T) {
+ orig := restoreCmd
+ t.Cleanup(func() { restoreCmd = orig })
+
+ fake := &FakeCommandRunner{}
+ restoreCmd = fake
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ if err := checkZFSPoolsAfterRestore(ctx, newTestLogger()); err != context.Canceled {
+ t.Fatalf("checkZFSPoolsAfterRestore error = %v, want context.Canceled", err)
+ }
+ if len(fake.Calls) != 0 {
+ t.Fatalf("expected no commands after canceled context, got %#v", fake.Calls)
+ }
+}
+
func TestCheckZFSPoolsAfterRestore_ConfiguredPools_NoImportables(t *testing.T) {
origCmd := restoreCmd
origFS := restoreFS
@@ -131,7 +177,7 @@ func TestCheckZFSPoolsAfterRestore_ConfiguredPools_NoImportables(t *testing.T) {
}
restoreCmd = fake
- if err := checkZFSPoolsAfterRestore(newTestLogger()); err != nil {
+ if err := checkZFSPoolsAfterRestore(context.Background(), newTestLogger()); err != nil {
t.Fatalf("checkZFSPoolsAfterRestore error: %v", err)
}
@@ -172,7 +218,7 @@ func TestCheckZFSPoolsAfterRestore_ReportsImportablePools(t *testing.T) {
}
restoreCmd = fake
- if err := checkZFSPoolsAfterRestore(newTestLogger()); err != nil {
+ if err := checkZFSPoolsAfterRestore(context.Background(), newTestLogger()); err != nil {
t.Fatalf("checkZFSPoolsAfterRestore error: %v", err)
}
@@ -313,10 +359,10 @@ func TestRunSafeClusterApply_AppliesVMStorageAndDatacenterConfigs(t *testing.T)
}
wantPrefixes := []string{
- "pvesh set /nodes/" + node + "/qemu/100/config --filename ",
- "pvesh set /nodes/" + node + "/lxc/101/config --filename ",
- "pvesh set /cluster/storage/local -conf ",
- "pvesh set /cluster/storage/backup_ext -conf ",
+ "pvesh set /nodes/" + node + "/qemu/100/config --name=vm100",
+ "pvesh set /nodes/" + node + "/lxc/101/config --hostname=ct101",
+ "pvesh create /storage --storage=local --type=dir --path=/var/lib/vz",
+ "pvesh create /storage --storage=backup_ext --type=nfs --server=10.0.0.1",
"pvesh set /cluster/config -conf ",
}
for _, prefix := range wantPrefixes {
@@ -331,6 +377,14 @@ func TestRunSafeClusterApply_AppliesVMStorageAndDatacenterConfigs(t *testing.T)
t.Fatalf("expected a call with prefix %q; calls=%#v", prefix, runner.calls)
}
}
+ for _, call := range runner.calls {
+ if strings.Contains(call, "--filename") {
+ t.Fatalf("VM/CT apply must not use invalid --filename flag; calls=%#v", runner.calls)
+ }
+ if strings.Contains(call, "/cluster/storage/") || (strings.Contains(call, " -conf ") && strings.Contains(call, "storage")) {
+ t.Fatalf("storage apply must not use invalid cluster storage path or -conf flag; calls=%#v", runner.calls)
+ }
+ }
}
func TestRunSafeClusterApply_AppliesPoolsFromUserCfg(t *testing.T) {
@@ -485,11 +539,10 @@ func TestRunSafeClusterApply_UsesSingleExportedNodeWhenHostnameMismatch(t *testi
t.Fatalf("runSafeClusterApply error: %v", err)
}
- wantPrefix := "pvesh set /nodes/" + targetNode + "/qemu/100/config --filename "
- wantSourceSuffix := filepath.Join("etc", "pve", "nodes", sourceNode, "qemu-server", "100.conf")
+ wantPrefix := "pvesh set /nodes/" + targetNode + "/qemu/100/config --name=vm100"
found := false
for _, call := range runner.calls {
- if strings.HasPrefix(call, wantPrefix) && strings.Contains(call, wantSourceSuffix) {
+ if strings.HasPrefix(call, wantPrefix) {
found = true
break
}
@@ -547,11 +600,10 @@ func TestRunSafeClusterApply_PromptsForSourceNodeWhenMultipleExportNodes(t *test
t.Fatalf("runSafeClusterApply error: %v", err)
}
- wantPrefix := "pvesh set /nodes/" + targetNode + "/qemu/101/config --filename "
- wantSourceSuffix := filepath.Join("etc", "pve", "nodes", sourceNode2, "qemu-server", "101.conf")
+ wantPrefix := "pvesh set /nodes/" + targetNode + "/qemu/101/config --name=vm101"
found := false
for _, call := range runner.calls {
- if strings.HasPrefix(call, wantPrefix) && strings.Contains(call, wantSourceSuffix) {
+ if strings.HasPrefix(call, wantPrefix) {
found = true
break
}
@@ -612,13 +664,9 @@ func TestRunRestoreCommandStream_FallsBackToExecCommand(t *testing.T) {
if err != nil {
t.Fatalf("runRestoreCommandStream error: %v", err)
}
- rc, ok := reader.(io.ReadCloser)
- if !ok {
- t.Fatalf("expected io.ReadCloser, got %T", reader)
- }
- defer rc.Close()
+ defer reader.Close()
- out, err := io.ReadAll(rc)
+ out, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("read: %v", err)
}
diff --git a/internal/orchestrator/restore_decision.go b/internal/orchestrator/restore_decision.go
index dafa20a6..903984f6 100644
--- a/internal/orchestrator/restore_decision.go
+++ b/internal/orchestrator/restore_decision.go
@@ -91,14 +91,12 @@ func inspectRestoreArchiveContents(archivePath string, logger *logging.Logger) (
if err != nil {
return nil, err
}
- if closer, ok := reader.(interface{ Close() error }); ok {
- defer func() {
- if closeErr := closer.Close(); closeErr != nil && err == nil {
- inspection = nil
- err = fmt.Errorf("inspect archive: %w", closeErr)
- }
- }()
- }
+ defer func() {
+ if closeErr := reader.Close(); closeErr != nil && err == nil {
+ inspection = nil
+ err = fmt.Errorf("inspect archive: %w", closeErr)
+ }
+ }()
tarReader := tar.NewReader(reader)
archivePaths, metadata, metadataErr, collectErr := collectRestoreArchiveFacts(tarReader)
diff --git a/internal/orchestrator/restore_decompression.go b/internal/orchestrator/restore_decompression.go
new file mode 100644
index 00000000..224fea52
--- /dev/null
+++ b/internal/orchestrator/restore_decompression.go
@@ -0,0 +1,103 @@
+// Package orchestrator coordinates backup, restore, decrypt, and related workflows.
+package orchestrator
+
+import (
+ "compress/gzip"
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/tis24dev/proxsave/internal/safeexec"
+)
+
+type restoreDecompressionFormat struct {
+ matches func(string) bool
+ open func(context.Context, *os.File) (io.ReadCloser, error)
+}
+
+// createDecompressionReader creates appropriate decompression reader based on file extension
+func createDecompressionReader(ctx context.Context, file *os.File, archivePath string) (io.ReadCloser, error) {
+ for _, format := range restoreDecompressionFormats() {
+ if format.matches(archivePath) {
+ return format.open(ctx, file)
+ }
+ }
+ return nil, fmt.Errorf("unsupported archive format: %s", filepath.Base(archivePath))
+}
+
+func closeDecompressionReader(reader io.Closer, errp *error, operation string) {
+ if reader == nil || errp == nil {
+ return
+ }
+ if closeErr := reader.Close(); closeErr != nil && *errp == nil {
+ *errp = fmt.Errorf("%s: %w", operation, closeErr)
+ }
+}
+
+func restoreDecompressionFormats() []restoreDecompressionFormat {
+ return []restoreDecompressionFormat{
+ {
+ matches: func(path string) bool { return strings.HasSuffix(path, ".tar.gz") || strings.HasSuffix(path, ".tgz") },
+ open: func(_ context.Context, file *os.File) (io.ReadCloser, error) { return gzip.NewReader(file) },
+ },
+ {matches: func(path string) bool { return strings.HasSuffix(path, ".tar.xz") }, open: createXZReader},
+ {
+ matches: func(path string) bool {
+ return strings.HasSuffix(path, ".tar.zst") || strings.HasSuffix(path, ".tar.zstd")
+ },
+ open: createZstdReader,
+ },
+ {matches: func(path string) bool { return strings.HasSuffix(path, ".tar.bz2") }, open: createBzip2Reader},
+ {matches: func(path string) bool { return strings.HasSuffix(path, ".tar.lzma") }, open: createLzmaReader},
+ {matches: func(path string) bool { return strings.HasSuffix(path, ".tar") }, open: func(_ context.Context, file *os.File) (io.ReadCloser, error) { return file, nil }},
+ }
+}
+
+// createXZReader creates an XZ decompression reader using injectable command runner
+func createXZReader(ctx context.Context, file *os.File) (io.ReadCloser, error) {
+ return runRestoreCommandStream(ctx, "xz", file, "-d", "-c")
+}
+
+// createZstdReader creates a Zstd decompression reader using injectable command runner
+func createZstdReader(ctx context.Context, file *os.File) (io.ReadCloser, error) {
+ return runRestoreCommandStream(ctx, "zstd", file, "-d", "-c")
+}
+
+// createBzip2Reader creates a Bzip2 decompression reader using injectable command runner
+func createBzip2Reader(ctx context.Context, file *os.File) (io.ReadCloser, error) {
+ return runRestoreCommandStream(ctx, "bzip2", file, "-d", "-c")
+}
+
+// createLzmaReader creates an LZMA decompression reader using injectable command runner
+func createLzmaReader(ctx context.Context, file *os.File) (io.ReadCloser, error) {
+ return runRestoreCommandStream(ctx, "lzma", file, "-d", "-c")
+}
+
+// runRestoreCommandStream starts a command that reads from stdin and exposes stdout as a ReadCloser.
+// It prefers an injectable streaming runner when available; otherwise falls back to safeexec.
+func runRestoreCommandStream(ctx context.Context, name string, stdin io.Reader, args ...string) (io.ReadCloser, error) {
+ type streamingRunner interface {
+ RunStream(ctx context.Context, name string, stdin io.Reader, args ...string) (io.ReadCloser, error)
+ }
+ if sr, ok := restoreCmd.(streamingRunner); ok && sr != nil {
+ return sr.RunStream(ctx, name, stdin, args...)
+ }
+
+ cmd, err := safeexec.CommandContext(ctx, name, args...)
+ if err != nil {
+ return nil, err
+ }
+ cmd.Stdin = stdin
+ stdout, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, fmt.Errorf("create %s pipe: %w", name, err)
+ }
+ if err := cmd.Start(); err != nil {
+ stdout.Close()
+ return nil, fmt.Errorf("start %s: %w", name, err)
+ }
+ return &waitReadCloser{ReadCloser: stdout, wait: cmd.Wait}, nil
+}
diff --git a/internal/orchestrator/restore_errors_test.go b/internal/orchestrator/restore_errors_test.go
index 313cfa91..bfe6c39e 100644
--- a/internal/orchestrator/restore_errors_test.go
+++ b/internal/orchestrator/restore_errors_test.go
@@ -11,6 +11,7 @@ import (
"os"
"path/filepath"
"strings"
+ "syscall"
"testing"
"time"
@@ -54,7 +55,7 @@ func TestRunRestoreCommandStream_UsesStreamingRunner(t *testing.T) {
if err != nil {
t.Fatalf("createXZReader: %v", err)
}
- defer reader.(io.Closer).Close()
+ defer reader.Close()
buf, err := io.ReadAll(reader)
if err != nil {
@@ -550,9 +551,11 @@ func TestApplyStorageCfg_WithMultipleBlocks(t *testing.T) {
// Write storage config with multiple blocks
cfgPath := filepath.Join(t.TempDir(), "storage.cfg")
content := `storage: local
+ type dir
path /var/lib/vz
storage: backup
+ type nfs
path /mnt/backup
`
if err := os.WriteFile(cfgPath, []byte(content), 0o644); err != nil {
@@ -569,6 +572,16 @@ storage: backup
if applied != 2 {
t.Fatalf("expected 2 applied, got %d (failed=%d)", applied, failed)
}
+ calls := strings.Join(restoreCmd.(*FakeCommandRunner).CallsList(), "\n")
+ if strings.Contains(calls, " -conf ") {
+ t.Fatalf("storage apply must not use -conf; calls=%s", calls)
+ }
+ if !strings.Contains(calls, "pvesh create /storage --storage=local --type=dir --path=/var/lib/vz") {
+ t.Fatalf("missing local storage args; calls=%s", calls)
+ }
+ if !strings.Contains(calls, "pvesh create /storage --storage=backup --type=nfs --path=/mnt/backup") {
+ t.Fatalf("missing backup storage args; calls=%s", calls)
+ }
}
func TestApplyStorageCfg_PveshError(t *testing.T) {
@@ -582,6 +595,7 @@ func TestApplyStorageCfg_PveshError(t *testing.T) {
cfgPath := filepath.Join(t.TempDir(), "storage.cfg")
content := `storage: local
+ type dir
path /var/lib/vz
`
if err := os.WriteFile(cfgPath, []byte(content), 0o644); err != nil {
@@ -882,6 +896,12 @@ func (f *ErrorInjectingFS) MkdirTemp(dir, pattern string) (string, error) {
func (f *ErrorInjectingFS) Rename(oldpath, newpath string) error {
return f.base.Rename(oldpath, newpath)
}
+func (f *ErrorInjectingFS) Lchown(path string, uid, gid int) error {
+ return f.base.Lchown(path, uid, gid)
+}
+func (f *ErrorInjectingFS) UtimesNano(path string, times []syscall.Timespec) error {
+ return f.base.UtimesNano(path, times)
+}
func (f *ErrorInjectingFS) MkdirAll(path string, perm os.FileMode) error {
if f.mkdirAllErr != nil {
@@ -1704,6 +1724,49 @@ func TestApplyVMConfigs_SuccessfulApply(t *testing.T) {
if applied != 1 || failed != 0 {
t.Fatalf("expected (1,0), got (%d,%d)", applied, failed)
}
+ calls := strings.Join(fake.CallsList(), "\n")
+ if strings.Contains(calls, "--filename") {
+ t.Fatalf("VM apply must not use --filename; calls=%s", calls)
+ }
+ if !strings.Contains(calls, "pvesh set /nodes/") || !strings.Contains(calls, "/qemu/100/config --name=test-vm") {
+ t.Fatalf("missing VM config args; calls=%s", calls)
+ }
+}
+
+func TestApplyVMConfigs_CreatesMissingGuestBeforeSet(t *testing.T) {
+ orig := restoreCmd
+ t.Cleanup(func() { restoreCmd = orig })
+
+ node := localNodeName()
+ getCall := fmt.Sprintf("pvesh get /nodes/%s/qemu/100/config", node)
+ fake := &FakeCommandRunner{
+ Outputs: map[string][]byte{},
+ Errors: map[string]error{
+ getCall: fmt.Errorf("not found"),
+ },
+ }
+ restoreCmd = fake
+
+ dir := t.TempDir()
+ configPath := filepath.Join(dir, "100.conf")
+ if err := os.WriteFile(configPath, []byte("name: test-vm"), 0o644); err != nil {
+ t.Fatalf("write config: %v", err)
+ }
+
+ entries := []vmEntry{{VMID: "100", Kind: "qemu", Path: configPath}}
+ logger := logging.New(logging.GetDefaultLogger().GetLevel(), false)
+ applied, failed := applyVMConfigs(context.Background(), entries, logger)
+
+ if applied != 1 || failed != 0 {
+ t.Fatalf("expected (1,0), got (%d,%d)", applied, failed)
+ }
+ calls := strings.Join(fake.CallsList(), "\n")
+ if !strings.Contains(calls, fmt.Sprintf("pvesh create /nodes/%s/qemu --vmid=100", node)) {
+ t.Fatalf("missing create call; calls=%s", calls)
+ }
+ if !strings.Contains(calls, fmt.Sprintf("pvesh set /nodes/%s/qemu/100/config --name=test-vm", node)) {
+ t.Fatalf("missing set call; calls=%s", calls)
+ }
}
// --------------------------------------------------------------------------
@@ -1906,7 +1969,12 @@ func TestExtractArchiveNative_OpenError(t *testing.T) {
restoreFS = osFS{}
logger := logging.New(logging.GetDefaultLogger().GetLevel(), false)
- err := extractArchiveNative(context.Background(), "/nonexistent/archive.tar", "/tmp", logger, nil, RestoreModeFull, nil, "", nil)
+ err := extractArchiveNative(context.Background(), restoreArchiveOptions{
+ archivePath: "/nonexistent/archive.tar",
+ destRoot: "/tmp",
+ logger: logger,
+ mode: RestoreModeFull,
+ })
if err == nil || !strings.Contains(err.Error(), "open archive") {
t.Fatalf("expected open error, got: %v", err)
}
diff --git a/internal/orchestrator/restore_filesystem_test.go b/internal/orchestrator/restore_filesystem_test.go
index 97d8a448..b627a727 100644
--- a/internal/orchestrator/restore_filesystem_test.go
+++ b/internal/orchestrator/restore_filesystem_test.go
@@ -276,7 +276,13 @@ func TestExtractArchiveNative_SkipFnSkipsFstab(t *testing.T) {
return name == "etc/fstab"
}
- if err := extractArchiveNative(context.Background(), archivePath, destRoot, newTestLogger(), nil, RestoreModeFull, nil, "", skipFn); err != nil {
+ if err := extractArchiveNative(context.Background(), restoreArchiveOptions{
+ archivePath: archivePath,
+ destRoot: destRoot,
+ logger: newTestLogger(),
+ mode: RestoreModeFull,
+ skipFn: skipFn,
+ }); err != nil {
t.Fatalf("extractArchiveNative error: %v", err)
}
diff --git a/internal/orchestrator/restore_firewall.go b/internal/orchestrator/restore_firewall.go
index 50e27a6a..0c2c6899 100644
--- a/internal/orchestrator/restore_firewall.go
+++ b/internal/orchestrator/restore_firewall.go
@@ -476,8 +476,8 @@ func armFirewallRollback(ctx context.Context, logger *logging.Logger, backupPath
}
if handle.unitName == "" {
- cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", timeoutSeconds, handle.scriptPath)
- if output, err := restoreCmd.Run(ctx, "sh", "-c", cmd); err != nil {
+ output, err := runBackgroundRollbackTimer(ctx, timeoutSeconds, handle.scriptPath)
+ if err != nil {
logger.Debug("Background rollback output: %s", strings.TrimSpace(string(output)))
return nil, fmt.Errorf("failed to arm rollback timer: %w", err)
}
diff --git a/internal/orchestrator/restore_firewall_additional_test.go b/internal/orchestrator/restore_firewall_additional_test.go
index 11f47ecb..beef3796 100644
--- a/internal/orchestrator/restore_firewall_additional_test.go
+++ b/internal/orchestrator/restore_firewall_additional_test.go
@@ -704,8 +704,7 @@ func TestArmFirewallRollback_SystemdAndBackgroundPaths(t *testing.T) {
t.Fatalf("expected unitName cleared after systemd-run failure, got %q", handle.unitName)
}
- cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 2, scriptPath)
- wantBackground := "sh -c " + cmd
+ wantBackground := backgroundRollbackCallKey(2, scriptPath)
calls := fakeCmd.CallsList()
if len(calls) != 2 || calls[1] != wantBackground {
t.Fatalf("unexpected calls: %#v", calls)
@@ -723,8 +722,7 @@ func TestArmFirewallRollback_SystemdAndBackgroundPaths(t *testing.T) {
timestamp := fakeTime.Current.Format("20060102_150405")
scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("firewall_rollback_%s.sh", timestamp))
- cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 1, scriptPath)
- backgroundKey := "sh -c " + cmd
+ backgroundKey := backgroundRollbackCallKey(1, scriptPath)
fakeCmd.Errors[backgroundKey] = fmt.Errorf("boom")
if _, err := armFirewallRollback(context.Background(), logger, "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil {
@@ -1563,7 +1561,7 @@ func TestArmFirewallRollback_DefaultWorkDirAndMinTimeout(t *testing.T) {
if len(calls) != 1 {
t.Fatalf("unexpected calls: %#v", calls)
}
- if !strings.Contains(calls[0], "sleep 1; /bin/sh") {
+ if calls[0] != backgroundRollbackCallKey(1, handle.scriptPath) {
t.Fatalf("expected timeoutSeconds to clamp to 1, got call=%q", calls[0])
}
}
diff --git a/internal/orchestrator/restore_ha.go b/internal/orchestrator/restore_ha.go
index 0111a16d..6765e7c9 100644
--- a/internal/orchestrator/restore_ha.go
+++ b/internal/orchestrator/restore_ha.go
@@ -404,8 +404,8 @@ func armHARollback(ctx context.Context, logger *logging.Logger, backupPath strin
}
if handle.unitName == "" {
- cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", timeoutSeconds, handle.scriptPath)
- if output, err := restoreCmd.Run(ctx, "sh", "-c", cmd); err != nil {
+ output, err := runBackgroundRollbackTimer(ctx, timeoutSeconds, handle.scriptPath)
+ if err != nil {
logger.Debug("Background rollback output: %s", strings.TrimSpace(string(output)))
return nil, fmt.Errorf("failed to schedule rollback timer: %w", err)
}
diff --git a/internal/orchestrator/restore_ha_additional_test.go b/internal/orchestrator/restore_ha_additional_test.go
index 4bee40e4..0d1554b7 100644
--- a/internal/orchestrator/restore_ha_additional_test.go
+++ b/internal/orchestrator/restore_ha_additional_test.go
@@ -275,7 +275,7 @@ func TestArmHARollback_CoversSchedulingPaths(t *testing.T) {
t.Fatalf("expected backup path in script, got:\n%s", string(script))
}
- wantBackground := "sh -c nohup sh -c 'sleep 2; /bin/sh " + handle.scriptPath + "' >/dev/null 2>&1 &"
+ wantBackground := backgroundRollbackCallKey(2, handle.scriptPath)
calls := env.cmd.CallsList()
if len(calls) != 1 || calls[0] != wantBackground {
t.Fatalf("unexpected calls: %#v", calls)
@@ -290,7 +290,7 @@ func TestArmHARollback_CoversSchedulingPaths(t *testing.T) {
if err != nil {
t.Fatalf("armHARollback error: %v", err)
}
- wantBackground := "sh -c nohup sh -c 'sleep 1; /bin/sh " + handle.scriptPath + "' >/dev/null 2>&1 &"
+ wantBackground := backgroundRollbackCallKey(1, handle.scriptPath)
calls := env.cmd.CallsList()
if len(calls) != 1 || calls[0] != wantBackground {
t.Fatalf("unexpected calls: %#v", calls)
@@ -319,7 +319,7 @@ func TestArmHARollback_CoversSchedulingPaths(t *testing.T) {
t.Fatalf("expected unitName to be cleared after systemd-run failure, got %q", handle.unitName)
}
- wantBackground := "sh -c nohup sh -c 'sleep 2; /bin/sh " + scriptPath + "' >/dev/null 2>&1 &"
+ wantBackground := backgroundRollbackCallKey(2, scriptPath)
calls := env.cmd.CallsList()
if len(calls) != 2 || calls[0] != systemdKey || calls[1] != wantBackground {
t.Fatalf("unexpected calls: %#v", calls)
@@ -332,7 +332,7 @@ func TestArmHARollback_CoversSchedulingPaths(t *testing.T) {
timestamp := env.fakeTime.Current.Format("20060102_150405")
scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("ha_rollback_%s.sh", timestamp))
- backgroundKey := "sh -c nohup sh -c 'sleep 1; /bin/sh " + scriptPath + "' >/dev/null 2>&1 &"
+ backgroundKey := backgroundRollbackCallKey(1, scriptPath)
env.cmd.Errors = map[string]error{
backgroundKey: fmt.Errorf("boom"),
}
diff --git a/internal/orchestrator/restore_notifications.go b/internal/orchestrator/restore_notifications.go
index df841569..8c6305d7 100644
--- a/internal/orchestrator/restore_notifications.go
+++ b/internal/orchestrator/restore_notifications.go
@@ -6,6 +6,7 @@ import (
"fmt"
"os"
"path/filepath"
+ "regexp"
"strings"
"github.com/tis24dev/proxsave/internal/logging"
@@ -23,6 +24,8 @@ type proxmoxNotificationSection struct {
RedactFlags []string
}
+var sectionHeaderTypePattern = regexp.MustCompile(`^[A-Za-z0-9_-]+$`)
+
func maybeApplyNotificationsFromStage(ctx context.Context, logger *logging.Logger, plan *RestorePlan, stageRoot string, dryRun bool) (err error) {
if plan == nil {
return nil
@@ -401,7 +404,7 @@ func parseProxmoxNotificationSections(content string) ([]proxmoxNotificationSect
return out, nil
}
-func parseProxmoxNotificationHeader(line string) (typ, name string, ok bool) {
+func parseSectionHeader(line string) (typ, name string, ok bool) {
idx := strings.Index(line, ":")
if idx <= 0 {
return "", "", false
@@ -411,19 +414,16 @@ func parseProxmoxNotificationHeader(line string) (typ, name string, ok bool) {
if typ == "" || name == "" {
return "", "", false
}
- for _, r := range typ {
- switch {
- case r >= 'a' && r <= 'z':
- case r >= 'A' && r <= 'Z':
- case r >= '0' && r <= '9':
- case r == '-' || r == '_':
- default:
- return "", "", false
- }
+ if !sectionHeaderTypePattern.MatchString(typ) {
+ return "", "", false
}
return typ, name, true
}
+func parseProxmoxNotificationHeader(line string) (typ, name string, ok bool) {
+ return parseSectionHeader(line)
+}
+
func parseProxmoxNotificationKV(line string) (key, value string) {
fields := strings.Fields(line)
if len(fields) == 0 {
diff --git a/internal/orchestrator/restore_services.go b/internal/orchestrator/restore_services.go
new file mode 100644
index 00000000..1590a325
--- /dev/null
+++ b/internal/orchestrator/restore_services.go
@@ -0,0 +1,473 @@
+// Package orchestrator coordinates backup, restore, decrypt, and related workflows.
+package orchestrator
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/tis24dev/proxsave/internal/logging"
+)
+
+var (
+ serviceStopTimeout = 45 * time.Second
+ serviceStopNoBlockTimeout = 15 * time.Second
+ serviceStartTimeout = 30 * time.Second
+ serviceVerifyTimeout = 30 * time.Second
+ serviceStatusCheckTimeout = 5 * time.Second
+ servicePollInterval = 500 * time.Millisecond
+ serviceRetryDelay = 500 * time.Millisecond
+)
+
+type restoreCommandResult struct {
+ out []byte
+ err error
+}
+
+type restoreCommandProgress struct {
+ enabled bool
+ service string
+ action string
+ deadline time.Time
+}
+
+type serviceInactiveWaiter struct {
+ ctx context.Context
+ logger *logging.Logger
+ service string
+ timeout time.Duration
+ deadline time.Time
+ progressEnabled bool
+ ticker *time.Ticker
+}
+
+func stopPVEClusterServices(ctx context.Context, logger *logging.Logger) error {
+ services := []string{"pve-cluster", "pvedaemon", "pveproxy", "pvestatd"}
+ for _, service := range services {
+ if err := stopServiceWithRetries(ctx, logger, service); err != nil {
+ return fmt.Errorf("failed to stop PVE services (%s): %w", service, err)
+ }
+ }
+ return nil
+}
+
+func startPVEClusterServices(ctx context.Context, logger *logging.Logger) error {
+ services := []string{"pve-cluster", "pvedaemon", "pveproxy", "pvestatd"}
+ for _, service := range services {
+ if err := startServiceWithRetries(ctx, logger, service); err != nil {
+ return fmt.Errorf("failed to start PVE services (%s): %w", service, err)
+ }
+ }
+ return nil
+}
+
+func stopPBSServices(ctx context.Context, logger *logging.Logger) error {
+ if _, err := restoreCmd.Run(ctx, "which", "systemctl"); err != nil {
+ return fmt.Errorf("systemctl not available: %w", err)
+ }
+ services := []string{"proxmox-backup-proxy", "proxmox-backup"}
+ var failures []string
+ for _, service := range services {
+ if err := stopServiceWithRetries(ctx, logger, service); err != nil {
+ failures = append(failures, fmt.Sprintf("%s: %v", service, err))
+ }
+ }
+ if len(failures) > 0 {
+ return errors.New(strings.Join(failures, "; "))
+ }
+ return nil
+}
+
+func startPBSServices(ctx context.Context, logger *logging.Logger) error {
+ if _, err := restoreCmd.Run(ctx, "which", "systemctl"); err != nil {
+ return fmt.Errorf("systemctl not available: %w", err)
+ }
+ services := []string{"proxmox-backup", "proxmox-backup-proxy"}
+ var failures []string
+ for _, service := range services {
+ if err := startServiceWithRetries(ctx, logger, service); err != nil {
+ failures = append(failures, fmt.Sprintf("%s: %v", service, err))
+ }
+ }
+ if len(failures) > 0 {
+ return errors.New(strings.Join(failures, "; "))
+ }
+ return nil
+}
+
+func unmountEtcPVE(ctx context.Context, logger *logging.Logger) error {
+ output, err := restoreCmd.Run(ctx, "umount", "/etc/pve")
+ msg := strings.TrimSpace(string(output))
+ if err != nil {
+ if strings.Contains(msg, "not mounted") {
+ logger.Info("Skipping umount /etc/pve (already unmounted)")
+ return nil
+ }
+ if msg != "" {
+ return fmt.Errorf("umount /etc/pve failed: %s", msg)
+ }
+ return fmt.Errorf("umount /etc/pve failed: %w", err)
+ }
+ if msg != "" {
+ logger.Debug("umount /etc/pve output: %s", msg)
+ }
+ return nil
+}
+
+func runCommandWithTimeout(ctx context.Context, logger *logging.Logger, timeout time.Duration, name string, args ...string) error {
+ return execCommand(ctx, logger, timeout, name, args...)
+}
+
+func execCommand(ctx context.Context, logger *logging.Logger, timeout time.Duration, name string, args ...string) error {
+ execCtx, cancel := commandContextWithTimeout(ctx, timeout)
+ defer cancel()
+ output, err := restoreCmd.Run(execCtx, name, args...)
+ msg := strings.TrimSpace(string(output))
+ if err != nil {
+ return restoreCommandError(execCtx, timeout, name, args, msg, err)
+ }
+ logRestoreCommandOutput(logger, name, args, msg)
+ return nil
+}
+
+func commandContextWithTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
+ if timeout <= 0 {
+ return ctx, func() {}
+ }
+ return context.WithTimeout(ctx, timeout)
+}
+
+func restoreCommandError(execCtx context.Context, timeout time.Duration, name string, args []string, msg string, err error) error {
+ command := fmt.Sprintf("%s %s", name, strings.Join(args, " "))
+ if timeout > 0 && (errors.Is(execCtx.Err(), context.DeadlineExceeded) || errors.Is(err, context.DeadlineExceeded)) {
+ return fmt.Errorf("%s timed out after %s", command, timeout)
+ }
+ if msg != "" {
+ return fmt.Errorf("%s failed: %s", command, msg)
+ }
+ return fmt.Errorf("%s failed: %w", command, err)
+}
+
+func logRestoreCommandOutput(logger *logging.Logger, name string, args []string, msg string) {
+ if msg != "" && logger != nil {
+ logger.Debug("%s %s: %s", name, strings.Join(args, " "), msg)
+ }
+}
+
+func stopServiceWithRetries(ctx context.Context, logger *logging.Logger, service string) error {
+ attempts := []struct {
+ description string
+ args []string
+ timeout time.Duration
+ }{
+ {"stop (no-block)", []string{"stop", "--no-block", service}, serviceStopNoBlockTimeout},
+ {"stop (blocking)", []string{"stop", service}, serviceStopTimeout},
+ {"aggressive stop", []string{"kill", "--signal=SIGTERM", "--kill-who=all", service}, serviceStopTimeout},
+ {"force kill", []string{"kill", "--signal=SIGKILL", "--kill-who=all", service}, serviceStopTimeout},
+ }
+
+ var lastErr error
+ for i, attempt := range attempts {
+ if i > 0 {
+ if err := sleepWithContext(ctx, serviceRetryDelay); err != nil {
+ return err
+ }
+ }
+
+ if logger != nil {
+ logger.Debug("Attempting %s for %s (%d/%d)", attempt.description, service, i+1, len(attempts))
+ }
+
+ if err := runCommandWithTimeoutCountdown(ctx, logger, attempt.timeout, service, attempt.description, "systemctl", attempt.args...); err != nil {
+ lastErr = err
+ continue
+ }
+ if err := waitForServiceInactive(ctx, logger, service, serviceVerifyTimeout); err != nil {
+ lastErr = err
+ continue
+ }
+ resetFailedService(ctx, logger, service)
+ return nil
+ }
+
+ if lastErr == nil {
+ lastErr = fmt.Errorf("unable to stop %s", service)
+ }
+ return lastErr
+}
+
+func startServiceWithRetries(ctx context.Context, logger *logging.Logger, service string) error {
+ attempts := []struct {
+ description string
+ args []string
+ }{
+ {"start", []string{"start", service}},
+ {"retry start", []string{"start", service}},
+ {"aggressive restart", []string{"restart", service}},
+ }
+
+ var lastErr error
+ for i, attempt := range attempts {
+ if i > 0 {
+ if err := sleepWithContext(ctx, serviceRetryDelay); err != nil {
+ return err
+ }
+ }
+
+ if logger != nil {
+ logger.Debug("Attempting %s for %s (%d/%d)", attempt.description, service, i+1, len(attempts))
+ }
+
+ if err := runCommandWithTimeout(ctx, logger, serviceStartTimeout, "systemctl", attempt.args...); err != nil {
+ lastErr = err
+ continue
+ }
+ return nil
+ }
+
+ if lastErr == nil {
+ lastErr = fmt.Errorf("unable to start %s", service)
+ }
+ return lastErr
+}
+
+func runCommandWithTimeoutCountdown(ctx context.Context, logger *logging.Logger, timeout time.Duration, service, action, name string, args ...string) error {
+ if timeout <= 0 {
+ return execCommand(ctx, logger, timeout, name, args...)
+ }
+
+ execCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+
+ resultCh := startRestoreCommand(execCtx, name, args...)
+ progress := newRestoreCommandProgress(service, action, timeout)
+ ticker := time.NewTicker(1 * time.Second)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case r := <-resultCh:
+ progress.clear()
+ return finishRestoreCommandResult(execCtx, logger, timeout, name, args, r)
+ case <-ticker.C:
+ progress.write(time.Until(progress.deadline))
+ case <-execCtx.Done():
+ return finishRestoreCommandTimeout(logger, name, args, timeout, resultCh, progress)
+ }
+ }
+}
+
+func startRestoreCommand(ctx context.Context, name string, args ...string) <-chan restoreCommandResult {
+ resultCh := make(chan restoreCommandResult, 1)
+ go func() {
+ out, err := restoreCmd.Run(ctx, name, args...)
+ resultCh <- restoreCommandResult{out: out, err: err}
+ }()
+ return resultCh
+}
+
+func newRestoreCommandProgress(service, action string, timeout time.Duration) restoreCommandProgress {
+ return restoreCommandProgress{
+ enabled: isTerminal(int(os.Stderr.Fd())),
+ service: service,
+ action: action,
+ deadline: time.Now().Add(timeout),
+ }
+}
+
+func (progress restoreCommandProgress) write(left time.Duration) {
+ if !progress.enabled {
+ return
+ }
+ seconds := int(left.Round(time.Second).Seconds())
+ if seconds < 0 {
+ seconds = 0
+ }
+ fmt.Fprintf(os.Stderr, "\rStopping %s: %s (attempt timeout in %ds)...", progress.service, progress.action, seconds)
+}
+
+func (progress restoreCommandProgress) clear() {
+ if !progress.enabled {
+ return
+ }
+ fmt.Fprint(os.Stderr, "\r")
+ fmt.Fprintln(os.Stderr, strings.Repeat(" ", 80))
+ fmt.Fprint(os.Stderr, "\r")
+}
+
+func (progress restoreCommandProgress) newline() {
+ if progress.enabled {
+ fmt.Fprintln(os.Stderr)
+ }
+}
+
+func finishRestoreCommandResult(execCtx context.Context, logger *logging.Logger, timeout time.Duration, name string, args []string, result restoreCommandResult) error {
+ msg := strings.TrimSpace(string(result.out))
+ if result.err != nil {
+ return restoreCommandError(execCtx, timeout, name, args, msg, result.err)
+ }
+ logRestoreCommandOutput(logger, name, args, msg)
+ return nil
+}
+
+func finishRestoreCommandTimeout(logger *logging.Logger, name string, args []string, timeout time.Duration, resultCh <-chan restoreCommandResult, progress restoreCommandProgress) error {
+ progress.write(0)
+ progress.newline()
+ select {
+ case result := <-resultCh:
+ logRestoreCommandOutput(logger, name, args, strings.TrimSpace(string(result.out)))
+ default:
+ }
+ return fmt.Errorf("%s %s timed out after %s", name, strings.Join(args, " "), timeout)
+}
+
+func waitForServiceInactive(ctx context.Context, logger *logging.Logger, service string, timeout time.Duration) error {
+ if timeout <= 0 {
+ return nil
+ }
+ waiter := newServiceInactiveWaiter(ctx, logger, service, timeout)
+ defer waiter.ticker.Stop()
+ for {
+ remaining := time.Until(waiter.deadline)
+ if err := waiter.ensureTimeRemaining(remaining); err != nil {
+ return err
+ }
+ active, err := isServiceActive(ctx, service, minDuration(remaining, serviceStatusCheckTimeout))
+ if err != nil {
+ return err
+ }
+ if !active {
+ waiter.logStopped()
+ return nil
+ }
+ if err := waiter.sleepOrCancel(remaining); err != nil {
+ return err
+ }
+ waiter.writeProgress(remaining)
+ }
+}
+
+func newServiceInactiveWaiter(ctx context.Context, logger *logging.Logger, service string, timeout time.Duration) serviceInactiveWaiter {
+ return serviceInactiveWaiter{
+ ctx: ctx,
+ logger: logger,
+ service: service,
+ timeout: timeout,
+ deadline: time.Now().Add(timeout),
+ progressEnabled: isTerminal(int(os.Stderr.Fd())),
+ ticker: time.NewTicker(1 * time.Second),
+ }
+}
+
+func (waiter serviceInactiveWaiter) ensureTimeRemaining(remaining time.Duration) error {
+ if remaining > 0 {
+ return nil
+ }
+ waiter.writeNewline()
+ return fmt.Errorf("%s still active after %s", waiter.service, waiter.timeout)
+}
+
+func (waiter serviceInactiveWaiter) logStopped() {
+ if waiter.logger != nil {
+ waiter.logger.Debug("%s stopped successfully", waiter.service)
+ }
+}
+
+func (waiter serviceInactiveWaiter) sleepOrCancel(remaining time.Duration) error {
+ timer := time.NewTimer(minDuration(remaining, servicePollInterval))
+ defer timer.Stop()
+ select {
+ case <-waiter.ctx.Done():
+ waiter.writeNewline()
+ return waiter.ctx.Err()
+ case <-timer.C:
+ return nil
+ }
+}
+
+func (waiter serviceInactiveWaiter) writeProgress(remaining time.Duration) {
+ select {
+ case <-waiter.ticker.C:
+ if waiter.progressEnabled {
+ seconds := int(remaining.Round(time.Second).Seconds())
+ if seconds < 0 {
+ seconds = 0
+ }
+ fmt.Fprintf(os.Stderr, "\rWaiting for %s to stop (%ds remaining)...", waiter.service, seconds)
+ }
+ default:
+ }
+}
+
+func (waiter serviceInactiveWaiter) writeNewline() {
+ if waiter.progressEnabled {
+ fmt.Fprintln(os.Stderr)
+ }
+}
+
+func resetFailedService(ctx context.Context, logger *logging.Logger, service string) {
+ resetCtx, cancel := context.WithTimeout(ctx, serviceStatusCheckTimeout)
+ defer cancel()
+
+ if _, err := restoreCmd.Run(resetCtx, "systemctl", "reset-failed", service); err != nil {
+ if logger != nil {
+ logger.Debug("systemctl reset-failed %s ignored: %v", service, err)
+ }
+ }
+}
+
+func isServiceActive(ctx context.Context, service string, timeout time.Duration) (bool, error) {
+ if timeout <= 0 {
+ timeout = serviceStatusCheckTimeout
+ }
+ checkCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+
+ output, err := restoreCmd.Run(checkCtx, "systemctl", "is-active", service)
+ msg := strings.TrimSpace(string(output))
+ if err == nil {
+ return true, nil
+ }
+ if errors.Is(checkCtx.Err(), context.DeadlineExceeded) || errors.Is(err, context.DeadlineExceeded) {
+ return false, fmt.Errorf("systemctl is-active %s timed out after %s", service, timeout)
+ }
+ if msg == "" {
+ msg = err.Error()
+ }
+ return parseSystemctlActiveState(service, msg)
+}
+
+func parseSystemctlActiveState(service, msg string) (bool, error) {
+ lower := strings.ToLower(msg)
+ if strings.Contains(lower, "deactivating") || strings.Contains(lower, "activating") {
+ return true, nil
+ }
+ if strings.Contains(lower, "inactive") || strings.Contains(lower, "failed") || strings.Contains(lower, "dead") {
+ return false, nil
+ }
+ return false, fmt.Errorf("systemctl is-active %s failed: %s", service, msg)
+}
+
+func minDuration(a, b time.Duration) time.Duration {
+ if a < b {
+ return a
+ }
+ return b
+}
+
+func sleepWithContext(ctx context.Context, d time.Duration) error {
+ if d <= 0 {
+ return nil
+ }
+ timer := time.NewTimer(d)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer.C:
+ return nil
+ }
+}
diff --git a/internal/orchestrator/restore_test.go b/internal/orchestrator/restore_test.go
index 2a9c37a5..622057ca 100644
--- a/internal/orchestrator/restore_test.go
+++ b/internal/orchestrator/restore_test.go
@@ -63,6 +63,29 @@ func TestExtractTarEntry_BlocksPathTraversal(t *testing.T) {
}
}
+func TestShouldSkipRestoreEntryTargetEtcPVEBoundary(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ header := &tar.Header{Name: "etc/pveuser.conf"}
+
+ for _, target := range []string{"/etc/pve", "/etc/pve/local.cfg"} {
+ skip, err := shouldSkipRestoreEntryTarget(header, target, string(os.PathSeparator), logger)
+ if err != nil {
+ t.Fatalf("shouldSkipRestoreEntryTarget(%q) error: %v", target, err)
+ }
+ if !skip {
+ t.Fatalf("expected %q to be skipped", target)
+ }
+ }
+
+ skip, err := shouldSkipRestoreEntryTarget(header, "/etc/pveuser.conf", string(os.PathSeparator), logger)
+ if err != nil {
+ t.Fatalf("shouldSkipRestoreEntryTarget false-positive path error: %v", err)
+ }
+ if skip {
+ t.Fatalf("did not expect /etc/pveuser.conf to match /etc/pve guard")
+ }
+}
+
func TestExtractPlainArchive_WithFakeFS_RestoresFiles(t *testing.T) {
origRestoreFS := restoreFS
fakeFS := NewFakeFS()
@@ -748,6 +771,18 @@ func TestExtractDirectory_Success(t *testing.T) {
// extractHardlink tests
// --------------------------------------------------------------------------
+type recordingLinkFS struct {
+ *FakeFS
+ oldname string
+ newname string
+}
+
+func (f *recordingLinkFS) Link(oldname, newname string) error {
+ f.oldname = oldname
+ f.newname = newname
+ return f.FakeFS.Link(oldname, newname)
+}
+
func TestExtractHardlink_AbsoluteTargetRejected(t *testing.T) {
header := &tar.Header{
Name: "link",
@@ -774,6 +809,86 @@ func TestExtractHardlink_EscapesRoot(t *testing.T) {
}
}
+func TestExtractHardlink_UsesResolvedTargetPath(t *testing.T) {
+ orig := restoreFS
+ fakeFS := NewFakeFS()
+ recordingFS := &recordingLinkFS{FakeFS: fakeFS}
+ restoreFS = recordingFS
+ t.Cleanup(func() {
+ restoreFS = orig
+ _ = fakeFS.Cleanup()
+ })
+
+ destRoot := fakeFS.Root
+ realDir := filepath.Join(destRoot, "real")
+ if err := fakeFS.MkdirAll(realDir, 0o755); err != nil {
+ t.Fatalf("mkdir real dir: %v", err)
+ }
+ realTarget := filepath.Join(realDir, "target.txt")
+ if err := fakeFS.WriteFile(realTarget, []byte("test"), 0o644); err != nil {
+ t.Fatalf("write real target: %v", err)
+ }
+ if err := os.Symlink("real", filepath.Join(destRoot, "alias")); err != nil {
+ t.Fatalf("create alias symlink: %v", err)
+ }
+
+ header := &tar.Header{
+ Name: "hardlink.txt",
+ Linkname: filepath.Join("alias", "target.txt"),
+ Typeflag: tar.TypeLink,
+ }
+ linkFile := filepath.Join(destRoot, header.Name)
+
+ if err := extractHardlink(linkFile, header, destRoot); err != nil {
+ t.Fatalf("extractHardlink failed: %v", err)
+ }
+ if recordingFS.oldname != realTarget {
+ t.Fatalf("hardlink source = %q, want resolved target %q", recordingFS.oldname, realTarget)
+ }
+ if recordingFS.newname != linkFile {
+ t.Fatalf("hardlink destination = %q, want %q", recordingFS.newname, linkFile)
+ }
+
+ realInfo, err := os.Stat(realTarget)
+ if err != nil {
+ t.Fatalf("stat real target: %v", err)
+ }
+ linkInfo, err := os.Stat(linkFile)
+ if err != nil {
+ t.Fatalf("stat hardlink: %v", err)
+ }
+ if !os.SameFile(realInfo, linkInfo) {
+ t.Fatalf("hardlink does not point to resolved target")
+ }
+}
+
+func TestExtractHardlink_RejectsSymlinkEscapeTarget(t *testing.T) {
+ orig := restoreFS
+ restoreFS = osFS{}
+ t.Cleanup(func() { restoreFS = orig })
+
+ destRoot := t.TempDir()
+ outside := t.TempDir()
+ if err := os.Symlink(outside, filepath.Join(destRoot, "escape-link")); err != nil {
+ t.Fatalf("create escape symlink: %v", err)
+ }
+
+ header := &tar.Header{
+ Name: "link.txt",
+ Linkname: filepath.Join("escape-link", "target.txt"),
+ Typeflag: tar.TypeLink,
+ }
+ linkFile := filepath.Join(destRoot, header.Name)
+
+ err := extractHardlink(linkFile, header, destRoot)
+ if err == nil || !strings.Contains(err.Error(), "escapes root") {
+ t.Fatalf("expected escapes root error, got: %v", err)
+ }
+ if _, err := os.Lstat(linkFile); !os.IsNotExist(err) {
+ t.Fatalf("hardlink should not be created, got err=%v", err)
+ }
+}
+
func TestExtractHardlink_Success(t *testing.T) {
orig := restoreFS
t.Cleanup(func() { restoreFS = orig })
@@ -1065,6 +1180,9 @@ nfs: nfs-backup
if blocks[0].ID != "local" || blocks[1].ID != "nfs-backup" {
t.Fatalf("unexpected block IDs: %v, %v", blocks[0].ID, blocks[1].ID)
}
+ if blocks[0].Type != "dir" || blocks[1].Type != "nfs" {
+ t.Fatalf("unexpected block types: %v, %v", blocks[0].Type, blocks[1].Type)
+ }
}
func TestParseStorageBlocks_LegacyStoragePrefix(t *testing.T) {
@@ -1092,6 +1210,9 @@ func TestParseStorageBlocks_LegacyStoragePrefix(t *testing.T) {
if blocks[0].ID != "local" {
t.Fatalf("unexpected block ID: %v", blocks[0].ID)
}
+ if blocks[0].Type != "" {
+ t.Fatalf("legacy storage block type = %q, want empty because type is in entries", blocks[0].Type)
+ }
}
// --------------------------------------------------------------------------
@@ -1232,17 +1353,45 @@ func TestReadVMName_FileNotFound(t *testing.T) {
}
// --------------------------------------------------------------------------
-// detectNodeForVM tests
+// localNodeName tests
// --------------------------------------------------------------------------
-func TestDetectNodeForVM_ReturnsHostname(t *testing.T) {
- node := detectNodeForVM()
- // detectNodeForVM returns the current hostname, not the node from path
+func TestLocalNodeName_ReturnsHostname(t *testing.T) {
+ node := localNodeName()
if node == "" {
t.Fatalf("expected non-empty node from hostname")
}
}
+func TestPveshArgsFromColonConfigLinesStopsAtSectionHeader(t *testing.T) {
+ args := pveshArgsFromColonConfigLines([]string{
+ "name: vm100",
+ "memory: 2048",
+ "[snapshot]",
+ "parent: base",
+ "snaptime: 123",
+ })
+
+ got := strings.Join(args, " ")
+ if !strings.Contains(got, "--name=vm100") || !strings.Contains(got, "--memory=2048") {
+ t.Fatalf("expected pre-section args, got %v", args)
+ }
+ if strings.Contains(got, "parent") || strings.Contains(got, "snaptime") {
+ t.Fatalf("snapshot section args must be ignored, got %v", args)
+ }
+}
+
+func TestPveshCreateGuestArgsIncludesLXCOstemplate(t *testing.T) {
+ args, err := pveshCreateGuestArgs("node1", vmEntry{VMID: "101", Kind: "lxc"}, []string{"--hostname=ct101", "--ostemplate=local:vztmpl/debian.tar.zst"})
+ if err != nil {
+ t.Fatalf("pveshCreateGuestArgs error = %v", err)
+ }
+ got := strings.Join(args, " ")
+ if !strings.Contains(got, "create /nodes/node1/lxc --vmid=101") || !strings.Contains(got, "--ostemplate=local:vztmpl/debian.tar.zst") {
+ t.Fatalf("unexpected create args: %v", args)
+ }
+}
+
// --------------------------------------------------------------------------
// detectConfiguredZFSPools tests
// --------------------------------------------------------------------------
diff --git a/internal/orchestrator/restore_workflow_abort_test.go b/internal/orchestrator/restore_workflow_abort_test.go
index 032330a2..27215e07 100644
--- a/internal/orchestrator/restore_workflow_abort_test.go
+++ b/internal/orchestrator/restore_workflow_abort_test.go
@@ -14,6 +14,18 @@ import (
"github.com/tis24dev/proxsave/internal/types"
)
+func TestRunRestoreWorkflowWithUIClearsStaleAbortInfoBeforeValidation(t *testing.T) {
+ lastRestoreAbortInfo = &RestoreAbortInfo{NetworkRollbackArmed: true}
+ t.Cleanup(ClearRestoreAbortInfo)
+
+ if err := runRestoreWorkflowWithUI(context.Background(), nil, nil, "vtest", nil); err == nil {
+ t.Fatalf("expected configuration error")
+ }
+ if got := GetLastRestoreAbortInfo(); got != nil {
+ t.Fatalf("expected stale abort info to be cleared, got %#v", got)
+ }
+}
+
func TestRunRestoreWorkflow_FstabPromptInputAborted_AbortsWorkflow(t *testing.T) {
origRestoreFS := restoreFS
origRestoreCmd := restoreCmd
diff --git a/internal/orchestrator/restore_workflow_ui.go b/internal/orchestrator/restore_workflow_ui.go
index c5dd7c27..0fd91a27 100644
--- a/internal/orchestrator/restore_workflow_ui.go
+++ b/internal/orchestrator/restore_workflow_ui.go
@@ -1,3 +1,4 @@
+// Package orchestrator coordinates backup, restore, decrypt, and related workflows.
package orchestrator
import (
@@ -46,6 +47,7 @@ func prepareRestoreBundleWithUI(ctx context.Context, cfg *config.Config, logger
}
func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui RestoreWorkflowUI) (err error) {
+ ClearRestoreAbortInfo()
if cfg == nil {
return fmt.Errorf("configuration not available")
}
@@ -453,7 +455,13 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l
"./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json",
},
}}
- if err := extractArchiveNative(ctx, prepared.ArchivePath, fsTempDir, logger, invCategory, RestoreModeCustom, nil, "", nil); err != nil {
+ if err := extractArchiveNative(ctx, restoreArchiveOptions{
+ archivePath: prepared.ArchivePath,
+ destRoot: fsTempDir,
+ logger: logger,
+ categories: invCategory,
+ mode: RestoreModeCustom,
+ }); err != nil {
logger.Debug("Failed to extract fstab inventory data (continuing): %v", err)
}
@@ -509,7 +517,13 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l
"./var/lib/proxsave-info/commands/pve/mapping_dir.json",
},
}}
- if err := extractArchiveNative(ctx, prepared.ArchivePath, exportRoot, logger, safeInvCategory, RestoreModeCustom, nil, "", nil); err != nil {
+ if err := extractArchiveNative(ctx, restoreArchiveOptions{
+ archivePath: prepared.ArchivePath,
+ destRoot: exportRoot,
+ logger: logger,
+ categories: safeInvCategory,
+ mode: RestoreModeCustom,
+ }); err != nil {
logger.Debug("Failed to extract SAFE apply inventory (continuing): %v", err)
}
@@ -841,7 +855,7 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l
if hasCategoryID(plan.NormalCategories, "zfs") {
logger.Info("")
- if err := checkZFSPoolsAfterRestore(logger); err != nil {
+ if err := checkZFSPoolsAfterRestore(ctx, logger); err != nil {
logger.Warning("ZFS pool check: %v", err)
}
} else {
@@ -978,7 +992,13 @@ func runFullRestoreWithUI(ctx context.Context, ui RestoreWorkflowUI, candidate *
"./etc/fstab",
},
}}
- if err := extractArchiveNative(ctx, prepared.ArchivePath, fsTempDir, logger, fsCategory, RestoreModeCustom, nil, "", nil); err != nil {
+ if err := extractArchiveNative(ctx, restoreArchiveOptions{
+ archivePath: prepared.ArchivePath,
+ destRoot: fsTempDir,
+ logger: logger,
+ categories: fsCategory,
+ mode: RestoreModeCustom,
+ }); err != nil {
logger.Warning("Failed to extract filesystem config for merge: %v", err)
} else {
currentFstab := filepath.Join(destRoot, "etc", "fstab")
diff --git a/internal/orchestrator/restore_zfs.go b/internal/orchestrator/restore_zfs.go
new file mode 100644
index 00000000..e4bfbf12
--- /dev/null
+++ b/internal/orchestrator/restore_zfs.go
@@ -0,0 +1,223 @@
+// Package orchestrator coordinates backup, restore, decrypt, and related workflows.
+package orchestrator
+
+import (
+ "bufio"
+ "context"
+ "path/filepath"
+ "sort"
+ "strings"
+
+ "github.com/tis24dev/proxsave/internal/logging"
+)
+
+var restoreGlob = filepath.Glob
+
+// checkZFSPoolsAfterRestore checks if ZFS pools need to be imported after restore.
+func checkZFSPoolsAfterRestore(ctx context.Context, logger *logging.Logger) error {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ if _, err := restoreCmd.Run(ctx, "which", "zpool"); err != nil {
+ if ctxErr := ctx.Err(); ctxErr != nil {
+ return ctxErr
+ }
+ // zpool utility not available -> no ZFS tooling installed
+ return nil
+ }
+
+ logger.Info("Checking ZFS pool status...")
+
+ configuredPools := detectConfiguredZFSPools()
+ importablePools, importOutput, importErr := detectImportableZFSPools(ctx)
+ if importErr != nil {
+ if ctxErr := ctx.Err(); ctxErr != nil {
+ return ctxErr
+ }
+ }
+
+ logConfiguredZFSPools(logger, configuredPools)
+ logImportableZFSPools(logger, importablePools, importOutput, importErr)
+
+ if len(importablePools) == 0 {
+ return logNoImportableZFSPools(ctx, logger, configuredPools)
+ }
+
+ logManualZFSImportInstructions(logger, importablePools)
+ return nil
+}
+
+func logConfiguredZFSPools(logger *logging.Logger, configuredPools []string) {
+ if len(configuredPools) == 0 {
+ return
+ }
+ logger.Warning("Found %d ZFS pool(s) configured for automatic import:", len(configuredPools))
+ for _, pool := range configuredPools {
+ logger.Warning(" - %s", pool)
+ }
+ logger.Info("")
+}
+
+func logImportableZFSPools(logger *logging.Logger, importablePools []string, importOutput string, importErr error) {
+ if importErr != nil {
+ logger.Warning("`zpool import` command returned an error: %v", importErr)
+ if strings.TrimSpace(importOutput) != "" {
+ logger.Warning("`zpool import` output:\n%s", importOutput)
+ }
+ return
+ }
+ if len(importablePools) > 0 {
+ logger.Warning("`zpool import` reports pools waiting to be imported:")
+ for _, pool := range importablePools {
+ logger.Warning(" - %s", pool)
+ }
+ logger.Info("")
+ }
+}
+
+func logNoImportableZFSPools(ctx context.Context, logger *logging.Logger, configuredPools []string) error {
+ logger.Info("`zpool import` did not report pools waiting for import.")
+ if len(configuredPools) == 0 {
+ return nil
+ }
+ logger.Info("")
+ for _, pool := range configuredPools {
+ if _, err := restoreCmd.Run(ctx, "zpool", "status", pool); err == nil {
+ logger.Info("Pool %s is already imported (no manual action needed)", pool)
+ } else {
+ if ctxErr := ctx.Err(); ctxErr != nil {
+ return ctxErr
+ }
+ logger.Warning("Systemd expects pool %s, but `zpool import` and `zpool status` did not report it. Check disk visibility and pool status.", pool)
+ }
+ }
+ return nil
+}
+
+func logManualZFSImportInstructions(logger *logging.Logger, importablePools []string) {
+ logger.Info("⚠ IMPORTANT: ZFS pools may need manual import after restore!")
+ logger.Info(" Before rebooting, run these commands:")
+ logger.Info(" 1. Check available pools: zpool import")
+ for _, pool := range importablePools {
+ logger.Info(" 2. Import pool manually: zpool import %s", pool)
+ }
+ logger.Info(" 3. Verify pool status: zpool status")
+ logger.Info("")
+ logger.Info(" If pools fail to import, check:")
+ logger.Info(" - journalctl -u zfs-import@.service oppure import@.service")
+ logger.Info(" - zpool import -d /dev/disk/by-id")
+ logger.Info("")
+}
+
+func detectConfiguredZFSPools() []string {
+ pools := make(map[string]struct{})
+ addConfiguredZFSPoolsFromDirs(pools)
+ addConfiguredZFSPoolsFromGlobPatterns(pools)
+ return sortedPoolNames(pools)
+}
+
+func addConfiguredZFSPoolsFromDirs(pools map[string]struct{}) {
+ directories := []string{
+ "/etc/systemd/system/zfs-import.target.wants",
+ "/etc/systemd/system/multi-user.target.wants",
+ }
+
+ for _, dir := range directories {
+ entries, err := restoreFS.ReadDir(dir)
+ if err != nil {
+ continue
+ }
+
+ for _, entry := range entries {
+ if pool := parsePoolNameFromUnit(entry.Name()); pool != "" {
+ pools[pool] = struct{}{}
+ }
+ }
+ }
+}
+
+func addConfiguredZFSPoolsFromGlobPatterns(pools map[string]struct{}) {
+ globPatterns := []string{
+ "/etc/systemd/system/zfs-import@*.service",
+ "/etc/systemd/system/import@*.service",
+ }
+
+ for _, pattern := range globPatterns {
+ matches, err := restoreGlob(pattern)
+ if err != nil {
+ continue
+ }
+ for _, match := range matches {
+ if pool := parsePoolNameFromUnit(filepath.Base(match)); pool != "" {
+ pools[pool] = struct{}{}
+ }
+ }
+ }
+}
+
+func sortedPoolNames(pools map[string]struct{}) []string {
+ var poolNames []string
+ for pool := range pools {
+ poolNames = append(poolNames, pool)
+ }
+ sort.Strings(poolNames)
+ return poolNames
+}
+
+func parsePoolNameFromUnit(unitName string) string {
+ switch {
+ case strings.HasPrefix(unitName, "zfs-import@") && strings.HasSuffix(unitName, ".service"):
+ pool := strings.TrimPrefix(unitName, "zfs-import@")
+ return strings.TrimSuffix(pool, ".service")
+ case strings.HasPrefix(unitName, "import@") && strings.HasSuffix(unitName, ".service"):
+ pool := strings.TrimPrefix(unitName, "import@")
+ return strings.TrimSuffix(pool, ".service")
+ default:
+ return ""
+ }
+}
+
+func detectImportableZFSPools(ctx context.Context) ([]string, string, error) {
+ output, err := restoreCmd.Run(ctx, "zpool", "import")
+ poolNames := parseZpoolImportOutput(string(output))
+ if err != nil {
+ return poolNames, string(output), err
+ }
+ return poolNames, string(output), nil
+}
+
+func parseZpoolImportOutput(output string) []string {
+ var pools []string
+ scanner := bufio.NewScanner(strings.NewReader(output))
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if strings.HasPrefix(strings.ToLower(line), "pool:") {
+ pool := strings.TrimSpace(line[len("pool:"):])
+ if pool != "" {
+ pools = append(pools, pool)
+ }
+ }
+ }
+ return pools
+}
+
+func combinePoolNames(a, b []string) []string {
+ merged := make(map[string]struct{})
+ for _, pool := range a {
+ merged[pool] = struct{}{}
+ }
+ for _, pool := range b {
+ merged[pool] = struct{}{}
+ }
+
+ if len(merged) == 0 {
+ return nil
+ }
+
+ names := make([]string, 0, len(merged))
+ for pool := range merged {
+ names = append(names, pool)
+ }
+ sort.Strings(names)
+ return names
+}
diff --git a/internal/orchestrator/temp_registry_test.go b/internal/orchestrator/temp_registry_test.go
index 071a4bc8..37eafc45 100644
--- a/internal/orchestrator/temp_registry_test.go
+++ b/internal/orchestrator/temp_registry_test.go
@@ -15,8 +15,6 @@ func newTestLogger() *logging.Logger {
}
func TestTempDirRegistryRegisterAndDeregister(t *testing.T) {
- t.Parallel()
-
regPath := filepath.Join(t.TempDir(), "temp-dirs.json")
registry, err := NewTempDirRegistry(newTestLogger(), regPath)
if err != nil {
@@ -54,8 +52,6 @@ func TestTempDirRegistryRegisterAndDeregister(t *testing.T) {
}
func TestTempDirRegistryCleanupOrphaned(t *testing.T) {
- t.Parallel()
-
regPath := filepath.Join(t.TempDir(), "temp-dirs.json")
registry, err := NewTempDirRegistry(newTestLogger(), regPath)
if err != nil {
diff --git a/internal/orchestrator/tui_simulation_test.go b/internal/orchestrator/tui_simulation_test.go
index b846286d..20d2cfdd 100644
--- a/internal/orchestrator/tui_simulation_test.go
+++ b/internal/orchestrator/tui_simulation_test.go
@@ -13,7 +13,10 @@ import (
"github.com/tis24dev/proxsave/internal/tui"
)
-const simAppInitialDrawTimeout = 2 * time.Second
+const (
+ simAppInitialDrawTimeout = 2 * time.Second
+ simAppCompletionTimeout = 10 * time.Second
+)
type simKey struct {
Key tcell.Key
@@ -35,9 +38,23 @@ func withSimAppSequence(t *testing.T, keys []simKey) <-chan struct{} {
done := make(chan struct{})
var injectOnce sync.Once
var injectWG sync.WaitGroup
+ var appMu sync.RWMutex
+ var currentApp *tui.App
+
+ stopCurrentApp := func() {
+ appMu.RLock()
+ app := currentApp
+ appMu.RUnlock()
+ if app != nil {
+ app.Stop()
+ }
+ }
newTUIApp = func() *tui.App {
app := tui.NewApp()
+ appMu.Lock()
+ currentApp = app
+ appMu.Unlock()
app.SetScreen(screen)
readyCh := make(chan struct{})
var readyOnce sync.Once
@@ -68,6 +85,8 @@ func withSimAppSequence(t *testing.T, keys []simKey) <-chan struct{} {
case <-done:
return
case <-timer.C:
+ t.Errorf("TUI simulation did not render its initial draw within %s", simAppInitialDrawTimeout)
+ stopCurrentApp()
return
}
@@ -83,12 +102,27 @@ func withSimAppSequence(t *testing.T, keys []simKey) <-chan struct{} {
}
screen.InjectKey(k.Key, k.R, mod)
}
+
+ if !timer.Stop() {
+ select {
+ case <-timer.C:
+ default:
+ }
+ }
+ timer.Reset(simAppCompletionTimeout)
+ select {
+ case <-done:
+ case <-timer.C:
+ t.Errorf("TUI simulation did not finish within %s after injecting %d key(s)", simAppCompletionTimeout, len(keys))
+ stopCurrentApp()
+ }
}()
})
return app
}
t.Cleanup(func() {
+ stopCurrentApp()
close(done)
injectWG.Wait()
newTUIApp = orig
diff --git a/internal/orchestrator/unescape_proc_path_test.go b/internal/orchestrator/unescape_proc_path_test.go
index 86c7352b..8e0a1676 100644
--- a/internal/orchestrator/unescape_proc_path_test.go
+++ b/internal/orchestrator/unescape_proc_path_test.go
@@ -3,8 +3,6 @@ package orchestrator
import "testing"
func TestUnescapeProcPath(t *testing.T) {
- t.Parallel()
-
tests := []struct {
name string
in string
@@ -25,7 +23,6 @@ func TestUnescapeProcPath(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
if got := unescapeProcPath(tt.in); got != tt.want {
t.Fatalf("unescapeProcPath(%q)=%q want %q", tt.in, got, tt.want)
}
diff --git a/internal/orchestrator/workflow_ui_tui_decrypt.go b/internal/orchestrator/workflow_ui_tui_decrypt.go
index 531571d5..fa0483a5 100644
--- a/internal/orchestrator/workflow_ui_tui_decrypt.go
+++ b/internal/orchestrator/workflow_ui_tui_decrypt.go
@@ -5,6 +5,7 @@ import (
"fmt"
"path/filepath"
"strings"
+ "sync"
"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"
@@ -86,35 +87,67 @@ func (u *tuiWorkflowUI) RunTask(ctx context.Context, title, initialMessage strin
form.SetParentView(page)
done := make(chan struct{})
+ started := make(chan struct{})
+ var startOnce sync.Once
var runErr error
+ queueProgressUpdate := func(update func()) {
+ select {
+ case <-taskCtx.Done():
+ return
+ default:
+ }
+ go func() {
+ select {
+ case <-taskCtx.Done():
+ return
+ default:
+ }
+ app.QueueUpdateDraw(update)
+ }()
+ }
+
report := func(message string) {
message = strings.TrimSpace(message)
if message == "" {
return
}
- app.QueueUpdateDraw(func() {
+ queueProgressUpdate(func() {
messageView.SetText(tview.Escape(message))
})
}
- go func() {
- runErr = run(taskCtx, report)
- close(done)
- app.QueueUpdateDraw(func() {
- app.Stop()
+ startTask := func() {
+ startOnce.Do(func() {
+ close(started)
+ go func() {
+ runErr = run(taskCtx, report)
+ close(done)
+ app.Stop()
+ }()
})
- }()
+ }
app.SetRoot(page, true).SetFocus(form.Form)
+ app.SetAfterDrawFunc(func(screen tcell.Screen) {
+ startTask()
+ })
if err := app.RunWithContext(taskCtx); err != nil {
cancel()
- <-done
+ select {
+ case <-started:
+ <-done
+ default:
+ }
return err
}
cancel()
- <-done
+ select {
+ case <-started:
+ <-done
+ default:
+ }
return runErr
}
diff --git a/internal/pbs/namespaces.go b/internal/pbs/namespaces.go
index d168bef4..973074f8 100644
--- a/internal/pbs/namespaces.go
+++ b/internal/pbs/namespaces.go
@@ -6,14 +6,14 @@ import (
"encoding/json"
"errors"
"fmt"
- "os/exec"
"path/filepath"
"time"
+ "github.com/tis24dev/proxsave/internal/safeexec"
"github.com/tis24dev/proxsave/internal/safefs"
)
-var execCommand = exec.CommandContext
+var execCommand = safeexec.CommandContext
// Namespace represents a single PBS namespace.
type Namespace struct {
@@ -57,7 +57,7 @@ func listNamespacesViaCLI(ctx context.Context, datastore string) ([]Namespace, e
return nil, err
}
- cmd := execCommand(
+ cmd, cmdErr := execCommand(
ctx,
"proxmox-backup-manager",
"datastore",
@@ -66,6 +66,9 @@ func listNamespacesViaCLI(ctx context.Context, datastore string) ([]Namespace, e
datastore,
"--output-format=json",
)
+ if cmdErr != nil {
+ return nil, cmdErr
+ }
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
diff --git a/internal/pbs/namespaces_test.go b/internal/pbs/namespaces_test.go
index f151caeb..22e7a473 100644
--- a/internal/pbs/namespaces_test.go
+++ b/internal/pbs/namespaces_test.go
@@ -3,6 +3,7 @@ package pbs
import (
"context"
"encoding/json"
+ "errors"
"fmt"
"os"
"os/exec"
@@ -192,6 +193,13 @@ func TestListNamespacesViaCLI_ErrorIncludesStderr(t *testing.T) {
}
}
+func TestListNamespacesViaCLI_ExecCommandError(t *testing.T) {
+ setExecCommandStub(t, "cmd-failure")
+ if _, err := listNamespacesViaCLI(context.Background(), "dummy"); err == nil || !strings.Contains(err.Error(), "simulated execCommand failure") {
+ t.Fatalf("expected execCommand error, got %v", err)
+ }
+}
+
func TestHelperProcess(t *testing.T) {
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
return
@@ -213,13 +221,22 @@ func TestHelperProcess(t *testing.T) {
func setExecCommandStub(t *testing.T, scenario string) {
t.Helper()
original := execCommand
- execCommand = func(context.Context, string, ...string) *exec.Cmd {
+ if scenario == "cmd-failure" {
+ execCommand = func(context.Context, string, ...string) (*exec.Cmd, error) {
+ return nil, errors.New("simulated execCommand failure")
+ }
+ t.Cleanup(func() {
+ execCommand = original
+ })
+ return
+ }
+ execCommand = func(context.Context, string, ...string) (*exec.Cmd, error) {
cmd := exec.Command(os.Args[0], "-test.run=TestHelperProcess", "--")
cmd.Env = append(os.Environ(),
"GO_WANT_HELPER_PROCESS=1",
"PBS_HELPER_SCENARIO="+scenario,
)
- return cmd
+ return cmd, nil
}
t.Cleanup(func() {
execCommand = original
diff --git a/internal/safeexec/safeexec.go b/internal/safeexec/safeexec.go
new file mode 100644
index 00000000..b4960255
--- /dev/null
+++ b/internal/safeexec/safeexec.go
@@ -0,0 +1,281 @@
+package safeexec
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "path"
+ "path/filepath"
+ "strings"
+ "unicode"
+)
+
+var ErrCommandNotAllowed = errors.New("command not allowed")
+
+// CommandContext creates commands only for binaries that are intentionally
+// allowed by the application. Keep exec.CommandContext calls in the switch so
+// static analyzers can see literal command names.
+func CommandContext(ctx context.Context, name string, args ...string) (*exec.Cmd, error) {
+ if strings.TrimSpace(name) != name || name == "" || strings.ContainsAny(name, `/\`) {
+ return nil, fmt.Errorf("%w: %q", ErrCommandNotAllowed, name)
+ }
+
+ switch name {
+ case "apt-cache":
+ return exec.CommandContext(ctx, "apt-cache", args...), nil
+ case "blkid":
+ return exec.CommandContext(ctx, "blkid", args...), nil
+ case "bridge":
+ return exec.CommandContext(ctx, "bridge", args...), nil
+ case "bzip2":
+ return exec.CommandContext(ctx, "bzip2", args...), nil
+ case "cat":
+ return exec.CommandContext(ctx, "cat", args...), nil
+ case "ceph":
+ return exec.CommandContext(ctx, "ceph", args...), nil
+ case "chattr":
+ return exec.CommandContext(ctx, "chattr", args...), nil
+ case "crontab":
+ return exec.CommandContext(ctx, "crontab", args...), nil
+ case "df":
+ return exec.CommandContext(ctx, "df", args...), nil
+ case "dmidecode":
+ return exec.CommandContext(ctx, "dmidecode", args...), nil
+ case "dpkg":
+ return exec.CommandContext(ctx, "dpkg", args...), nil
+ case "dpkg-query":
+ return exec.CommandContext(ctx, "dpkg-query", args...), nil
+ case "echo":
+ return exec.CommandContext(ctx, "echo", args...), nil
+ case "ethtool":
+ return exec.CommandContext(ctx, "ethtool", args...), nil
+ case "false":
+ return exec.CommandContext(ctx, "false", args...), nil
+ case "firewall-cmd":
+ return exec.CommandContext(ctx, "firewall-cmd", args...), nil
+ case "free":
+ return exec.CommandContext(ctx, "free", args...), nil
+ case "hostname":
+ return exec.CommandContext(ctx, "hostname", args...), nil
+ case "ifreload":
+ return exec.CommandContext(ctx, "ifreload", args...), nil
+ case "ifup":
+ return exec.CommandContext(ctx, "ifup", args...), nil
+ case "ip":
+ return exec.CommandContext(ctx, "ip", args...), nil
+ case "iptables":
+ return exec.CommandContext(ctx, "iptables", args...), nil
+ case "iptables-save":
+ return exec.CommandContext(ctx, "iptables-save", args...), nil
+ case "ip6tables":
+ return exec.CommandContext(ctx, "ip6tables", args...), nil
+ case "ip6tables-save":
+ return exec.CommandContext(ctx, "ip6tables-save", args...), nil
+ case "journalctl":
+ return exec.CommandContext(ctx, "journalctl", args...), nil
+ case "lsblk":
+ return exec.CommandContext(ctx, "lsblk", args...), nil
+ case "lspci":
+ return exec.CommandContext(ctx, "lspci", args...), nil
+ case "lscpu":
+ return exec.CommandContext(ctx, "lscpu", args...), nil
+ case "lsmod":
+ return exec.CommandContext(ctx, "lsmod", args...), nil
+ case "lsusb":
+ return exec.CommandContext(ctx, "lsusb", args...), nil
+ case "lvs":
+ return exec.CommandContext(ctx, "lvs", args...), nil
+ case "lzma":
+ return exec.CommandContext(ctx, "lzma", args...), nil
+ case "mailq":
+ return exec.CommandContext(ctx, "mailq", args...), nil
+ case "mount":
+ return exec.CommandContext(ctx, "mount", args...), nil
+ case "mountpoint":
+ return exec.CommandContext(ctx, "mountpoint", args...), nil
+ case "nft":
+ return exec.CommandContext(ctx, "nft", args...), nil
+ case "pbzip2":
+ return exec.CommandContext(ctx, "pbzip2", args...), nil
+ case "pgrep":
+ return exec.CommandContext(ctx, "pgrep", args...), nil
+ case "pigz":
+ return exec.CommandContext(ctx, "pigz", args...), nil
+ case "ping":
+ return exec.CommandContext(ctx, "ping", args...), nil
+ case "pvs":
+ return exec.CommandContext(ctx, "pvs", args...), nil
+ case "proxmox-backup-client":
+ return exec.CommandContext(ctx, "proxmox-backup-client", args...), nil
+ case "proxmox-backup-manager":
+ return exec.CommandContext(ctx, "proxmox-backup-manager", args...), nil
+ case "proxmox-mail-forward":
+ return exec.CommandContext(ctx, "proxmox-mail-forward", args...), nil
+ case "proxmox-tape":
+ return exec.CommandContext(ctx, "proxmox-tape", args...), nil
+ case "ps":
+ return exec.CommandContext(ctx, "ps", args...), nil
+ case "pvecm":
+ return exec.CommandContext(ctx, "pvecm", args...), nil
+ case "pve-firewall":
+ return exec.CommandContext(ctx, "pve-firewall", args...), nil
+ case "pvenode":
+ return exec.CommandContext(ctx, "pvenode", args...), nil
+ case "pvesh":
+ return exec.CommandContext(ctx, "pvesh", args...), nil
+ case "pvesm":
+ return exec.CommandContext(ctx, "pvesm", args...), nil
+ case "pveum":
+ return exec.CommandContext(ctx, "pveum", args...), nil
+ case "pveversion":
+ return exec.CommandContext(ctx, "pveversion", args...), nil
+ case "rclone":
+ return exec.CommandContext(ctx, "rclone", args...), nil
+ case "sendmail":
+ return exec.CommandContext(ctx, "sendmail", args...), nil
+ case "sensors":
+ return exec.CommandContext(ctx, "sensors", args...), nil
+ case "sh":
+ return exec.CommandContext(ctx, "sh", args...), nil
+ case "smartctl":
+ return exec.CommandContext(ctx, "smartctl", args...), nil
+ case "ss":
+ return exec.CommandContext(ctx, "ss", args...), nil
+ case "systemctl":
+ return exec.CommandContext(ctx, "systemctl", args...), nil
+ case "systemd-run":
+ return exec.CommandContext(ctx, "systemd-run", args...), nil
+ case "sysctl":
+ return exec.CommandContext(ctx, "sysctl", args...), nil
+ case "tail":
+ return exec.CommandContext(ctx, "tail", args...), nil
+ case "tar":
+ return exec.CommandContext(ctx, "tar", args...), nil
+ case "udevadm":
+ return exec.CommandContext(ctx, "udevadm", args...), nil
+ case "umount":
+ return exec.CommandContext(ctx, "umount", args...), nil
+ case "uname":
+ return exec.CommandContext(ctx, "uname", args...), nil
+ case "ufw":
+ return exec.CommandContext(ctx, "ufw", args...), nil
+ case "vgs":
+ return exec.CommandContext(ctx, "vgs", args...), nil
+ case "which":
+ return exec.CommandContext(ctx, "which", args...), nil
+ case "xz":
+ return exec.CommandContext(ctx, "xz", args...), nil
+ case "zfs":
+ return exec.CommandContext(ctx, "zfs", args...), nil
+ case "zpool":
+ return exec.CommandContext(ctx, "zpool", args...), nil
+ case "zstd":
+ return exec.CommandContext(ctx, "zstd", args...), nil
+ default:
+ return nil, fmt.Errorf("%w: %q", ErrCommandNotAllowed, name)
+ }
+}
+
+func CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error) {
+ cmd, err := CommandContext(ctx, name, args...)
+ if err != nil {
+ return nil, err
+ }
+ return cmd.CombinedOutput()
+}
+
+func Output(ctx context.Context, name string, args ...string) ([]byte, error) {
+ cmd, err := CommandContext(ctx, name, args...)
+ if err != nil {
+ return nil, err
+ }
+ return cmd.Output()
+}
+
+func TrustedCommandContext(ctx context.Context, execPath string, args ...string) (*exec.Cmd, error) {
+ if err := ValidateTrustedExecutablePath(execPath); err != nil {
+ return nil, err
+ }
+ // #nosec G204 -- execPath is absolute, regular, executable, and not world-writable.
+ return exec.CommandContext(ctx, execPath, args...), nil // nosemgrep: go.lang.security.audit.dangerous-exec-command.dangerous-exec-command
+}
+
+func ValidateTrustedExecutablePath(execPath string) error {
+ clean := strings.TrimSpace(execPath)
+ if clean == "" {
+ return fmt.Errorf("executable path is empty")
+ }
+ if !filepath.IsAbs(clean) {
+ return fmt.Errorf("executable path must be absolute: %s", execPath)
+ }
+ info, err := os.Stat(clean)
+ if err != nil {
+ return fmt.Errorf("stat executable path: %w", err)
+ }
+ if !info.Mode().IsRegular() {
+ return fmt.Errorf("executable path is not a regular file: %s", clean)
+ }
+ if info.Mode().Perm()&0o111 == 0 {
+ return fmt.Errorf("executable path is not executable: %s", clean)
+ }
+ if info.Mode().Perm()&0o002 != 0 {
+ return fmt.Errorf("executable path is world-writable: %s", clean)
+ }
+ return nil
+}
+
+func ValidateRcloneRemoteName(remote string) error {
+ if remote == "" {
+ return fmt.Errorf("rclone remote name is empty")
+ }
+ if strings.HasPrefix(remote, "-") {
+ return fmt.Errorf("rclone remote name must not start with '-'")
+ }
+ if strings.ContainsAny(remote, `/\:`) {
+ return fmt.Errorf("rclone remote name contains a path separator or colon")
+ }
+ for _, r := range remote {
+ if unicode.IsSpace(r) || unicode.IsControl(r) {
+ return fmt.Errorf("rclone remote name contains whitespace or control characters")
+ }
+ }
+ return nil
+}
+
+func ValidateRemoteRelativePath(value, field string) error {
+ clean := strings.TrimSpace(value)
+ if clean == "" {
+ return nil
+ }
+ for _, r := range clean {
+ if unicode.IsControl(r) {
+ return fmt.Errorf("%s contains control characters", field)
+ }
+ }
+ normalized := path.Clean(strings.Trim(clean, "/"))
+ if normalized == "." {
+ return nil
+ }
+ if strings.HasPrefix(normalized, "../") || normalized == ".." {
+ return fmt.Errorf("%s must not traverse outside the configured remote", field)
+ }
+ return nil
+}
+
+func ProcPath(pid int, leaf string) (string, error) {
+ if pid <= 0 {
+ return "", fmt.Errorf("pid must be positive")
+ }
+ switch leaf {
+ case "comm":
+ return fmt.Sprintf("/proc/%d/comm", pid), nil
+ case "status":
+ return fmt.Sprintf("/proc/%d/status", pid), nil
+ case "exe":
+ return fmt.Sprintf("/proc/%d/exe", pid), nil
+ default:
+ return "", fmt.Errorf("unsupported proc leaf: %s", leaf)
+ }
+}
diff --git a/internal/safeexec/safeexec_test.go b/internal/safeexec/safeexec_test.go
new file mode 100644
index 00000000..3144ad80
--- /dev/null
+++ b/internal/safeexec/safeexec_test.go
@@ -0,0 +1,110 @@
+package safeexec
+
+import (
+ "context"
+ "errors"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestCommandContextAllowlist(t *testing.T) {
+ allowedCommands := []string{
+ "rclone",
+ "tar",
+ "xz",
+ "zstd",
+ "systemctl",
+ "mailq",
+ "tail",
+ "journalctl",
+ "pvesh",
+ "pveum",
+ "proxmox-backup-manager",
+ }
+ for _, command := range allowedCommands {
+ if _, err := CommandContext(context.Background(), command); err != nil {
+ t.Fatalf("CommandContext(%q) allowed command error: %v", command, err)
+ }
+ }
+ if _, err := CommandContext(context.Background(), "not-a-proxsave-command"); !errors.Is(err, ErrCommandNotAllowed) {
+ t.Fatalf("CommandContext unknown command error = %v, want ErrCommandNotAllowed", err)
+ }
+ if _, err := CommandContext(context.Background(), "/bin/sh"); !errors.Is(err, ErrCommandNotAllowed) {
+ t.Fatalf("CommandContext path command error = %v, want ErrCommandNotAllowed", err)
+ }
+}
+
+func TestValidateTrustedExecutablePath(t *testing.T) {
+ dir := t.TempDir()
+ execPath := filepath.Join(dir, "proxsave")
+ if err := os.WriteFile(execPath, []byte("#!/bin/sh\nexit 0\n"), 0o700); err != nil {
+ t.Fatal(err)
+ }
+ if err := ValidateTrustedExecutablePath(execPath); err != nil {
+ t.Fatalf("ValidateTrustedExecutablePath valid error: %v", err)
+ }
+
+ if err := ValidateTrustedExecutablePath("relative"); err == nil {
+ t.Fatalf("expected relative path to be rejected")
+ }
+
+ worldWritable := filepath.Join(dir, "ww")
+ if err := os.WriteFile(worldWritable, []byte("#!/bin/sh\nexit 0\n"), 0o777); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.Chmod(worldWritable, 0o777); err != nil {
+ t.Fatal(err)
+ }
+ if err := ValidateTrustedExecutablePath(worldWritable); err == nil {
+ t.Fatalf("expected world-writable executable to be rejected")
+ }
+}
+
+func TestValidateRcloneRemoteName(t *testing.T) {
+ valid := []string{"remote", "s3backup_01", "gdrive-prod"}
+ for _, name := range valid {
+ if err := ValidateRcloneRemoteName(name); err != nil {
+ t.Fatalf("ValidateRcloneRemoteName(%q) error: %v", name, err)
+ }
+ }
+
+ invalid := []string{"", "-remote", "bad remote", "bad/remote", "bad:remote", "bad\nremote"}
+ for _, name := range invalid {
+ if err := ValidateRcloneRemoteName(name); err == nil {
+ t.Fatalf("ValidateRcloneRemoteName(%q) expected error", name)
+ }
+ }
+}
+
+func TestValidateRemoteRelativePath(t *testing.T) {
+ valid := []string{"", "tenant/a", "/tenant/a/", "tenant with spaces/a"}
+ for _, value := range valid {
+ if err := ValidateRemoteRelativePath(value, "path"); err != nil {
+ t.Fatalf("ValidateRemoteRelativePath(%q) error: %v", value, err)
+ }
+ }
+
+ invalid := []string{"../escape", "tenant/../../escape", "bad\npath"}
+ for _, value := range invalid {
+ if err := ValidateRemoteRelativePath(value, "path"); err == nil {
+ t.Fatalf("ValidateRemoteRelativePath(%q) expected error", value)
+ }
+ }
+}
+
+func TestProcPath(t *testing.T) {
+ got, err := ProcPath(123, "status")
+ if err != nil {
+ t.Fatalf("ProcPath valid error: %v", err)
+ }
+ if got != "/proc/123/status" {
+ t.Fatalf("ProcPath = %q", got)
+ }
+ if _, err := ProcPath(0, "status"); err == nil {
+ t.Fatalf("expected pid 0 to be rejected")
+ }
+ if _, err := ProcPath(123, "../status"); err == nil {
+ t.Fatalf("expected unsupported leaf to be rejected")
+ }
+}
diff --git a/internal/security/procscan.go b/internal/security/procscan.go
index 0ea3ab98..b8755fa4 100644
--- a/internal/security/procscan.go
+++ b/internal/security/procscan.go
@@ -6,6 +6,8 @@ import (
"path/filepath"
"regexp"
"strings"
+
+ "github.com/tis24dev/proxsave/internal/safeexec"
)
// Heuristic detection for safe kernel-style processes.
@@ -28,25 +30,28 @@ type procInfo struct {
func readProcInfo(pid int) procInfo {
info := procInfo{}
- commPath := fmt.Sprintf("/proc/%d/comm", pid)
- if data, err := os.ReadFile(commPath); err == nil {
- info.comm = strings.TrimSpace(string(data))
+ if commPath, err := safeexec.ProcPath(pid, "comm"); err == nil {
+ if data, err := os.ReadFile(commPath); err == nil {
+ info.comm = strings.TrimSpace(string(data))
+ }
}
- statusPath := fmt.Sprintf("/proc/%d/status", pid)
- if data, err := os.ReadFile(statusPath); err == nil {
- lines := strings.Split(string(data), "\n")
- for _, line := range lines {
- if strings.HasPrefix(line, "PPid:") {
- _, _ = fmt.Sscanf(line, "PPid:\t%d", &info.ppid)
- break
+ if statusPath, err := safeexec.ProcPath(pid, "status"); err == nil {
+ if data, err := os.ReadFile(statusPath); err == nil {
+ lines := strings.Split(string(data), "\n")
+ for _, line := range lines {
+ if strings.HasPrefix(line, "PPid:") {
+ _, _ = fmt.Sscanf(line, "PPid:\t%d", &info.ppid)
+ break
+ }
}
}
}
- exePath := fmt.Sprintf("/proc/%d/exe", pid)
- if target, err := filepath.EvalSymlinks(exePath); err == nil {
- info.exe = target
+ if exePath, err := safeexec.ProcPath(pid, "exe"); err == nil {
+ if target, err := filepath.EvalSymlinks(exePath); err == nil {
+ info.exe = target
+ }
}
return info
diff --git a/internal/security/security.go b/internal/security/security.go
index 152c1802..5104a469 100644
--- a/internal/security/security.go
+++ b/internal/security/security.go
@@ -21,6 +21,7 @@ import (
"github.com/tis24dev/proxsave/internal/config"
"github.com/tis24dev/proxsave/internal/environment"
"github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/safeexec"
"github.com/tis24dev/proxsave/internal/types"
)
@@ -634,7 +635,11 @@ func (c *Checker) checkFirewall(ctx context.Context) {
return
}
- cmd := exec.CommandContext(ctx, "iptables", "-L", "-n")
+ cmd, err := safeexec.CommandContext(ctx, "iptables", "-L", "-n")
+ if err != nil {
+ c.addWarning("Failed to prepare iptables command: %v", err)
+ return
+ }
output, err := cmd.Output()
if err != nil {
c.addWarning("Failed to run iptables -L -n: %v", err)
@@ -664,7 +669,11 @@ func (c *Checker) checkOpenPorts(ctx context.Context) {
return
}
- cmd := exec.CommandContext(ctx, "ss", "-tulnap")
+ cmd, err := safeexec.CommandContext(ctx, "ss", "-tulnap")
+ if err != nil {
+ c.addWarning("Failed to prepare 'ss -tulnap': %v", err)
+ return
+ }
output, err := cmd.Output()
if err != nil {
c.addWarning("Failed to execute 'ss -tulnap': %v", err)
@@ -700,7 +709,10 @@ func (c *Checker) checkOpenPortsAgainstSuspiciousList(ctx context.Context) {
if _, err := exec.LookPath("ss"); err != nil {
return
}
- cmd := exec.CommandContext(ctx, "ss", "-tuln")
+ cmd, err := safeexec.CommandContext(ctx, "ss", "-tuln")
+ if err != nil {
+ return
+ }
output, err := cmd.Output()
if err != nil {
return
@@ -721,7 +733,11 @@ func (c *Checker) checkOpenPortsAgainstSuspiciousList(ctx context.Context) {
}
func (c *Checker) checkSuspiciousProcesses(ctx context.Context) {
- cmd := exec.CommandContext(ctx, "ps", "-eo", "user=,state=,vsz=,pid=,command=")
+ cmd, err := safeexec.CommandContext(ctx, "ps", "-eo", "user=,state=,vsz=,pid=,command=")
+ if err != nil {
+ c.addWarning("Failed to prepare 'ps' for process inspection: %v", err)
+ return
+ }
output, err := cmd.Output()
if err != nil {
c.addWarning("Failed to execute 'ps' for process inspection: %v", err)
diff --git a/internal/storage/cloud.go b/internal/storage/cloud.go
index 8d306ad1..b367ff50 100644
--- a/internal/storage/cloud.go
+++ b/internal/storage/cloud.go
@@ -15,6 +15,7 @@ import (
"github.com/tis24dev/proxsave/internal/config"
"github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/safeexec"
"github.com/tis24dev/proxsave/internal/types"
"github.com/tis24dev/proxsave/pkg/utils"
)
@@ -89,6 +90,28 @@ func (c *CloudStorage) buildRcloneArgs(subcommand string) []string {
return args
}
+func validateRcloneArgs(args []string) error {
+ if len(args) == 0 {
+ return fmt.Errorf("missing rclone subcommand")
+ }
+ switch args[0] {
+ case "copyto", "delete", "deletefile", "ls", "lsf", "lsl", "mkdir", "touch":
+ default:
+ return fmt.Errorf("rclone subcommand not allowed: %s", args[0])
+ }
+ for _, arg := range args {
+ if strings.TrimSpace(arg) == "" {
+ return fmt.Errorf("rclone argument must not be empty")
+ }
+ for _, r := range arg {
+ if r < 0x20 || r == 0x7f {
+ return fmt.Errorf("rclone argument contains control characters")
+ }
+ }
+ }
+ return nil
+}
+
func splitRemoteRef(ref string) (remoteName, relPath string) {
parts := strings.SplitN(ref, ":", 2)
if len(parts) < 2 {
@@ -163,9 +186,19 @@ func NewCloudStorage(cfg *config.Config, logger *logging.Logger) (*CloudStorage,
// (base path from CLOUD_REMOTE plus optional CLOUD_REMOTE_PATH)
rawRemote := strings.TrimSpace(cfg.CloudRemote)
remoteName, basePath := splitRemoteRef(rawRemote)
+ remoteName = strings.TrimSpace(remoteName)
+ if err := safeexec.ValidateRcloneRemoteName(remoteName); err != nil {
+ return nil, fmt.Errorf("invalid CLOUD_REMOTE: %w", err)
+ }
basePath = strings.Trim(strings.TrimSpace(basePath), "/")
+ if err := safeexec.ValidateRemoteRelativePath(basePath, "CLOUD_REMOTE path"); err != nil {
+ return nil, err
+ }
userPrefix := strings.Trim(strings.TrimSpace(cfg.CloudRemotePath), "/")
+ if err := safeexec.ValidateRemoteRelativePath(userPrefix, "CLOUD_REMOTE_PATH"); err != nil {
+ return nil, err
+ }
combinedPrefix := strings.Trim(path.Join(basePath, userPrefix), "/")
@@ -1759,6 +1792,12 @@ func (c *CloudStorage) markCloudLogPathAvailable() {
}
func (c *CloudStorage) exec(ctx context.Context, name string, args ...string) ([]byte, error) {
+ if name != "rclone" {
+ return nil, fmt.Errorf("cloud storage may only execute rclone, got %q", name)
+ }
+ if err := validateRcloneArgs(args); err != nil {
+ return nil, err
+ }
if c.execCommand != nil {
return c.execCommand(ctx, name, args...)
}
@@ -1773,7 +1812,10 @@ func (c *CloudStorage) callWaitForRetry(ctx context.Context, d time.Duration) er
}
func defaultExecCommand(ctx context.Context, name string, args ...string) ([]byte, error) {
- cmd := exec.CommandContext(ctx, name, args...)
+ cmd, err := safeexec.CommandContext(ctx, name, args...)
+ if err != nil {
+ return nil, err
+ }
return cmd.CombinedOutput()
}
diff --git a/internal/tui/abort_context_test.go b/internal/tui/abort_context_test.go
index 93778c1c..2654353f 100644
--- a/internal/tui/abort_context_test.go
+++ b/internal/tui/abort_context_test.go
@@ -167,6 +167,25 @@ func TestAppRunWithContext_NilContextRunsUntilStopped(t *testing.T) {
}
}
+func TestAppRunWithContext_StopBeforeRunStopsWhenRunStarts(t *testing.T) {
+ app, _, _ := newSimulationApp(t)
+ app.Stop()
+
+ done := make(chan error, 1)
+ go func() {
+ done <- app.RunWithContext(context.Background())
+ }()
+
+ select {
+ case err := <-done:
+ if err != nil {
+ t.Fatalf("err=%v want nil", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatal("timed out waiting for pre-run Stop to end RunWithContext")
+ }
+}
+
func TestAppRunWithContext_ReturnsNilWhenStoppedWithoutCancellation(t *testing.T) {
app, _, started := newSimulationApp(t)
done := make(chan error, 1)
diff --git a/internal/tui/app.go b/internal/tui/app.go
index e190f67f..f1d480e9 100644
--- a/internal/tui/app.go
+++ b/internal/tui/app.go
@@ -2,16 +2,27 @@ package tui
import (
"context"
+ "sync"
"sync/atomic"
"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"
)
+const (
+ appRunStateIdle = iota
+ appRunStateStarting
+ appRunStateRunning
+ appRunStateFinished
+)
+
// App wraps tview.Application with Proxmox-specific configuration
type App struct {
*tview.Application
- stopHook func()
+ stopHook func()
+ runMu sync.Mutex
+ runState int
+ stopRequested bool
}
// NewApp creates a new TUI application with Proxmox theme
@@ -49,8 +60,57 @@ func (a *App) Stop() {
return
}
if a.Application != nil {
- a.Application.Stop()
+ a.runMu.Lock()
+ switch a.runState {
+ case appRunStateIdle, appRunStateStarting:
+ // tview.Stop before Run clears the configured screen; apply it once
+ // the event loop can process the request instead.
+ a.stopRequested = true
+ a.runMu.Unlock()
+ return
+ case appRunStateRunning:
+ a.runMu.Unlock()
+ a.Application.Stop()
+ return
+ default:
+ a.runMu.Unlock()
+ }
+ }
+}
+
+func (a *App) Run() error {
+ if a == nil || a.Application == nil {
+ return nil
}
+
+ a.runMu.Lock()
+ a.runState = appRunStateStarting
+ a.runMu.Unlock()
+
+ go a.markRunningAndStopIfRequested()
+
+ err := a.Application.Run()
+
+ a.runMu.Lock()
+ a.runState = appRunStateFinished
+ a.stopRequested = false
+ a.runMu.Unlock()
+
+ return err
+}
+
+func (a *App) markRunningAndStopIfRequested() {
+ a.QueueUpdate(func() {
+ a.runMu.Lock()
+ a.runState = appRunStateRunning
+ stopRequested := a.stopRequested
+ a.stopRequested = false
+ a.runMu.Unlock()
+
+ if stopRequested {
+ a.Application.Stop()
+ }
+ })
}
func (a *App) RunWithContext(ctx context.Context) error {
diff --git a/internal/tui/wizard/post_install_audit_core.go b/internal/tui/wizard/post_install_audit_core.go
index 5d9a76a0..9ba0ee41 100644
--- a/internal/tui/wizard/post_install_audit_core.go
+++ b/internal/tui/wizard/post_install_audit_core.go
@@ -13,6 +13,7 @@ import (
"time"
"github.com/tis24dev/proxsave/internal/config"
+ "github.com/tis24dev/proxsave/internal/safeexec"
)
// PostInstallAuditSuggestion represents an optional feature that appears to be enabled
@@ -51,11 +52,14 @@ func postInstallAuditAllowedKeysSet() map[string]struct{} {
func runPostInstallAuditDryRun(ctx context.Context, execPath, configPath string) (output string, exitCode int, err error) {
// Run a dry-run with warning-level logs to keep output minimal while still capturing
// all actionable "set KEY=false" hints.
- cmd := exec.CommandContext(ctx, execPath,
+ cmd, err := safeexec.TrustedCommandContext(ctx, execPath,
"--dry-run",
"--log-level", "warning",
"--config", configPath,
)
+ if err != nil {
+ return "", -1, err
+ }
out, runErr := cmd.CombinedOutput()
if runErr == nil {
return string(out), 0, nil
diff --git a/internal/types/exit_codes.go b/internal/types/exit_codes.go
index 439c5f58..b91b84ac 100644
--- a/internal/types/exit_codes.go
+++ b/internal/types/exit_codes.go
@@ -49,6 +49,9 @@ const (
// ExitSecurityError - Errors detected by the security check.
ExitSecurityError ExitCode = 14
+
+ // ExitEncryptionError - Error during encryption setup or processing.
+ ExitEncryptionError ExitCode = 15
)
// String returns a human-readable description of the exit code.
@@ -84,6 +87,8 @@ func (e ExitCode) String() string {
return "panic error"
case ExitSecurityError:
return "security error"
+ case ExitEncryptionError:
+ return "encryption error"
default:
return "unknown error"
}
diff --git a/internal/types/exit_codes_test.go b/internal/types/exit_codes_test.go
index bda518c5..2cb8f7f6 100644
--- a/internal/types/exit_codes_test.go
+++ b/internal/types/exit_codes_test.go
@@ -17,6 +17,7 @@ func TestExitCodeString(t *testing.T) {
{"network error", ExitNetworkError, "network error"},
{"permission error", ExitPermissionError, "permission error"},
{"verification error", ExitVerificationError, "verification error"},
+ {"encryption error", ExitEncryptionError, "encryption error"},
{"unknown", ExitCode(99), "unknown error"},
}
@@ -45,6 +46,7 @@ func TestExitCodeInt(t *testing.T) {
{"network error", ExitNetworkError, 6},
{"permission error", ExitPermissionError, 7},
{"verification error", ExitVerificationError, 8},
+ {"encryption error", ExitEncryptionError, 15},
}
for _, tt := range tests {