diff --git a/cmd/proxsave/encryption_setup.go b/cmd/proxsave/encryption_setup.go new file mode 100644 index 00000000..0d493c36 --- /dev/null +++ b/cmd/proxsave/encryption_setup.go @@ -0,0 +1,60 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + + "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/orchestrator" + "github.com/tis24dev/proxsave/internal/types" +) + +type encryptionSetupResult struct { + Config *config.Config + RecipientPath string + WroteRecipientFile bool + ReusedExistingRecipients bool +} + +func runInitialEncryptionSetupWithUI(ctx context.Context, configPath string, ui orchestrator.AgeSetupUI) (*encryptionSetupResult, error) { + cfg, err := config.LoadConfig(configPath) + if err != nil { + return nil, fmt.Errorf("failed to reload configuration after install: %w", err) + } + + logger := logging.New(types.LogLevelError, false) + logger.SetOutput(io.Discard) + + orch := orchestrator.New(logger, false) + orch.SetConfig(cfg) + + var setupResult *orchestrator.AgeRecipientSetupResult + if ui != nil { + setupResult, err = orch.EnsureAgeRecipientsReadyWithUIDetails(ctx, ui) + } else { + setupResult, err = orch.EnsureAgeRecipientsReadyWithDetails(ctx) + } + if err != nil { + if errors.Is(err, orchestrator.ErrAgeRecipientSetupAborted) { + return nil, fmt.Errorf("encryption setup aborted by user: %w", errInteractiveAborted) + } + return nil, fmt.Errorf("encryption setup failed: %w", err) + } + + result := &encryptionSetupResult{Config: cfg} + if setupResult != nil { + result.RecipientPath = setupResult.RecipientPath + result.WroteRecipientFile = setupResult.WroteRecipientFile + result.ReusedExistingRecipients = setupResult.ReusedExistingRecipients + } + + return result, nil +} + +func runInitialEncryptionSetup(ctx context.Context, configPath string) error { + _, err := runInitialEncryptionSetupWithUI(ctx, configPath, nil) + return err +} diff --git a/cmd/proxsave/encryption_setup_test.go b/cmd/proxsave/encryption_setup_test.go new file mode 100644 index 00000000..72456499 --- /dev/null +++ b/cmd/proxsave/encryption_setup_test.go @@ -0,0 +1,228 @@ +package main + +import ( + "context" + "os" + "path/filepath" + "testing" + + "filippo.io/age" + + "github.com/tis24dev/proxsave/internal/orchestrator" + "github.com/tis24dev/proxsave/internal/testutil" +) + +type testAgeSetupUI = testutil.AgeSetupUIStub[orchestrator.AgeRecipientDraft] + +func TestRunInitialEncryptionSetupWithUIReloadsConfig(t *testing.T) { + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + if err := os.MkdirAll(filepath.Dir(configPath), 0o700); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + content := "BASE_DIR=" + baseDir + "\nENCRYPT_ARCHIVE=true\nAGE_RECIPIENT=" + id.Recipient().String() + "\n" + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + result, err := runInitialEncryptionSetupWithUI(context.Background(), configPath, nil) + if err != nil { + t.Fatalf("runInitialEncryptionSetupWithUI error: %v", err) + } + if result == nil || result.Config == nil { + t.Fatalf("expected config result") + } + if len(result.Config.AgeRecipients) != 1 || result.Config.AgeRecipients[0] != id.Recipient().String() { + t.Fatalf("AgeRecipients=%v; want [%s]", result.Config.AgeRecipients, id.Recipient().String()) + } + if !result.ReusedExistingRecipients { + t.Fatalf("expected ReusedExistingRecipients=true") + } + if result.WroteRecipientFile { + t.Fatalf("expected WroteRecipientFile=false") + } + if result.RecipientPath != "" { + t.Fatalf("RecipientPath=%q; want empty for reuse-only result", result.RecipientPath) + } +} + +func TestRunInitialEncryptionSetupWithUIUsesProvidedUI(t *testing.T) { + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + if err := os.MkdirAll(filepath.Dir(configPath), 0o700); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + content := "BASE_DIR=" + baseDir + "\nENCRYPT_ARCHIVE=true\n" + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + ui := &testAgeSetupUI{ + AbortErr: orchestrator.ErrAgeRecipientSetupAborted, + Drafts: []*orchestrator.AgeRecipientDraft{ + {Kind: orchestrator.AgeRecipientInputExisting, PublicKey: id.Recipient().String()}, + }, + AddMore: []bool{false}, + } + + result, err := runInitialEncryptionSetupWithUI(context.Background(), configPath, ui) + if err != nil { + t.Fatalf("runInitialEncryptionSetupWithUI error: %v", err) + } + + expectedPath := filepath.Join(baseDir, "identity", "age", "recipient.txt") + if result == nil || result.Config == nil { + t.Fatalf("expected setup result with config") + } + if result.RecipientPath != expectedPath { + t.Fatalf("RecipientPath=%q; want %q", result.RecipientPath, expectedPath) + } + if !result.WroteRecipientFile { + t.Fatalf("expected WroteRecipientFile=true") + } + if result.ReusedExistingRecipients { + t.Fatalf("expected ReusedExistingRecipients=false") + } + if _, err := os.Stat(expectedPath); err != nil { + t.Fatalf("expected recipient file at %s: %v", expectedPath, err) + } +} + +func TestRunInitialEncryptionSetupWithUIReusesExistingFileWithoutReportingWrite(t *testing.T) { + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + + baseDir := t.TempDir() + recipientPath := filepath.Join(baseDir, "identity", "age", "recipient.txt") + if err := os.MkdirAll(filepath.Dir(recipientPath), 0o700); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(recipientPath, []byte(id.Recipient().String()+"\n"), 0o600); err != nil { + t.Fatalf("WriteFile(%s): %v", recipientPath, err) + } + + configPath := filepath.Join(baseDir, "env", "backup.env") + if err := os.MkdirAll(filepath.Dir(configPath), 0o700); err != nil { + t.Fatalf("MkdirAll(%s): %v", filepath.Dir(configPath), err) + } + content := "BASE_DIR=" + baseDir + "\nENCRYPT_ARCHIVE=true\nAGE_RECIPIENT_FILE=" + recipientPath + "\n" + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("WriteFile(%s): %v", configPath, err) + } + + result, err := runInitialEncryptionSetupWithUI(context.Background(), configPath, nil) + if err != nil { + t.Fatalf("runInitialEncryptionSetupWithUI error: %v", err) + } + + if result == nil || result.Config == nil { + t.Fatalf("expected setup result with config") + } + if !result.ReusedExistingRecipients { + t.Fatalf("expected ReusedExistingRecipients=true") + } + if result.WroteRecipientFile { + t.Fatalf("expected WroteRecipientFile=false") + } + if result.RecipientPath != "" { + t.Fatalf("RecipientPath=%q; want empty for reuse-only result", result.RecipientPath) + } +} + +func TestRunNewKeySetupKeepsDefaultRecipientPathContract(t *testing.T) { + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + ui := &testAgeSetupUI{ + AbortErr: orchestrator.ErrAgeRecipientSetupAborted, + Overwrite: true, + Drafts: []*orchestrator.AgeRecipientDraft{ + {Kind: orchestrator.AgeRecipientInputExisting, PublicKey: id.Recipient().String()}, + }, + AddMore: []bool{false}, + } + + recipientPath, err := runNewKeySetup(context.Background(), configPath, baseDir, nil, ui) + if err != nil { + t.Fatalf("runNewKeySetup error: %v", err) + } + + target := filepath.Join(baseDir, "identity", "age", "recipient.txt") + if recipientPath != target { + t.Fatalf("recipientPath=%q; want %q", recipientPath, target) + } + content, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile(%s): %v", target, err) + } + if got := string(content); got != id.Recipient().String()+"\n" { + t.Fatalf("content=%q; want %q", got, id.Recipient().String()+"\n") + } +} + +func TestRunNewKeySetupUsesConfiguredRecipientFile(t *testing.T) { + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + if err := os.MkdirAll(filepath.Dir(configPath), 0o700); err != nil { + t.Fatalf("MkdirAll(%s): %v", filepath.Dir(configPath), err) + } + + customPath := filepath.Join(baseDir, "custom", "recipient.txt") + content := "BASE_DIR=" + baseDir + "\nENCRYPT_ARCHIVE=true\nAGE_RECIPIENT_FILE=" + customPath + "\n" + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("WriteFile(%s): %v", configPath, err) + } + + ui := &testAgeSetupUI{ + AbortErr: orchestrator.ErrAgeRecipientSetupAborted, + Overwrite: true, + Drafts: []*orchestrator.AgeRecipientDraft{ + {Kind: orchestrator.AgeRecipientInputExisting, PublicKey: id.Recipient().String()}, + }, + AddMore: []bool{false}, + } + + recipientPath, err := runNewKeySetup(context.Background(), configPath, baseDir, nil, ui) + if err != nil { + t.Fatalf("runNewKeySetup error: %v", err) + } + if recipientPath != customPath { + t.Fatalf("recipientPath=%q; want %q", recipientPath, customPath) + } + + customContent, err := os.ReadFile(customPath) + if err != nil { + t.Fatalf("ReadFile(%s): %v", customPath, err) + } + if got := string(customContent); got != id.Recipient().String()+"\n" { + t.Fatalf("content=%q; want %q", got, id.Recipient().String()+"\n") + } + + defaultPath := filepath.Join(baseDir, "identity", "age", "recipient.txt") + if _, err := os.Stat(defaultPath); !os.IsNotExist(err) { + t.Fatalf("default path %s should not be written, stat err=%v", defaultPath, err) + } +} diff --git a/cmd/proxsave/helpers_test.go b/cmd/proxsave/helpers_test.go index e27d735a..dd3490f2 100644 --- a/cmd/proxsave/helpers_test.go +++ b/cmd/proxsave/helpers_test.go @@ -411,8 +411,55 @@ func TestInputMapInputError(t *testing.T) { func TestValidateFutureFeatures_SecondaryWithoutPath(t *testing.T) { cfg := &config.Config{SecondaryEnabled: true} - if err := validateFutureFeatures(cfg); err == nil { - t.Error("expected error for secondary enabled without path") + err := validateFutureFeatures(cfg) + if err == nil { + t.Fatal("expected error for secondary enabled without path") + } + if got, want := err.Error(), "SECONDARY_PATH is required when SECONDARY_ENABLED=true"; got != want { + t.Fatalf("validateFutureFeatures error = %q, want %q", got, want) + } +} + +func TestValidateFutureFeatures_SecondaryRejectsRemotePath(t *testing.T) { + cfg := &config.Config{ + SecondaryEnabled: true, + SecondaryPath: "remote:path", + } + + err := validateFutureFeatures(cfg) + if err == nil { + t.Fatal("expected error for remote-style secondary path") + } + if got, want := err.Error(), "SECONDARY_PATH must be an absolute local filesystem path"; got != want { + t.Fatalf("validateFutureFeatures error = %q, want %q", got, want) + } +} + +func TestValidateFutureFeatures_SecondaryAllowsEmptyLogPath(t *testing.T) { + cfg := &config.Config{ + SecondaryEnabled: true, + SecondaryPath: "/backup/secondary", + SecondaryLogPath: "", + } + + if err := validateFutureFeatures(cfg); err != nil { + t.Fatalf("expected empty secondary log path to be allowed, got %v", err) + } +} + +func TestValidateFutureFeatures_SecondaryRejectsInvalidLogPath(t *testing.T) { + cfg := &config.Config{ + SecondaryEnabled: true, + SecondaryPath: "/backup/secondary", + SecondaryLogPath: "remote:/logs", + } + + err := validateFutureFeatures(cfg) + if err == nil { + t.Fatal("expected error for invalid secondary log path") + } + if got, want := err.Error(), "SECONDARY_LOG_PATH must be an absolute local filesystem path"; got != want { + t.Fatalf("validateFutureFeatures error = %q, want %q", got, want) } } diff --git a/cmd/proxsave/install.go b/cmd/proxsave/install.go index 55606234..01288642 100644 --- a/cmd/proxsave/install.go +++ b/cmd/proxsave/install.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "io" "os" "os/exec" "path/filepath" @@ -13,15 +12,28 @@ import ( "strings" "github.com/tis24dev/proxsave/internal/config" + cronutil "github.com/tis24dev/proxsave/internal/cron" "github.com/tis24dev/proxsave/internal/identity" "github.com/tis24dev/proxsave/internal/logging" - "github.com/tis24dev/proxsave/internal/notify" - "github.com/tis24dev/proxsave/internal/orchestrator" "github.com/tis24dev/proxsave/internal/tui/wizard" - "github.com/tis24dev/proxsave/internal/types" buildinfo "github.com/tis24dev/proxsave/internal/version" ) +var ( + newInstallEnsureInteractiveStdin = ensureInteractiveStdin + newInstallConfirmCLI = confirmNewInstallCLI + newInstallConfirmTUI = wizard.ConfirmNewInstall + newInstallRunInstall = runInstall + newInstallRunInstallTUI = runInstallTUI + configureCronTimeFunc = configureCronTime +) + +type installConfigResult struct { + EnableEncryption bool + SkipConfigWizard bool + CronSchedule string +} + func runInstall(ctx context.Context, configPath string, bootstrap *logging.BootstrapLogger) (err error) { logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "resolving configuration path") resolvedPath, err := resolveInstallConfigPath(configPath) @@ -74,11 +86,18 @@ func runInstall(ctx context.Context, configPath string, bootstrap *logging.Boots } logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "running config wizard") - enableEncryption, skipConfigWizard, err := runConfigWizardCLI(ctx, reader, configPath, tmpConfigPath, baseDir, bootstrap) + configResult, err := runConfigWizardCLI(ctx, reader, configPath, tmpConfigPath, baseDir, bootstrap) if err != nil { return err } - logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "config wizard done (encryption=%v skip=%v)", enableEncryption, skipConfigWizard) + logging.DebugStepBootstrap( + bootstrap, + "install workflow (cli)", + "config wizard done (encryption=%v skip=%v cron=%s)", + configResult.EnableEncryption, + configResult.SkipConfigWizard, + configResult.CronSchedule, + ) logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "installing support docs") if err := installSupportDocs(baseDir, bootstrap); err != nil { @@ -86,12 +105,12 @@ func runInstall(ctx context.Context, configPath string, bootstrap *logging.Boots } logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "running encryption setup if needed") - if err := runEncryptionSetupIfNeeded(ctx, configPath, enableEncryption, skipConfigWizard, bootstrap); err != nil { + if err := runEncryptionSetupIfNeeded(ctx, configPath, configResult.EnableEncryption, configResult.SkipConfigWizard, bootstrap); err != nil { return err } // Optional post-install audit: run a dry-run and offer to disable unused collectors. - if !skipConfigWizard { + if !configResult.SkipConfigWizard { logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "post-install audit") if err := runPostInstallAuditCLI(ctx, reader, execInfo.ExecPath, configPath, bootstrap); err != nil { return err @@ -106,10 +125,16 @@ func runInstall(ctx context.Context, configPath string, bootstrap *logging.Boots } logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "finalizing symlinks and cron") - runPostInstallSymlinksAndCron(ctx, baseDir, execInfo, bootstrap) + runPostInstallSymlinksAndCron( + ctx, + baseDir, + execInfo, + bootstrap, + buildInstallCronSchedule(configResult.SkipConfigWizard, configResult.CronSchedule), + ) logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "detecting telegram identity") - telegramCode = detectTelegramCode(baseDir) + telegramCode = detectTelegramCodeWithContext(ctx, baseDir) if telegramCode != "" { logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "telegram identity detected") } else { @@ -125,123 +150,6 @@ func runInstall(ctx context.Context, configPath string, bootstrap *logging.Boots return nil } -func runTelegramSetupCLI(ctx context.Context, reader *bufio.Reader, baseDir, configPath string, bootstrap *logging.BootstrapLogger) error { - cfg, err := config.LoadConfig(configPath) - if err != nil { - if bootstrap != nil { - bootstrap.Warning("Telegram setup: unable to load config (skipping): %v", err) - } - return nil - } - if cfg == nil || !cfg.TelegramEnabled { - return nil - } - - mode := strings.ToLower(strings.TrimSpace(cfg.TelegramBotType)) - if mode == "" { - mode = "centralized" - } - if mode == "personal" { - // No centralized pairing check exists for personal mode. - if bootstrap != nil { - bootstrap.Info("Telegram setup: personal mode selected (no centralized pairing check)") - } - return nil - } - - fmt.Println("\n--- Telegram setup (optional) ---") - fmt.Println("You enabled Telegram notifications (centralized bot).") - - info, idErr := identity.Detect(baseDir, nil) - if idErr != nil { - fmt.Printf("WARNING: Unable to compute server identity (non-blocking): %v\n", idErr) - if bootstrap != nil { - bootstrap.Warning("Telegram setup: identity detection failed (non-blocking): %v", idErr) - } - return nil - } - - serverID := "" - if info != nil { - serverID = strings.TrimSpace(info.ServerID) - } - if serverID == "" { - fmt.Println("WARNING: Server ID unavailable; skipping Telegram setup.") - if bootstrap != nil { - bootstrap.Warning("Telegram setup: server ID unavailable; skipping") - } - return nil - } - - fmt.Printf("Server ID: %s\n", serverID) - if info != nil && strings.TrimSpace(info.IdentityFile) != "" { - fmt.Printf("Identity file: %s\n", strings.TrimSpace(info.IdentityFile)) - } - fmt.Println() - fmt.Println("1) Open Telegram and start @ProxmoxAN_bot") - fmt.Println("2) Send the Server ID above (digits only)") - fmt.Println("3) Verify pairing (recommended)") - fmt.Println() - - check, err := promptYesNo(ctx, reader, "Check Telegram pairing now? [Y/n]: ", true) - if err != nil { - return wrapInstallError(err) - } - if !check { - fmt.Println("Skipped verification. You can verify later by running proxsave.") - if bootstrap != nil { - bootstrap.Info("Telegram setup: verification skipped by user") - } - return nil - } - - serverHost := strings.TrimSpace(cfg.TelegramServerAPIHost) - if serverHost == "" { - serverHost = "https://bot.tis24.it:1443" - } - - attempts := 0 - for { - attempts++ - status := notify.CheckTelegramRegistration(ctx, serverHost, serverID, nil) - if status.Code == 200 && status.Error == nil { - fmt.Println("✓ Telegram linked successfully.") - if bootstrap != nil { - bootstrap.Info("Telegram setup: verified (attempts=%d)", attempts) - } - return nil - } - - msg := strings.TrimSpace(status.Message) - if msg == "" { - msg = "Registration not active yet" - } - fmt.Printf("Telegram: %s\n", msg) - switch status.Code { - case 403, 409: - fmt.Println("Hint: Start the bot, send the Server ID, then retry.") - case 422: - fmt.Println("Hint: The Server ID appears invalid. If this persists, re-run the installer.") - default: - if status.Error != nil { - fmt.Printf("Hint: Check failed: %v\n", status.Error) - } - } - - retry, err := promptYesNo(ctx, reader, "Check again? [y/N]: ", false) - if err != nil { - return wrapInstallError(err) - } - if !retry { - fmt.Println("Verification not completed. You can retry later by running proxsave.") - if bootstrap != nil { - bootstrap.Info("Telegram setup: not verified (attempts=%d last=%d %s)", attempts, status.Code, msg) - } - return nil - } - } -} - func runPostInstallAuditCLI(ctx context.Context, reader *bufio.Reader, execPath, configPath string, bootstrap *logging.BootstrapLogger) error { fmt.Println("\n--- Post-install check (optional) ---") run, err := promptYesNo(ctx, reader, "Run a dry-run to detect unused components and reduce warnings? [Y/n]: ", true) @@ -361,25 +269,25 @@ func runPostInstallAuditCLI(ctx context.Context, reader *bufio.Reader, execPath, func runNewInstall(ctx context.Context, configPath string, bootstrap *logging.BootstrapLogger, useCLI bool) (err error) { done := logging.DebugStartBootstrap(bootstrap, "new-install workflow", "config=%s", configPath) defer func() { done(err) }() - resolvedPath, err := resolveInstallConfigPath(configPath) - if err != nil { - return err - } - - baseDir := deriveBaseDirFromConfig(resolvedPath) logging.DebugStepBootstrap(bootstrap, "new-install workflow", "ensuring interactive stdin") - if err := ensureInteractiveStdin(); err != nil { + if err := newInstallEnsureInteractiveStdin(); err != nil { return err } - buildSig := buildSignature() - if strings.TrimSpace(buildSig) == "" { - buildSig = "n/a" + logging.DebugStepBootstrap(bootstrap, "new-install workflow", "building reset plan") + plan, err := buildNewInstallPlan(configPath) + if err != nil { + return err } logging.DebugStepBootstrap(bootstrap, "new-install workflow", "confirming reset") - confirm, err := wizard.ConfirmNewInstall(baseDir, buildSig) + var confirm bool + if useCLI { + confirm, err = newInstallConfirmCLI(ctx, bufio.NewReader(os.Stdin), plan) + } else { + confirm, err = newInstallConfirmTUI(ctx, plan.BaseDir, plan.BuildSignature, plan.PreservedEntries) + } if err != nil { return wrapInstallError(err) } @@ -387,16 +295,18 @@ func runNewInstall(ctx context.Context, configPath string, bootstrap *logging.Bo return wrapInstallError(errInteractiveAborted) } - bootstrap.Info("Resetting %s (preserving env/ and identity/)", baseDir) + if bootstrap != nil { + bootstrap.Info("Resetting %s (preserving %s)", plan.BaseDir, formatNewInstallPreservedEntries(plan.PreservedEntries)) + } logging.DebugStepBootstrap(bootstrap, "new-install workflow", "resetting base dir") - if err := resetInstallBaseDir(baseDir, bootstrap); err != nil { + if err := resetInstallBaseDirWithContext(ctx, plan.BaseDir, bootstrap); err != nil { return err } if useCLI { - return runInstall(ctx, resolvedPath, bootstrap) + return newInstallRunInstall(ctx, plan.ResolvedConfigPath, bootstrap) } - return runInstallTUI(ctx, resolvedPath, bootstrap) + return newInstallRunInstallTUI(ctx, plan.ResolvedConfigPath, bootstrap) } func printInstallFooter(installErr error, configPath, baseDir, telegramCode, permStatus, permMessage string) { @@ -472,7 +382,7 @@ func printInstallFooter(installErr error, configPath, baseDir, telegramCode, per fmt.Println(" --help - Show all options") fmt.Println(" --dry-run - Test without changes") fmt.Println(" --install - Re-run interactive installation/setup") - fmt.Println(" --new-install - Wipe installation directory (keep env/identity) then run installer") + fmt.Println(" --new-install - Wipe installation directory (keep build/env/identity) then run installer") fmt.Println(" --upgrade - Update proxsave binary to latest release (also adds missing keys to backup.env)") fmt.Println(" --newkey - Generate a new encryption key for backups") fmt.Println(" --decrypt - Decrypt an existing backup archive") @@ -537,53 +447,64 @@ func handleLegacyInstall(ctx context.Context, reader *bufio.Reader, baseDir stri return nil } -func runConfigWizardCLI(ctx context.Context, reader *bufio.Reader, configPath, tmpConfigPath, baseDir string, bootstrap *logging.BootstrapLogger) (enableEncryption bool, skipConfigWizard bool, err error) { +func runConfigWizardCLI(ctx context.Context, reader *bufio.Reader, configPath, tmpConfigPath, baseDir string, bootstrap *logging.BootstrapLogger) (result installConfigResult, err error) { done := logging.DebugStartBootstrap(bootstrap, "install config wizard (cli)", "config=%s", configPath) defer func() { done(err) }() logging.DebugStepBootstrap(bootstrap, "install config wizard (cli)", "preparing base template") template, skipConfigWizard, err := prepareBaseTemplate(ctx, reader, configPath) if err != nil { - return false, false, wrapInstallError(err) + return installConfigResult{}, wrapInstallError(err) } if skipConfigWizard { - return false, true, nil + return installConfigResult{SkipConfigWizard: true}, nil } logging.DebugStepBootstrap(bootstrap, "install config wizard (cli)", "configuring secondary storage") if template, err = configureSecondaryStorage(ctx, reader, template); err != nil { - return false, false, wrapInstallError(err) + return installConfigResult{}, wrapInstallError(err) } logging.DebugStepBootstrap(bootstrap, "install config wizard (cli)", "configuring cloud storage") if template, err = configureCloudStorage(ctx, reader, template); err != nil { - return false, false, wrapInstallError(err) + return installConfigResult{}, wrapInstallError(err) } logging.DebugStepBootstrap(bootstrap, "install config wizard (cli)", "configuring firewall rules") if template, err = configureFirewallRules(ctx, reader, template); err != nil { - return false, false, wrapInstallError(err) + return installConfigResult{}, wrapInstallError(err) } logging.DebugStepBootstrap(bootstrap, "install config wizard (cli)", "configuring notifications") if template, err = configureNotifications(ctx, reader, template); err != nil { - return false, false, wrapInstallError(err) + return installConfigResult{}, wrapInstallError(err) } logging.DebugStepBootstrap(bootstrap, "install config wizard (cli)", "configuring encryption") - enableEncryption, err = configureEncryption(ctx, reader, &template) + result.EnableEncryption, err = configureEncryption(ctx, reader, &template) if err != nil { - return false, false, wrapInstallError(err) + return installConfigResult{}, wrapInstallError(err) + } + + logging.DebugStepBootstrap(bootstrap, "install config wizard (cli)", "configuring cron time") + cronTime, err := configureCronTimeFunc(ctx, reader, cronutil.DefaultTime) + if err != nil { + return installConfigResult{}, wrapInstallError(err) + } + result.CronSchedule = cronutil.TimeToSchedule(cronTime) + + if bootstrap != nil { + bootstrap.Info("Cron schedule selected: %s", cronTime) } logging.DebugStepBootstrap(bootstrap, "install config wizard (cli)", "writing configuration") if err := writeConfigFile(configPath, tmpConfigPath, template); err != nil { - return false, false, err + return installConfigResult{}, err } if bootstrap != nil { bootstrap.Info("✓ Configuration saved at %s", configPath) } - return enableEncryption, false, nil + return result, nil } func runEncryptionSetupIfNeeded(ctx context.Context, configPath string, enableEncryption, skipConfigWizard bool, bootstrap *logging.BootstrapLogger) (err error) { @@ -605,7 +526,7 @@ func runEncryptionSetupIfNeeded(ctx context.Context, configPath string, enableEn return nil } -func runPostInstallSymlinksAndCron(ctx context.Context, baseDir string, execInfo ExecInfo, bootstrap *logging.BootstrapLogger) { +func runPostInstallSymlinksAndCron(ctx context.Context, baseDir string, execInfo ExecInfo, bootstrap *logging.BootstrapLogger, cronSchedule string) { done := logging.DebugStartBootstrap(bootstrap, "post-install setup", "base=%s", baseDir) defer func() { done(nil) }() // Clean up legacy bash-based symlinks that point to the old installer scripts. @@ -624,13 +545,15 @@ func runPostInstallSymlinksAndCron(ctx context.Context, baseDir string, execInfo // Migrate legacy cron entries pointing to the bash script to the Go binary. // If no cron entry exists at all, create a default one at 02:00 every day. - cronSchedule := resolveCronSchedule(nil) + if strings.TrimSpace(cronSchedule) == "" { + cronSchedule = resolveCronScheduleFromEnv() + } logging.DebugStepBootstrap(bootstrap, "post-install setup", "migrating cron entries") migrateLegacyCronEntries(ctx, baseDir, execInfo.ExecPath, bootstrap, cronSchedule) } -func detectTelegramCode(baseDir string) string { - info, err := identity.Detect(baseDir, nil) +func detectTelegramCodeWithContext(ctx context.Context, baseDir string) string { + info, err := identity.DetectWithContext(ctx, baseDir, nil) if err != nil { return "" } @@ -639,6 +562,13 @@ func detectTelegramCode(baseDir string) string { } func resetInstallBaseDir(baseDir string, bootstrap *logging.BootstrapLogger) (err error) { + return resetInstallBaseDirWithContext(context.Background(), baseDir, bootstrap) +} + +func resetInstallBaseDirWithContext(ctx context.Context, baseDir string, bootstrap *logging.BootstrapLogger) (err error) { + if ctx == nil { + ctx = context.Background() + } done := logging.DebugStartBootstrap(bootstrap, "reset install base", "base=%s", baseDir) defer func() { done(err) }() baseDir = filepath.Clean(baseDir) @@ -655,21 +585,22 @@ func resetInstallBaseDir(baseDir string, bootstrap *logging.BootstrapLogger) (er return fmt.Errorf("failed to list base directory %s: %w", baseDir, err) } - preserve := map[string]struct{}{ - "env": {}, - "identity": {}, - "build": {}, - } + preserve := newInstallPreserveSet() for _, entry := range entries { + if err := ctx.Err(); err != nil { + return err + } name := entry.Name() if _, keep := preserve[name]; keep { - bootstrap.Info("Preserving %s", filepath.Join(baseDir, name)) + logBootstrapInfo(bootstrap, "Preserving %s", filepath.Join(baseDir, name)) continue } target := filepath.Join(baseDir, name) logging.DebugStepBootstrap(bootstrap, "reset install base", "removing %s", target) - clearImmutableAttributes(target, bootstrap) + if err := clearImmutableAttributesWithContext(ctx, target, bootstrap); err != nil { + return err + } // Best-effort: ensure write permission before removal if entry.IsDir() { _ = os.Chmod(target, 0o700) @@ -679,7 +610,7 @@ func resetInstallBaseDir(baseDir string, bootstrap *logging.BootstrapLogger) (er if err := os.RemoveAll(target); err != nil { return fmt.Errorf("failed to remove %s: %w", target, err) } - bootstrap.Info("Removed %s", target) + logBootstrapInfo(bootstrap, "Removed %s", target) } return nil @@ -700,22 +631,18 @@ func printInstallBanner(configPath string) { } func prepareBaseTemplate(ctx context.Context, reader *bufio.Reader, configPath string) (string, bool, error) { - if info, err := os.Stat(configPath); err == nil { - if info.Mode().IsRegular() { - overwrite, err := promptYesNo(ctx, reader, fmt.Sprintf("%s already exists. Overwrite? [y/N]: ", configPath), false) - if err != nil { - return "", false, err - } - if !overwrite { - fmt.Println("Existing configuration detected, keeping current backup.env and skipping configuration wizard.") - return "", true, nil - } - } - } else if !os.IsNotExist(err) { - return "", false, fmt.Errorf("failed to access configuration file: %w", err) + decision, err := prepareExistingConfigDecisionCLI(ctx, reader, configPath) + if err != nil { + return "", false, err } - - return config.DefaultEnvTemplate(), false, nil + if decision.AbortInstall { + return "", false, errInteractiveAborted + } + if decision.SkipConfigWizard { + fmt.Println("Existing configuration detected, keeping current backup.env and skipping configuration wizard.") + return "", true, nil + } + return decision.BaseTemplate, false, nil } func configureSecondaryStorage(ctx context.Context, reader *bufio.Reader, template string) (string, error) { @@ -730,23 +657,35 @@ func configureSecondaryStorage(ctx context.Context, reader *bufio.Reader, templa return "", err } if enableSecondary { - secondaryPath, err := promptNonEmpty(ctx, reader, "Secondary backup path (SECONDARY_PATH): ") - if err != nil { - return "", err + var secondaryPath string + for { + secondaryPath, err = promptNonEmpty(ctx, reader, "Secondary backup path (SECONDARY_PATH): ") + if err != nil { + return "", err + } + secondaryPath = sanitizeEnvValue(secondaryPath) + if err := config.ValidateRequiredSecondaryPath(secondaryPath); err != nil { + fmt.Printf("%v\n", err) + continue + } + break } - secondaryPath = sanitizeEnvValue(secondaryPath) - secondaryLog, err := promptNonEmpty(ctx, reader, "Secondary log path (SECONDARY_LOG_PATH): ") - if err != nil { - return "", err + var secondaryLog string + for { + secondaryLog, err = promptOptional(ctx, reader, "Secondary log path (SECONDARY_LOG_PATH, optional - press Enter to skip): ") + if err != nil { + return "", err + } + secondaryLog = sanitizeEnvValue(secondaryLog) + if err := config.ValidateOptionalSecondaryLogPath(secondaryLog); err != nil { + fmt.Printf("%v\n", err) + continue + } + break } - secondaryLog = sanitizeEnvValue(secondaryLog) - template = setEnvValue(template, "SECONDARY_ENABLED", "true") - template = setEnvValue(template, "SECONDARY_PATH", secondaryPath) - template = setEnvValue(template, "SECONDARY_LOG_PATH", secondaryLog) + template = config.ApplySecondaryStorageSettings(template, true, secondaryPath, secondaryLog) } else { - template = setEnvValue(template, "SECONDARY_ENABLED", "false") - template = setEnvValue(template, "SECONDARY_PATH", "") - template = setEnvValue(template, "SECONDARY_LOG_PATH", "") + template = config.ApplySecondaryStorageSettings(template, false, "", "") } return template, nil } @@ -838,6 +777,22 @@ func configureEncryption(ctx context.Context, reader *bufio.Reader, template *st return enableEncryption, nil } +func configureCronTime(ctx context.Context, reader *bufio.Reader, defaultCron string) (string, error) { + fmt.Println("\n--- Schedule ---") + for { + cronTime, err := promptOptional(ctx, reader, fmt.Sprintf("Cron time for daily proxsave job (HH:MM) [%s]: ", defaultCron)) + if err != nil { + return "", err + } + normalized, err := cronutil.NormalizeTime(cronTime, defaultCron) + if err != nil { + fmt.Printf("%v\n", err) + continue + } + return normalized, nil + } +} + func writeConfigFile(configPath, tmpConfigPath, content string) error { dir := filepath.Dir(configPath) if err := os.MkdirAll(dir, 0o700); err != nil { @@ -852,25 +807,6 @@ func writeConfigFile(configPath, tmpConfigPath, content string) error { return nil } -func runInitialEncryptionSetup(ctx context.Context, configPath string) error { - cfg, err := config.LoadConfig(configPath) - if err != nil { - return fmt.Errorf("failed to reload configuration after install: %w", err) - } - logger := logging.New(types.LogLevelError, false) - logger.SetOutput(io.Discard) - orch := orchestrator.New(logger, false) - orch.SetConfig(cfg) - if err := orch.EnsureAgeRecipientsReady(ctx); err != nil { - if errors.Is(err, orchestrator.ErrAgeRecipientSetupAborted) { - // Treat AGE wizard abort as an interactive abort for install UX - return fmt.Errorf("encryption setup aborted by user: %w", errInteractiveAborted) - } - return fmt.Errorf("encryption setup failed: %w", err) - } - return nil -} - func wrapInstallError(err error) error { if err == nil { return nil @@ -905,9 +841,20 @@ func isInstallAbortedError(err error) bool { // clearImmutableAttributes attempts to remove immutable flags (chattr -i) so deletion can proceed. // It logs warnings on failure but does not return an error, since removal will report issues later. func clearImmutableAttributes(target string, bootstrap *logging.BootstrapLogger) { + _ = clearImmutableAttributesWithContext(context.Background(), target, bootstrap) +} + +func clearImmutableAttributesWithContext(ctx context.Context, target string, bootstrap *logging.BootstrapLogger) error { + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return err + } + chattrPath, err := exec.LookPath("chattr") if err != nil { - return + return nil } argsList := [][]string{{chattrPath, "-i", target}} @@ -916,14 +863,21 @@ func clearImmutableAttributes(target string, bootstrap *logging.BootstrapLogger) } for _, args := range argsList { - cmd := exec.Command(args[0], args[1:]...) + if err := ctx.Err(); err != nil { + return err + } + cmd := exec.CommandContext(ctx, args[0], args[1:]...) if out, err := cmd.CombinedOutput(); err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } trimmed := strings.TrimSpace(string(out)) if trimmed != "" { - bootstrap.Warning("Failed to clear immutable flag on %s: %v (%s)", target, err, trimmed) + logBootstrapWarning(bootstrap, "Failed to clear immutable flag on %s: %v (%s)", target, err, trimmed) } else { - bootstrap.Warning("Failed to clear immutable flag on %s: %v", target, err) + logBootstrapWarning(bootstrap, "Failed to clear immutable flag on %s: %v", target, err) } } } + return nil } diff --git a/cmd/proxsave/install_existing_config.go b/cmd/proxsave/install_existing_config.go new file mode 100644 index 00000000..a243d1fc --- /dev/null +++ b/cmd/proxsave/install_existing_config.go @@ -0,0 +1,110 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + "github.com/tis24dev/proxsave/internal/config" +) + +type existingConfigMode int + +const ( + existingConfigOverwrite existingConfigMode = iota + existingConfigEdit + existingConfigKeepContinue + existingConfigCancel +) + +type existingConfigDecision struct { + BaseTemplate string + SkipConfigWizard bool + AbortInstall bool +} + +func promptExistingConfigModeCLI(ctx context.Context, reader *bufio.Reader, configPath string) (existingConfigMode, error) { + info, err := os.Stat(configPath) + if err != nil { + if os.IsNotExist(err) { + return existingConfigOverwrite, nil + } + return existingConfigCancel, fmt.Errorf("failed to access configuration file: %w", err) + } + if !info.Mode().IsRegular() { + return existingConfigCancel, fmt.Errorf("configuration file path is not a regular file: %s", configPath) + } + + fmt.Printf("%s already exists.\n", configPath) + fmt.Println("Choose how to proceed:") + fmt.Println(" [1] Overwrite (start from embedded template)") + fmt.Println(" [2] Edit existing (use current file as base)") + fmt.Println(" [3] Keep existing & continue (skip configuration wizard)") + fmt.Println(" [0] Cancel installation") + + for { + choice, err := promptOptional(ctx, reader, "Choice [3]: ") + if err != nil { + return existingConfigCancel, err + } + switch strings.TrimSpace(choice) { + case "": + fallthrough + case "3": + return existingConfigKeepContinue, nil + case "1": + return existingConfigOverwrite, nil + case "2": + return existingConfigEdit, nil + case "0": + return existingConfigCancel, nil + default: + fmt.Println("Please enter 1, 2, 3 or 0.") + } + } +} + +func resolveExistingConfigDecision(mode existingConfigMode, configPath string) (existingConfigDecision, error) { + switch mode { + case existingConfigOverwrite: + return existingConfigDecision{ + BaseTemplate: config.DefaultEnvTemplate(), + SkipConfigWizard: false, + AbortInstall: false, + }, nil + case existingConfigEdit: + content, err := os.ReadFile(configPath) + if err != nil { + return existingConfigDecision{}, fmt.Errorf("read existing configuration: %w", err) + } + return existingConfigDecision{ + BaseTemplate: string(content), + SkipConfigWizard: false, + AbortInstall: false, + }, nil + case existingConfigKeepContinue: + return existingConfigDecision{ + BaseTemplate: "", + SkipConfigWizard: true, + AbortInstall: false, + }, nil + case existingConfigCancel: + return existingConfigDecision{ + BaseTemplate: "", + SkipConfigWizard: false, + AbortInstall: true, + }, nil + default: + return existingConfigDecision{}, fmt.Errorf("unsupported existing configuration mode: %d", mode) + } +} + +func prepareExistingConfigDecisionCLI(ctx context.Context, reader *bufio.Reader, configPath string) (existingConfigDecision, error) { + mode, err := promptExistingConfigModeCLI(ctx, reader, configPath) + if err != nil { + return existingConfigDecision{}, err + } + return resolveExistingConfigDecision(mode, configPath) +} diff --git a/cmd/proxsave/install_existing_config_test.go b/cmd/proxsave/install_existing_config_test.go new file mode 100644 index 00000000..8de7a58a --- /dev/null +++ b/cmd/proxsave/install_existing_config_test.go @@ -0,0 +1,170 @@ +package main + +import ( + "bufio" + "context" + "errors" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestPromptExistingConfigModeCLIMissingFileDefaultsToOverwrite(t *testing.T) { + missing := filepath.Join(t.TempDir(), "missing.env") + mode, err := promptExistingConfigModeCLI(context.Background(), bufio.NewReader(strings.NewReader("")), missing) + if err != nil { + t.Fatalf("promptExistingConfigModeCLI error: %v", err) + } + if mode != existingConfigOverwrite { + t.Fatalf("expected overwrite mode, got %v", mode) + } +} + +func TestPromptExistingConfigModeCLIOptions(t *testing.T) { + cfgFile := createTempFile(t, "EXISTING=1\n") + tests := []struct { + name string + input string + want existingConfigMode + }{ + {name: "default keep continue", input: "\n", want: existingConfigKeepContinue}, + {name: "overwrite", input: "1\n", want: existingConfigOverwrite}, + {name: "edit", input: "2\n", want: existingConfigEdit}, + {name: "keep continue", input: "3\n", want: existingConfigKeepContinue}, + {name: "cancel", input: "0\n", want: existingConfigCancel}, + {name: "invalid then overwrite", input: "x\n1\n", want: existingConfigOverwrite}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + reader := bufio.NewReader(strings.NewReader(tc.input)) + var mode existingConfigMode + var err error + captureStdout(t, func() { + mode, err = promptExistingConfigModeCLI(context.Background(), reader, cfgFile) + }) + if err != nil { + t.Fatalf("promptExistingConfigModeCLI error: %v", err) + } + if mode != tc.want { + t.Fatalf("mode = %v, want %v", mode, tc.want) + } + }) + } +} + +func TestResolveExistingConfigDecision(t *testing.T) { + cfgFile := createTempFile(t, "EXISTING=1\n") + + overwrite, err := resolveExistingConfigDecision(existingConfigOverwrite, cfgFile) + if err != nil { + t.Fatalf("overwrite decision error: %v", err) + } + if overwrite.SkipConfigWizard || overwrite.AbortInstall { + t.Fatalf("overwrite decision flags are invalid: %+v", overwrite) + } + if strings.TrimSpace(overwrite.BaseTemplate) == "" { + t.Fatalf("overwrite base template should not be empty") + } + + edit, err := resolveExistingConfigDecision(existingConfigEdit, cfgFile) + if err != nil { + t.Fatalf("edit decision error: %v", err) + } + if edit.SkipConfigWizard || edit.AbortInstall { + t.Fatalf("edit decision flags are invalid: %+v", edit) + } + if !strings.Contains(edit.BaseTemplate, "EXISTING=1") { + t.Fatalf("expected existing content, got %q", edit.BaseTemplate) + } + + keep, err := resolveExistingConfigDecision(existingConfigKeepContinue, cfgFile) + if err != nil { + t.Fatalf("keep decision error: %v", err) + } + if !keep.SkipConfigWizard || keep.AbortInstall { + t.Fatalf("keep decision flags are invalid: %+v", keep) + } + + cancel, err := resolveExistingConfigDecision(existingConfigCancel, cfgFile) + if err != nil { + t.Fatalf("cancel decision error: %v", err) + } + if cancel.SkipConfigWizard || !cancel.AbortInstall { + t.Fatalf("cancel decision flags are invalid: %+v", cancel) + } +} + +func TestPrepareExistingConfigDecisionCLICancel(t *testing.T) { + cfgFile := createTempFile(t, "EXISTING=1\n") + reader := bufio.NewReader(strings.NewReader("0\n")) + decision, err := prepareExistingConfigDecisionCLI(context.Background(), reader, cfgFile) + if err != nil { + t.Fatalf("prepareExistingConfigDecisionCLI error: %v", err) + } + if !decision.AbortInstall { + t.Fatalf("expected abort decision, got %+v", decision) + } +} + +func TestResolveExistingConfigDecisionEditReadError(t *testing.T) { + cfgFile := filepath.Join(t.TempDir(), "missing.env") + _, err := resolveExistingConfigDecision(existingConfigEdit, cfgFile) + if err == nil { + t.Fatalf("expected read error for missing file") + } +} + +func TestPromptExistingConfigModeCLIPropagatesReadError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + cfgFile := createTempFile(t, "EXISTING=1\n") + _, err := promptExistingConfigModeCLI(ctx, bufio.NewReader(strings.NewReader("1\n")), cfgFile) + if !errors.Is(err, errInteractiveAborted) { + t.Fatalf("expected interactive aborted error, got %v", err) + } +} + +func TestPromptExistingConfigModeCLINonRegularFile(t *testing.T) { + dirPath := t.TempDir() + _, err := promptExistingConfigModeCLI(context.Background(), bufio.NewReader(strings.NewReader("1\n")), dirPath) + if err == nil { + t.Fatalf("expected error for non-regular file") + } + if !strings.Contains(err.Error(), "not a regular file") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestResolveExistingConfigDecisionUnsupportedMode(t *testing.T) { + cfgFile := createTempFile(t, "EXISTING=1\n") + _, err := resolveExistingConfigDecision(existingConfigMode(99), cfgFile) + if err == nil { + t.Fatalf("expected unsupported mode error") + } +} + +func TestPromptExistingConfigModeCLIStatError(t *testing.T) { + pathWithNul := string([]byte{0}) + _, err := promptExistingConfigModeCLI(context.Background(), bufio.NewReader(strings.NewReader("1\n")), pathWithNul) + if err == nil { + t.Fatalf("expected stat error") + } +} + +func TestResolveExistingConfigDecisionEditExistingContentExact(t *testing.T) { + cfg := filepath.Join(t.TempDir(), "backup.env") + content := "KEY=VALUE\nANOTHER=1\n" + if err := os.WriteFile(cfg, []byte(content), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + decision, err := resolveExistingConfigDecision(existingConfigEdit, cfg) + if err != nil { + t.Fatalf("resolveExistingConfigDecision error: %v", err) + } + if decision.BaseTemplate != content { + t.Fatalf("expected exact content, got %q", decision.BaseTemplate) + } +} diff --git a/cmd/proxsave/install_test.go b/cmd/proxsave/install_test.go index fc9b8350..46ae7e6b 100644 --- a/cmd/proxsave/install_test.go +++ b/cmd/proxsave/install_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + cronutil "github.com/tis24dev/proxsave/internal/cron" "github.com/tis24dev/proxsave/internal/logging" ) @@ -105,7 +106,7 @@ func TestIsInstallAbortedError(t *testing.T) { } } -func TestResetInstallBaseDirPreservesEnvAndIdentity(t *testing.T) { +func TestResetInstallBaseDirPreservesCoreDirectories(t *testing.T) { base := t.TempDir() // setup contents @@ -134,6 +135,15 @@ func TestResetInstallBaseDirPreservesEnvAndIdentity(t *testing.T) { t.Fatalf("setup identity file: %v", err) } + buildDir := filepath.Join(base, "build") + if err := os.Mkdir(buildDir, 0o755); err != nil { + t.Fatalf("setup build: %v", err) + } + buildFile := filepath.Join(buildDir, "keep.txt") + if err := os.WriteFile(buildFile, []byte("build"), 0o600); err != nil { + t.Fatalf("setup build file: %v", err) + } + logger := logging.NewBootstrapLogger() if err := resetInstallBaseDir(base, logger); err != nil { t.Fatalf("resetInstallBaseDir returned error: %v", err) @@ -157,6 +167,73 @@ func TestResetInstallBaseDirPreservesEnvAndIdentity(t *testing.T) { if _, err := os.Stat(idFile); err != nil { t.Fatalf("identity file should remain: %v", err) } + if _, err := os.Stat(buildDir); err != nil { + t.Fatalf("build dir should remain: %v", err) + } + if _, err := os.Stat(buildFile); err != nil { + t.Fatalf("build file should remain: %v", err) + } +} + +func TestResetInstallBaseDirRespectsSharedPreserveSet(t *testing.T) { + base := t.TempDir() + for _, entry := range newInstallPreservedEntries() { + dirPath := filepath.Join(base, entry) + if err := os.MkdirAll(dirPath, 0o755); err != nil { + t.Fatalf("setup %s: %v", entry, err) + } + filePath := filepath.Join(dirPath, "keep.txt") + if err := os.WriteFile(filePath, []byte(entry), 0o600); err != nil { + t.Fatalf("setup %s file: %v", entry, err) + } + } + if err := os.WriteFile(filepath.Join(base, "drop.txt"), []byte("drop"), 0o600); err != nil { + t.Fatalf("setup drop file: %v", err) + } + + logger := logging.NewBootstrapLogger() + if err := resetInstallBaseDir(base, logger); err != nil { + t.Fatalf("resetInstallBaseDir returned error: %v", err) + } + + for _, entry := range newInstallPreservedEntries() { + filePath := filepath.Join(base, entry, "keep.txt") + if _, err := os.Stat(filePath); err != nil { + t.Fatalf("expected preserved file for %s, got %v", entry, err) + } + } + if _, err := os.Stat(filepath.Join(base, "drop.txt")); !os.IsNotExist(err) { + t.Fatalf("expected drop.txt removed, got err=%v", err) + } +} + +func TestResetInstallBaseDirAllowsNilBootstrap(t *testing.T) { + base := t.TempDir() + preservedDir := filepath.Join(base, "env") + if err := os.MkdirAll(preservedDir, 0o755); err != nil { + t.Fatalf("setup env: %v", err) + } + preservedFile := filepath.Join(preservedDir, "backup.env") + if err := os.WriteFile(preservedFile, []byte("KEEP=1"), 0o600); err != nil { + t.Fatalf("setup env file: %v", err) + } + removedFile := filepath.Join(base, "drop.txt") + if err := os.WriteFile(removedFile, []byte("drop"), 0o600); err != nil { + t.Fatalf("setup drop file: %v", err) + } + + captureStdout(t, func() { + if err := resetInstallBaseDir(base, nil); err != nil { + t.Fatalf("resetInstallBaseDir returned error: %v", err) + } + }) + + if _, err := os.Stat(preservedFile); err != nil { + t.Fatalf("expected preserved file to remain, got %v", err) + } + if _, err := os.Stat(removedFile); !os.IsNotExist(err) { + t.Fatalf("expected drop.txt removed, got err=%v", err) + } } func TestResetInstallBaseDirRefusesRoot(t *testing.T) { @@ -166,9 +243,28 @@ func TestResetInstallBaseDirRefusesRoot(t *testing.T) { } } +func TestResetInstallBaseDirWithContext_CanceledBeforeRemoval(t *testing.T) { + base := t.TempDir() + dropFile := filepath.Join(base, "drop.txt") + if err := os.WriteFile(dropFile, []byte("drop"), 0o600); err != nil { + t.Fatalf("setup drop file: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := resetInstallBaseDirWithContext(ctx, base, logging.NewBootstrapLogger()) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want %v", err, context.Canceled) + } + if _, statErr := os.Stat(dropFile); statErr != nil { + t.Fatalf("expected file to remain after canceled reset, got %v", statErr) + } +} + func TestPrepareBaseTemplateExistingSkip(t *testing.T) { cfgFile := createTempFile(t, "existing config") - reader := bufio.NewReader(strings.NewReader("n\n")) + reader := bufio.NewReader(strings.NewReader("3\n")) var tmpl string var skip bool var err error @@ -188,7 +284,7 @@ func TestPrepareBaseTemplateExistingSkip(t *testing.T) { func TestPrepareBaseTemplateOverwrite(t *testing.T) { cfgFile := createTempFile(t, "old") - reader := bufio.NewReader(strings.NewReader("y\n")) + reader := bufio.NewReader(strings.NewReader("1\n")) var tmpl string var skip bool var err error @@ -206,6 +302,35 @@ func TestPrepareBaseTemplateOverwrite(t *testing.T) { } } +func TestPrepareBaseTemplateEditExisting(t *testing.T) { + cfgFile := createTempFile(t, "EXISTING=1\n") + reader := bufio.NewReader(strings.NewReader("2\n")) + var tmpl string + var skip bool + var err error + captureStdout(t, func() { + tmpl, skip, err = prepareBaseTemplate(context.Background(), reader, cfgFile) + }) + if err != nil { + t.Fatalf("prepareBaseTemplate error: %v", err) + } + if skip { + t.Fatalf("expected skip=false for edit existing") + } + if !strings.Contains(tmpl, "EXISTING=1") { + t.Fatalf("expected existing template content, got %q", tmpl) + } +} + +func TestPrepareBaseTemplateCancel(t *testing.T) { + cfgFile := createTempFile(t, "EXISTING=1\n") + reader := bufio.NewReader(strings.NewReader("0\n")) + _, _, err := prepareBaseTemplate(context.Background(), reader, cfgFile) + if !errors.Is(err, errInteractiveAborted) { + t.Fatalf("expected interactive abort, got %v", err) + } +} + func TestConfigureSecondaryStorageEnabled(t *testing.T) { var result string var err error @@ -228,6 +353,60 @@ func TestConfigureSecondaryStorageEnabled(t *testing.T) { } } +func TestConfigureSecondaryStorageEnabledWithEmptyLogPath(t *testing.T) { + var result string + var err error + ctx := context.Background() + reader := bufio.NewReader(strings.NewReader("y\n/mnt/secondary\n\n")) + captureStdout(t, func() { + result, err = configureSecondaryStorage(ctx, reader, "") + }) + if err != nil { + t.Fatalf("configureSecondaryStorage error: %v", err) + } + if !strings.Contains(result, "SECONDARY_ENABLED=true") { + t.Fatalf("expected SECONDARY_ENABLED=true in template: %q", result) + } + if !strings.Contains(result, "SECONDARY_PATH=/mnt/secondary") { + t.Fatalf("expected secondary path in template: %q", result) + } + if !strings.Contains(result, "SECONDARY_LOG_PATH=") { + t.Fatalf("expected empty secondary log path in template: %q", result) + } +} + +func TestConfigureSecondaryStorageRejectsInvalidBackupPath(t *testing.T) { + var result string + var err error + ctx := context.Background() + reader := bufio.NewReader(strings.NewReader("y\nrelative/path\n/mnt/secondary\n\n")) + captureStdout(t, func() { + result, err = configureSecondaryStorage(ctx, reader, "") + }) + if err != nil { + t.Fatalf("configureSecondaryStorage error: %v", err) + } + if !strings.Contains(result, "SECONDARY_PATH=/mnt/secondary") { + t.Fatalf("expected corrected secondary path in template: %q", result) + } +} + +func TestConfigureSecondaryStorageRejectsInvalidLogPath(t *testing.T) { + var result string + var err error + ctx := context.Background() + reader := bufio.NewReader(strings.NewReader("y\n/mnt/secondary\nremote:/logs\n\n")) + captureStdout(t, func() { + result, err = configureSecondaryStorage(ctx, reader, "") + }) + if err != nil { + t.Fatalf("configureSecondaryStorage error: %v", err) + } + if !strings.Contains(result, "SECONDARY_LOG_PATH=") { + t.Fatalf("expected empty secondary log path in template: %q", result) + } +} + func TestConfigureSecondaryStorageDisabled(t *testing.T) { var result string var err error @@ -242,6 +421,38 @@ func TestConfigureSecondaryStorageDisabled(t *testing.T) { if !strings.Contains(result, "SECONDARY_ENABLED=false") { t.Fatalf("expected disabled flag in template: %q", result) } + if !strings.Contains(result, "SECONDARY_PATH=") { + t.Fatalf("expected cleared secondary path in template: %q", result) + } + if !strings.Contains(result, "SECONDARY_LOG_PATH=") { + t.Fatalf("expected cleared secondary log path in template: %q", result) + } +} + +func TestConfigureSecondaryStorageDisabledClearsExistingValues(t *testing.T) { + var result string + var err error + ctx := context.Background() + reader := bufio.NewReader(strings.NewReader("n\n")) + template := "SECONDARY_ENABLED=true\nSECONDARY_PATH=/mnt/old-secondary\nSECONDARY_LOG_PATH=/mnt/old-secondary/logs\n" + captureStdout(t, func() { + result, err = configureSecondaryStorage(ctx, reader, template) + }) + if err != nil { + t.Fatalf("configureSecondaryStorage error: %v", err) + } + for _, needle := range []string{ + "SECONDARY_ENABLED=false", + "SECONDARY_PATH=", + "SECONDARY_LOG_PATH=", + } { + if !strings.Contains(result, needle) { + t.Fatalf("expected %q in template: %q", needle, result) + } + } + if strings.Contains(result, "/mnt/old-secondary") { + t.Fatalf("expected old secondary values to be cleared: %q", result) + } } func TestConfigureCloudStorageEnabled(t *testing.T) { @@ -367,6 +578,130 @@ func TestConfigureEncryption(t *testing.T) { } } +func TestConfigureCronTime(t *testing.T) { + t.Run("empty input uses default", func(t *testing.T) { + var cronTime string + var err error + reader := bufio.NewReader(strings.NewReader("\n")) + captureStdout(t, func() { + cronTime, err = configureCronTime(context.Background(), reader, cronutil.DefaultTime) + }) + if err != nil { + t.Fatalf("configureCronTime returned error: %v", err) + } + if cronTime != cronutil.DefaultTime { + t.Fatalf("configureCronTime default = %q, want %q", cronTime, cronutil.DefaultTime) + } + }) + + t.Run("invalid input re-prompts until valid", func(t *testing.T) { + var cronTime string + var err error + reader := bufio.NewReader(strings.NewReader("24:00\n3:7\n")) + output := captureStdout(t, func() { + cronTime, err = configureCronTime(context.Background(), reader, cronutil.DefaultTime) + }) + if err != nil { + t.Fatalf("configureCronTime returned error: %v", err) + } + if cronTime != "03:07" { + t.Fatalf("configureCronTime normalized = %q, want %q", cronTime, "03:07") + } + if !strings.Contains(output, "cron hour must be between 00 and 23") { + t.Fatalf("expected validation error in output, got %q", output) + } + }) + + t.Run("aborted input returns sentinel", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + reader := bufio.NewReader(strings.NewReader("03:15\n")) + _, err := configureCronTime(ctx, reader, cronutil.DefaultTime) + if !errors.Is(err, errInteractiveAborted) { + t.Fatalf("expected errInteractiveAborted, got %v", err) + } + }) +} + +func TestRunConfigWizardCLIReturnsCronSchedule(t *testing.T) { + cfgDir := t.TempDir() + configPath := filepath.Join(cfgDir, "env", "backup.env") + tmpConfigPath := configPath + ".tmp" + reader := bufio.NewReader(strings.NewReader("n\nn\nn\nn\nn\nn\n03:15\n")) + + var result installConfigResult + var err error + captureStdout(t, func() { + result, err = runConfigWizardCLI(context.Background(), reader, configPath, tmpConfigPath, "/opt/proxsave", nil) + }) + if err != nil { + t.Fatalf("runConfigWizardCLI returned error: %v", err) + } + if result.SkipConfigWizard { + t.Fatal("expected SkipConfigWizard=false") + } + if result.EnableEncryption { + t.Fatal("expected EnableEncryption=false") + } + if result.CronSchedule != "15 03 * * *" { + t.Fatalf("CronSchedule = %q, want %q", result.CronSchedule, "15 03 * * *") + } + + content, readErr := os.ReadFile(configPath) + if readErr != nil { + t.Fatalf("expected config file to be written: %v", readErr) + } + if !strings.Contains(string(content), "ENCRYPT_ARCHIVE=false") { + t.Fatalf("expected config content to be written, got %q", string(content)) + } +} + +func TestRunConfigWizardCLISkipLeavesCronScheduleEmpty(t *testing.T) { + cfgFile := createTempFile(t, "EXISTING=1\n") + tmpConfigPath := cfgFile + ".tmp" + reader := bufio.NewReader(strings.NewReader("3\n")) + + var result installConfigResult + var err error + captureStdout(t, func() { + result, err = runConfigWizardCLI(context.Background(), reader, cfgFile, tmpConfigPath, "/opt/proxsave", nil) + }) + if err != nil { + t.Fatalf("runConfigWizardCLI returned error: %v", err) + } + if !result.SkipConfigWizard { + t.Fatal("expected SkipConfigWizard=true") + } + if result.CronSchedule != "" { + t.Fatalf("expected empty CronSchedule when skipping wizard, got %q", result.CronSchedule) + } +} + +func TestRunConfigWizardCLIAbortAtCronPromptDoesNotWriteConfig(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "env", "backup.env") + tmpConfigPath := configPath + ".tmp" + + originalConfigureCronTime := configureCronTimeFunc + t.Cleanup(func() { configureCronTimeFunc = originalConfigureCronTime }) + + configureCronTimeFunc = func(ctx context.Context, reader *bufio.Reader, defaultCron string) (string, error) { + return "", errInteractiveAborted + } + + reader := bufio.NewReader(strings.NewReader("n\nn\nn\nn\nn\nn\n")) + + _, err := runConfigWizardCLI(context.Background(), reader, configPath, tmpConfigPath, "/opt/proxsave", nil) + if !errors.Is(err, errInteractiveAborted) { + t.Fatalf("expected errInteractiveAborted, got %v", err) + } + if _, statErr := os.Stat(configPath); !os.IsNotExist(statErr) { + t.Fatalf("expected config file not to exist, got err=%v", statErr) + } + if _, statErr := os.Stat(tmpConfigPath); !os.IsNotExist(statErr) { + t.Fatalf("expected temp config file not to exist, got err=%v", statErr) + } +} + func createTempFile(t *testing.T, content string) string { t.Helper() f, err := os.CreateTemp(t.TempDir(), "config-*.env") diff --git a/cmd/proxsave/install_tui.go b/cmd/proxsave/install_tui.go index 93fb39c7..bb9555d0 100644 --- a/cmd/proxsave/install_tui.go +++ b/cmd/proxsave/install_tui.go @@ -5,14 +5,11 @@ import ( "errors" "fmt" "os" - "path/filepath" "strings" - "filippo.io/age" - + cronutil "github.com/tis24dev/proxsave/internal/cron" "github.com/tis24dev/proxsave/internal/identity" "github.com/tis24dev/proxsave/internal/logging" - "github.com/tis24dev/proxsave/internal/orchestrator" "github.com/tis24dev/proxsave/internal/tui/wizard" ) @@ -65,7 +62,7 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo // Check if config exists logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "checking existing configuration") - existingAction, err := wizard.CheckExistingConfig(configPath, buildSig) + existingAction, err := wizard.CheckExistingConfig(ctx, configPath, buildSig) if err != nil { return err } @@ -75,9 +72,12 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo baseTemplate := "" switch existingAction { - case wizard.ExistingConfigSkip: - logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "user skipped configuration") + case wizard.ExistingConfigCancel: + logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "user cancelled installation") return wrapInstallError(errInteractiveAborted) + case wizard.ExistingConfigKeepContinue: + logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "using existing configuration and skipping wizard") + skipConfigWizard = true case wizard.ExistingConfigEdit: logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "editing existing configuration") content, readErr := os.ReadFile(configPath) @@ -122,7 +122,9 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo return err } - bootstrap.Debug("Configuration saved at %s", configPath) + if bootstrap != nil { + bootstrap.Debug("Configuration saved at %s", configPath) + } } // Install support docs @@ -136,73 +138,49 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo if bootstrap != nil { bootstrap.Info("Running initial encryption setup (AGE recipients)") } - logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "running AGE setup wizard") - recipientPath := filepath.Join(baseDir, "identity", "age", "recipient.txt") - ageData, err := wizard.RunAgeSetupWizard(ctx, recipientPath, configPath, buildSig) + logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "running AGE setup via orchestrator") + setupResult, err := runInitialEncryptionSetupWithUI(ctx, configPath, wizard.NewAgeSetupUI(configPath, buildSig)) if err != nil { - if errors.Is(err, wizard.ErrAgeSetupCancelled) { - return fmt.Errorf("encryption setup aborted by user: %w", errInteractiveAborted) - } else { - return fmt.Errorf("AGE setup failed: %w", err) - } + return err } - // Process the AGE data based on setup type - var recipientKey string - switch ageData.SetupType { - case "existing": - recipientKey = ageData.PublicKey - case "passphrase": - // Derive recipient from passphrase - recipient, err := deriveRecipientFromPassphrase(ageData.Passphrase) - if err != nil { - return fmt.Errorf("failed to derive recipient from passphrase: %w", err) - } - recipientKey = recipient - case "privatekey": - // Derive recipient from private key - recipient, err := deriveRecipientFromPrivateKey(ageData.PrivateKey) - if err != nil { - return fmt.Errorf("failed to derive recipient from private key: %w", err) + if bootstrap != nil { + bootstrap.Info("AGE encryption configured successfully") + if setupResult.WroteRecipientFile && setupResult.RecipientPath != "" { + bootstrap.Info("Recipient saved to: %s", setupResult.RecipientPath) + } else if setupResult.ReusedExistingRecipients { + bootstrap.Info("Using existing AGE recipient configuration") } - recipientKey = recipient + bootstrap.Info("IMPORTANT: Keep your passphrase/private key offline and secure!") } - - // Save the recipient - logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "saving AGE recipient") - if err := wizard.SaveAgeRecipient(recipientPath, recipientKey); err != nil { - return fmt.Errorf("failed to save AGE recipient: %w", err) - } - - bootstrap.Info("AGE encryption configured successfully") - bootstrap.Info("Recipient saved to: %s", recipientPath) - bootstrap.Info("IMPORTANT: Keep your passphrase/private key offline and secure!") } // Optional post-install audit: run a dry-run and offer to disable unused collectors // based on actionable warning hints like "set BACKUP_*=false to disable". - auditRes, auditErr := wizard.RunPostInstallAuditWizard(ctx, execInfo.ExecPath, configPath, buildSig) - if bootstrap != nil { - if auditErr != nil { - bootstrap.Warning("Post-install check failed (non-blocking): %v", auditErr) - } else { - switch { - case !auditRes.Ran: - bootstrap.Info("Post-install audit: skipped by user") - case auditRes.CollectErr != nil: - bootstrap.Warning("Post-install audit failed (non-blocking): %v", auditRes.CollectErr) - case len(auditRes.Suggestions) == 0: - bootstrap.Info("Post-install audit: no unused components detected") - default: - keys := make([]string, 0, len(auditRes.Suggestions)) - for _, s := range auditRes.Suggestions { - keys = append(keys, s.Key) - } - bootstrap.Info("Post-install audit: suggested disables (%d): %s", len(keys), strings.Join(keys, ", ")) - if len(auditRes.AppliedKeys) > 0 { - bootstrap.Info("Post-install audit: disabled (%d): %s", len(auditRes.AppliedKeys), strings.Join(auditRes.AppliedKeys, ", ")) - } else { - bootstrap.Info("Post-install audit: no disables applied") + if !skipConfigWizard { + auditRes, auditErr := wizard.RunPostInstallAuditWizard(ctx, execInfo.ExecPath, configPath, buildSig) + if bootstrap != nil { + if auditErr != nil { + bootstrap.Warning("Post-install check failed (non-blocking): %v", auditErr) + } else { + switch { + case !auditRes.Ran: + bootstrap.Info("Post-install audit: skipped by user") + case auditRes.CollectErr != nil: + bootstrap.Warning("Post-install audit failed (non-blocking): %v", auditRes.CollectErr) + case len(auditRes.Suggestions) == 0: + bootstrap.Info("Post-install audit: no unused components detected") + default: + keys := make([]string, 0, len(auditRes.Suggestions)) + for _, s := range auditRes.Suggestions { + keys = append(keys, s.Key) + } + bootstrap.Info("Post-install audit: suggested disables (%d): %s", len(keys), strings.Join(keys, ", ")) + if len(auditRes.AppliedKeys) > 0 { + bootstrap.Info("Post-install audit: disabled (%d): %s", len(auditRes.AppliedKeys), strings.Join(auditRes.AppliedKeys, ", ")) + } else { + bootstrap.Info("Post-install audit: no disables applied") + } } } } @@ -210,21 +188,16 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo // Telegram setup (centralized bot): if enabled during install, guide the user through // pairing and allow an explicit verification step with retry + skip. - if wizardData != nil && (wizardData.NotificationMode == "telegram" || wizardData.NotificationMode == "both") { + if !skipConfigWizard && wizardData != nil && (wizardData.NotificationMode == "telegram" || wizardData.NotificationMode == "both") { telegramRes, telegramErr := wizard.RunTelegramSetupWizard(ctx, baseDir, configPath, buildSig) if telegramErr != nil && bootstrap != nil { bootstrap.Warning("Telegram setup failed (non-blocking): %v", telegramErr) } + if bootstrap != nil && telegramErr == nil { + logTelegramSetupBootstrapOutcome(bootstrap, telegramRes.TelegramSetupBootstrap) + } if bootstrap != nil && telegramRes.Shown { - if telegramRes.ConfigError != "" { - bootstrap.Warning("Telegram setup: failed to load config (non-blocking): %s", telegramRes.ConfigError) - } - if telegramRes.IdentityDetectError != "" { - bootstrap.Warning("Telegram setup: identity detection issue (non-blocking): %s", telegramRes.IdentityDetectError) - } - if telegramRes.TelegramMode == "personal" { - bootstrap.Info("Telegram setup: personal mode selected (no centralized pairing check)") - } else if telegramRes.Verified { + if telegramRes.Verified { bootstrap.Info("Telegram setup: verified (code=%d)", telegramRes.LastStatusCode) } else if telegramRes.SkippedVerification { bootstrap.Info("Telegram setup: verification skipped by user") @@ -251,12 +224,16 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo ensureGoSymlink(execInfo.ExecPath, bootstrap) // Migrate legacy cron entries - cronSchedule := resolveCronSchedule(wizardData) + wizardCronSchedule := "" + if wizardData != nil { + wizardCronSchedule = cronutil.TimeToSchedule(wizardData.CronTime) + } + cronSchedule := buildInstallCronSchedule(skipConfigWizard, wizardCronSchedule) logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "migrating cron entries") migrateLegacyCronEntries(ctx, baseDir, execInfo.ExecPath, bootstrap, cronSchedule) // Attempt to resolve or create a server identity for Telegram pairing - if info, err := identity.Detect(baseDir, nil); err == nil { + if info, err := identity.DetectWithContext(ctx, baseDir, nil); err == nil { if code := info.ServerID; code != "" { telegramCode = code } @@ -275,23 +252,3 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo return nil } - -// deriveRecipientFromPassphrase derives a deterministic AGE recipient from a passphrase -func deriveRecipientFromPassphrase(passphrase string) (string, error) { - return orchestrator.DeriveDeterministicRecipientFromPassphrase(passphrase) -} - -// deriveRecipientFromPrivateKey derives the recipient (public key) from an AGE private key -func deriveRecipientFromPrivateKey(privateKey string) (string, error) { - privateKey = strings.TrimSpace(privateKey) - if privateKey == "" { - return "", fmt.Errorf("private key cannot be empty") - } - - identity, err := age.ParseX25519Identity(privateKey) - if err != nil { - return "", fmt.Errorf("invalid AGE private key: %w", err) - } - - return identity.Recipient().String(), nil -} diff --git a/cmd/proxsave/main.go b/cmd/proxsave/main.go index 73d06bd6..26e05ec3 100644 --- a/cmd/proxsave/main.go +++ b/cmd/proxsave/main.go @@ -786,7 +786,7 @@ func run() int { serverIDValue := strings.TrimSpace(cfg.ServerID) serverMACValue := "" telegramServerStatus := "Telegram disabled" - if info, err := identity.Detect(cfg.BaseDir, logger); err != nil { + if info, err := identity.DetectWithContext(ctx, cfg.BaseDir, logger); err != nil { logging.Warning("WARNING: Failed to load server identity: %v", err) identityInfo = info } else { @@ -1566,10 +1566,7 @@ func printNetworkRollbackCountdown(abortInfo *orchestrator.RestoreAbortInfo) { } fmt.Printf("\r Remaining: %ds ", int(remaining.Seconds())) - select { - case <-ticker.C: - continue - } + <-ticker.C } fmt.Printf("%s===========================================%s\n", color, colorReset) @@ -1637,7 +1634,7 @@ func printFinalSummary(finalExitCode int) { fmt.Println(" --help - Show all options") fmt.Println(" --dry-run - Test without changes") fmt.Println(" --install - Re-run interactive installation/setup") - fmt.Println(" --new-install - Wipe installation directory (keep env/identity) then run installer") + fmt.Println(" --new-install - Wipe installation directory (keep build/env/identity) then run installer") fmt.Println(" --env-migration - Run installer and migrate legacy Bash backup.env to Go template") fmt.Println(" --env-migration-dry-run - Preview installer/migration without writing files") fmt.Println(" --upgrade - Update proxsave binary to latest release (also adds missing keys to backup.env)") diff --git a/cmd/proxsave/new_install.go b/cmd/proxsave/new_install.go new file mode 100644 index 00000000..7e8ae2c9 --- /dev/null +++ b/cmd/proxsave/new_install.go @@ -0,0 +1,83 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "sort" + "strings" +) + +type newInstallPlan struct { + ResolvedConfigPath string + BaseDir string + BuildSignature string + PreservedEntries []string +} + +var newInstallBuildSignature = buildSignature + +func buildNewInstallPlan(configPath string) (newInstallPlan, error) { + resolvedPath, err := resolveInstallConfigPath(configPath) + if err != nil { + return newInstallPlan{}, err + } + + buildSig := strings.TrimSpace(newInstallBuildSignature()) + if buildSig == "" { + buildSig = "n/a" + } + + return newInstallPlan{ + ResolvedConfigPath: resolvedPath, + BaseDir: deriveBaseDirFromConfig(resolvedPath), + BuildSignature: buildSig, + PreservedEntries: newInstallPreservedEntries(), + }, nil +} + +func newInstallPreservedEntries() []string { + preserved := []string{"env", "identity", "build"} + sort.Strings(preserved) + return preserved +} + +func newInstallPreserveSet() map[string]struct{} { + preserved := newInstallPreservedEntries() + result := make(map[string]struct{}, len(preserved)) + for _, entry := range preserved { + result[entry] = struct{}{} + } + return result +} + +func formatNewInstallPreservedEntries(entries []string) string { + formatted := make([]string, 0, len(entries)) + for _, entry := range entries { + trimmed := strings.TrimSpace(entry) + if trimmed == "" { + continue + } + formatted = append(formatted, trimmed+"/") + } + if len(formatted) == 0 { + return "(none)" + } + return strings.Join(formatted, " ") +} + +func confirmNewInstallCLI(ctx context.Context, reader *bufio.Reader, plan newInstallPlan) (bool, error) { + if reader == nil { + reader = bufio.NewReader(os.Stdin) + } + + fmt.Println() + fmt.Println("--- New installation reset ---") + fmt.Printf("Base directory: %s\n", plan.BaseDir) + fmt.Printf("Build signature: %s\n", plan.BuildSignature) + fmt.Printf("Preserved entries: %s\n", formatNewInstallPreservedEntries(plan.PreservedEntries)) + fmt.Println("Everything else under the base directory will be removed.") + + return promptYesNo(ctx, reader, "Continue? [y/N]: ", false) +} diff --git a/cmd/proxsave/new_install_test.go b/cmd/proxsave/new_install_test.go new file mode 100644 index 00000000..de43452e --- /dev/null +++ b/cmd/proxsave/new_install_test.go @@ -0,0 +1,425 @@ +package main + +import ( + "bufio" + "context" + "errors" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func testNewInstallPreservedEntries() []string { + return []string{"build", "env", "identity"} +} + +func registerNewInstallBuildSignature(t *testing.T, fn func() string) { + t.Helper() + original := newInstallBuildSignature + newInstallBuildSignature = fn + t.Cleanup(func() { + newInstallBuildSignature = original + }) +} + +func registerTestStdin(t *testing.T, content string) { + t.Helper() + original := os.Stdin + file, err := os.CreateTemp(t.TempDir(), "stdin-*") + if err != nil { + t.Fatalf("create temp stdin: %v", err) + } + if _, err := file.WriteString(content); err != nil { + _ = file.Close() + t.Fatalf("write temp stdin: %v", err) + } + if _, err := file.Seek(0, 0); err != nil { + _ = file.Close() + t.Fatalf("seek temp stdin: %v", err) + } + os.Stdin = file + t.Cleanup(func() { + os.Stdin = original + _ = file.Close() + }) +} + +func TestNewInstallPreservedEntries(t *testing.T) { + got := newInstallPreservedEntries() + want := testNewInstallPreservedEntries() + if !reflect.DeepEqual(got, want) { + t.Fatalf("newInstallPreservedEntries() = %#v, want %#v", got, want) + } +} + +func TestBuildNewInstallPlan(t *testing.T) { + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + + plan, err := buildNewInstallPlan(configPath) + if err != nil { + t.Fatalf("buildNewInstallPlan error: %v", err) + } + if plan.ResolvedConfigPath != configPath { + t.Fatalf("resolved config path = %q, want %q", plan.ResolvedConfigPath, configPath) + } + if plan.BaseDir != baseDir { + t.Fatalf("base dir = %q, want %q", plan.BaseDir, baseDir) + } + if strings.TrimSpace(plan.BuildSignature) == "" { + t.Fatalf("build signature should not be empty") + } + if !reflect.DeepEqual(plan.PreservedEntries, newInstallPreservedEntries()) { + t.Fatalf("preserved entries = %#v, want %#v", plan.PreservedEntries, newInstallPreservedEntries()) + } +} + +func TestBuildNewInstallPlanUsesNAWhenBuildSignatureBlank(t *testing.T) { + registerNewInstallBuildSignature(t, func() string { return " " }) + + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + + plan, err := buildNewInstallPlan(configPath) + if err != nil { + t.Fatalf("buildNewInstallPlan error: %v", err) + } + if plan.BuildSignature != "n/a" { + t.Fatalf("build signature = %q, want %q", plan.BuildSignature, "n/a") + } +} + +func TestBuildNewInstallPlanRejectsEmptyConfigPath(t *testing.T) { + if _, err := buildNewInstallPlan(" "); err == nil { + t.Fatalf("expected error for empty config path") + } +} + +func TestFormatNewInstallPreservedEntries(t *testing.T) { + tests := []struct { + name string + entries []string + want string + }{ + { + name: "formats trimmed entries", + entries: []string{" build ", "env", " identity"}, + want: "build/ env/ identity/", + }, + { + name: "returns none for nil input", + entries: nil, + want: "(none)", + }, + { + name: "returns none for blank input", + entries: []string{"", " ", "\t"}, + want: "(none)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatNewInstallPreservedEntries(tt.entries); got != tt.want { + t.Fatalf("formatNewInstallPreservedEntries(%v) = %q, want %q", tt.entries, got, tt.want) + } + }) + } +} + +func TestConfirmNewInstallCLIContinue(t *testing.T) { + plan := newInstallPlan{ + BaseDir: "/opt/proxsave", + BuildSignature: "sig-123", + PreservedEntries: testNewInstallPreservedEntries(), + } + + reader := bufio.NewReader(strings.NewReader("y\n")) + var confirmed bool + var err error + output := captureStdout(t, func() { + confirmed, err = confirmNewInstallCLI(context.Background(), reader, plan) + }) + if err != nil { + t.Fatalf("confirmNewInstallCLI error: %v", err) + } + if !confirmed { + t.Fatalf("expected confirmation=true") + } + if !strings.Contains(output, "Preserved entries: build/ env/ identity/") { + t.Fatalf("expected preserved entries output, got %q", output) + } +} + +func TestConfirmNewInstallCLIContextCancelled(t *testing.T) { + plan := newInstallPlan{ + BaseDir: "/opt/proxsave", + BuildSignature: "sig-123", + PreservedEntries: testNewInstallPreservedEntries(), + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := confirmNewInstallCLI(ctx, bufio.NewReader(strings.NewReader("y\n")), plan) + if !errors.Is(err, errInteractiveAborted) { + t.Fatalf("expected errInteractiveAborted, got %v", err) + } +} + +func TestConfirmNewInstallCLIUsesStdinWhenReaderNil(t *testing.T) { + registerTestStdin(t, "y\n") + + plan := newInstallPlan{ + BaseDir: "/opt/proxsave", + BuildSignature: "sig-123", + PreservedEntries: testNewInstallPreservedEntries(), + } + + var confirmed bool + var err error + output := captureStdout(t, func() { + confirmed, err = confirmNewInstallCLI(context.Background(), nil, plan) + }) + if err != nil { + t.Fatalf("confirmNewInstallCLI error: %v", err) + } + if !confirmed { + t.Fatalf("expected confirmation=true") + } + if !strings.Contains(output, "Continue? [y/N]: ") { + t.Fatalf("expected prompt in output, got %q", output) + } +} + +func TestRunNewInstallCLIUsesCLIConfirmOnly(t *testing.T) { + originalEnsure := newInstallEnsureInteractiveStdin + originalConfirmCLI := newInstallConfirmCLI + originalConfirmTUI := newInstallConfirmTUI + originalRunInstall := newInstallRunInstall + originalRunInstallTUI := newInstallRunInstallTUI + defer func() { + newInstallEnsureInteractiveStdin = originalEnsure + newInstallConfirmCLI = originalConfirmCLI + newInstallConfirmTUI = originalConfirmTUI + newInstallRunInstall = originalRunInstall + newInstallRunInstallTUI = originalRunInstallTUI + }() + + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + stalePath := filepath.Join(baseDir, "stale.txt") + if err := os.WriteFile(stalePath, []byte("stale"), 0o600); err != nil { + t.Fatalf("write stale marker: %v", err) + } + + newInstallEnsureInteractiveStdin = func() error { return nil } + + cliConfirmCalled := false + newInstallConfirmCLI = func(ctx context.Context, reader *bufio.Reader, plan newInstallPlan) (bool, error) { + cliConfirmCalled = true + if plan.BaseDir != baseDir { + t.Fatalf("plan base dir = %q, want %q", plan.BaseDir, baseDir) + } + return true, nil + } + + newInstallConfirmTUI = func(ctx context.Context, baseDirArg, buildSig string, preservedEntries []string) (bool, error) { + t.Fatalf("TUI confirmation must not be called in --cli mode") + return false, nil + } + + runInstallCalled := false + newInstallRunInstall = func(ctx context.Context, cfg string, bootstrap *logging.BootstrapLogger) error { + runInstallCalled = true + if cfg != configPath { + t.Fatalf("runInstall config path = %q, want %q", cfg, configPath) + } + return nil + } + newInstallRunInstallTUI = func(ctx context.Context, cfg string, bootstrap *logging.BootstrapLogger) error { + t.Fatalf("runInstallTUI must not be called in --cli mode") + return nil + } + + if err := runNewInstall(context.Background(), configPath, logging.NewBootstrapLogger(), true); err != nil { + t.Fatalf("runNewInstall error: %v", err) + } + if !cliConfirmCalled { + t.Fatalf("expected CLI confirmation to be called") + } + if !runInstallCalled { + t.Fatalf("expected runInstall to be called") + } + if _, err := os.Stat(stalePath); !os.IsNotExist(err) { + t.Fatalf("expected stale marker to be removed by reset, got err=%v", err) + } +} + +func TestRunNewInstallCancelSkipsReset(t *testing.T) { + originalEnsure := newInstallEnsureInteractiveStdin + originalConfirmCLI := newInstallConfirmCLI + originalRunInstall := newInstallRunInstall + originalRunInstallTUI := newInstallRunInstallTUI + defer func() { + newInstallEnsureInteractiveStdin = originalEnsure + newInstallConfirmCLI = originalConfirmCLI + newInstallRunInstall = originalRunInstall + newInstallRunInstallTUI = originalRunInstallTUI + }() + + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + markerPath := filepath.Join(baseDir, "marker.txt") + if err := os.WriteFile(markerPath, []byte("keep"), 0o600); err != nil { + t.Fatalf("write marker: %v", err) + } + + newInstallEnsureInteractiveStdin = func() error { return nil } + newInstallConfirmCLI = func(ctx context.Context, reader *bufio.Reader, plan newInstallPlan) (bool, error) { + return false, nil + } + newInstallRunInstall = func(ctx context.Context, cfg string, bootstrap *logging.BootstrapLogger) error { + t.Fatalf("runInstall must not be called on cancel") + return nil + } + newInstallRunInstallTUI = func(ctx context.Context, cfg string, bootstrap *logging.BootstrapLogger) error { + t.Fatalf("runInstallTUI must not be called on cancel") + return nil + } + + err := runNewInstall(context.Background(), configPath, logging.NewBootstrapLogger(), true) + if !errors.Is(err, errInteractiveAborted) { + t.Fatalf("expected interactive abort, got %v", err) + } + if _, statErr := os.Stat(markerPath); statErr != nil { + t.Fatalf("expected marker to remain after cancel, got %v", statErr) + } +} + +func TestRunNewInstallTUIPassesContextToConfirm(t *testing.T) { + originalEnsure := newInstallEnsureInteractiveStdin + originalConfirmCLI := newInstallConfirmCLI + originalConfirmTUI := newInstallConfirmTUI + originalRunInstall := newInstallRunInstall + originalRunInstallTUI := newInstallRunInstallTUI + defer func() { + newInstallEnsureInteractiveStdin = originalEnsure + newInstallConfirmCLI = originalConfirmCLI + newInstallConfirmTUI = originalConfirmTUI + newInstallRunInstall = originalRunInstall + newInstallRunInstallTUI = originalRunInstallTUI + }() + + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + ctx := t.Context() + + newInstallEnsureInteractiveStdin = func() error { return nil } + newInstallConfirmCLI = func(ctx context.Context, reader *bufio.Reader, plan newInstallPlan) (bool, error) { + t.Fatalf("CLI confirmation must not be called in TUI mode") + return false, nil + } + newInstallConfirmTUI = func(gotCtx context.Context, baseDirArg, buildSig string, preservedEntries []string) (bool, error) { + if gotCtx != ctx { + t.Fatalf("got context %p, want %p", gotCtx, ctx) + } + if baseDirArg != baseDir { + t.Fatalf("baseDir=%q, want %q", baseDirArg, baseDir) + } + return false, nil + } + newInstallRunInstall = func(ctx context.Context, cfg string, bootstrap *logging.BootstrapLogger) error { + t.Fatalf("runInstall must not be called in TUI mode") + return nil + } + newInstallRunInstallTUI = func(ctx context.Context, cfg string, bootstrap *logging.BootstrapLogger) error { + t.Fatalf("runInstallTUI must not be called when confirmation is declined") + return nil + } + + err := runNewInstall(ctx, configPath, logging.NewBootstrapLogger(), false) + if !errors.Is(err, errInteractiveAborted) { + t.Fatalf("expected interactive abort, got %v", err) + } +} + +func TestRunNewInstallTUIUsesTUIConfirmAndRunInstallTUI(t *testing.T) { + originalEnsure := newInstallEnsureInteractiveStdin + originalConfirmCLI := newInstallConfirmCLI + originalConfirmTUI := newInstallConfirmTUI + originalRunInstall := newInstallRunInstall + originalRunInstallTUI := newInstallRunInstallTUI + defer func() { + newInstallEnsureInteractiveStdin = originalEnsure + newInstallConfirmCLI = originalConfirmCLI + newInstallConfirmTUI = originalConfirmTUI + newInstallRunInstall = originalRunInstall + newInstallRunInstallTUI = originalRunInstallTUI + }() + + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + stalePath := filepath.Join(baseDir, "stale.txt") + if err := os.WriteFile(stalePath, []byte("stale"), 0o600); err != nil { + t.Fatalf("write stale marker: %v", err) + } + + ctx := t.Context() + newInstallEnsureInteractiveStdin = func() error { return nil } + newInstallConfirmCLI = func(ctx context.Context, reader *bufio.Reader, plan newInstallPlan) (bool, error) { + t.Fatalf("CLI confirmation must not be called in TUI mode") + return false, nil + } + + tuiConfirmCalled := false + newInstallConfirmTUI = func(gotCtx context.Context, baseDirArg, buildSig string, preservedEntries []string) (bool, error) { + tuiConfirmCalled = true + if gotCtx != ctx { + t.Fatalf("got context %p, want %p", gotCtx, ctx) + } + if baseDirArg != baseDir { + t.Fatalf("baseDir=%q, want %q", baseDirArg, baseDir) + } + if strings.TrimSpace(buildSig) == "" { + t.Fatalf("expected non-empty build signature") + } + if !reflect.DeepEqual(preservedEntries, testNewInstallPreservedEntries()) { + t.Fatalf("preservedEntries=%#v, want %#v", preservedEntries, testNewInstallPreservedEntries()) + } + return true, nil + } + + newInstallRunInstall = func(ctx context.Context, cfg string, bootstrap *logging.BootstrapLogger) error { + t.Fatalf("runInstall must not be called in TUI mode") + return nil + } + + runInstallTuiCalled := false + newInstallRunInstallTUI = func(gotCtx context.Context, cfg string, bootstrap *logging.BootstrapLogger) error { + runInstallTuiCalled = true + if gotCtx != ctx { + t.Fatalf("got context %p, want %p", gotCtx, ctx) + } + if cfg != configPath { + t.Fatalf("runInstallTUI config path = %q, want %q", cfg, configPath) + } + return nil + } + + if err := runNewInstall(ctx, configPath, logging.NewBootstrapLogger(), false); err != nil { + t.Fatalf("runNewInstall error: %v", err) + } + if !tuiConfirmCalled { + t.Fatalf("expected TUI confirmation to be called") + } + if !runInstallTuiCalled { + t.Fatalf("expected runInstallTUI to be called") + } + if _, err := os.Stat(stalePath); !os.IsNotExist(err) { + t.Fatalf("expected stale marker to be removed by reset, got err=%v", err) + } +} diff --git a/cmd/proxsave/newkey.go b/cmd/proxsave/newkey.go index 9099f652..5adbf853 100644 --- a/cmd/proxsave/newkey.go +++ b/cmd/proxsave/newkey.go @@ -75,99 +75,89 @@ func runNewKey(ctx context.Context, configPath string, logLevel types.LogLevel, } func runNewKeyTUI(ctx context.Context, configPath, baseDir string, bootstrap *logging.BootstrapLogger) (err error) { - recipientPath := filepath.Join(baseDir, "identity", "age", "recipient.txt") sig := buildSignature() if strings.TrimSpace(sig) == "" { sig = "n/a" } - done := logging.DebugStartBootstrap(bootstrap, "newkey workflow (tui)", "recipient=%s", recipientPath) + done := logging.DebugStartBootstrap(bootstrap, "newkey workflow (tui)", "config=%s", configPath) defer func() { done(err) }() - // If a recipient already exists, ask for confirmation before overwriting - if _, err := os.Stat(recipientPath); err == nil { - logging.DebugStepBootstrap(bootstrap, "newkey workflow (tui)", "existing recipient found") - confirm, err := wizard.ConfirmRecipientOverwrite(recipientPath, configPath, sig) - if err != nil { - return err - } - if !confirm { - return wrapInstallError(errInteractiveAborted) - } - if err := orchestrator.BackupAgeRecipientFile(recipientPath); err != nil && bootstrap != nil { - bootstrap.Warning("WARNING: %v", err) - } + logging.DebugStepBootstrap(bootstrap, "newkey workflow (tui)", "running AGE setup via orchestrator") + recipientPath, err := runNewKeySetup(ctx, configPath, baseDir, logging.GetDefaultLogger(), wizard.NewAgeSetupUI(configPath, sig)) + if err != nil { + return err } - recipients := make([]string, 0, 2) - for { - logging.DebugStepBootstrap(bootstrap, "newkey workflow (tui)", "running AGE setup wizard") - ageData, err := wizard.RunAgeSetupWizard(ctx, recipientPath, configPath, sig) - if err != nil { - if errors.Is(err, wizard.ErrAgeSetupCancelled) { - return wrapInstallError(errInteractiveAborted) - } - return fmt.Errorf("AGE setup failed: %w", err) - } - - // Process the AGE data based on setup type - var recipientKey string - switch ageData.SetupType { - case "existing": - recipientKey = ageData.PublicKey - case "passphrase": - recipient, err := deriveRecipientFromPassphrase(ageData.Passphrase) - if err != nil { - return fmt.Errorf("failed to derive recipient from passphrase: %w", err) - } - recipientKey = recipient - case "privatekey": - recipient, err := deriveRecipientFromPrivateKey(ageData.PrivateKey) - if err != nil { - return fmt.Errorf("failed to derive recipient from private key: %w", err) - } - recipientKey = recipient - default: - return fmt.Errorf("unknown AGE setup type: %s", ageData.SetupType) - } + logNewKeySuccess(recipientPath, bootstrap) - if err := orchestrator.ValidateRecipientString(recipientKey); err != nil { - return fmt.Errorf("invalid recipient: %w", err) - } - recipients = append(recipients, recipientKey) + return nil +} - logging.DebugStepBootstrap(bootstrap, "newkey workflow (tui)", "recipient count=%d", len(recipients)) - addMore, err := wizard.ConfirmAddRecipient(configPath, sig, len(recipients)) - if err != nil { - return err - } - if !addMore { - break - } +func runNewKeyCLI(ctx context.Context, configPath, baseDir string, logger *logging.Logger, bootstrap *logging.BootstrapLogger) error { + recipientPath, err := runNewKeySetup(ctx, configPath, baseDir, logger, nil) + if err != nil { + return err } - recipients = orchestrator.DedupeRecipientStrings(recipients) - if len(recipients) == 0 { - return fmt.Errorf("no AGE recipients provided") - } - logging.DebugStepBootstrap(bootstrap, "newkey workflow (tui)", "saving recipients") - if err := orchestrator.WriteRecipientFile(recipientPath, recipients); err != nil { - return fmt.Errorf("failed to save AGE recipients: %w", err) + logNewKeySuccess(recipientPath, bootstrap) + + return nil +} + +func logNewKeySuccess(recipientPath string, bootstrap *logging.BootstrapLogger) { + if bootstrap != nil { + bootstrap.Info("✓ New AGE recipient(s) generated and saved to %s", recipientPath) + bootstrap.Info("IMPORTANT: Keep your passphrase/private key offline and secure!") + return } - bootstrap.Info("✓ New AGE recipient(s) generated and saved to %s", recipientPath) - bootstrap.Info("IMPORTANT: Keep your passphrase/private key offline and secure!") + fmt.Printf("✓ New AGE recipient(s) generated and saved to %s\n", recipientPath) + fmt.Println("IMPORTANT: Keep your passphrase/private key offline and secure!") +} - return nil +func modeLabel(useCLI bool) string { + if useCLI { + return "cli" + } + return "tui" } -func runNewKeyCLI(ctx context.Context, configPath, baseDir string, logger *logging.Logger, bootstrap *logging.BootstrapLogger) error { - recipientPath := filepath.Join(baseDir, "identity", "age", "recipient.txt") +func loadNewKeyConfig(configPath, baseDir string) (*config.Config, string, error) { + defaultRecipientPath := filepath.Join(baseDir, "identity", "age", "recipient.txt") cfg := &config.Config{ BaseDir: baseDir, ConfigPath: configPath, EncryptArchive: true, - AgeRecipientFile: recipientPath, + AgeRecipientFile: defaultRecipientPath, + } + + if _, err := os.Stat(configPath); err == nil { + loaded, err := config.LoadConfig(configPath) + if err != nil { + return nil, "", fmt.Errorf("load configuration for newkey: %w", err) + } + cfg = loaded + cfg.BaseDir = baseDir + cfg.ConfigPath = configPath + cfg.EncryptArchive = true + } else if !errors.Is(err, os.ErrNotExist) { + return nil, "", fmt.Errorf("inspect configuration for newkey: %w", err) + } + + recipientPath := strings.TrimSpace(cfg.AgeRecipientFile) + if recipientPath == "" { + recipientPath = defaultRecipientPath + } + cfg.AgeRecipientFile = recipientPath + + return cfg, recipientPath, nil +} + +func runNewKeySetup(ctx context.Context, configPath, baseDir string, logger *logging.Logger, ui orchestrator.AgeSetupUI) (string, error) { + cfg, recipientPath, err := loadNewKeyConfig(configPath, baseDir) + if err != nil { + return "", err } if logger == nil { @@ -179,22 +169,17 @@ func runNewKeyCLI(ctx context.Context, configPath, baseDir string, logger *loggi orch.SetConfig(cfg) orch.SetForceNewAgeRecipient(true) - if err := orch.EnsureAgeRecipientsReady(ctx); err != nil { + if ui != nil { + err = orch.EnsureAgeRecipientsReadyWithUI(ctx, ui) + } else { + err = orch.EnsureAgeRecipientsReady(ctx) + } + if err != nil { if errors.Is(err, orchestrator.ErrAgeRecipientSetupAborted) { - return wrapInstallError(errInteractiveAborted) + return "", wrapInstallError(errInteractiveAborted) } - return fmt.Errorf("AGE setup failed: %w", err) + return "", fmt.Errorf("AGE setup failed: %w", err) } - bootstrap.Info("✓ New AGE recipient(s) generated and saved to %s", recipientPath) - bootstrap.Info("IMPORTANT: Keep your passphrase/private key offline and secure!") - - return nil -} - -func modeLabel(useCLI bool) string { - if useCLI { - return "cli" - } - return "tui" + return recipientPath, nil } diff --git a/cmd/proxsave/newkey_test.go b/cmd/proxsave/newkey_test.go new file mode 100644 index 00000000..91bf11ff --- /dev/null +++ b/cmd/proxsave/newkey_test.go @@ -0,0 +1,151 @@ +package main + +import ( + "bytes" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +func captureNewKeyStdout(t *testing.T, fn func()) string { + t.Helper() + orig := os.Stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("pipe: %v", err) + } + os.Stdout = w + + var buf bytes.Buffer + done := make(chan struct{}) + go func() { + _, _ = io.Copy(&buf, r) + close(done) + }() + + cleaned := false + cleanup := func() { + if cleaned { + return + } + cleaned = true + os.Stdout = orig + _ = w.Close() + <-done + _ = r.Close() + } + t.Cleanup(cleanup) + + fn() + + cleanup() + return buf.String() +} + +func TestLogNewKeySuccessWithoutBootstrapFallsBackToStdout(t *testing.T) { + recipientPath := filepath.Join("/tmp", "identity", "age", "recipient.txt") + + output := captureNewKeyStdout(t, func() { + logNewKeySuccess(recipientPath, nil) + }) + + if !strings.Contains(output, "✓ New AGE recipient(s) generated and saved to "+recipientPath) { + t.Fatalf("expected recipient success message, got %q", output) + } + if !strings.Contains(output, "IMPORTANT: Keep your passphrase/private key offline and secure!") { + t.Fatalf("expected security reminder, got %q", output) + } +} + +func TestLogNewKeySuccessWithBootstrapUsesBootstrapLogger(t *testing.T) { + recipientPath := filepath.Join("/tmp", "identity", "age", "recipient.txt") + bootstrap := logging.NewBootstrapLogger() + bootstrap.SetLevel(types.LogLevelInfo) + + var mirrorBuf bytes.Buffer + mirror := logging.New(types.LogLevelDebug, false) + mirror.SetOutput(&mirrorBuf) + bootstrap.SetMirrorLogger(mirror) + + output := captureNewKeyStdout(t, func() { + logNewKeySuccess(recipientPath, bootstrap) + }) + + if !strings.Contains(output, "✓ New AGE recipient(s) generated and saved to "+recipientPath) { + t.Fatalf("expected bootstrap stdout success message, got %q", output) + } + if !strings.Contains(output, "IMPORTANT: Keep your passphrase/private key offline and secure!") { + t.Fatalf("expected bootstrap stdout security reminder, got %q", output) + } + + mirrorOutput := mirrorBuf.String() + if !strings.Contains(mirrorOutput, "New AGE recipient(s) generated and saved to "+recipientPath) { + t.Fatalf("expected mirror logger success message, got %q", mirrorOutput) + } + if !strings.Contains(mirrorOutput, "IMPORTANT: Keep your passphrase/private key offline and secure!") { + t.Fatalf("expected mirror logger security reminder, got %q", mirrorOutput) + } +} + +func TestLoadNewKeyConfigUsesConfiguredRecipientFile(t *testing.T) { + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + if err := os.MkdirAll(filepath.Dir(configPath), 0o700); err != nil { + t.Fatalf("MkdirAll(%s): %v", filepath.Dir(configPath), err) + } + + customPath := filepath.Join(baseDir, "custom", "recipient.txt") + content := "BASE_DIR=" + baseDir + "\nENCRYPT_ARCHIVE=false\nAGE_RECIPIENT_FILE=" + customPath + "\n" + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("WriteFile(%s): %v", configPath, err) + } + + cfg, recipientPath, err := loadNewKeyConfig(configPath, baseDir) + if err != nil { + t.Fatalf("loadNewKeyConfig error: %v", err) + } + if recipientPath != customPath { + t.Fatalf("recipientPath=%q; want %q", recipientPath, customPath) + } + if cfg == nil { + t.Fatalf("expected config") + } + if cfg.BaseDir != baseDir { + t.Fatalf("BaseDir=%q; want %q", cfg.BaseDir, baseDir) + } + if cfg.ConfigPath != configPath { + t.Fatalf("ConfigPath=%q; want %q", cfg.ConfigPath, configPath) + } + if cfg.AgeRecipientFile != customPath { + t.Fatalf("AgeRecipientFile=%q; want %q", cfg.AgeRecipientFile, customPath) + } + if !cfg.EncryptArchive { + t.Fatalf("EncryptArchive=false; want true") + } +} + +func TestLoadNewKeyConfigFailsForInvalidExistingConfig(t *testing.T) { + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "env", "backup.env") + if err := os.MkdirAll(filepath.Dir(configPath), 0o700); err != nil { + t.Fatalf("MkdirAll(%s): %v", filepath.Dir(configPath), err) + } + + content := "BASE_DIR=" + baseDir + "\nCUSTOM_BACKUP_PATHS=\"\nunterminated\n" + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("WriteFile(%s): %v", configPath, err) + } + + _, _, err := loadNewKeyConfig(configPath, baseDir) + if err == nil { + t.Fatalf("expected loadNewKeyConfig to fail for invalid config") + } + if !strings.Contains(err.Error(), "load configuration for newkey") { + t.Fatalf("expected wrapped configuration load error, got %v", err) + } +} diff --git a/cmd/proxsave/prompts.go b/cmd/proxsave/prompts.go index 15b906af..fe9ca654 100644 --- a/cmd/proxsave/prompts.go +++ b/cmd/proxsave/prompts.go @@ -49,18 +49,25 @@ func promptYesNo(ctx context.Context, reader *bufio.Reader, question string, def func promptNonEmpty(ctx context.Context, reader *bufio.Reader, question string) (string, error) { for { - if err := ctx.Err(); err != nil { - return "", errInteractiveAborted - } - fmt.Print(question) - resp, err := input.ReadLineWithContext(ctx, reader) + resp, err := promptOptional(ctx, reader, question) if err != nil { return "", err } - resp = strings.TrimSpace(resp) if resp != "" { return resp, nil } fmt.Println("Value cannot be empty.") } } + +func promptOptional(ctx context.Context, reader *bufio.Reader, question string) (string, error) { + if err := ctx.Err(); err != nil { + return "", errInteractiveAborted + } + fmt.Print(question) + resp, err := input.ReadLineWithContext(ctx, reader) + if err != nil { + return "", err + } + return strings.TrimSpace(resp), nil +} diff --git a/cmd/proxsave/runtime_helpers.go b/cmd/proxsave/runtime_helpers.go index b95d90eb..0c8e8794 100644 --- a/cmd/proxsave/runtime_helpers.go +++ b/cmd/proxsave/runtime_helpers.go @@ -239,8 +239,13 @@ func resolveHostname() string { } func validateFutureFeatures(cfg *config.Config) error { - if cfg.SecondaryEnabled && cfg.SecondaryPath == "" { - return fmt.Errorf("secondary backup enabled but SECONDARY_PATH is empty") + if cfg.SecondaryEnabled { + if err := config.ValidateRequiredSecondaryPath(cfg.SecondaryPath); err != nil { + return err + } + if err := config.ValidateOptionalSecondaryLogPath(cfg.SecondaryLogPath); err != nil { + return err + } } if cfg.CloudEnabled && cfg.CloudRemote == "" { logging.Warning("Cloud backup enabled but CLOUD_REMOTE is empty – disabling cloud storage for this run") @@ -339,6 +344,9 @@ func fetchStorageStats(ctx context.Context, backend storage.Storage, logger *log func formatStorageInitSummary(name string, cfg *config.Config, location storage.BackupLocation, stats *storage.StorageStats, backups []*types.BackupMetadata) string { retentionConfig := storage.NewRetentionConfigFromConfig(cfg, location) + if retentionConfig.Policy == "gfs" { + retentionConfig = storage.EffectiveGFSRetentionConfig(retentionConfig) + } if stats == nil { reason := "unable to gather stats" diff --git a/cmd/proxsave/runtime_helpers_more_test.go b/cmd/proxsave/runtime_helpers_more_test.go index 2573e8ed..f71a813b 100644 --- a/cmd/proxsave/runtime_helpers_more_test.go +++ b/cmd/proxsave/runtime_helpers_more_test.go @@ -278,15 +278,15 @@ func TestFormatStorageInitSummary(t *testing.T) { cfgGFS := &config.Config{ RetentionPolicy: "gfs", - RetentionDaily: 1, - RetentionWeekly: 0, + RetentionDaily: 0, + RetentionWeekly: 1, RetentionMonthly: 0, RetentionYearly: -1, } now := time.Now() backups := []*types.BackupMetadata{ {Timestamp: now.Add(-1 * time.Hour)}, - {Timestamp: now.Add(-2 * time.Hour)}, + {Timestamp: now.Add(-8 * 24 * time.Hour)}, } stats := &storage.StorageStats{TotalBackups: 2} @@ -294,6 +294,12 @@ func TestFormatStorageInitSummary(t *testing.T) { if !bytes.Contains([]byte(summary), []byte("Kept (est.):")) { t.Fatalf("expected GFS summary to include retention estimates, got: %s", summary) } + if !bytes.Contains([]byte(summary), []byte("Daily: 1/1")) { + t.Fatalf("expected GFS summary to normalize daily tier, got: %s", summary) + } + if !bytes.Contains([]byte(summary), []byte("Weekly: 1/1")) { + t.Fatalf("expected GFS summary to keep one weekly backup, got: %s", summary) + } } func TestCleanupAfterRun(t *testing.T) { diff --git a/cmd/proxsave/schedule_helpers.go b/cmd/proxsave/schedule_helpers.go index 4b7da8a3..d8d5b279 100644 --- a/cmd/proxsave/schedule_helpers.go +++ b/cmd/proxsave/schedule_helpers.go @@ -3,49 +3,35 @@ package main import ( "fmt" "os" - "strconv" "strings" - "github.com/tis24dev/proxsave/internal/tui/wizard" + cronutil "github.com/tis24dev/proxsave/internal/cron" ) -// resolveCronSchedule returns a cron schedule string (e.g. "0 2 * * *") derived from -// wizard data or environment variables, falling back to 02:00 if unavailable. -func resolveCronSchedule(data *wizard.InstallWizardData) string { - // Try wizard data first - if data != nil { - cron := strings.TrimSpace(data.CronTime) - if cron != "" { - if schedule := cronToSchedule(cron); schedule != "" { - return schedule - } - } - } - - // Environment overrides +// resolveCronScheduleFromEnv returns a cron schedule string derived from the +// legacy environment overrides, falling back to 02:00 if unavailable. +func resolveCronScheduleFromEnv() string { if s := strings.TrimSpace(os.Getenv("CRON_SCHEDULE")); s != "" { return s } + hour := strings.TrimSpace(os.Getenv("CRON_HOUR")) min := strings.TrimSpace(os.Getenv("CRON_MINUTE")) if hour != "" && min != "" { return fmt.Sprintf("%s %s * * *", min, hour) } - // Default: 02:00 - return "0 2 * * *" + return cronutil.TimeToSchedule(cronutil.DefaultTime) } -// cronToSchedule converts HH:MM into "MM HH * * *". -func cronToSchedule(cron string) string { - parts := strings.Split(cron, ":") - if len(parts) != 2 { - return "" - } - hour, errH := strconv.Atoi(parts[0]) - min, errM := strconv.Atoi(parts[1]) - if errH != nil || errM != nil || hour < 0 || hour > 23 || min < 0 || min > 59 { - return "" +// buildInstallCronSchedule keeps wizard-driven installs independent from +// env-based overrides while preserving the existing skip-wizard behavior. +func buildInstallCronSchedule(skipConfigWizard bool, cronSchedule string) string { + if !skipConfigWizard { + if schedule := strings.TrimSpace(cronSchedule); schedule != "" { + return schedule + } + return cronutil.TimeToSchedule(cronutil.DefaultTime) } - return fmt.Sprintf("%02d %02d * * *", min, hour) + return resolveCronScheduleFromEnv() } diff --git a/cmd/proxsave/schedule_helpers_test.go b/cmd/proxsave/schedule_helpers_test.go index 1e906af3..fd268f54 100644 --- a/cmd/proxsave/schedule_helpers_test.go +++ b/cmd/proxsave/schedule_helpers_test.go @@ -3,45 +3,14 @@ package main import ( "testing" - "github.com/tis24dev/proxsave/internal/tui/wizard" + cronutil "github.com/tis24dev/proxsave/internal/cron" ) -func TestCronToSchedule(t *testing.T) { - tests := []struct { - name string - in string - want string - }{ - {"valid with padding", "2:5", "05 02 * * *"}, - {"valid already padded", "02:05", "05 02 * * *"}, - {"invalid format", "0205", ""}, - {"invalid hour", "24:00", ""}, - {"invalid minute", "00:60", ""}, - {"non numeric", "aa:bb", ""}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := cronToSchedule(tt.in); got != tt.want { - t.Fatalf("cronToSchedule(%q) = %q, want %q", tt.in, got, tt.want) - } - }) - } -} - -func TestResolveCronSchedule(t *testing.T) { - t.Run("wizard data takes precedence", func(t *testing.T) { - t.Setenv("CRON_SCHEDULE", "0 4 * * *") - data := &wizard.InstallWizardData{CronTime: "03:15"} - if got := resolveCronSchedule(data); got != "15 03 * * *" { - t.Fatalf("resolveCronSchedule(wizard) = %q, want %q", got, "15 03 * * *") - } - }) - +func TestResolveCronScheduleFromEnv(t *testing.T) { t.Run("env CRON_SCHEDULE overrides", func(t *testing.T) { t.Setenv("CRON_SCHEDULE", "5 1 * * *") - if got := resolveCronSchedule(nil); got != "5 1 * * *" { - t.Fatalf("resolveCronSchedule(env) = %q, want %q", got, "5 1 * * *") + if got := resolveCronScheduleFromEnv(); got != "5 1 * * *" { + t.Fatalf("resolveCronScheduleFromEnv() = %q, want %q", got, "5 1 * * *") } }) @@ -49,8 +18,8 @@ func TestResolveCronSchedule(t *testing.T) { t.Setenv("CRON_SCHEDULE", "") t.Setenv("CRON_HOUR", "22") t.Setenv("CRON_MINUTE", "10") - if got := resolveCronSchedule(nil); got != "10 22 * * *" { - t.Fatalf("resolveCronSchedule(hour/minute) = %q, want %q", got, "10 22 * * *") + if got := resolveCronScheduleFromEnv(); got != "10 22 * * *" { + t.Fatalf("resolveCronScheduleFromEnv() = %q, want %q", got, "10 22 * * *") } }) @@ -58,8 +27,31 @@ func TestResolveCronSchedule(t *testing.T) { t.Setenv("CRON_SCHEDULE", "") t.Setenv("CRON_HOUR", "") t.Setenv("CRON_MINUTE", "") - if got := resolveCronSchedule(nil); got != "0 2 * * *" { - t.Fatalf("resolveCronSchedule(default) = %q, want %q", got, "0 2 * * *") + if got := resolveCronScheduleFromEnv(); got != cronutil.TimeToSchedule(cronutil.DefaultTime) { + t.Fatalf("resolveCronScheduleFromEnv() = %q, want %q", got, cronutil.TimeToSchedule(cronutil.DefaultTime)) + } + }) +} + +func TestBuildInstallCronSchedule(t *testing.T) { + t.Run("wizard schedule takes precedence over env", func(t *testing.T) { + t.Setenv("CRON_SCHEDULE", "5 1 * * *") + if got := buildInstallCronSchedule(false, "15 03 * * *"); got != "15 03 * * *" { + t.Fatalf("buildInstallCronSchedule(false, schedule) = %q, want %q", got, "15 03 * * *") + } + }) + + t.Run("wizard run with empty schedule falls back to default time not env", func(t *testing.T) { + t.Setenv("CRON_SCHEDULE", "5 1 * * *") + if got := buildInstallCronSchedule(false, ""); got != cronutil.TimeToSchedule(cronutil.DefaultTime) { + t.Fatalf("buildInstallCronSchedule(false, \"\") = %q, want %q", got, cronutil.TimeToSchedule(cronutil.DefaultTime)) + } + }) + + t.Run("skip wizard uses env fallback", func(t *testing.T) { + t.Setenv("CRON_SCHEDULE", "5 1 * * *") + if got := buildInstallCronSchedule(true, "15 03 * * *"); got != "5 1 * * *" { + t.Fatalf("buildInstallCronSchedule(true, schedule) = %q, want %q", got, "5 1 * * *") } }) } diff --git a/cmd/proxsave/telegram_setup_cli.go b/cmd/proxsave/telegram_setup_cli.go new file mode 100644 index 00000000..f6c7c0c6 --- /dev/null +++ b/cmd/proxsave/telegram_setup_cli.go @@ -0,0 +1,247 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "strconv" + "strings" + "unicode" + "unicode/utf8" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/notify" + "github.com/tis24dev/proxsave/internal/orchestrator" +) + +var ( + telegramSetupBuildBootstrap = orchestrator.BuildTelegramSetupBootstrap + telegramSetupCheckRegistration = notify.CheckTelegramRegistration + telegramSetupPromptYesNo = promptYesNo +) + +const maxTelegramSetupVerificationAttempts = 10 +const maxTelegramSetupStatusMessageLen = 200 + +func sanitizeTelegramSetupStatusMessage(raw string) string { + msg := strings.TrimSpace(raw) + if msg == "" { + return "" + } + + sanitized := stripTelegramTerminalSequences(msg) + sanitized = truncateTelegramSetupStatusMessage(sanitized, maxTelegramSetupStatusMessageLen) + if sanitized != "" { + return sanitized + } + + quoted := strconv.QuoteToASCII(msg) + quoted = strings.TrimPrefix(quoted, `"`) + quoted = strings.TrimSuffix(quoted, `"`) + return truncateTelegramSetupStatusMessage(quoted, maxTelegramSetupStatusMessageLen) +} + +func truncateTelegramSetupStatusMessage(msg string, max int) string { + msg = strings.TrimSpace(msg) + if msg == "" || max <= 0 { + return "" + } + runes := []rune(msg) + if len(runes) <= max { + return msg + } + return string(runes[:max]) + "...(truncated)" +} + +func stripTelegramTerminalSequences(msg string) string { + var b strings.Builder + b.Grow(len(msg)) + pendingSpace := false + + for i := 0; i < len(msg); { + switch msg[i] { + case 0x1b: + i = skipTelegramEscapeSequence(msg, i) + pendingSpace = true + continue + case 0x9b: + i = skipTelegramCSI(msg, i+1) + pendingSpace = true + continue + } + + r, size := utf8.DecodeRuneInString(msg[i:]) + if r == utf8.RuneError && size == 1 { + i++ + continue + } + if unicode.IsSpace(r) || unicode.IsControl(r) { + pendingSpace = true + i += size + continue + } + if !unicode.IsPrint(r) { + i += size + continue + } + if pendingSpace && b.Len() > 0 { + b.WriteByte(' ') + } + pendingSpace = false + b.WriteRune(r) + i += size + } + + return strings.TrimSpace(b.String()) +} + +func skipTelegramEscapeSequence(msg string, i int) int { + if i >= len(msg) || msg[i] != 0x1b { + return i + 1 + } + i++ + if i >= len(msg) { + return i + } + switch msg[i] { + case '[': + return skipTelegramCSI(msg, i+1) + case ']': + return skipTelegramOSC(msg, i+1) + case 'P', 'X', '^', '_': + return skipTelegramST(msg, i+1) + default: + return i + 1 + } +} + +func skipTelegramCSI(msg string, i int) int { + for i < len(msg) { + b := msg[i] + i++ + if b >= 0x40 && b <= 0x7e { + return i + } + } + return i +} + +func skipTelegramOSC(msg string, i int) int { + for i < len(msg) { + switch msg[i] { + case 0x07: + return i + 1 + case 0x1b: + if i+1 < len(msg) && msg[i+1] == '\\' { + return i + 2 + } + } + i++ + } + return i +} + +func skipTelegramST(msg string, i int) int { + for i < len(msg) { + if msg[i] == 0x1b && i+1 < len(msg) && msg[i+1] == '\\' { + return i + 2 + } + i++ + } + return i +} + +func logTelegramSetupBootstrapOutcome(bootstrap *logging.BootstrapLogger, state orchestrator.TelegramSetupBootstrap) { + switch state.Eligibility { + case orchestrator.TelegramSetupSkipConfigError: + if strings.TrimSpace(state.ConfigError) != "" { + logBootstrapWarning(bootstrap, "Telegram setup: unable to load config (skipping): %s", state.ConfigError) + } + case orchestrator.TelegramSetupSkipPersonalMode: + logBootstrapInfo(bootstrap, "Telegram setup: personal mode selected (no centralized pairing check)") + case orchestrator.TelegramSetupSkipIdentityUnavailable: + if strings.TrimSpace(state.IdentityDetectError) != "" { + logBootstrapWarning(bootstrap, "Telegram setup: identity detection failed (non-blocking): %s", state.IdentityDetectError) + return + } + logBootstrapWarning(bootstrap, "Telegram setup: server ID unavailable; skipping") + } +} + +func runTelegramSetupCLI(ctx context.Context, reader *bufio.Reader, baseDir, configPath string, bootstrap *logging.BootstrapLogger) error { + state, err := telegramSetupBuildBootstrap(configPath, baseDir) + if err != nil { + logBootstrapWarning(bootstrap, "Telegram setup bootstrap failed (non-blocking): %v", err) + return nil + } + + logTelegramSetupBootstrapOutcome(bootstrap, state) + if state.Eligibility != orchestrator.TelegramSetupEligibleCentralized { + return nil + } + + fmt.Println("\n--- Telegram setup (optional) ---") + fmt.Println("You enabled Telegram notifications (centralized bot).") + fmt.Printf("Server ID: %s\n", state.ServerID) + if strings.TrimSpace(state.IdentityFile) != "" { + fmt.Printf("Identity file: %s\n", strings.TrimSpace(state.IdentityFile)) + } + fmt.Println() + fmt.Println("1) Open Telegram and start @ProxmoxAN_bot") + fmt.Println("2) Send the Server ID above (digits only)") + fmt.Println("3) Verify pairing (recommended)") + fmt.Println() + + check, err := telegramSetupPromptYesNo(ctx, reader, "Check Telegram pairing now? [Y/n]: ", true) + if err != nil { + return wrapInstallError(err) + } + if !check { + fmt.Println("Skipped verification. You can verify later by running proxsave.") + logBootstrapInfo(bootstrap, "Telegram setup: verification skipped by user") + return nil + } + + attempts := 0 + for { + attempts++ + status := telegramSetupCheckRegistration(ctx, state.ServerAPIHost, state.ServerID, nil) + if status.Code == 200 && status.Error == nil { + fmt.Println("✓ Telegram linked successfully.") + logBootstrapInfo(bootstrap, "Telegram setup: verified (attempts=%d)", attempts) + return nil + } + + msg := sanitizeTelegramSetupStatusMessage(status.Message) + if msg == "" { + msg = "Registration not active yet" + } + fmt.Printf("Telegram: %s\n", msg) + switch status.Code { + case 403, 409: + fmt.Println("Hint: Start the bot, send the Server ID, then retry.") + case 422: + fmt.Println("Hint: The Server ID appears invalid. If this persists, re-run the installer.") + default: + if status.Error != nil { + fmt.Printf("Hint: Check failed: %v\n", status.Error) + } + } + + if attempts >= maxTelegramSetupVerificationAttempts { + fmt.Println("Maximum verification attempts reached. You can retry later by running proxsave.") + logBootstrapInfo(bootstrap, "Telegram setup: not verified (attempts=%d last=%d %s)", attempts, status.Code, msg) + return nil + } + + retry, err := telegramSetupPromptYesNo(ctx, reader, "Check again? [y/N]: ", false) + if err != nil { + return wrapInstallError(err) + } + if !retry { + fmt.Println("Verification not completed. You can retry later by running proxsave.") + logBootstrapInfo(bootstrap, "Telegram setup: not verified (attempts=%d last=%d %s)", attempts, status.Code, msg) + return nil + } + } +} diff --git a/cmd/proxsave/telegram_setup_cli_test.go b/cmd/proxsave/telegram_setup_cli_test.go new file mode 100644 index 00000000..3ab0c7a3 --- /dev/null +++ b/cmd/proxsave/telegram_setup_cli_test.go @@ -0,0 +1,318 @@ +package main + +import ( + "bufio" + "bytes" + "context" + "errors" + "strings" + "testing" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/notify" + "github.com/tis24dev/proxsave/internal/orchestrator" + "github.com/tis24dev/proxsave/internal/types" +) + +func stubTelegramSetupCLIDeps(t *testing.T) { + t.Helper() + + origBuildBootstrap := telegramSetupBuildBootstrap + origCheckRegistration := telegramSetupCheckRegistration + origPromptYesNo := telegramSetupPromptYesNo + + t.Cleanup(func() { + telegramSetupBuildBootstrap = origBuildBootstrap + telegramSetupCheckRegistration = origCheckRegistration + telegramSetupPromptYesNo = origPromptYesNo + }) +} + +func TestRunTelegramSetupCLI_SkipOnConfigError(t *testing.T) { + stubTelegramSetupCLIDeps(t) + + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupSkipConfigError, + ConfigError: "parse failed", + }, nil + } + telegramSetupPromptYesNo = func(ctx context.Context, reader *bufio.Reader, question string, defaultYes bool) (bool, error) { + t.Fatalf("prompt should not run for config skip") + return false, nil + } + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + t.Fatalf("registration check should not run for config skip") + return notify.TelegramRegistrationStatus{} + } + + if err := runTelegramSetupCLI(context.Background(), bufio.NewReader(strings.NewReader("")), t.TempDir(), "/fake/backup.env", logging.NewBootstrapLogger()); err != nil { + t.Fatalf("runTelegramSetupCLI error: %v", err) + } +} + +func TestRunTelegramSetupCLI_SkipOnPersonalMode(t *testing.T) { + stubTelegramSetupCLIDeps(t) + + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupSkipPersonalMode, + ConfigLoaded: true, + TelegramEnabled: true, + TelegramMode: "personal", + }, nil + } + telegramSetupPromptYesNo = func(ctx context.Context, reader *bufio.Reader, question string, defaultYes bool) (bool, error) { + t.Fatalf("prompt should not run for personal mode") + return false, nil + } + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + t.Fatalf("registration check should not run for personal mode") + return notify.TelegramRegistrationStatus{} + } + + if err := runTelegramSetupCLI(context.Background(), bufio.NewReader(strings.NewReader("")), t.TempDir(), "/fake/backup.env", logging.NewBootstrapLogger()); err != nil { + t.Fatalf("runTelegramSetupCLI error: %v", err) + } +} + +func TestRunTelegramSetupCLI_SkipOnMissingIdentity(t *testing.T) { + stubTelegramSetupCLIDeps(t) + + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupSkipIdentityUnavailable, + ConfigLoaded: true, + TelegramEnabled: true, + TelegramMode: "centralized", + IdentityDetectError: "detect failed", + }, nil + } + telegramSetupPromptYesNo = func(ctx context.Context, reader *bufio.Reader, question string, defaultYes bool) (bool, error) { + t.Fatalf("prompt should not run when identity is unavailable") + return false, nil + } + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + t.Fatalf("registration check should not run when identity is unavailable") + return notify.TelegramRegistrationStatus{} + } + + if err := runTelegramSetupCLI(context.Background(), bufio.NewReader(strings.NewReader("")), t.TempDir(), "/fake/backup.env", logging.NewBootstrapLogger()); err != nil { + t.Fatalf("runTelegramSetupCLI error: %v", err) + } +} + +func TestRunTelegramSetupCLI_DeclineVerification(t *testing.T) { + stubTelegramSetupCLIDeps(t) + + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupEligibleCentralized, + ConfigLoaded: true, + TelegramEnabled: true, + TelegramMode: "centralized", + ServerAPIHost: "https://api.example.test", + ServerID: "123456789", + IdentityFile: "/tmp/.server_identity", + }, nil + } + telegramSetupPromptYesNo = func(ctx context.Context, reader *bufio.Reader, question string, defaultYes bool) (bool, error) { + if !strings.Contains(question, "Check Telegram pairing now?") { + t.Fatalf("unexpected question: %s", question) + } + return false, nil + } + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + t.Fatalf("registration check should not run when user declines") + return notify.TelegramRegistrationStatus{} + } + + if err := runTelegramSetupCLI(context.Background(), bufio.NewReader(strings.NewReader("")), t.TempDir(), "/fake/backup.env", logging.NewBootstrapLogger()); err != nil { + t.Fatalf("runTelegramSetupCLI error: %v", err) + } +} + +func TestRunTelegramSetupCLI_VerifiesSuccessfully(t *testing.T) { + stubTelegramSetupCLIDeps(t) + + var promptCalls int + var checkCalls int + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupEligibleCentralized, + ConfigLoaded: true, + TelegramEnabled: true, + TelegramMode: "centralized", + ServerAPIHost: "https://api.example.test", + ServerID: "123456789", + IdentityFile: "/tmp/.server_identity", + }, nil + } + telegramSetupPromptYesNo = func(ctx context.Context, reader *bufio.Reader, question string, defaultYes bool) (bool, error) { + promptCalls++ + if promptCalls != 1 { + t.Fatalf("unexpected prompt call count: %d", promptCalls) + } + return true, nil + } + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + checkCalls++ + if serverAPIHost != "https://api.example.test" { + t.Fatalf("serverAPIHost=%q, want https://api.example.test", serverAPIHost) + } + if serverID != "123456789" { + t.Fatalf("serverID=%q, want 123456789", serverID) + } + return notify.TelegramRegistrationStatus{Code: 200, Message: "ok"} + } + + if err := runTelegramSetupCLI(context.Background(), bufio.NewReader(strings.NewReader("")), t.TempDir(), "/fake/backup.env", logging.NewBootstrapLogger()); err != nil { + t.Fatalf("runTelegramSetupCLI error: %v", err) + } + if promptCalls != 1 { + t.Fatalf("promptCalls=%d, want 1", promptCalls) + } + if checkCalls != 1 { + t.Fatalf("checkCalls=%d, want 1", checkCalls) + } +} + +func TestRunTelegramSetupCLI_StopsAfterMaxVerificationAttempts(t *testing.T) { + stubTelegramSetupCLIDeps(t) + + var promptCalls int + var checkCalls int + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupEligibleCentralized, + ConfigLoaded: true, + TelegramEnabled: true, + TelegramMode: "centralized", + ServerAPIHost: "https://api.example.test", + ServerID: "123456789", + IdentityFile: "/tmp/.server_identity", + }, nil + } + telegramSetupPromptYesNo = func(ctx context.Context, reader *bufio.Reader, question string, defaultYes bool) (bool, error) { + promptCalls++ + return true, nil + } + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + checkCalls++ + return notify.TelegramRegistrationStatus{ + Code: 409, + Message: "not linked yet", + } + } + + bootstrap := logging.NewBootstrapLogger() + var mirrorBuf bytes.Buffer + mirror := logging.New(types.LogLevelDebug, false) + mirror.SetOutput(&mirrorBuf) + bootstrap.SetMirrorLogger(mirror) + + if err := runTelegramSetupCLI(context.Background(), bufio.NewReader(strings.NewReader("")), t.TempDir(), "/fake/backup.env", bootstrap); err != nil { + t.Fatalf("runTelegramSetupCLI error: %v", err) + } + if checkCalls != maxTelegramSetupVerificationAttempts { + t.Fatalf("checkCalls=%d, want %d", checkCalls, maxTelegramSetupVerificationAttempts) + } + if promptCalls != maxTelegramSetupVerificationAttempts { + t.Fatalf("promptCalls=%d, want %d", promptCalls, maxTelegramSetupVerificationAttempts) + } + if !strings.Contains(mirrorBuf.String(), "Telegram setup: not verified (attempts=10 last=409 not linked yet)") { + t.Fatalf("expected max-attempt failure log, got %q", mirrorBuf.String()) + } +} + +func TestRunTelegramSetupCLI_BootstrapErrorNonBlocking(t *testing.T) { + stubTelegramSetupCLIDeps(t) + + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{}, errors.New("boom") + } + telegramSetupPromptYesNo = func(ctx context.Context, reader *bufio.Reader, question string, defaultYes bool) (bool, error) { + t.Fatalf("prompt should not run on bootstrap error") + return false, nil + } + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + t.Fatalf("registration check should not run on bootstrap error") + return notify.TelegramRegistrationStatus{} + } + + if err := runTelegramSetupCLI(context.Background(), bufio.NewReader(strings.NewReader("")), t.TempDir(), "/fake/backup.env", logging.NewBootstrapLogger()); err != nil { + t.Fatalf("runTelegramSetupCLI error: %v", err) + } +} + +func TestSanitizeTelegramSetupStatusMessage_StripsTerminalEscapes(t *testing.T) { + raw := " \x1b[31mneeds\tpairing\r\nnow\x1b[0m\x07 " + + got := sanitizeTelegramSetupStatusMessage(raw) + + if got != "needs pairing now" { + t.Fatalf("sanitizeTelegramSetupStatusMessage(%q) = %q, want %q", raw, got, "needs pairing now") + } + if strings.Contains(got, "\x1b") { + t.Fatalf("sanitized message should not contain escape characters: %q", got) + } +} + +func TestSanitizeTelegramSetupStatusMessage_FallsBackToQuotedSafeText(t *testing.T) { + raw := strings.Repeat("\x1b", maxTelegramSetupStatusMessageLen+5) + + got := sanitizeTelegramSetupStatusMessage(raw) + + if got == "" { + t.Fatal("expected fallback message") + } + if strings.Contains(got, "\x1b") { + t.Fatalf("fallback should not contain raw escape characters: %q", got) + } + if !strings.Contains(got, `\x1b`) { + t.Fatalf("fallback should retain a safe escaped representation, got %q", got) + } + if !strings.Contains(got, "...(truncated)") { + t.Fatalf("expected truncated fallback output, got %q", got) + } +} + +func TestRunTelegramSetupCLI_SanitizesRegistrationStatusOutput(t *testing.T) { + stubTelegramSetupCLIDeps(t) + + promptCalls := 0 + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupEligibleCentralized, + ConfigLoaded: true, + TelegramEnabled: true, + TelegramMode: "centralized", + ServerAPIHost: "https://api.example.test", + ServerID: "123456789", + }, nil + } + telegramSetupPromptYesNo = func(ctx context.Context, reader *bufio.Reader, question string, defaultYes bool) (bool, error) { + promptCalls++ + return promptCalls == 1, nil + } + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + return notify.TelegramRegistrationStatus{ + Code: 500, + Message: "\x1b[31mneeds\tpairing\r\nnow\x1b[0m\x07", + Error: errors.New("unexpected status 500"), + } + } + + output := captureStdout(t, func() { + if err := runTelegramSetupCLI(context.Background(), bufio.NewReader(strings.NewReader("")), t.TempDir(), "/fake/backup.env", logging.NewBootstrapLogger()); err != nil { + t.Fatalf("runTelegramSetupCLI error: %v", err) + } + }) + + if !strings.Contains(output, "Telegram: needs pairing now") { + t.Fatalf("expected sanitized Telegram status in output, got %q", output) + } + if strings.Contains(output, "\x1b") { + t.Fatalf("output should not contain raw escape sequences, got %q", output) + } +} diff --git a/cmd/proxsave/upgrade.go b/cmd/proxsave/upgrade.go index 9f4ff1f5..0aac1efb 100644 --- a/cmd/proxsave/upgrade.go +++ b/cmd/proxsave/upgrade.go @@ -177,12 +177,12 @@ func runUpgrade(ctx context.Context, args *cli.Args, bootstrap *logging.Bootstra cleanupLegacyBashSymlinks(baseDir, bootstrap) ensureGoSymlink(execPath, bootstrap) - cronSchedule := resolveCronSchedule(nil) + cronSchedule := resolveCronScheduleFromEnv() logging.DebugStepBootstrap(bootstrap, "upgrade workflow", "migrating cron entries") migrateLegacyCronEntries(ctx, baseDir, execPath, bootstrap, cronSchedule) telegramCode := "" - if info, err := identity.Detect(baseDir, nil); err == nil { + if info, err := identity.DetectWithContext(ctx, baseDir, nil); err == nil { if code := strings.TrimSpace(info.ServerID); code != "" { telegramCode = code } @@ -631,7 +631,7 @@ func printUpgradeFooter(upgradeErr error, version, configPath, baseDir, telegram fmt.Println(" proxsave (alias: proxmox-backup) - Start backup") fmt.Println(" --upgrade - Update proxsave binary to latest release (also adds missing keys to backup.env)") fmt.Println(" --install - Re-run interactive installation/setup") - fmt.Println(" --new-install - Wipe installation directory (keep env/identity) then run installer") + fmt.Println(" --new-install - Wipe installation directory (keep build/env/identity) then run installer") fmt.Println(" --upgrade-config - Upgrade configuration file using the embedded template (run after installing a new binary)") fmt.Println() diff --git a/docs/BACKUP_ENV_MAPPING.md b/docs/BACKUP_ENV_MAPPING.md index e10ec583..001b21d6 100644 --- a/docs/BACKUP_ENV_MAPPING.md +++ b/docs/BACKUP_ENV_MAPPING.md @@ -116,11 +116,15 @@ ENABLE_SECONDARY_BACKUP = RENAMED(SECONDARY_ENABLED) ✅ FULL_SECURITY_CHECK = RENAMED(SECURITY_CHECK_ENABLED) ✅ LOCAL_BACKUP_PATH = RENAMED(BACKUP_PATH) ✅ LOCAL_LOG_PATH = RENAMED(LOG_PATH) ✅ +EMAIL_ENABLE = RENAMED(EMAIL_ENABLED) ✅ +GOTIFY_ENABLE = RENAMED(GOTIFY_ENABLED) ✅ PROMETHEUS_ENABLED = RENAMED(METRICS_ENABLED) ✅ PROMETHEUS_TEXTFILE_DIR = RENAMED(METRICS_PATH) ✅ PXAR_INCLUDE_PATTERN = RENAMED(PXAR_FILE_INCLUDE_PATTERN) ✅ RCLONE_REMOTE = RENAMED(CLOUD_REMOTE) ✅ SECONDARY_BACKUP_PATH = RENAMED(SECONDARY_PATH) ✅ +TELEGRAM_ENABLE = RENAMED(TELEGRAM_ENABLED) ✅ +WEBHOOK_ENABLE = RENAMED(WEBHOOK_ENABLED) ✅ ### Semantic changes ⚠️ *Require manual value conversion* diff --git a/docs/CLI_REFERENCE.md b/docs/CLI_REFERENCE.md index c9850468..5ed74a5b 100644 --- a/docs/CLI_REFERENCE.md +++ b/docs/CLI_REFERENCE.md @@ -131,19 +131,23 @@ Some interactive commands support two interface modes: **Use `--cli` when**: TUI rendering issues occur or advanced debugging is needed. **Existing configuration**: -- If the configuration file already exists, the **TUI wizard** prompts you to **Overwrite**, **Edit existing** (uses the current file as base and pre-fills the wizard fields), or **Keep & exit**. -- In **CLI mode** (`--cli`), you will be prompted to overwrite; choosing "No" keeps the file and skips the configuration wizard. +- If the configuration file already exists, **both TUI and CLI** now offer the same choices: + - **Overwrite** (start from embedded template) + - **Edit existing** (use current file as base and pre-fill wizard fields) + - **Keep existing & continue** (leave file untouched and skip configuration wizard) + - **Cancel** (abort installation) +- In **Keep existing & continue** mode, config-dependent post-steps are skipped (encryption setup, post-install audit, Telegram pairing), while finalization steps still run (docs install, symlink/cron finalization, permissions normalization). **Wizard workflow**: 1. Generates/updates the configuration file (`configs/backup.env` by default) -2. Optionally configures secondary storage +2. Optionally configures secondary storage (`SECONDARY_PATH` required if enabled; `SECONDARY_LOG_PATH` optional; invalid secondary paths are re-prompted/rejected; disabling secondary storage clears both saved secondary paths) 3. Optionally configures cloud storage (rclone) 4. Optionally enables firewall rules collection (`BACKUP_FIREWALL_RULES=false` by default) 5. Optionally sets up notifications (Telegram, Email; Email defaults to `EMAIL_DELIVERY_METHOD=relay`) 6. Optionally configures encryption (AGE setup) -7. (TUI) Optionally selects a cron time (HH:MM) for the `proxsave` cron entry +7. Optionally selects a cron time (HH:MM, default `02:00`) for the `proxsave` cron entry in both CLI and TUI install flows 8. Optionally runs a post-install dry-run audit and offers to disable unused collectors (actionable hints like `set BACKUP_*=false to disable`) -9. (If Telegram enabled) Shows Server ID and offers pairing verification (retry/skip supported) +9. (If Telegram centralized mode is enabled and config + Server ID resolve successfully) Shows Server ID and offers pairing verification (retry/skip supported); otherwise install continues and logs why pairing was skipped 10. Finalizes installation (symlinks, cron migration, permission checks) **Install log**: The installer writes a session log under `/tmp/proxsave/install-*.log` (includes audit results and Telegram pairing outcome). @@ -346,21 +350,21 @@ Next step: ./build/proxsave --dry-run # TUI mode (default) - terminal interface ./build/proxsave --newkey -# CLI mode - text prompts (for debugging) +# CLI mode - text prompts (for debugging or when TUI rendering is unavailable) ./build/proxsave --newkey --cli ``` **Use `--cli` when**: TUI rendering issues occur or advanced debugging is needed. **`--newkey` workflow**: -1. Uses the default recipient file: `${BASE_DIR}/identity/age/recipient.txt` (same as `AGE_RECIPIENT_FILE` in the template) +1. Uses the configured `AGE_RECIPIENT_FILE` when present; otherwise falls back to `${BASE_DIR}/identity/age/recipient.txt` 2. Prompts for one of: - **Existing public recipient**: paste an `age1...` recipient - **Passphrase-derived**: enter a passphrase (proxsave derives the recipient; the passphrase is **not stored**) - **Private key-derived**: paste an `AGE-SECRET-KEY-...` key (not stored; proxsave stores only the derived public recipient) 3. Writes/overwrites the recipient file after confirmation -**Note**: In `--cli` mode (text prompts), you can add multiple recipients. The default TUI flow saves a single recipient; you can always add more by editing the recipient file (one per line). +**Note**: Both CLI and TUI `--newkey` flows support adding multiple recipients and de-duplicate repeated entries before saving. **For complete encryption guide**, see: **[Encryption Guide](ENCRYPTION.md)** diff --git a/docs/CLOUD_STORAGE.md b/docs/CLOUD_STORAGE.md index 240b112f..3e4161b8 100644 --- a/docs/CLOUD_STORAGE.md +++ b/docs/CLOUD_STORAGE.md @@ -815,7 +815,7 @@ cp -a /restore/* / A: No, currently only one `CLOUD_REMOTE` is supported. Workaround: Use `rclone union` to combine multiple backends. **Q: Can I use a network address like "192.168.0.10/folder" for SECONDARY_PATH?** -A: **No**. `SECONDARY_PATH` and `BACKUP_PATH` require **filesystem-mounted paths only**. Network shares must be mounted first using NFS/CIFS/SMB mount commands, then you use the local mount point path (e.g., `/mnt/nas-backup`). +A: **No**. `SECONDARY_PATH` and `BACKUP_PATH` require **absolute local filesystem paths**. For network shares, mount them first using NFS/CIFS/SMB, then use the local mount point path (e.g., `/mnt/nas-backup`). If you want to use a direct network address without mounting, configure it as `CLOUD_REMOTE` using rclone with an S3-compatible backend (like MinIO) or appropriate protocol. diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index e0f3d826..cba8da5a 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -408,10 +408,10 @@ BACKUP_EXCLUDE_PATTERNS="*/cache/**, /var/tmp/**, *.log" # Enable secondary storage SECONDARY_ENABLED=false # true | false -# Secondary backup path +# Secondary backup path (required when SECONDARY_ENABLED=true) SECONDARY_PATH=/mnt/secondary/backup -# Secondary log path +# Secondary log path (optional) SECONDARY_LOG_PATH=/mnt/secondary/log ``` @@ -421,7 +421,9 @@ Additional local storage for redundant backup copies - mounted NAS, USB drives, ### IMPORTANT PATH REQUIREMENTS -- `SECONDARY_PATH` **must be a filesystem-mounted path** (e.g., `/mnt/nas-backup`, `/media/usb-drive`) +- `SECONDARY_PATH` **must be an absolute local filesystem path** (e.g., `/mnt/nas-backup`, `/media/usb-drive`) +- `SECONDARY_LOG_PATH`, when set, must follow the **same absolute local path rules** +- `SECONDARY_LOG_PATH` is optional; when empty, secondary backup copies still run, but secondary log copy/cleanup is disabled - `SECONDARY_PATH` **CANNOT** be a network address (e.g., `192.168.0.10/folder`, `//server/share`) - Network shares **must be mounted first** using standard Linux mounting (NFS/CIFS/SMB) @@ -443,6 +445,7 @@ sudo mount -t cifs //192.168.0.10/backup /mnt/nas-backup -o credentials=/root/.s **2. Then configure SECONDARY_PATH**: ```bash SECONDARY_PATH=/mnt/nas-backup # ✓ Correct - uses mounted path +SECONDARY_LOG_PATH=/mnt/nas-logs # Optional ``` ### What NOT to Do @@ -460,6 +463,7 @@ SECONDARY_PATH=\\192.168.0.10\backup # ✗ WRONG - Windows path - Secondary storage is **non-critical** (failures log warnings, don't abort backup) - Files copied via native Go (no dependency on rclone) - Same retention policy as primary storage +- Invalid configured secondary paths fail fast during configuration loading --- @@ -800,8 +804,8 @@ TELEGRAM_CHAT_ID= # Chat ID (your user ID or group ID) 3. Open Telegram and start `@ProxmoxAN_bot` 4. Send the Server ID to the bot 5. Verify pairing: - - **TUI installer**: press `Check` (retry supported). `Continue` appears only after success; use `Skip` (or `ESC`) to proceed without verification. - - **CLI installer**: opt into the check and retry when prompted. + - **TUI installer**: the Telegram setup screen is shown only when config loads successfully, centralized mode is active, and a Server ID is available. When shown, press `Check` (retry supported). `Continue` appears only after success; use `Skip` (or `ESC`) to proceed without verification. + - **CLI installer**: the same eligibility rules apply, then you can opt into the check and retry when prompted. - Normal runs also verify automatically and will skip Telegram if not paired yet. **Setup personal bot**: @@ -831,6 +835,8 @@ EMAIL_RECIPIENT= # e.g., "admin@example.com" EMAIL_FROM=no-reply@proxmox.tis24.it ``` +If `EMAIL_ENABLED` is omitted, the default remains `false`. The legacy alias `EMAIL_ENABLE` is still accepted during migration and runtime loading. + **Delivery methods**: - **relay**: Uses cloud relay (outbound HTTPS) - **sendmail**: Uses `/usr/sbin/sendmail` (requires a working local MTA, e.g. postfix) diff --git a/docs/ENCRYPTION.md b/docs/ENCRYPTION.md index 63dcad13..99fffc35 100644 --- a/docs/ENCRYPTION.md +++ b/docs/ENCRYPTION.md @@ -133,7 +133,7 @@ You can create/update recipients in two ways: # Dedicated wizard (TUI by default) ./build/proxsave --newkey -# Use CLI prompts instead of TUI (useful for debugging and multi-recipient setups) +# Use CLI prompts instead of TUI (useful for debugging or when TUI rendering is unavailable) ./build/proxsave --newkey --cli ``` @@ -147,7 +147,7 @@ If `ENCRYPT_ARCHIVE=true` and no recipients are configured, proxsave will start **Notes**: - Proxsave stores **only recipients** (public keys) in `${BASE_DIR}/identity/age/recipient.txt`. Keep private keys and passphrases offline. - `AGE_RECIPIENT` and `AGE_RECIPIENT_FILE` are **merged and de-duplicated**. -- The CLI setup supports multiple recipients; otherwise you can add multiple recipients by editing the file (one per line). +- Both TUI and CLI setup flows support multiple recipients and de-duplicate repeated entries before saving. --- diff --git a/docs/INSTALL.md b/docs/INSTALL.md index 43b7954e..e09b6b79 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -207,28 +207,45 @@ The installation wizard creates your configuration file interactively: ./build/proxsave --new-install ``` -If the configuration file already exists, the **TUI wizard** will ask whether to: +If the configuration file already exists, **both TUI and CLI** ask whether to: - **Overwrite** (start from the embedded template) - **Edit existing** (use the current file as base and pre-fill the wizard fields) -- **Keep & exit** (leave the file untouched and exit) +- **Keep existing & continue** (leave the file untouched and skip the configuration wizard) +- **Cancel** (exit installation) + +In **Keep existing & continue** mode, config-dependent post-steps are skipped: +- AGE setup +- Post-install check wizard +- Telegram pairing wizard + +Final install steps still run: +- Support docs installation +- Symlink and cron finalization +- Permission normalization **Wizard prompts:** 1. **Configuration file path**: Default `configs/backup.env` (accepts absolute or relative paths within repo) -2. **Secondary storage**: Optional path for backup/log copies +2. **Secondary storage**: Optional path for backup/log copies; disabling it clears both saved secondary paths from `backup.env` 3. **Cloud storage (rclone)**: Optional rclone configuration (supports `CLOUD_REMOTE` as a remote name (recommended) or legacy `remote:path`; `CLOUD_LOG_PATH` supports path-only (recommended) or `otherremote:/path`) 4. **Firewall rules**: Optional firewall rules collection toggle (`BACKUP_FIREWALL_RULES=false` by default; supports iptables/nftables) 5. **Notifications**: Enable Telegram (centralized) and Email notifications (wizard defaults to `EMAIL_DELIVERY_METHOD=relay`; you can switch to `sendmail` or `pmf` later) 6. **Encryption**: AGE encryption setup (runs sub-wizard immediately if enabled) -7. **Cron schedule**: Choose cron time (HH:MM) for the `proxsave` cron entry (TUI mode only) +7. **Cron schedule**: Choose cron time (HH:MM, default `02:00`) for the `proxsave` cron entry in both CLI and TUI install modes 8. **Post-install check (optional)**: Runs `proxsave --dry-run` and shows actionable warnings like `set BACKUP_*=false to disable`, allowing you to disable unused collectors and reduce WARNING noise -9. **Telegram pairing (optional)**: If Telegram (centralized) is enabled, shows your Server ID and lets you verify pairing with the bot (retry/skip supported) +9. **Telegram pairing (optional)**: If Telegram centralized mode is enabled and the installer can load a valid config plus a Server ID, it shows your Server ID and lets you verify pairing with the bot (retry/skip supported). Otherwise installation continues and logs why pairing was skipped. #### Telegram pairing wizard (TUI) -If you enable Telegram notifications during `--install` (centralized bot), the installer opens an additional **Telegram Setup** screen after the post-install check. +If you enable Telegram notifications during `--install`, the installer opens an additional **Telegram Setup** screen only when all of these are true: +- `TELEGRAM_ENABLED=true` +- `BOT_TELEGRAM_TYPE=centralized` (or left empty, which defaults to centralized) +- `backup.env` loads successfully +- a Server ID can be resolved from `/identity/.server_identity` + +If any of those checks fail, installation continues without this screen and logs the skip reason (for example config load failure, personal mode, or missing server identity). -It does **not** modify your `backup.env`. It only: +When shown, it does **not** modify your `backup.env`. It only: - Computes/loads the **Server ID** and persists it (identity file) - Guides you through pairing with the centralized bot - Lets you verify pairing immediately (retry supported) @@ -239,7 +256,7 @@ It does **not** modify your `backup.env`. It only: - **Status**: live feedback from the pairing check - **Actions**: - `Check`: verify pairing (press again to retry) - - `Continue`: available only after a successful check (centralized mode), or immediately in personal mode / when the Server ID is unavailable + - `Continue`: available only after a successful check - `Skip`: leave without verification (in centralized mode, `ESC` behaves like Skip when not verified) **Where the Server ID is stored:** @@ -251,7 +268,7 @@ It does **not** modify your `backup.env`. It only: - Other errors: temporary server/network issue; retry or skip and pair later **CLI mode:** -- With `--install --cli`, the installer prints the Server ID and asks whether to run the check now (with a retry loop). +- With `--install --cli`, the installer follows the same eligibility rules, then prints the Server ID and asks whether to run the check now (with a retry loop). **Features:** @@ -260,7 +277,7 @@ It does **not** modify your `backup.env`. It only: - Creates all necessary directories with proper permissions (0700) - Immediate AGE key generation if encryption is enabled - Optional post-install audit to disable unused collectors (keeps changes explicit; nothing is disabled silently) -- Optional Telegram pairing wizard (centralized mode) that displays Server ID and verifies the bot registration (retry/skip supported) +- Optional Telegram pairing wizard (centralized mode, valid config, Server ID available) that displays Server ID and verifies the bot registration (retry/skip supported) - Install session log under `/tmp/proxsave/install-*.log` (includes audit results and Telegram pairing outcome) After completion, edit `configs/backup.env` manually for advanced options. diff --git a/docs/MIGRATION_GUIDE.md b/docs/MIGRATION_GUIDE.md index 9854a545..0c58a3da 100644 --- a/docs/MIGRATION_GUIDE.md +++ b/docs/MIGRATION_GUIDE.md @@ -242,7 +242,7 @@ These old Bash variable names **still work** in Go (automatic fallback): | `LOCAL_BACKUP_PATH` | `BACKUP_PATH` | ✅ Auto-fallback | | `ENABLE_CLOUD_BACKUP` | `CLOUD_ENABLED` | ✅ Auto-fallback | | `PROMETHEUS_ENABLED` | `METRICS_ENABLED` | ✅ Auto-fallback | -| `PROMETHEUS_PATH` | `METRICS_PATH` | ✅ Auto-fallback | +| `PROMETHEUS_TEXTFILE_DIR` | `METRICS_PATH` | ✅ Auto-fallback | | `TELEGRAM_ENABLE` | `TELEGRAM_ENABLED` | ✅ Auto-fallback | | `EMAIL_ENABLE` | `EMAIL_ENABLED` | ✅ Auto-fallback | | `GOTIFY_ENABLE` | `GOTIFY_ENABLED` | ✅ Auto-fallback | @@ -251,6 +251,8 @@ These old Bash variable names **still work** in Go (automatic fallback): **What this means**: You can keep using old variable names, and Go will automatically read them. However, **it's recommended to update to new names** for clarity and future compatibility. +For email notifications, if `EMAIL_ENABLED` is omitted entirely, the runtime default is `false`, matching the template. + ### Variables Requiring Conversion #### 1. Storage Thresholds (SEMANTIC CHANGE) @@ -445,7 +447,7 @@ TELEGRAM_ENABLE=true TELEGRAM_ENABLED=true ``` -**Note**: Automatic fallback should handle this, but explicitly updating is cleaner. +**Note**: Automatic fallback handles the legacy `_ENABLE` aliases, but explicitly updating is cleaner. For email, leaving `EMAIL_ENABLED` unset now keeps notifications disabled by default. --- diff --git a/go.mod b/go.mod index f99518e7..f68fbd19 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/tis24dev/proxsave -go 1.25 +go 1.25.0 toolchain go1.25.8 @@ -8,9 +8,9 @@ require ( filippo.io/age v1.3.1 github.com/gdamore/tcell/v2 v2.13.8 github.com/rivo/tview v0.42.0 - golang.org/x/crypto v0.48.0 - golang.org/x/term v0.40.0 - golang.org/x/text v0.34.0 + golang.org/x/crypto v0.49.0 + golang.org/x/term v0.41.0 + golang.org/x/text v0.35.0 ) require ( @@ -19,5 +19,5 @@ require ( github.com/gdamore/encoding v1.0.1 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect - golang.org/x/sys v0.41.0 // indirect + golang.org/x/sys v0.42.0 // indirect ) diff --git a/go.sum b/go.sum index 3ce3e362..7d5f6201 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,8 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -36,20 +36,20 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= -golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/internal/backup/collector_pve.go b/internal/backup/collector_pve.go index 8fe16eba..fa15ef27 100644 --- a/internal/backup/collector_pve.go +++ b/internal/backup/collector_pve.go @@ -1997,9 +1997,7 @@ func (c *Collector) describeDiskUsage(ctx context.Context, path string, ioTimeou if err != nil { return "", err } - total := int64(stat.Blocks) * int64(stat.Bsize) - available := int64(stat.Bavail) * int64(stat.Bsize) - used := total - available + total, available, used := safefs.SpaceUsageFromStatfs(stat) if total <= 0 { return "", fmt.Errorf("invalid filesystem statistics for %s", path) } diff --git a/internal/backup/collector_pve_additional_test.go b/internal/backup/collector_pve_additional_test.go index 5b5ef981..8c4e915c 100644 --- a/internal/backup/collector_pve_additional_test.go +++ b/internal/backup/collector_pve_additional_test.go @@ -2,13 +2,17 @@ package backup import ( "context" + "fmt" "io" "os" "path/filepath" "strings" + "syscall" "testing" + "time" "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/safefs" "github.com/tis24dev/proxsave/internal/types" ) @@ -88,6 +92,40 @@ func TestPatternWriterWrite_WritesRelativePathLine(t *testing.T) { } } +func TestDescribeDiskUsageMatchesStatfsSemantics(t *testing.T) { + collector := newTestCollector(t) + tempDir := t.TempDir() + + var stat syscall.Statfs_t + if err := syscall.Statfs(tempDir, &stat); err != nil { + t.Fatalf("Statfs: %v", err) + } + total, available, used := safefs.SpaceUsageFromStatfs(stat) + + got, err := collector.describeDiskUsage(context.Background(), tempDir, time.Second) + if err != nil { + t.Fatalf("describeDiskUsage error: %v", err) + } + + want := fmt.Sprintf("Used: %s / Total: %s (Free: %s)", + FormatBytes(used), + FormatBytes(total), + FormatBytes(available), + ) + if got != want { + t.Fatalf("describeDiskUsage = %q; want %q", got, want) + } +} + +func TestDescribeDiskUsageReturnsStatfsError(t *testing.T) { + collector := newTestCollector(t) + missingPath := filepath.Join(t.TempDir(), "missing") + + if _, err := collector.describeDiskUsage(context.Background(), missingPath, time.Second); err == nil { + t.Fatalf("describeDiskUsage() error = nil; want statfs error") + } +} + func TestCollectorCopyBackupSample_CopiesFile(t *testing.T) { logger := logging.New(types.LogLevelDebug, false) logger.SetOutput(io.Discard) diff --git a/internal/config/config.go b/internal/config/config.go index bb35388e..70be759d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,6 +12,17 @@ import ( "github.com/tis24dev/proxsave/pkg/utils" ) +const ( + telegramEnabledKey = "TELEGRAM_ENABLED" + telegramEnableLegacyKey = "TELEGRAM_ENABLE" + emailEnabledKey = "EMAIL_ENABLED" + emailEnableLegacyKey = "EMAIL_ENABLE" + gotifyEnabledKey = "GOTIFY_ENABLED" + gotifyEnableLegacyKey = "GOTIFY_ENABLE" + webhookEnabledKey = "WEBHOOK_ENABLED" + webhookEnableLegacyKey = "WEBHOOK_ENABLE" +) + var ( multiValueKeys = map[string]bool{ "BACKUP_EXCLUDE_PATTERNS": true, @@ -24,6 +35,13 @@ var ( "CUSTOM_BACKUP_PATHS": true, "BACKUP_BLACKLIST": true, } + + legacyNotificationEnableAliases = map[string]string{ + telegramEnableLegacyKey: telegramEnabledKey, + emailEnableLegacyKey: emailEnabledKey, + gotifyEnableLegacyKey: gotifyEnabledKey, + webhookEnableLegacyKey: webhookEnabledKey, + } ) // Config contains the full backup system configuration. @@ -113,7 +131,7 @@ type Config struct { // GFS (Grandfather-Father-Son) retention settings // If ANY of these is > 0, GFS retention is enabled (overrides simple retention) - RetentionDaily int // Keep backups from last N days (0 = disabled) + RetentionDaily int // Keep the GFS daily tier; in GFS mode values <= 0 are normalized to 1 RetentionWeekly int // Keep N weekly backups, one per week (0 = disabled) RetentionMonthly int // Keep N monthly backups, one per month (0 = disabled) RetentionYearly int // Keep N yearly backups, one per year (0 = keep all yearly) @@ -307,10 +325,12 @@ func (c *Config) loadEnvOverrides() { "MAX_LOCAL_BACKUPS", "MAX_SECONDARY_BACKUPS", "MAX_CLOUD_BACKUPS", "RETENTION_DAILY", "RETENTION_WEEKLY", "RETENTION_MONTHLY", "RETENTION_YEARLY", "BUNDLE_ASSOCIATED_FILES", "ENCRYPT_ARCHIVE", "AGE_RECIPIENT", "AGE_RECIPIENT_FILE", - "TELEGRAM_ENABLED", "BOT_TELEGRAM_TYPE", "TELEGRAM_BOT_TOKEN", "TELEGRAM_CHAT_ID", - "EMAIL_ENABLED", "EMAIL_DELIVERY_METHOD", "EMAIL_FALLBACK_SENDMAIL", + "TELEGRAM_ENABLE", "TELEGRAM_ENABLED", "BOT_TELEGRAM_TYPE", "TELEGRAM_BOT_TOKEN", "TELEGRAM_CHAT_ID", + "EMAIL_ENABLE", "EMAIL_ENABLED", "EMAIL_DELIVERY_METHOD", "EMAIL_FALLBACK_SENDMAIL", "EMAIL_RECIPIENT", "EMAIL_FROM", - "WEBHOOK_ENABLED", "WEBHOOK_ENDPOINTS", "WEBHOOK_FORMAT", "WEBHOOK_TIMEOUT", + "GOTIFY_ENABLE", "GOTIFY_ENABLED", "GOTIFY_SERVER_URL", "GOTIFY_TOKEN", + "GOTIFY_PRIORITY_SUCCESS", "GOTIFY_PRIORITY_WARNING", "GOTIFY_PRIORITY_FAILURE", + "WEBHOOK_ENABLE", "WEBHOOK_ENABLED", "WEBHOOK_ENDPOINTS", "WEBHOOK_FORMAT", "WEBHOOK_TIMEOUT", "WEBHOOK_MAX_RETRIES", "WEBHOOK_RETRY_DELAY", "METRICS_ENABLED", "METRICS_PATH", "SECURITY_CHECK_ENABLED", "AUTO_UPDATE_HASHES", "AUTO_FIX_PERMISSIONS", @@ -325,7 +345,12 @@ func (c *Config) loadEnvOverrides() { for _, key := range envKeys { if envValue := os.Getenv(key); envValue != "" { - c.raw[key] = envValue + upperKey := strings.ToUpper(key) + if canonicalKey, ok := legacyNotificationEnableAliases[upperKey]; ok { + c.raw[canonicalKey] = envValue + continue + } + c.raw[upperKey] = envValue } } } @@ -344,10 +369,25 @@ func (c *Config) parse() error { if err := c.parseCollectionSettings(); err != nil { return err } + if err := c.validateSecondarySettings(); err != nil { + return err + } c.autoDetectPBSAuth() return nil } +func (c *Config) validateSecondarySettings() error { + if c.SecondaryEnabled { + if err := ValidateRequiredSecondaryPath(c.SecondaryPath); err != nil { + return err + } + } + if err := ValidateOptionalSecondaryLogPath(c.SecondaryLogPath); err != nil { + return err + } + return nil +} + func (c *Config) parseGeneralSettings() { c.BackupEnabled = c.getBool("BACKUP_ENABLED", true) c.DryRun = c.getBool("DRY_RUN", false) @@ -568,20 +608,20 @@ func (c *Config) parseRetentionSettings() { } func (c *Config) parseNotificationSettings() { - c.TelegramEnabled = c.getBool("TELEGRAM_ENABLED", false) + c.TelegramEnabled = c.getBoolWithLegacyAlias(telegramEnabledKey, telegramEnableLegacyKey, false) c.TelegramBotType = c.getString("BOT_TELEGRAM_TYPE", "centralized") c.TelegramBotToken = c.getString("TELEGRAM_BOT_TOKEN", "") c.TelegramChatID = c.getString("TELEGRAM_CHAT_ID", "") c.TelegramServerAPIHost = "https://bot.tis24.it:1443" c.ServerID = "" - c.EmailEnabled = c.getBool("EMAIL_ENABLED", true) + c.EmailEnabled = c.getBoolWithLegacyAlias(emailEnabledKey, emailEnableLegacyKey, false) c.EmailDeliveryMethod = c.getString("EMAIL_DELIVERY_METHOD", "relay") c.EmailFallbackSendmail = c.getBool("EMAIL_FALLBACK_SENDMAIL", true) c.EmailRecipient = c.getString("EMAIL_RECIPIENT", "") c.EmailFrom = c.getString("EMAIL_FROM", "no-reply@proxmox.tis24.it") - c.GotifyEnabled = c.getBool("GOTIFY_ENABLED", false) + c.GotifyEnabled = c.getBoolWithLegacyAlias(gotifyEnabledKey, gotifyEnableLegacyKey, false) c.GotifyServerURL = strings.TrimSpace(c.getString("GOTIFY_SERVER_URL", "")) c.GotifyToken = strings.TrimSpace(c.getString("GOTIFY_TOKEN", "")) c.GotifyPrioritySuccess = c.ensurePositiveInt("GOTIFY_PRIORITY_SUCCESS", 2) @@ -595,7 +635,7 @@ func (c *Config) parseNotificationSettings() { c.WorkerMaxRetries = 2 c.WorkerRetryDelay = 2 - c.WebhookEnabled = c.getBool("WEBHOOK_ENABLED", false) + c.WebhookEnabled = c.getBoolWithLegacyAlias(webhookEnabledKey, webhookEnableLegacyKey, false) c.WebhookDefaultFormat = c.getString("WEBHOOK_FORMAT", "generic") c.WebhookTimeout = c.getInt("WEBHOOK_TIMEOUT", 30) c.WebhookMaxRetries = c.getInt("WEBHOOK_MAX_RETRIES", 3) @@ -991,6 +1031,10 @@ func (c *Config) getBoolWithFallback(keys []string, defaultValue bool) bool { return defaultValue } +func (c *Config) getBoolWithLegacyAlias(key, legacyKey string, defaultValue bool) bool { + return c.getBoolWithFallback([]string{key, legacyKey}, defaultValue) +} + func (c *Config) getIntWithFallback(keys []string, defaultValue int) int { for _, key := range keys { if val, ok := c.raw[key]; ok { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 3fb7ab09..0f2ec7cb 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -9,25 +9,9 @@ import ( "github.com/tis24dev/proxsave/internal/types" ) -func setBaseDirEnv(t *testing.T, value string) func() { +func setBaseDirEnv(t *testing.T, value string) { t.Helper() - - prev := os.Getenv("BASE_DIR") - if value == "" { - _ = os.Unsetenv("BASE_DIR") - } else { - if err := os.Setenv("BASE_DIR", value); err != nil { - t.Fatalf("failed to set BASE_DIR: %v", err) - } - } - - return func() { - if prev == "" { - _ = os.Unsetenv("BASE_DIR") - } else { - _ = os.Setenv("BASE_DIR", prev) - } - } + t.Setenv("BASE_DIR", value) } func TestLoadConfig(t *testing.T) { @@ -59,8 +43,7 @@ BACKUP_BLACKLIST=/var/data/tmp t.Fatalf("Failed to create test config: %v", err) } - cleanup := setBaseDirEnv(t, "/env/base/dir") - defer cleanup() + setBaseDirEnv(t, "/env/base/dir") cfg, err := LoadConfig(configPath) if err != nil { @@ -260,8 +243,7 @@ AGE_RECIPIENT_FILE=${BASE_DIR}/identity/age/recipient.txt t.Fatalf("Failed to write config: %v", err) } - cleanup := setBaseDirEnv(t, "/custom/base") - defer cleanup() + setBaseDirEnv(t, "/custom/base") cfg, err := LoadConfig(configPath) if err != nil { @@ -288,6 +270,72 @@ func TestLoadConfigNotFound(t *testing.T) { } } +func TestLoadConfigAllowsInvalidSecondaryPathWhenDisabled(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "invalid-secondary.env") + content := `BACKUP_PATH=/test/backup +LOG_PATH=/test/log +SECONDARY_ENABLED=false +SECONDARY_PATH=remote:path +` + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("Failed to create config file: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + if cfg.SecondaryEnabled { + t.Fatal("SecondaryEnabled expected false") + } + if cfg.SecondaryPath != "remote:path" { + t.Fatalf("SecondaryPath = %q; want %q", cfg.SecondaryPath, "remote:path") + } +} + +func TestLoadConfigRejectsInvalidSecondaryPathWhenEnabled(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "invalid-secondary-enabled.env") + content := `BACKUP_PATH=/test/backup +LOG_PATH=/test/log +SECONDARY_ENABLED=true +SECONDARY_PATH=remote:path +` + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("Failed to create config file: %v", err) + } + + _, err := LoadConfig(configPath) + if err == nil { + t.Fatal("expected LoadConfig to fail") + } + if got, want := err.Error(), "SECONDARY_PATH must be an absolute local filesystem path"; !strings.Contains(got, want) { + t.Fatalf("LoadConfig() error = %q, want substring %q", got, want) + } +} + +func TestLoadConfigRejectsInvalidSecondaryLogPathWhenConfigured(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "invalid-secondary-log.env") + content := `BACKUP_PATH=/test/backup +LOG_PATH=/test/log +SECONDARY_ENABLED=false +SECONDARY_LOG_PATH=remote:/logs +` + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("Failed to create config file: %v", err) + } + + _, err := LoadConfig(configPath) + if err == nil { + t.Fatal("expected LoadConfig to fail") + } + if got, want := err.Error(), "SECONDARY_LOG_PATH must be an absolute local filesystem path"; !strings.Contains(got, want) { + t.Fatalf("LoadConfig() error = %q, want substring %q", got, want) + } +} + func TestLoadConfigWithQuotes(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "test_quotes.env") @@ -301,8 +349,7 @@ LOG_PATH=/path/without/quotes t.Fatalf("Failed to create test config: %v", err) } - cleanup := setBaseDirEnv(t, "/quotes/base") - defer cleanup() + setBaseDirEnv(t, "/quotes/base") cfg, err := LoadConfig(configPath) if err != nil { @@ -344,8 +391,7 @@ DEBUG_LEVEL=4 t.Fatalf("Failed to create test config: %v", err) } - cleanup := setBaseDirEnv(t, "/comments/base") - defer cleanup() + setBaseDirEnv(t, "/comments/base") cfg, err := LoadConfig(configPath) if err != nil { @@ -398,8 +444,7 @@ func TestConfigDefaults(t *testing.T) { t.Fatalf("Failed to create test config: %v", err) } - cleanup := setBaseDirEnv(t, "/defaults/base") - defer cleanup() + setBaseDirEnv(t, "/defaults/base") cfg, err := LoadConfig(configPath) if err != nil { @@ -427,11 +472,96 @@ func TestConfigDefaults(t *testing.T) { t.Errorf("Default LocalRetentionDays = %d; want 7", cfg.LocalRetentionDays) } + if cfg.EmailEnabled { + t.Error("Expected default EmailEnabled to be false") + } + if cfg.BaseDir != "/defaults/base" { t.Errorf("Default BaseDir = %q; want %q", cfg.BaseDir, "/defaults/base") } } +func TestLoadConfigNotificationLegacyEnableAliases(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "legacy_notifications.env") + + content := `TELEGRAM_ENABLE=true +EMAIL_ENABLE=true +GOTIFY_ENABLE=true +WEBHOOK_ENABLE=true +` + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("Failed to create test config: %v", err) + } + + setBaseDirEnv(t, "/legacy/notifications/base") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if !cfg.TelegramEnabled { + t.Error("Expected TelegramEnabled to be true via TELEGRAM_ENABLE") + } + if !cfg.EmailEnabled { + t.Error("Expected EmailEnabled to be true via EMAIL_ENABLE") + } + if !cfg.GotifyEnabled { + t.Error("Expected GotifyEnabled to be true via GOTIFY_ENABLE") + } + if !cfg.WebhookEnabled { + t.Error("Expected WebhookEnabled to be true via WEBHOOK_ENABLE") + } +} + +func TestLoadEnvOverridesNotificationLegacyEnableAliases(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "legacy_notification_env_override.env") + + content := `TELEGRAM_ENABLED=false +EMAIL_ENABLED=false +GOTIFY_ENABLED=false +WEBHOOK_ENABLED=false +BOT_TELEGRAM_TYPE=centralized +` + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("Failed to create test config: %v", err) + } + + overrides := map[string]string{ + "TELEGRAM_ENABLE": "true", + "EMAIL_ENABLE": "true", + "GOTIFY_ENABLE": "true", + "WEBHOOK_ENABLE": "true", + "BOT_TELEGRAM_TYPE": "personal", + } + for key, value := range overrides { + t.Setenv(key, value) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if !cfg.TelegramEnabled { + t.Error("Expected TelegramEnabled to be true via TELEGRAM_ENABLE env override") + } + if !cfg.EmailEnabled { + t.Error("Expected EmailEnabled to be true via EMAIL_ENABLE env override") + } + if !cfg.GotifyEnabled { + t.Error("Expected GotifyEnabled to be true via GOTIFY_ENABLE env override") + } + if !cfg.WebhookEnabled { + t.Error("Expected WebhookEnabled to be true via WEBHOOK_ENABLE env override") + } + if cfg.TelegramBotType != "personal" { + t.Errorf("TelegramBotType = %q; want personal from env override", cfg.TelegramBotType) + } +} + func TestLoadConfigBaseDirFromConfig(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "base_dir.env") @@ -443,8 +573,7 @@ BACKUP_PATH=${BASE_DIR}/backup-data t.Fatalf("Failed to create test config: %v", err) } - cleanup := setBaseDirEnv(t, "") - defer cleanup() + setBaseDirEnv(t, "") cfg, err := LoadConfig(configPath) if err != nil { @@ -471,8 +600,7 @@ PVE_CONFIG_PATH=/etc/pve t.Fatalf("Failed to create test config: %v", err) } - cleanup := setBaseDirEnv(t, "/config-vars/base") - defer cleanup() + setBaseDirEnv(t, "/config-vars/base") cfg, err := LoadConfig(configPath) if err != nil { @@ -499,8 +627,7 @@ MAX_LOCAL_BACKUPS=10 t.Fatalf("Failed to create test config: %v", err) } - cleanup := setBaseDirEnv(t, "/retention/base") - defer cleanup() + setBaseDirEnv(t, "/retention/base") cfg, err := LoadConfig(configPath) if err != nil { @@ -608,8 +735,7 @@ LOCK_PATH=/test/lock t.Fatalf("failed to write config: %v", err) } - cleanup := setBaseDirEnv(t, "/env/base/dir") - defer cleanup() + setBaseDirEnv(t, "/env/base/dir") cfg, err := LoadConfig(configPath) if err != nil { @@ -668,16 +794,8 @@ BACKUP_PATH=/fromfile t.Fatalf("failed to write config: %v", err) } - if err := os.Setenv("BACKUP_ENABLED", "true"); err != nil { - t.Fatalf("failed to set env BACKUP_ENABLED: %v", err) - } - if err := os.Setenv("BACKUP_PATH", "/fromenv"); err != nil { - t.Fatalf("failed to set env BACKUP_PATH: %v", err) - } - defer func() { - _ = os.Unsetenv("BACKUP_ENABLED") - _ = os.Unsetenv("BACKUP_PATH") - }() + t.Setenv("BACKUP_ENABLED", "true") + t.Setenv("BACKUP_PATH", "/fromenv") cfg, err := LoadConfig(configPath) if err != nil { @@ -692,6 +810,57 @@ BACKUP_PATH=/fromfile } } +func TestLoadEnvOverridesOverridesNotificationFields(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "notification_env_override.env") + content := `GOTIFY_SERVER_URL=https://from-file.example +GOTIFY_TOKEN=file-token +GOTIFY_PRIORITY_SUCCESS=2 +WEBHOOK_ENDPOINTS=file_hook +WEBHOOK_FORMAT=generic +WEBHOOK_TIMEOUT=30 +` + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + overrides := map[string]string{ + "GOTIFY_SERVER_URL": "https://from-env.example", + "GOTIFY_TOKEN": "env-token", + "GOTIFY_PRIORITY_SUCCESS": "9", + "WEBHOOK_ENDPOINTS": "env_hook", + "WEBHOOK_FORMAT": "slack", + "WEBHOOK_TIMEOUT": "45", + } + for key, value := range overrides { + t.Setenv(key, value) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if cfg.GotifyServerURL != "https://from-env.example" { + t.Fatalf("GotifyServerURL = %q; want https://from-env.example", cfg.GotifyServerURL) + } + if cfg.GotifyToken != "env-token" { + t.Fatalf("GotifyToken = %q; want env-token", cfg.GotifyToken) + } + if cfg.GotifyPrioritySuccess != 9 { + t.Fatalf("GotifyPrioritySuccess = %d; want 9", cfg.GotifyPrioritySuccess) + } + if len(cfg.WebhookEndpointNames) != 1 || cfg.WebhookEndpointNames[0] != "env_hook" { + t.Fatalf("WebhookEndpointNames = %#v; want [env_hook]", cfg.WebhookEndpointNames) + } + if cfg.WebhookDefaultFormat != "slack" { + t.Fatalf("WebhookDefaultFormat = %q; want slack", cfg.WebhookDefaultFormat) + } + if cfg.WebhookTimeout != 45 { + t.Fatalf("WebhookTimeout = %d; want 45", cfg.WebhookTimeout) + } +} + func TestConfigFallbackHelpers(t *testing.T) { cfg := &Config{ raw: map[string]string{ @@ -731,13 +900,9 @@ func TestConfigFallbackHelpers(t *testing.T) { } func TestExpandEnvVarsAndBaseDir(t *testing.T) { - restoreBase := setBaseDirEnv(t, "/env/base") - defer restoreBase() + setBaseDirEnv(t, "/env/base") - if err := os.Setenv("FOO", "bar"); err != nil { - t.Fatalf("failed to set FOO: %v", err) - } - defer func() { _ = os.Unsetenv("FOO") }() + t.Setenv("FOO", "bar") in := "${FOO}/$FOO/${BASE_DIR}/suffix" got := expandEnvVars(in) @@ -883,8 +1048,7 @@ func TestAutoDetectPBSTokenParsesFiles(t *testing.T) { } func TestAutoDetectPBSAuthEnvAndTokenPriority(t *testing.T) { - restoreBase := setBaseDirEnv(t, "/pbs/base") - defer restoreBase() + setBaseDirEnv(t, "/pbs/base") tmpDir := t.TempDir() tokenFile := filepath.Join(tmpDir, "pbs_token") @@ -893,14 +1057,9 @@ func TestAutoDetectPBSAuthEnvAndTokenPriority(t *testing.T) { } // Case 1: environment variables have highest priority - _ = os.Setenv("PBS_REPOSITORY", "envrepo") - _ = os.Setenv("PBS_PASSWORD", "envpass") - _ = os.Setenv("PBS_FINGERPRINT", "envfp") - defer func() { - _ = os.Unsetenv("PBS_REPOSITORY") - _ = os.Unsetenv("PBS_PASSWORD") - _ = os.Unsetenv("PBS_FINGERPRINT") - }() + t.Setenv("PBS_REPOSITORY", "envrepo") + t.Setenv("PBS_PASSWORD", "envpass") + t.Setenv("PBS_FINGERPRINT", "envfp") cfg := &Config{ SecureAccount: tmpDir, @@ -916,9 +1075,9 @@ func TestAutoDetectPBSAuthEnvAndTokenPriority(t *testing.T) { } // Case 2: no env, use raw config values - _ = os.Unsetenv("PBS_REPOSITORY") - _ = os.Unsetenv("PBS_PASSWORD") - _ = os.Unsetenv("PBS_FINGERPRINT") + t.Setenv("PBS_REPOSITORY", "") + t.Setenv("PBS_PASSWORD", "") + t.Setenv("PBS_FINGERPRINT", "") cfg2 := &Config{ SecureAccount: tmpDir, diff --git a/internal/config/env_mutation.go b/internal/config/env_mutation.go new file mode 100644 index 00000000..5ade6071 --- /dev/null +++ b/internal/config/env_mutation.go @@ -0,0 +1,24 @@ +package config + +import ( + "strings" + + "github.com/tis24dev/proxsave/pkg/utils" +) + +// ApplySecondaryStorageSettings writes the canonical secondary-storage state +// into an env template. Disabled secondary storage always clears both +// SECONDARY_PATH and SECONDARY_LOG_PATH so the saved config matches user intent. +func ApplySecondaryStorageSettings(template string, enabled bool, secondaryPath string, secondaryLogPath string) string { + if enabled { + template = utils.SetEnvValue(template, "SECONDARY_ENABLED", "true") + template = utils.SetEnvValue(template, "SECONDARY_PATH", strings.TrimSpace(secondaryPath)) + template = utils.SetEnvValue(template, "SECONDARY_LOG_PATH", strings.TrimSpace(secondaryLogPath)) + return template + } + + template = utils.SetEnvValue(template, "SECONDARY_ENABLED", "false") + template = utils.SetEnvValue(template, "SECONDARY_PATH", "") + template = utils.SetEnvValue(template, "SECONDARY_LOG_PATH", "") + return template +} diff --git a/internal/config/env_mutation_test.go b/internal/config/env_mutation_test.go new file mode 100644 index 00000000..6b434b68 --- /dev/null +++ b/internal/config/env_mutation_test.go @@ -0,0 +1,74 @@ +package config + +import ( + "strings" + "testing" +) + +func TestApplySecondaryStorageSettingsEnabled(t *testing.T) { + template := "SECONDARY_ENABLED=false\nSECONDARY_PATH=\nSECONDARY_LOG_PATH=\n" + + got := ApplySecondaryStorageSettings(template, true, " /mnt/secondary ", " /mnt/secondary/log ") + + for _, needle := range []string{ + "SECONDARY_ENABLED=true", + "SECONDARY_PATH=/mnt/secondary", + "SECONDARY_LOG_PATH=/mnt/secondary/log", + } { + if !strings.Contains(got, needle) { + t.Fatalf("expected %q in template:\n%s", needle, got) + } + } +} + +func TestApplySecondaryStorageSettingsEnabledWithEmptyLogPath(t *testing.T) { + template := "SECONDARY_ENABLED=false\nSECONDARY_PATH=\nSECONDARY_LOG_PATH=/old/log\n" + + got := ApplySecondaryStorageSettings(template, true, "/mnt/secondary", "") + + for _, needle := range []string{ + "SECONDARY_ENABLED=true", + "SECONDARY_PATH=/mnt/secondary", + "SECONDARY_LOG_PATH=", + } { + if !strings.Contains(got, needle) { + t.Fatalf("expected %q in template:\n%s", needle, got) + } + } +} + +func TestApplySecondaryStorageSettingsDisabledClearsValues(t *testing.T) { + template := "SECONDARY_ENABLED=true\nSECONDARY_PATH=/mnt/old-secondary\nSECONDARY_LOG_PATH=/mnt/old-secondary/logs\n" + + got := ApplySecondaryStorageSettings(template, false, "/ignored", "/ignored/logs") + + for _, needle := range []string{ + "SECONDARY_ENABLED=false", + "SECONDARY_PATH=", + "SECONDARY_LOG_PATH=", + } { + if !strings.Contains(got, needle) { + t.Fatalf("expected %q in template:\n%s", needle, got) + } + } + if strings.Contains(got, "/mnt/old-secondary") { + t.Fatalf("expected old secondary values to be cleared:\n%s", got) + } +} + +func TestApplySecondaryStorageSettingsDisabledAppendsCanonicalState(t *testing.T) { + template := "BACKUP_ENABLED=true\n" + + got := ApplySecondaryStorageSettings(template, false, "", "") + + for _, needle := range []string{ + "BACKUP_ENABLED=true", + "SECONDARY_ENABLED=false", + "SECONDARY_PATH=", + "SECONDARY_LOG_PATH=", + } { + if !strings.Contains(got, needle) { + t.Fatalf("expected %q in template:\n%s", needle, got) + } + } +} diff --git a/internal/config/migration.go b/internal/config/migration.go index 653ddd48..cf5c2f5f 100644 --- a/internal/config/migration.go +++ b/internal/config/migration.go @@ -53,6 +53,10 @@ var migrationRules = map[string]migrationRule{ "BACKUP_NETWORK_CONFIGS": {LegacyKeys: []string{"BACKUP_NETWORK_CONFIG"}}, "BACKUP_REMOTE_CONFIGS": {LegacyKeys: []string{"BACKUP_REMOTE_CFG"}}, "BACKUP_CRON_JOBS": {LegacyKeys: []string{"BACKUP_CRONTABS"}}, + "TELEGRAM_ENABLED": {LegacyKeys: []string{"TELEGRAM_ENABLE"}}, + "EMAIL_ENABLED": {LegacyKeys: []string{"EMAIL_ENABLE"}}, + "GOTIFY_ENABLED": {LegacyKeys: []string{"GOTIFY_ENABLE"}}, + "WEBHOOK_ENABLED": {LegacyKeys: []string{"WEBHOOK_ENABLE"}}, "METRICS_ENABLED": {LegacyKeys: []string{"PROMETHEUS_ENABLED"}}, "METRICS_PATH": {LegacyKeys: []string{"PROMETHEUS_TEXTFILE_DIR"}}, "PXAR_FILE_INCLUDE_PATTERN": {LegacyKeys: []string{"PXAR_INCLUDE_PATTERN"}}, @@ -206,8 +210,13 @@ func validateMigratedConfig(cfg *Config) error { if strings.TrimSpace(cfg.LogPath) == "" { return fmt.Errorf("LOG_PATH cannot be empty") } - if cfg.SecondaryEnabled && strings.TrimSpace(cfg.SecondaryPath) == "" { - return fmt.Errorf("SECONDARY_PATH required when SECONDARY_ENABLED=true") + if cfg.SecondaryEnabled { + if err := ValidateRequiredSecondaryPath(cfg.SecondaryPath); err != nil { + return err + } + if err := ValidateOptionalSecondaryLogPath(cfg.SecondaryLogPath); err != nil { + return err + } } if cfg.CloudEnabled && strings.TrimSpace(cfg.CloudRemote) == "" { return fmt.Errorf("CLOUD_REMOTE required when CLOUD_ENABLED=true") diff --git a/internal/config/migration_test.go b/internal/config/migration_test.go index 450d7701..5457aaa4 100644 --- a/internal/config/migration_test.go +++ b/internal/config/migration_test.go @@ -159,6 +159,8 @@ const baseInstallTemplate = `BACKUP_ENABLED=true BACKUP_PATH=/default/backup LOG_PATH=/default/log SECONDARY_ENABLED=false +SECONDARY_PATH= +SECONDARY_LOG_PATH= CLOUD_ENABLED=false SET_BACKUP_PERMISSIONS=false BACKUP_USER=backup @@ -192,6 +194,29 @@ func TestMigrateLegacyEnvCreatesConfigAndKeepsValues(t *testing.T) { }) } +func TestMigrateLegacyEnvRejectsInvalidSecondaryPath(t *testing.T) { + withTemplate(t, baseInstallTemplate, func() { + tmpDir := t.TempDir() + legacyPath := filepath.Join(tmpDir, "legacy.env") + outputPath := filepath.Join(tmpDir, "backup.env") + legacyContent := strings.Join([]string{ + "ENABLE_SECONDARY_BACKUP=true", + "SECONDARY_BACKUP_PATH=remote:path", + }, "\n") + "\n" + if err := os.WriteFile(legacyPath, []byte(legacyContent), 0600); err != nil { + t.Fatalf("failed to write legacy env: %v", err) + } + + _, err := MigrateLegacyEnv(legacyPath, outputPath) + if err == nil { + t.Fatal("expected migration to fail") + } + if got, want := err.Error(), "SECONDARY_PATH must be an absolute local filesystem path"; !strings.Contains(got, want) { + t.Fatalf("MigrateLegacyEnv error = %q, want substring %q", got, want) + } + }) +} + func TestMigrateLegacyEnvCreatesBackupWhenOverwriting(t *testing.T) { withTemplate(t, baseInstallTemplate, func() { tmpDir := t.TempDir() @@ -401,6 +426,57 @@ NEW_FLAG=true }) } +func TestPlanLegacyEnvMigrationMapsLegacyNotificationEnableAliases(t *testing.T) { + template := `TELEGRAM_ENABLED=false +EMAIL_ENABLED=false +GOTIFY_ENABLED=false +WEBHOOK_ENABLED=false +` + withTemplate(t, template, func() { + tmpDir := t.TempDir() + legacyPath := filepath.Join(tmpDir, "legacy.env") + outputPath := filepath.Join(tmpDir, "backup.env") + + legacyContent := strings.Join([]string{ + "TELEGRAM_ENABLE=true", + "EMAIL_ENABLE=true", + "GOTIFY_ENABLE=true", + "WEBHOOK_ENABLE=true", + "", + }, "\n") + if err := os.WriteFile(legacyPath, []byte(legacyContent), 0o600); err != nil { + t.Fatalf("failed to write legacy env: %v", err) + } + + _, merged, err := PlanLegacyEnvMigration(legacyPath, outputPath) + if err != nil { + t.Fatalf("PlanLegacyEnvMigration returned error: %v", err) + } + + for _, want := range []string{ + "TELEGRAM_ENABLED=true", + "EMAIL_ENABLED=true", + "GOTIFY_ENABLED=true", + "WEBHOOK_ENABLED=true", + } { + if !strings.Contains(merged, want) { + t.Fatalf("expected migrated config to contain %q:\n%s", want, merged) + } + } + + for _, legacyKey := range []string{ + "TELEGRAM_ENABLE=", + "EMAIL_ENABLE=", + "GOTIFY_ENABLE=", + "WEBHOOK_ENABLE=", + } { + if strings.Contains(merged, legacyKey) { + t.Fatalf("expected migrated config to replace legacy key %q:\n%s", legacyKey, merged) + } + } + }) +} + func TestInvertBoolAndBoolToString(t *testing.T) { tests := []struct { in string diff --git a/internal/config/templates/backup.env b/internal/config/templates/backup.env index ad5a9b0f..8369ff69 100644 --- a/internal/config/templates/backup.env +++ b/internal/config/templates/backup.env @@ -93,7 +93,7 @@ LOG_PATH=${BASE_DIR}/log # Primary log storage path # ---------------------------------------------------------------------- # Secondary storage # ---------------------------------------------------------------------- -# IMPORTANT: SECONDARY_PATH must be a filesystem-mounted path (e.g., /mnt/nas-backup) +# IMPORTANT: SECONDARY_PATH must be an absolute local filesystem path (e.g., /mnt/nas-backup) # It CANNOT be a network address like "192.168.0.10/folder" or "//server/share" # # For local network storage (NAS): @@ -111,8 +111,8 @@ LOG_PATH=${BASE_DIR}/log # Primary log storage path # For direct network access without mounting, use CLOUD_REMOTE with rclone instead. # ---------------------------------------------------------------------- SECONDARY_ENABLED=false # true-false = enable disable copy backup on secondary path -SECONDARY_PATH= # Secondary backup storage path -SECONDARY_LOG_PATH= # Secondary log storage path +SECONDARY_PATH= # Required absolute secondary backup path when secondary storage is enabled +SECONDARY_LOG_PATH= # Optional absolute secondary log path (same rules as SECONDARY_PATH) # ---------------------------------------------------------------------- # Cloud storage (rclone) diff --git a/internal/config/validation_secondary.go b/internal/config/validation_secondary.go new file mode 100644 index 00000000..06564208 --- /dev/null +++ b/internal/config/validation_secondary.go @@ -0,0 +1,58 @@ +package config + +import ( + "fmt" + "path/filepath" + "strings" +) + +const secondaryPathFormatMessage = "must be an absolute local filesystem path" + +// ValidateRequiredSecondaryPath validates SECONDARY_PATH when secondary storage is enabled. +func ValidateRequiredSecondaryPath(path string) error { + return validateSecondaryLocalPath(path, "SECONDARY_PATH", true) +} + +// ValidateOptionalSecondaryPath validates SECONDARY_PATH when configured but not required. +func ValidateOptionalSecondaryPath(path string) error { + return validateSecondaryLocalPath(path, "SECONDARY_PATH", false) +} + +// ValidateOptionalSecondaryLogPath validates SECONDARY_LOG_PATH when provided. +func ValidateOptionalSecondaryLogPath(path string) error { + return validateSecondaryLocalPath(path, "SECONDARY_LOG_PATH", false) +} + +func validateSecondaryLocalPath(path, fieldName string, required bool) error { + clean := strings.TrimSpace(path) + if clean == "" { + if required { + return fmt.Errorf("%s is required when SECONDARY_ENABLED=true", fieldName) + } + return nil + } + + if isUNCStylePath(clean) { + return fmt.Errorf("%s %s", fieldName, secondaryPathFormatMessage) + } + + if strings.Contains(clean, ":") && !filepath.IsAbs(clean) { + return fmt.Errorf("%s %s", fieldName, secondaryPathFormatMessage) + } + + if !filepath.IsAbs(clean) { + return fmt.Errorf("%s %s", fieldName, secondaryPathFormatMessage) + } + + return nil +} + +func isUNCStylePath(path string) bool { + if strings.HasPrefix(path, `\\`) { + return true + } + if strings.HasPrefix(path, "//") { + return len(path) == 2 || path[2] != '/' + } + return false +} diff --git a/internal/config/validation_secondary_test.go b/internal/config/validation_secondary_test.go new file mode 100644 index 00000000..a90a0c8f --- /dev/null +++ b/internal/config/validation_secondary_test.go @@ -0,0 +1,102 @@ +package config + +import "testing" + +func TestValidateRequiredSecondaryPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + wantErr string + }{ + {name: "valid mount path", path: "/mnt/secondary"}, + {name: "valid subdirectory", path: "/mnt/secondary/log"}, + {name: "valid absolute with colon", path: "/mnt/data:archive"}, + {name: "empty", path: "", wantErr: "SECONDARY_PATH is required when SECONDARY_ENABLED=true"}, + {name: "relative", path: "relative/path", wantErr: "SECONDARY_PATH must be an absolute local filesystem path"}, + {name: "rclone remote", path: "gdrive:backups", wantErr: "SECONDARY_PATH must be an absolute local filesystem path"}, + {name: "host remote", path: "host:/backup", wantErr: "SECONDARY_PATH must be an absolute local filesystem path"}, + {name: "unc share", path: "//server/share", wantErr: "SECONDARY_PATH must be an absolute local filesystem path"}, + {name: "windows unc share", path: `\\server\share`, wantErr: "SECONDARY_PATH must be an absolute local filesystem path"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := ValidateRequiredSecondaryPath(tt.path) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("ValidateRequiredSecondaryPath(%q) error = %v", tt.path, err) + } + return + } + if err == nil || err.Error() != tt.wantErr { + t.Fatalf("ValidateRequiredSecondaryPath(%q) error = %v, want %q", tt.path, err, tt.wantErr) + } + }) + } +} + +func TestValidateOptionalSecondaryLogPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + wantErr string + }{ + {name: "empty allowed", path: ""}, + {name: "valid path", path: "/mnt/secondary/log"}, + {name: "relative", path: "logs", wantErr: "SECONDARY_LOG_PATH must be an absolute local filesystem path"}, + {name: "remote style", path: "remote:/logs", wantErr: "SECONDARY_LOG_PATH must be an absolute local filesystem path"}, + {name: "unc share", path: "//server/logs", wantErr: "SECONDARY_LOG_PATH must be an absolute local filesystem path"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := ValidateOptionalSecondaryLogPath(tt.path) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("ValidateOptionalSecondaryLogPath(%q) error = %v", tt.path, err) + } + return + } + if err == nil || err.Error() != tt.wantErr { + t.Fatalf("ValidateOptionalSecondaryLogPath(%q) error = %v, want %q", tt.path, err, tt.wantErr) + } + }) + } +} + +func TestValidateOptionalSecondaryPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + wantErr string + }{ + {name: "empty allowed", path: ""}, + {name: "valid path", path: "/mnt/secondary"}, + {name: "relative", path: "relative/path", wantErr: "SECONDARY_PATH must be an absolute local filesystem path"}, + {name: "remote style", path: "remote:/backup", wantErr: "SECONDARY_PATH must be an absolute local filesystem path"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := ValidateOptionalSecondaryPath(tt.path) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("ValidateOptionalSecondaryPath(%q) error = %v", tt.path, err) + } + return + } + if err == nil || err.Error() != tt.wantErr { + t.Fatalf("ValidateOptionalSecondaryPath(%q) error = %v, want %q", tt.path, err, tt.wantErr) + } + }) + } +} diff --git a/internal/cron/cron.go b/internal/cron/cron.go new file mode 100644 index 00000000..4c6c5193 --- /dev/null +++ b/internal/cron/cron.go @@ -0,0 +1,51 @@ +package cron + +import ( + "fmt" + "strconv" + "strings" +) + +const DefaultTime = "02:00" + +// NormalizeTime validates a cron time in HH:MM form and returns a normalized, +// zero-padded value. Empty input falls back to defaultValue. +func NormalizeTime(input string, defaultValue string) (string, error) { + value := strings.TrimSpace(input) + if value == "" { + value = strings.TrimSpace(defaultValue) + } + hour, minute, err := parseTime(value) + if err != nil { + return "", err + } + return fmt.Sprintf("%02d:%02d", hour, minute), nil +} + +// TimeToSchedule converts HH:MM into "MM HH * * *". Invalid input returns "". +func TimeToSchedule(cronTime string) string { + hour, minute, err := parseTime(strings.TrimSpace(cronTime)) + if err != nil { + return "" + } + return fmt.Sprintf("%02d %02d * * *", minute, hour) +} + +func parseTime(value string) (int, int, error) { + parts := strings.Split(strings.TrimSpace(value), ":") + if len(parts) != 2 { + return 0, 0, fmt.Errorf("cron time must be in HH:MM format") + } + + hour, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil || hour < 0 || hour > 23 { + return 0, 0, fmt.Errorf("cron hour must be between 00 and 23") + } + + minute, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil || minute < 0 || minute > 59 { + return 0, 0, fmt.Errorf("cron minute must be between 00 and 59") + } + + return hour, minute, nil +} diff --git a/internal/cron/cron_test.go b/internal/cron/cron_test.go new file mode 100644 index 00000000..40077505 --- /dev/null +++ b/internal/cron/cron_test.go @@ -0,0 +1,61 @@ +package cron + +import "testing" + +func TestNormalizeTime(t *testing.T) { + tests := []struct { + name string + input string + defaultValue string + want string + wantErr string + }{ + {name: "default fallback", input: "", defaultValue: DefaultTime, want: DefaultTime}, + {name: "normalize short values", input: "3:7", defaultValue: DefaultTime, want: "03:07"}, + {name: "trim whitespace", input: " 03:15 ", defaultValue: DefaultTime, want: "03:15"}, + {name: "invalid format", input: "0315", defaultValue: DefaultTime, wantErr: "cron time must be in HH:MM format"}, + {name: "invalid hour", input: "24:00", defaultValue: DefaultTime, wantErr: "cron hour must be between 00 and 23"}, + {name: "invalid minute", input: "00:60", defaultValue: DefaultTime, wantErr: "cron minute must be between 00 and 59"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeTime(tt.input, tt.defaultValue) + if tt.wantErr != "" { + if err == nil { + t.Fatal("expected error") + } + if err.Error() != tt.wantErr { + t.Fatalf("NormalizeTime(%q) error = %q, want %q", tt.input, err.Error(), tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("NormalizeTime(%q) returned error: %v", tt.input, err) + } + if got != tt.want { + t.Fatalf("NormalizeTime(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestTimeToSchedule(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "valid", in: "02:05", want: "05 02 * * *"}, + {name: "normalized short", in: "2:5", want: "05 02 * * *"}, + {name: "invalid", in: "bad", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TimeToSchedule(tt.in); got != tt.want { + t.Fatalf("TimeToSchedule(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} diff --git a/internal/environment/detect_additional_test.go b/internal/environment/detect_additional_test.go index cd780ea7..732ed4ec 100644 --- a/internal/environment/detect_additional_test.go +++ b/internal/environment/detect_additional_test.go @@ -248,14 +248,7 @@ func TestDetectPBSViaSources(t *testing.T) { // TestExtendPath tests PATH environment variable extension func TestExtendPath(t *testing.T) { - // Save original PATH - originalPath := os.Getenv("PATH") - defer func() { - _ = os.Setenv("PATH", originalPath) - }() - - // Set a minimal PATH - _ = os.Setenv("PATH", "/usr/local/bin") + t.Setenv("PATH", "/usr/local/bin") extendPath() @@ -271,10 +264,7 @@ func TestExtendPath(t *testing.T) { // TestExtendPathIdempotent tests that extendPath doesn't duplicate paths func TestExtendPathIdempotent(t *testing.T) { - originalPath := os.Getenv("PATH") - defer func() { - _ = os.Setenv("PATH", originalPath) - }() + t.Setenv("PATH", "") // Call extendPath twice extendPath() diff --git a/internal/identity/identity.go b/internal/identity/identity.go index 87ae5827..2e2621e4 100644 --- a/internal/identity/identity.go +++ b/internal/identity/identity.go @@ -2,6 +2,7 @@ package identity import ( "bufio" + "context" "crypto/sha256" "encoding/base64" "errors" @@ -44,6 +45,14 @@ var ( // Detect resolves the server identity (ID + MAC address) and ensures persistence. func Detect(baseDir string, logger *logging.Logger) (*Info, error) { + return DetectWithContext(context.Background(), baseDir, logger) +} + +// DetectWithContext resolves the server identity (ID + MAC address) and ensures persistence. +func DetectWithContext(ctx context.Context, baseDir string, logger *logging.Logger) (*Info, error) { + if ctx == nil { + ctx = context.Background() + } info := &Info{} baseDir = strings.TrimSpace(baseDir) logDebug(logger, "Identity: starting detection (baseDir=%q)", baseDir) @@ -78,7 +87,9 @@ func Detect(baseDir string, logger *logging.Logger) (*Info, error) { if strings.TrimSpace(info.PrimaryMAC) == "" && strings.TrimSpace(boundMAC) != "" { info.PrimaryMAC = boundMAC } - maybeUpgradeIdentityFile(identityPath, id, info.PrimaryMAC, macs, logger) + if err := maybeUpgradeIdentityFileWithContext(ctx, identityPath, id, info.PrimaryMAC, macs, logger); err != nil { + return info, err + } return info, nil } logDebug(logger, "Identity: identity file %s returned empty server ID; generating new one", identityPath) @@ -106,7 +117,10 @@ func Detect(baseDir string, logger *logging.Logger) (*Info, error) { logDebug(logger, "Identity: identity directory ready: %s", identityDir) logDebug(logger, "Identity: persisting identity file (0600 + immutable) to %s", identityPath) - if err := writeIdentityFile(identityPath, encodedFile, logger); err != nil { + if err := writeIdentityFileWithContext(ctx, identityPath, encodedFile, logger); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return info, err + } logWarning(logger, "Identity: failed to write server identity file %s: %v (server ID will NOT be persisted)", identityPath, err) return info, nil } @@ -731,21 +745,30 @@ func computeSystemKey(machineID, hostnamePart, extra string) string { return fmt.Sprintf("%x", sum)[:16] } -func maybeUpgradeIdentityFile(path string, serverID string, primaryMAC string, macs []string, logger *logging.Logger) { +func maybeUpgradeIdentityFile(path string, serverID string, primaryMAC string, macs []string, logger *logging.Logger) error { + return maybeUpgradeIdentityFileWithContext(context.Background(), path, serverID, primaryMAC, macs, logger) +} + +func maybeUpgradeIdentityFileWithContext(ctx context.Context, path string, serverID string, primaryMAC string, macs []string, logger *logging.Logger) error { + if ctx == nil { + ctx = context.Background() + } data, err := os.ReadFile(path) if err != nil { - return + return nil } if identityPayloadHasKeyLabels(string(data), logger) { - return + return nil } updated, err := encodeProtectedServerIDWithMACs(serverID, macs, primaryMAC, logger) if err != nil { - return + return err } - if err := writeIdentityFile(path, updated, logger); err != nil { + if err := writeIdentityFileWithContext(ctx, path, updated, logger); err != nil { logDebug(logger, "Identity: failed to upgrade identity file format: %v", err) + return err } + return nil } func normalizeMAC(mac string) string { @@ -784,10 +807,19 @@ func identityPayloadHasKeyLabels(fileContent string, logger *logging.Logger) boo } func writeIdentityFile(path, content string, logger *logging.Logger) error { + return writeIdentityFileWithContext(context.Background(), path, content, logger) +} + +func writeIdentityFileWithContext(ctx context.Context, path, content string, logger *logging.Logger) error { + if ctx == nil { + ctx = context.Background() + } logDebug(logger, "Identity: writeIdentityFile: start path=%s contentBytes=%d", path, len(content)) // Ensure file is writable even if immutable was previously set - _ = setImmutableAttribute(path, false, logger) + if err := setImmutableAttributeWithContext(ctx, path, false, logger); err != nil { + return err + } if err := os.WriteFile(path, []byte(content), 0o600); err != nil { logDebug(logger, "Identity: writeIdentityFile: os.WriteFile failed: %v", err) @@ -799,7 +831,9 @@ func writeIdentityFile(path, content string, logger *logging.Logger) error { return err } - _ = setImmutableAttribute(path, true, logger) + if err := setImmutableAttributeWithContext(ctx, path, true, logger); err != nil { + return err + } logDebug(logger, "Identity: writeIdentityFile: done path=%s", path) return nil @@ -890,6 +924,18 @@ func logDebug(logger *logging.Logger, format string, args ...interface{}) { } func setImmutableAttribute(path string, enable bool, logger *logging.Logger) error { + return setImmutableAttributeWithContext(context.Background(), path, enable, logger) +} + +func setImmutableAttributeWithContext(ctx context.Context, path string, enable bool, logger *logging.Logger) error { + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + logDebug(logger, "Identity: immutable: context canceled before chattr for %s: %v", path, err) + return err + } + if runtime.GOOS != "linux" { logDebug(logger, "Identity: immutable: skip (GOOS=%s)", runtime.GOOS) return nil @@ -921,8 +967,12 @@ func setImmutableAttribute(path string, enable bool, logger *logging.Logger) err flag = "-i" } - cmd := exec.Command(chattrPath, flag, path) + cmd := exec.CommandContext(ctx, chattrPath, flag, path) if err := cmd.Run(); err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + logDebug(logger, "Identity: immutable: chattr canceled for %s: %v", path, ctxErr) + return ctxErr + } logDebug(logger, "Identity: immutable: chattr failed (ignored): %v", err) return nil } diff --git a/internal/identity/identity_test.go b/internal/identity/identity_test.go index 82716962..21711565 100644 --- a/internal/identity/identity_test.go +++ b/internal/identity/identity_test.go @@ -2,11 +2,14 @@ package identity import ( "bytes" + "context" "crypto/sha256" "encoding/base64" + "errors" "fmt" "os" "path/filepath" + "runtime" "strings" "testing" "time" @@ -200,6 +203,25 @@ func TestDetectCreatesIdentityFileInBaseDir(t *testing.T) { } } +func TestSetImmutableAttributeWithContext_CanceledBeforeCommand(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("requires linux") + } + + path := filepath.Join(t.TempDir(), "identity") + if err := os.WriteFile(path, []byte("data"), 0o600); err != nil { + t.Fatalf("write temp file: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := setImmutableAttributeWithContext(ctx, path, false, nil) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want %v", err, context.Canceled) + } +} + func TestDetectUsesExistingIdentityFile(t *testing.T) { baseDir := t.TempDir() identityDir := filepath.Join(baseDir, identityDirName) @@ -242,6 +264,69 @@ func TestDetectUsesExistingIdentityFile(t *testing.T) { } } +func TestDetectWithContext_PropagatesCancellationDuringLegacyUpgrade(t *testing.T) { + origRead := readFirstLineFunc + origHost := hostnameFunc + t.Cleanup(func() { + readFirstLineFunc = origRead + hostnameFunc = origHost + }) + + hostnameFunc = func() (string, error) { return "host-one", nil } + readFirstLineFunc = func(path string, limit int) string { + switch path { + case "/etc/machine-id": + return "machine-one" + case "/sys/class/dmi/id/product_uuid": + return "" + default: + return "" + } + } + + baseDir := t.TempDir() + identityDir := filepath.Join(baseDir, identityDirName) + if err := os.MkdirAll(identityDir, 0o755); err != nil { + t.Fatalf("failed to create identity dir: %v", err) + } + identityPath := filepath.Join(identityDir, identityFileName) + + t.Cleanup(func() { + _ = setImmutableAttribute(identityPath, false, nil) + }) + + const serverID = "1234567890123456" + _, macs := collectMACCandidates(nil) + if len(macs) == 0 { + t.Skip("no non-loopback MACs available on this system") + } + primaryMAC := macs[0] + legacy, err := encodeProtectedServerIDLegacy(serverID, primaryMAC) + if err != nil { + t.Fatalf("encodeProtectedServerIDLegacy() error = %v", err) + } + if err := os.WriteFile(identityPath, []byte(legacy), 0o600); err != nil { + t.Fatalf("failed to write legacy identity file: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + info, err := DetectWithContext(ctx, baseDir, nil) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want %v", err, context.Canceled) + } + if info == nil { + t.Fatal("expected info even on cancellation") + } + if info.ServerID != serverID { + t.Fatalf("ServerID = %q, want %q", info.ServerID, serverID) + } + if info.IdentityFile != identityPath { + t.Fatalf("IdentityFile = %q, want %q", info.IdentityFile, identityPath) + } +} + func TestLoadServerIDTriesAllMACAddresses(t *testing.T) { baseDir := t.TempDir() identityDir := filepath.Join(baseDir, identityDirName) @@ -632,7 +717,9 @@ func TestMaybeUpgradeIdentityFileRewritesLegacyToV2WithAltMACs(t *testing.T) { t.Fatalf("failed to write legacy identity file: %v", err) } - maybeUpgradeIdentityFile(path, serverID, macPrimary, []string{macPrimary, macAlt}, nil) + if err := maybeUpgradeIdentityFile(path, serverID, macPrimary, []string{macPrimary, macAlt}, nil); err != nil { + t.Fatalf("maybeUpgradeIdentityFile() error = %v", err) + } upgraded, err := os.ReadFile(path) if err != nil { @@ -1613,7 +1700,9 @@ func TestCollectMACCandidatesWithLogger(t *testing.T) { func TestMaybeUpgradeIdentityFileNonExistent(t *testing.T) { // Should not panic on non-existent file - maybeUpgradeIdentityFile("/nonexistent/path/identity.conf", "1234567890123456", "aa:bb:cc:dd:ee:ff", nil, nil) + if err := maybeUpgradeIdentityFile("/nonexistent/path/identity.conf", "1234567890123456", "aa:bb:cc:dd:ee:ff", nil, nil); err != nil { + t.Fatalf("maybeUpgradeIdentityFile() error = %v", err) + } } func TestMaybeUpgradeIdentityFileAlreadyUpgraded(t *testing.T) { @@ -1655,7 +1744,9 @@ func TestMaybeUpgradeIdentityFileAlreadyUpgraded(t *testing.T) { original, _ := os.ReadFile(path) // Try to upgrade - should be no-op since already v2 - maybeUpgradeIdentityFile(path, serverID, macs[0], macs, nil) + if err := maybeUpgradeIdentityFile(path, serverID, macs[0], macs, nil); err != nil { + t.Fatalf("maybeUpgradeIdentityFile() error = %v", err) + } // Content should not have changed (same format) after, _ := os.ReadFile(path) diff --git a/internal/logging/bootstrap.go b/internal/logging/bootstrap.go index 49dc83ac..47f9aede 100644 --- a/internal/logging/bootstrap.go +++ b/internal/logging/bootstrap.go @@ -34,6 +34,9 @@ func NewBootstrapLogger() *BootstrapLogger { // SetLevel updates the minimum level used during Flush. func (b *BootstrapLogger) SetLevel(level types.LogLevel) { + if b == nil { + return + } b.mu.Lock() defer b.mu.Unlock() b.minLevel = level @@ -41,6 +44,9 @@ func (b *BootstrapLogger) SetLevel(level types.LogLevel) { // Println records a raw line (used for banners/text without a header). func (b *BootstrapLogger) Println(message string) { + if b == nil { + return + } fmt.Println(message) b.mirrorLog(types.LogLevelInfo, message) b.recordRaw(message) @@ -48,6 +54,9 @@ func (b *BootstrapLogger) Println(message string) { // Debug records a debug message without printing it to the console. func (b *BootstrapLogger) Debug(format string, args ...interface{}) { + if b == nil { + return + } msg := fmt.Sprintf(format, args...) b.mirrorLog(types.LogLevelDebug, msg) b.record(types.LogLevelDebug, msg) @@ -55,6 +64,9 @@ func (b *BootstrapLogger) Debug(format string, args ...interface{}) { // Printf records a formatted line as raw. func (b *BootstrapLogger) Printf(format string, args ...interface{}) { + if b == nil { + return + } msg := fmt.Sprintf(format, args...) fmt.Println(msg) b.mirrorLog(types.LogLevelInfo, msg) @@ -63,6 +75,9 @@ func (b *BootstrapLogger) Printf(format string, args ...interface{}) { // Info logs an early informational message. func (b *BootstrapLogger) Info(format string, args ...interface{}) { + if b == nil { + return + } msg := fmt.Sprintf(format, args...) fmt.Println(msg) b.mirrorLog(types.LogLevelInfo, msg) @@ -71,6 +86,9 @@ func (b *BootstrapLogger) Info(format string, args ...interface{}) { // Warning records an early warning message (printed to stderr). func (b *BootstrapLogger) Warning(format string, args ...interface{}) { + if b == nil { + return + } msg := fmt.Sprintf(format, args...) if !strings.HasSuffix(msg, "\n") { msg += "\n" @@ -83,6 +101,9 @@ func (b *BootstrapLogger) Warning(format string, args ...interface{}) { // Error records an early error message (stderr). func (b *BootstrapLogger) Error(format string, args ...interface{}) { + if b == nil { + return + } msg := fmt.Sprintf(format, args...) if !strings.HasSuffix(msg, "\n") { msg += "\n" @@ -114,6 +135,9 @@ func (b *BootstrapLogger) recordRaw(message string) { // Flush flushes accumulated entries into the main logger (only the first time). func (b *BootstrapLogger) Flush(logger *Logger) { + if b == nil || logger == nil { + return + } b.mu.Lock() defer b.mu.Unlock() if b.flushed { @@ -121,9 +145,7 @@ func (b *BootstrapLogger) Flush(logger *Logger) { } for _, entry := range b.entries { if entry.raw { - if logger != nil { - logger.AppendRaw(entry.message) - } + logger.AppendRaw(entry.message) continue } if entry.level > b.minLevel { @@ -150,6 +172,9 @@ func (b *BootstrapLogger) Flush(logger *Logger) { // SetMirrorLogger forwards every bootstrap message to the provided logger. func (b *BootstrapLogger) SetMirrorLogger(logger *Logger) { + if b == nil { + return + } b.mu.Lock() b.mirror = logger b.mu.Unlock() diff --git a/internal/logging/bootstrap_test.go b/internal/logging/bootstrap_test.go index 8bbacbbf..a7b752f1 100644 --- a/internal/logging/bootstrap_test.go +++ b/internal/logging/bootstrap_test.go @@ -92,3 +92,26 @@ func TestBootstrapLoggerDebugMirrorsAndFlushesAtDebugLevel(t *testing.T) { t.Fatalf("expected debug message to be flushed, got %q", flushBuf.String()) } } + +func TestBootstrapLoggerNilReceiverNoops(t *testing.T) { + var b *BootstrapLogger + + logger := New(types.LogLevelDebug, false) + var buf bytes.Buffer + logger.SetOutput(&buf) + + b.SetLevel(types.LogLevelDebug) + b.SetMirrorLogger(logger) + b.Println("plain") + b.Printf("plain-%d", 1) + b.Debug("debug") + b.Info("info") + b.Warning("warn") + b.Error("err") + b.Flush(logger) + b.Flush(nil) + + if buf.Len() != 0 { + t.Fatalf("nil bootstrap receiver should not write to logger, got %q", buf.String()) + } +} diff --git a/internal/notify/context_helpers.go b/internal/notify/context_helpers.go new file mode 100644 index 00000000..46955b33 --- /dev/null +++ b/internal/notify/context_helpers.go @@ -0,0 +1,25 @@ +package notify + +import ( + "context" + "time" +) + +func sleepWithContext(ctx context.Context, d time.Duration) error { + if err := ctx.Err(); err != nil { + return err + } + if d <= 0 { + return nil + } + + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/internal/notify/email.go b/internal/notify/email.go index 20a3afac..ec25bc12 100644 --- a/internal/notify/email.go +++ b/internal/notify/email.go @@ -760,15 +760,22 @@ func (e *EmailNotifier) detectQueueEntry(ctx context.Context, recipient string) } // tailMailLog reads the last maxLines from the first available mail log file. -func (e *EmailNotifier) tailMailLog(maxLines int) ([]string, string) { +func (e *EmailNotifier) tailMailLog(ctx context.Context, maxLines int) ([]string, string) { + if err := ctx.Err(); err != nil { + return nil, "" + } + for _, logFile := range mailLogPaths { if _, err := os.Stat(logFile); err != nil { continue } - cmd := exec.Command("tail", "-n", strconv.Itoa(maxLines), logFile) + cmd := exec.CommandContext(ctx, "tail", "-n", strconv.Itoa(maxLines), logFile) output, err := cmd.Output() if err != nil { + if ctx.Err() != nil { + return nil, "" + } continue } @@ -777,13 +784,16 @@ func (e *EmailNotifier) tailMailLog(maxLines int) ([]string, string) { } // Fallback to journald if traditional log files are unavailable + if err := ctx.Err(); err != nil { + return nil, "" + } if _, err := exec.LookPath("journalctl"); err == nil { args := []string{"--no-pager", "-n", strconv.Itoa(maxLines)} for _, unit := range []string{"postfix.service", "sendmail.service", "exim4.service"} { args = append(args, "-u", unit) } - cmd := exec.Command("journalctl", args...) + cmd := exec.CommandContext(ctx, "journalctl", args...) output, err := cmd.Output() if err == nil && len(output) > 0 { lines := strings.Split(strings.TrimRight(string(output), "\n"), "\n") @@ -795,8 +805,8 @@ func (e *EmailNotifier) tailMailLog(maxLines int) ([]string, string) { } // checkRecentMailLogs checks recent mail log entries for errors -func (e *EmailNotifier) checkRecentMailLogs() []string { - lines, _ := e.tailMailLog(50) +func (e *EmailNotifier) checkRecentMailLogs(ctx context.Context) []string { + lines, _ := e.tailMailLog(ctx, 50) if len(lines) == 0 { return nil } @@ -832,8 +842,8 @@ func extractQueueID(outputs ...string) string { } // inspectMailLogStatus looks for a delivery status line for the given queue ID. -func (e *EmailNotifier) inspectMailLogStatus(queueID string) (status, matchedLine, logPath string) { - lines, logPath := e.tailMailLog(80) +func (e *EmailNotifier) inspectMailLogStatus(ctx context.Context, queueID string) (status, matchedLine, logPath string) { + lines, logPath := e.tailMailLog(ctx, 80) if len(lines) == 0 || logPath == "" { return "", "", logPath } @@ -1297,7 +1307,13 @@ func (e *EmailNotifier) sendViaSendmail(ctx context.Context, recipient, subject, e.logger.Debug("=== Post-send verification ===") // Brief pause to let sendmail process the message - time.Sleep(500 * time.Millisecond) + if err := sleepWithContext(ctx, 500*time.Millisecond); err != nil { + e.logger.Debug("Skipping post-send verification because context ended: %v", err) + e.logger.Debug("✅ Email handed off to sendmail successfully") + e.logger.Info("NOTE: Sendmail exit code 0 means email accepted to queue, not necessarily delivered") + e.logger.Info(" To verify actual delivery, check: mailq and /var/log/mail.log") + return queueID, "sendmail", sendmailPath, nil + } // Check queue again to see if message is stuck if queueCount, err := e.checkMailQueue(ctx); err == nil { @@ -1317,7 +1333,7 @@ func (e *EmailNotifier) sendViaSendmail(ctx context.Context, recipient, subject, } // Check recent mail logs for errors (always surface summary, details only in debug) - recentErrors := e.checkRecentMailLogs() + recentErrors := e.checkRecentMailLogs(ctx) if len(recentErrors) > 0 { e.logger.Warning("⚠ Recent mail log entries indicate potential delivery issues (found %d error-like lines)", len(recentErrors)) e.logger.Info(" Suggestion: inspect /var/log/mail.log (or maillog/mail.err) on this host for details") @@ -1345,7 +1361,7 @@ func (e *EmailNotifier) sendViaSendmail(ctx context.Context, recipient, subject, } if queueID != "" { - status, matchedLine, logPath := e.inspectMailLogStatus(queueID) + status, matchedLine, logPath := e.inspectMailLogStatus(ctx, queueID) e.logMailLogStatus(queueID, status, matchedLine, logPath) } else { e.logger.Debug("Sendmail did not report a queue ID; attempting to detect from mail queue output") @@ -1356,7 +1372,7 @@ func (e *EmailNotifier) sendViaSendmail(ctx context.Context, recipient, subject, if queueLine != "" && e.logger.GetLevel() >= types.LogLevelDebug { e.logger.Debug("Mail queue entry: %s", queueLine) } - status, matchedLine, logPath := e.inspectMailLogStatus(queueID) + status, matchedLine, logPath := e.inspectMailLogStatus(ctx, queueID) e.logMailLogStatus(queueID, status, matchedLine, logPath) } else { e.logger.Debug("No matching mail queue entry found for %s immediately after sending", recipient) diff --git a/internal/notify/email_mailq_test.go b/internal/notify/email_mailq_test.go index 2b7733f6..8c10cdab 100644 --- a/internal/notify/email_mailq_test.go +++ b/internal/notify/email_mailq_test.go @@ -2,6 +2,8 @@ package notify import ( "context" + "os" + "path/filepath" "testing" "github.com/tis24dev/proxsave/internal/logging" @@ -95,3 +97,34 @@ func TestEmailNotifierDetectQueueEntryNotFound(t *testing.T) { t.Fatalf("detectQueueEntry()=(%q,%q) want empty", queueID, line) } } + +func TestEmailNotifierTailMailLogSkipsWorkWhenContextCanceled(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + notifier, err := NewEmailNotifier(EmailConfig{Enabled: true, DeliveryMethod: EmailDeliverySendmail}, types.ProxmoxBS, logger) + if err != nil { + t.Fatalf("NewEmailNotifier() error=%v", err) + } + + origMailLogPaths := mailLogPaths + t.Cleanup(func() { mailLogPaths = origMailLogPaths }) + + logDir := t.TempDir() + logFile := filepath.Join(logDir, "mail.log") + if err := os.WriteFile(logFile, []byte("postfix/smtp[2]: ABC123: status=sent\n"), 0o600); err != nil { + t.Fatalf("write log file: %v", err) + } + mailLogPaths = []string{logFile} + + mockCmdEnv(t, "tail", "postfix/smtp[2]: ABC123: status=sent", 0) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + lines, logPath := notifier.tailMailLog(ctx, 50) + if len(lines) != 0 { + t.Fatalf("tailMailLog() returned lines after context cancellation: %#v", lines) + } + if logPath != "" { + t.Fatalf("tailMailLog() returned log path %q after context cancellation", logPath) + } +} diff --git a/internal/notify/email_parsing_test.go b/internal/notify/email_parsing_test.go index 41c9a153..92c9d6f1 100644 --- a/internal/notify/email_parsing_test.go +++ b/internal/notify/email_parsing_test.go @@ -2,6 +2,7 @@ package notify import ( "bytes" + "context" "io" "os" "path/filepath" @@ -84,7 +85,7 @@ func TestInspectMailLogStatus(t *testing.T) { t.Fatalf("NewEmailNotifier() error=%v", err) } - status, matchedLine, usedPath := notifier.inspectMailLogStatus(queueID) + status, matchedLine, usedPath := notifier.inspectMailLogStatus(context.Background(), queueID) if status != "sent" { t.Fatalf("status=%q want %q (matchedLine=%q)", status, "sent", matchedLine) } @@ -126,7 +127,7 @@ func TestEmailNotifierCheckRecentMailLogsDetectsErrors(t *testing.T) { t.Fatalf("NewEmailNotifier() error=%v", err) } - lines := notifier.checkRecentMailLogs() + lines := notifier.checkRecentMailLogs(context.Background()) if len(lines) < 3 { t.Fatalf("expected >=3 error-like lines, got %d: %#v", len(lines), lines) } @@ -184,7 +185,7 @@ func TestInspectMailLogStatus_Variants(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - status, matched, usedPath := notifier.inspectMailLogStatus(tt.queueID) + status, matched, usedPath := notifier.inspectMailLogStatus(context.Background(), tt.queueID) if status != tt.want { t.Fatalf("status=%q want %q (matched=%q)", status, tt.want, matched) } diff --git a/internal/notify/email_relay.go b/internal/notify/email_relay.go index 241dfd32..8d99bd9a 100644 --- a/internal/notify/email_relay.go +++ b/internal/notify/email_relay.go @@ -86,13 +86,19 @@ func sendViaCloudRelay( var lastErr error skipDefaultDelay := false for attempt := 0; attempt <= config.MaxRetries; attempt++ { + if err := ctx.Err(); err != nil { + return err + } + if attempt > 0 { if skipDefaultDelay { logger.Debug("Retry attempt %d/%d resuming after rate-limit cooldown (no extra delay)", attempt, config.MaxRetries) skipDefaultDelay = false } else { logger.Debug("Retry attempt %d/%d after %ds delay...", attempt, config.MaxRetries, config.RetryDelay) - time.Sleep(time.Duration(config.RetryDelay) * time.Second) + if err := sleepWithContext(ctx, time.Duration(config.RetryDelay)*time.Second); err != nil { + return err + } } } @@ -114,6 +120,9 @@ func sendViaCloudRelay( // Send request resp, err := client.Do(req) if err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } lastErr = fmt.Errorf("request failed: %w", err) logger.Warning("Cloud relay request failed (attempt %d/%d): %v", attempt+1, config.MaxRetries+1, err) continue @@ -123,6 +132,9 @@ func sendViaCloudRelay( body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } lastErr = fmt.Errorf("failed to read response: %w", err) continue } @@ -174,7 +186,9 @@ func sendViaCloudRelay( // Otherwise, retry with longer delay logger.Debug("Waiting 5 seconds before retry due to rate limiting...") - time.Sleep(5 * time.Second) + if err := sleepWithContext(ctx, 5*time.Second); err != nil { + return err + } skipDefaultDelay = true lastErr = fmt.Errorf("rate limit exceeded") continue diff --git a/internal/notify/email_relay_test.go b/internal/notify/email_relay_test.go index d2bf68e7..97468d43 100644 --- a/internal/notify/email_relay_test.go +++ b/internal/notify/email_relay_test.go @@ -5,8 +5,10 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/hex" + "errors" "net/http" "net/http/httptest" + "sync/atomic" "testing" "time" @@ -175,10 +177,9 @@ func TestSendViaCloudRelay_StatusHandling(t *testing.T) { } func TestSendViaCloudRelay_RetryOnServerError(t *testing.T) { - attempts := 0 + var attempts int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attempts++ - if attempts < 3 { + if atomic.AddInt32(&attempts, 1) < 3 { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(`{"error":"temporary"}`)) return @@ -210,7 +211,47 @@ func TestSendViaCloudRelay_RetryOnServerError(t *testing.T) { if err != nil { t.Fatalf("expected success after retries, got error: %v", err) } - if attempts != 3 { - t.Fatalf("expected 3 attempts, got %d", attempts) + if got := atomic.LoadInt32(&attempts); got != 3 { + t.Fatalf("expected 3 attempts, got %d", got) + } +} + +func TestSendViaCloudRelay_StopsRetryingWhenContextCanceled(t *testing.T) { + var attempts int32 + ctx, cancel := context.WithCancel(context.Background()) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if atomic.AddInt32(&attempts, 1) == 1 { + cancel() + } + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"temporary"}`)) + })) + defer server.Close() + + cfg := CloudRelayConfig{ + WorkerURL: server.URL, + WorkerToken: "token", + HMACSecret: "secret", + Timeout: 5, + MaxRetries: 3, + RetryDelay: 0, + } + + logger := logging.New(types.LogLevelDebug, false) + err := sendViaCloudRelay(ctx, cfg, EmailRelayPayload{ + To: "dest@test.invalid", + Subject: "subject", + Report: map[string]interface{}{"ok": true}, + Timestamp: time.Now().Unix(), + ServerMAC: "00:11:22:33:44:55", + ScriptVersion: "0.0.1", + ServerID: "server-id", + }, logger) + + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation error, got %v", err) + } + if got := atomic.LoadInt32(&attempts); got != 1 { + t.Fatalf("expected 1 attempt after cancellation, got %d", got) } } diff --git a/internal/notify/webhook.go b/internal/notify/webhook.go index d317a0f5..87e0f1c1 100644 --- a/internal/notify/webhook.go +++ b/internal/notify/webhook.go @@ -218,9 +218,15 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { + if err := ctx.Err(); err != nil { + return err + } + if attempt > 0 { w.logger.Debug("Retry attempt %d/%d after %ds delay...", attempt, maxRetries, retryDelay) - time.Sleep(time.Duration(retryDelay) * time.Second) + if err := sleepWithContext(ctx, time.Duration(retryDelay)*time.Second); err != nil { + return err + } } // Determine HTTP method @@ -315,6 +321,9 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We requestDuration := time.Since(requestStart) if err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } lastErr = fmt.Errorf("request failed: %w", err) w.logger.Warning("⚠️ Request failed after %dms (attempt %d/%d): %v", requestDuration.Milliseconds(), attempt+1, maxRetries+1, err) @@ -327,6 +336,9 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We resp.Body.Close() if err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } lastErr = fmt.Errorf("failed to read response: %w", err) w.logger.Warning("Failed to read response body: %v", err) continue @@ -376,7 +388,9 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We w.logger.Warning("⚠️ Rate limited (HTTP 429): %s", string(body)) if attempt < maxRetries { w.logger.Debug("Waiting 10 seconds before retry due to rate limiting...") - time.Sleep(10 * time.Second) + if err := sleepWithContext(ctx, 10*time.Second); err != nil { + return err + } } lastErr = fmt.Errorf("rate limit exceeded (HTTP 429)") continue diff --git a/internal/notify/webhook_test.go b/internal/notify/webhook_test.go index 78926cba..841c828f 100644 --- a/internal/notify/webhook_test.go +++ b/internal/notify/webhook_test.go @@ -277,6 +277,52 @@ func TestWebhookNotifier_Send_Success(t *testing.T) { } } +func TestWebhookNotifier_SendToEndpoint_StopsRetryingWhenContextCanceled(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + attempts := 0 + ctx, cancel := context.WithCancel(context.Background()) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + cancel() + } + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"temporary"}`)) + })) + defer server.Close() + + cfg := config.WebhookConfig{ + Enabled: true, + DefaultFormat: "generic", + Timeout: 30, + MaxRetries: 3, + RetryDelay: 0, + Endpoints: []config.WebhookEndpoint{ + { + Name: "test-webhook", + URL: server.URL, + Format: "generic", + Method: "POST", + Auth: config.WebhookAuth{Type: "none"}, + }, + }, + } + + notifier, err := NewWebhookNotifier(&cfg, logger) + if err != nil { + t.Fatalf("Failed to create notifier: %v", err) + } + + err = notifier.sendToEndpoint(ctx, cfg.Endpoints[0], createTestNotificationData()) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation error, got %v", err) + } + if attempts != 1 { + t.Fatalf("expected 1 attempt after cancellation, got %d", attempts) + } +} + func TestWebhookNotifier_Send_Retry(t *testing.T) { logger := logging.New(types.LogLevelDebug, false) attempts := 0 diff --git a/internal/orchestrator/additional_helpers_test.go b/internal/orchestrator/additional_helpers_test.go index 8b20d737..c1b6caf6 100644 --- a/internal/orchestrator/additional_helpers_test.go +++ b/internal/orchestrator/additional_helpers_test.go @@ -175,7 +175,7 @@ func (s *stubStorage) VerifyUpload(ctx context.Context, localFile, remoteFile st return true, nil } func (s *stubStorage) GetStats(ctx context.Context) (*storage.StorageStats, error) { - return &storage.StorageStats{TotalBackups: len(s.list), AvailableSpace: 1024, TotalSpace: 2048}, nil + return &storage.StorageStats{TotalBackups: len(s.list), AvailableSpace: 1024, UsedSpace: 768, TotalSpace: 2048}, nil } func TestApplyStorageStatsSimplePrimary(t *testing.T) { @@ -184,12 +184,12 @@ func TestApplyStorageStatsSimplePrimary(t *testing.T) { logger: logging.New(types.LogLevelError, false), } stats := &BackupStats{} - storageStats := &storage.StorageStats{TotalBackups: 3, AvailableSpace: 100, TotalSpace: 200} + storageStats := &storage.StorageStats{TotalBackups: 3, AvailableSpace: 100, UsedSpace: 80, TotalSpace: 200} retentionCfg := storage.RetentionConfig{Policy: "simple", MaxBackups: 5} adapter.applyStorageStats(storageStats, retentionCfg, stats) - if stats.LocalBackups != 3 || stats.LocalFreeSpace != 100 || stats.LocalTotalSpace != 200 { + if stats.LocalBackups != 3 || stats.LocalFreeSpace != 100 || stats.LocalUsedSpace != 80 || stats.LocalTotalSpace != 200 { t.Fatalf("local stats not set correctly: %+v", stats) } if stats.LocalRetentionPolicy != "simple" { @@ -210,7 +210,7 @@ func TestApplyStorageStatsGFSPrimary(t *testing.T) { logger: logging.New(types.LogLevelError, false), } stats := &BackupStats{} - storageStats := &storage.StorageStats{TotalBackups: len(backups), AvailableSpace: 500, TotalSpace: 1000} + storageStats := &storage.StorageStats{TotalBackups: len(backups), AvailableSpace: 500, UsedSpace: 400, TotalSpace: 1000} retentionCfg := storage.RetentionConfig{Policy: "gfs", Daily: 1, Weekly: 1, Monthly: 1, Yearly: 1} adapter.applyStorageStats(storageStats, retentionCfg, stats) @@ -226,6 +226,28 @@ func TestApplyStorageStatsGFSPrimary(t *testing.T) { } } +func TestApplyStorageStatsSimpleSecondaryUsesUsedSpace(t *testing.T) { + adapter := &StorageAdapter{ + backend: &stubStorage{loc: storage.LocationSecondary}, + logger: logging.New(types.LogLevelError, false), + } + stats := &BackupStats{} + storageStats := &storage.StorageStats{TotalBackups: 4, AvailableSpace: 300, UsedSpace: 700, TotalSpace: 1000} + retentionCfg := storage.RetentionConfig{Policy: "simple", MaxBackups: 7} + + adapter.applyStorageStats(storageStats, retentionCfg, stats) + + if !stats.SecondaryEnabled { + t.Fatalf("SecondaryEnabled = false, want true") + } + if stats.SecondaryBackups != 4 || stats.SecondaryFreeSpace != 300 || stats.SecondaryUsedSpace != 700 || stats.SecondaryTotalSpace != 1000 { + t.Fatalf("secondary stats not set correctly: %+v", stats) + } + if stats.SecondaryRetentionPolicy != "simple" { + t.Fatalf("SecondaryRetentionPolicy = %q, want simple", stats.SecondaryRetentionPolicy) + } +} + func TestSetAndFinalizeStorageStatus(t *testing.T) { stats := &BackupStats{} adapter := &StorageAdapter{ diff --git a/internal/orchestrator/age_setup_ui.go b/internal/orchestrator/age_setup_ui.go new file mode 100644 index 00000000..172d2f0c --- /dev/null +++ b/internal/orchestrator/age_setup_ui.go @@ -0,0 +1,24 @@ +package orchestrator + +import "context" + +type AgeRecipientInputKind int + +const ( + AgeRecipientInputExisting AgeRecipientInputKind = iota + AgeRecipientInputPassphrase + AgeRecipientInputPrivateKey +) + +type AgeRecipientDraft struct { + Kind AgeRecipientInputKind + PublicKey string + Passphrase string + PrivateKey string +} + +type AgeSetupUI interface { + ConfirmOverwriteExistingRecipient(ctx context.Context, recipientPath string) (bool, error) + CollectRecipientDraft(ctx context.Context, recipientPath string) (*AgeRecipientDraft, error) + ConfirmAddAnotherRecipient(ctx context.Context, currentCount int) (bool, error) +} diff --git a/internal/orchestrator/age_setup_ui_cli.go b/internal/orchestrator/age_setup_ui_cli.go new file mode 100644 index 00000000..b6a9b4a0 --- /dev/null +++ b/internal/orchestrator/age_setup_ui_cli.go @@ -0,0 +1,86 @@ +package orchestrator + +import ( + "bufio" + "context" + "fmt" + "os" + + "github.com/tis24dev/proxsave/internal/logging" +) + +type cliAgeSetupUI struct { + reader *bufio.Reader + logger *logging.Logger +} + +func newCLIAgeSetupUI(reader *bufio.Reader, logger *logging.Logger) AgeSetupUI { + if reader == nil { + reader = bufio.NewReader(os.Stdin) + } + return &cliAgeSetupUI{ + reader: reader, + logger: logger, + } +} + +func (u *cliAgeSetupUI) ConfirmOverwriteExistingRecipient(ctx context.Context, recipientPath string) (bool, error) { + fmt.Printf("WARNING: this will remove the existing AGE recipients stored at %s. Existing backups remain decryptable with your old private key.\n", recipientPath) + return promptYesNoAge(ctx, u.reader, fmt.Sprintf("Delete %s and enter a new recipient? [y/N]: ", recipientPath)) +} + +func (u *cliAgeSetupUI) CollectRecipientDraft(ctx context.Context, recipientPath string) (*AgeRecipientDraft, error) { + for { + fmt.Println("\n[1] Use an existing AGE public key") + fmt.Println("[2] Generate an AGE public key using a personal passphrase/password - not stored on the server") + fmt.Println("[3] Generate an AGE public key from an existing personal private key - not stored on the server") + fmt.Println("[4] Exit setup") + + option, err := promptOptionAge(ctx, u.reader, "Select an option [1-4]: ") + if err != nil { + return nil, err + } + if option == "4" { + return nil, ErrAgeRecipientSetupAborted + } + + switch option { + case "1": + value, err := promptPublicRecipientAge(ctx, u.reader) + if err != nil { + u.warn(err) + continue + } + return &AgeRecipientDraft{Kind: AgeRecipientInputExisting, PublicKey: value}, nil + case "2": + passphrase, err := promptAndConfirmPassphraseAge(ctx) + if err != nil { + u.warn(err) + continue + } + return &AgeRecipientDraft{Kind: AgeRecipientInputPassphrase, Passphrase: passphrase}, nil + case "3": + privateKey, err := promptPrivateKeyValueAge(ctx) + if err != nil { + u.warn(err) + continue + } + return &AgeRecipientDraft{Kind: AgeRecipientInputPrivateKey, PrivateKey: privateKey}, nil + } + } +} + +func (u *cliAgeSetupUI) ConfirmAddAnotherRecipient(ctx context.Context, currentCount int) (bool, error) { + return promptYesNoAge(ctx, u.reader, "Add another recipient? [y/N]: ") +} + +func (u *cliAgeSetupUI) warn(err error) { + if err == nil { + return + } + if u.logger != nil { + u.logger.Warning("Encryption setup: %v", err) + return + } + fmt.Printf("WARNING: %v\n", err) +} diff --git a/internal/orchestrator/age_setup_workflow.go b/internal/orchestrator/age_setup_workflow.go new file mode 100644 index 00000000..339fcf67 --- /dev/null +++ b/internal/orchestrator/age_setup_workflow.go @@ -0,0 +1,254 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + + "filippo.io/age" +) + +type AgeRecipientSetupResult struct { + RecipientPath string + WroteRecipientFile bool + ReusedExistingRecipients bool +} + +func (o *Orchestrator) EnsureAgeRecipientsReadyWithUI(ctx context.Context, ui AgeSetupUI) error { + if o == nil || o.cfg == nil || !o.cfg.EncryptArchive { + return nil + } + _, _, err := o.prepareAgeRecipientsWithUI(ctx, ui) + return err +} + +func (o *Orchestrator) EnsureAgeRecipientsReadyWithUIDetails(ctx context.Context, ui AgeSetupUI) (*AgeRecipientSetupResult, error) { + if o == nil || o.cfg == nil || !o.cfg.EncryptArchive { + return nil, nil + } + _, result, err := o.prepareAgeRecipientsWithUI(ctx, ui) + return result, err +} + +func (o *Orchestrator) EnsureAgeRecipientsReadyWithDetails(ctx context.Context) (*AgeRecipientSetupResult, error) { + return o.EnsureAgeRecipientsReadyWithUIDetails(ctx, nil) +} + +func (o *Orchestrator) prepareAgeRecipientsWithUI(ctx context.Context, ui AgeSetupUI) ([]age.Recipient, *AgeRecipientSetupResult, error) { + if o.cfg == nil || !o.cfg.EncryptArchive { + return nil, nil, nil + } + + if o.ageRecipientCache != nil && !o.forceNewAgeRecipient { + return cloneRecipients(o.ageRecipientCache), &AgeRecipientSetupResult{ReusedExistingRecipients: true}, nil + } + + recipients, candidatePath, err := o.collectRecipientStrings() + if err != nil { + return nil, nil, err + } + + result := &AgeRecipientSetupResult{} + if len(recipients) > 0 && !o.forceNewAgeRecipient { + result.ReusedExistingRecipients = true + } + + if len(recipients) == 0 { + if ui == nil { + if !o.isInteractiveShell() { + if o.logger != nil { + o.logger.Error("Encryption setup requires interaction. Run the script interactively to complete the AGE recipient setup, then re-run in automated mode.") + o.logger.Debug("HINT Set AGE_RECIPIENT or AGE_RECIPIENT_FILE to bypass the interactive setup and re-run.") + } + return nil, nil, fmt.Errorf("age recipients not configured") + } + ui = newCLIAgeSetupUI(nil, o.logger) + } + + wizardRecipients, setupResult, err := o.runAgeSetupWorkflow(ctx, candidatePath, ui) + if err != nil { + return nil, nil, err + } + recipients = append(recipients, wizardRecipients...) + result = setupResult + if o.cfg.AgeRecipientFile == "" { + o.cfg.AgeRecipientFile = setupResult.RecipientPath + } + } + + if len(recipients) == 0 { + return nil, nil, fmt.Errorf("no AGE recipients configured after setup") + } + + parsed, err := parseRecipientStrings(recipients) + if err != nil { + return nil, nil, err + } + o.ageRecipientCache = cloneRecipients(parsed) + o.forceNewAgeRecipient = false + return cloneRecipients(parsed), result, nil +} + +func (o *Orchestrator) runAgeSetupWorkflow(ctx context.Context, candidatePath string, ui AgeSetupUI) ([]string, *AgeRecipientSetupResult, error) { + targetPath := strings.TrimSpace(candidatePath) + fs := o.filesystem() + if targetPath == "" { + targetPath = o.defaultAgeRecipientFile() + } + if targetPath == "" { + return nil, nil, fmt.Errorf("unable to determine default path for AGE recipients") + } + + if o.logger != nil { + o.logger.Info("Encryption setup: no AGE recipients found, starting interactive wizard") + o.logger.Debug("Encryption setup: target recipient file resolved to %s (force new recipient=%t)", targetPath, o.forceNewAgeRecipient) + } + + confirmedOverwriteExisting := false + if o.forceNewAgeRecipient { + if _, err := fs.Stat(targetPath); err == nil { + confirmedOverwriteExisting = true + if o.logger != nil { + o.logger.Debug("Encryption setup: existing AGE recipient file found at %s; requesting overwrite confirmation", targetPath) + } + confirm, err := ui.ConfirmOverwriteExistingRecipient(ctx, targetPath) + if err != nil { + return nil, nil, mapAgeSetupAbort(err) + } + if !confirm { + if o.logger != nil { + o.logger.Info("Encryption setup: overwrite declined for %s; leaving existing AGE recipient file unchanged", targetPath) + } + return nil, nil, ErrAgeRecipientSetupAborted + } + if o.logger != nil { + o.logger.Debug("Encryption setup: overwrite confirmed for %s; backup will be created before replacing the file", targetPath) + } + } else if !errors.Is(err, os.ErrNotExist) { + return nil, nil, fmt.Errorf("failed to inspect existing AGE recipients at %s: %w", targetPath, err) + } else if o.logger != nil { + o.logger.Debug("Encryption setup: no existing AGE recipient file found at %s; a new file will be created", targetPath) + } + } + + recipients := make([]string, 0) + for { + draft, err := ui.CollectRecipientDraft(ctx, targetPath) + if err != nil { + return nil, nil, mapAgeSetupAbort(err) + } + if draft == nil { + return nil, nil, ErrAgeRecipientSetupAborted + } + + value, err := resolveAgeRecipientDraft(draft) + if err != nil { + if o.logger != nil { + o.logger.Warning("Encryption setup: %v", err) + } + continue + } + recipients = append(recipients, value) + + more, err := ui.ConfirmAddAnotherRecipient(ctx, len(recipients)) + if err != nil { + return nil, nil, mapAgeSetupAbort(err) + } + if !more { + break + } + } + + recipients = dedupeRecipientStrings(recipients) + if len(recipients) == 0 { + return nil, nil, fmt.Errorf("no recipients provided") + } + if o.logger != nil { + o.logger.Debug("Encryption setup: collected %d unique AGE recipient(s) for %s", len(recipients), targetPath) + } + + backupPath := "" + if confirmedOverwriteExisting { + if o.logger != nil { + o.logger.Debug("Encryption setup: creating backup of existing AGE recipient file at %s before overwrite", targetPath) + } + var err error + backupPath, err = backupExistingRecipientFileWithDeps(fs, o.clock, targetPath) + if err != nil { + if o.logger != nil { + o.logger.Warning("Encryption setup: failed to back up existing AGE recipients at %s: %v", targetPath, err) + } + return nil, nil, fmt.Errorf("backup existing AGE recipients at %s: %w", targetPath, err) + } + if o.logger != nil { + o.logger.Info("Encryption setup: existing AGE recipients backed up to %s", backupPath) + } + } + + if o.logger != nil { + o.logger.Debug("Encryption setup: writing %d AGE recipient(s) to %s (overwrite existing=%t)", len(recipients), targetPath, confirmedOverwriteExisting) + } + if err := writeRecipientFileWithDeps(fs, o.clock, targetPath, recipients); err != nil { + return nil, nil, err + } + + if o.logger != nil { + o.logger.Info("Saved %d AGE recipient(s) to %s", len(recipients), targetPath) + if backupPath != "" { + o.logger.Debug("Encryption setup: previous AGE recipient file for %s was preserved at %s", targetPath, backupPath) + } + o.logger.Info("Reminder: keep the AGE private key offline; the server stores only recipients.") + } + return recipients, &AgeRecipientSetupResult{ + RecipientPath: targetPath, + WroteRecipientFile: true, + }, nil +} + +func resolveAgeRecipientDraft(draft *AgeRecipientDraft) (string, error) { + if draft == nil { + return "", fmt.Errorf("recipient draft is required") + } + + switch draft.Kind { + case AgeRecipientInputExisting: + value := strings.TrimSpace(draft.PublicKey) + if err := ValidateRecipientString(value); err != nil { + return "", err + } + return value, nil + case AgeRecipientInputPassphrase: + passphrase := strings.TrimSpace(draft.Passphrase) + defer resetString(&passphrase) + if passphrase == "" { + return "", fmt.Errorf("passphrase cannot be empty") + } + if err := validatePassphraseStrength([]byte(passphrase)); err != nil { + return "", err + } + recipient, err := deriveDeterministicRecipientFromPassphrase(passphrase) + if err != nil { + return "", err + } + return recipient, nil + case AgeRecipientInputPrivateKey: + privateKey := strings.TrimSpace(draft.PrivateKey) + defer resetString(&privateKey) + return ParseAgePrivateKeyRecipient(privateKey) + default: + return "", fmt.Errorf("unsupported AGE setup input kind: %d", draft.Kind) + } +} + +func mapAgeSetupAbort(err error) error { + if err == nil { + return nil + } + mapped := mapInputAbortToAgeAbort(err) + if errors.Is(mapped, ErrAgeRecipientSetupAborted) { + return ErrAgeRecipientSetupAborted + } + return mapped +} diff --git a/internal/orchestrator/age_setup_workflow_test.go b/internal/orchestrator/age_setup_workflow_test.go new file mode 100644 index 00000000..377cb7ac --- /dev/null +++ b/internal/orchestrator/age_setup_workflow_test.go @@ -0,0 +1,293 @@ +package orchestrator + +import ( + "context" + "errors" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "filippo.io/age" + + "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/testutil" +) + +type mockAgeSetupUI = testutil.AgeSetupUIStub[AgeRecipientDraft] + +type renameFailFS struct { + *FakeFS + err error +} + +func (f *renameFailFS) Rename(oldpath, newpath string) error { + return f.err +} + +func TestEnsureAgeRecipientsReadyWithUI_ReusesConfiguredRecipientsWithoutPrompting(t *testing.T) { + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + + ui := &mockAgeSetupUI{AbortErr: ErrAgeRecipientSetupAborted} + orch := newEncryptionTestOrchestrator(&config.Config{ + EncryptArchive: true, + BaseDir: t.TempDir(), + AgeRecipients: []string{id.Recipient().String()}, + }) + + if err := orch.EnsureAgeRecipientsReadyWithUI(context.Background(), ui); err != nil { + t.Fatalf("EnsureAgeRecipientsReadyWithUI error: %v", err) + } + if ui.CollectCalls != 0 || ui.OverwriteCalls != 0 || ui.AddCalls != 0 { + t.Fatalf("UI should not have been used when recipients already exist: %#v", ui) + } +} + +func TestEnsureAgeRecipientsReadyWithUI_ConfiguresRecipientsWithoutTTY(t *testing.T) { + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + + tmp := t.TempDir() + ui := &mockAgeSetupUI{ + AbortErr: ErrAgeRecipientSetupAborted, + Drafts: []*AgeRecipientDraft{ + {Kind: AgeRecipientInputExisting, PublicKey: id.Recipient().String()}, + }, + AddMore: []bool{false}, + } + cfg := &config.Config{EncryptArchive: true, BaseDir: tmp} + orch := newEncryptionTestOrchestrator(cfg) + + if err := orch.EnsureAgeRecipientsReadyWithUI(context.Background(), ui); err != nil { + t.Fatalf("EnsureAgeRecipientsReadyWithUI error: %v", err) + } + + target := filepath.Join(tmp, "identity", "age", "recipient.txt") + content, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile(%s): %v", target, err) + } + if got := string(content); got != id.Recipient().String()+"\n" { + t.Fatalf("content=%q; want %q", got, id.Recipient().String()+"\n") + } + if cfg.AgeRecipientFile != target { + t.Fatalf("AgeRecipientFile=%q; want %q", cfg.AgeRecipientFile, target) + } +} + +func TestMapAgeSetupAbort_NormalizesAbortSignals(t *testing.T) { + if !errors.Is(mapAgeSetupAbort(context.Canceled), ErrAgeRecipientSetupAborted) { + t.Fatalf("expected context.Canceled to normalize to %v", ErrAgeRecipientSetupAborted) + } + if !errors.Is(mapAgeSetupAbort(ErrAgeRecipientSetupAborted), ErrAgeRecipientSetupAborted) { + t.Fatalf("expected ErrAgeRecipientSetupAborted to remain normalized") + } + + sentinel := errors.New("boom") + if got := mapAgeSetupAbort(sentinel); got != sentinel { + t.Fatalf("expected non-abort error passthrough, got %v", got) + } +} + +func TestEnsureAgeRecipientsReadyWithUI_ForceNewRecipientDeclineReturnsAbort(t *testing.T) { + tmp := t.TempDir() + target := filepath.Join(tmp, "identity", "age", "recipient.txt") + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(target, []byte("old\n"), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + ui := &mockAgeSetupUI{ + AbortErr: ErrAgeRecipientSetupAborted, + Overwrite: false, + } + orch := newEncryptionTestOrchestrator(&config.Config{ + EncryptArchive: true, + BaseDir: tmp, + AgeRecipientFile: target, + }) + orch.SetForceNewAgeRecipient(true) + + err := orch.EnsureAgeRecipientsReadyWithUI(context.Background(), ui) + if !errors.Is(err, ErrAgeRecipientSetupAborted) { + t.Fatalf("err=%v; want %v", err, ErrAgeRecipientSetupAborted) + } + if ui.OverwriteCalls != 1 { + t.Fatalf("overwriteCalls=%d; want 1", ui.OverwriteCalls) + } + if ui.CollectCalls != 0 { + t.Fatalf("collectCalls=%d; want 0", ui.CollectCalls) + } + if _, statErr := os.Stat(target); statErr != nil { + t.Fatalf("recipient file should remain in place, stat err=%v", statErr) + } +} + +func TestEnsureAgeRecipientsReadyWithUI_ForceNewRecipientSuccessfulOverwriteCreatesBackupOnCommit(t *testing.T) { + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + + tmp := t.TempDir() + target := filepath.Join(tmp, "identity", "age", "recipient.txt") + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(target, []byte("old\n"), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + ui := &mockAgeSetupUI{ + AbortErr: ErrAgeRecipientSetupAborted, + Overwrite: true, + Drafts: []*AgeRecipientDraft{ + {Kind: AgeRecipientInputExisting, PublicKey: id.Recipient().String()}, + }, + AddMore: []bool{false}, + } + cfg := &config.Config{ + EncryptArchive: true, + BaseDir: tmp, + AgeRecipientFile: target, + } + fakeTime := &FakeTime{Current: time.Date(2026, 3, 17, 10, 11, 12, 0, time.UTC)} + orch := newEncryptionTestOrchestrator(cfg) + orch.SetForceNewAgeRecipient(true) + orch.clock = fakeTime + + if err := orch.EnsureAgeRecipientsReadyWithUI(context.Background(), ui); err != nil { + t.Fatalf("EnsureAgeRecipientsReadyWithUI error: %v", err) + } + + backupPath := target + ".bak-" + fakeTime.Current.Format("20060102-150405.000000000") + backup, err := os.ReadFile(backupPath) + if err != nil { + t.Fatalf("ReadFile(%s): %v", backupPath, err) + } + if got := strings.TrimSpace(string(backup)); got != "old" { + t.Fatalf("backup content=%q; want %q", got, "old") + } + + content, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile(%s): %v", target, err) + } + if got := strings.TrimSpace(string(content)); got != id.Recipient().String() { + t.Fatalf("content=%q; want %q", got, id.Recipient().String()) + } + if ui.OverwriteCalls != 1 { + t.Fatalf("overwriteCalls=%d; want 1", ui.OverwriteCalls) + } +} + +func TestRunAgeSetupWorkflow_ForceNewRecipientBackupFailurePreservesOriginal(t *testing.T) { + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + + fs := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fs.Root) }) + fakeTime := &FakeTime{Current: time.Date(2026, 3, 17, 12, 0, 0, 0, time.UTC)} + target := "/identity/age/recipient.txt" + if err := fs.AddFile(target, []byte("old\n")); err != nil { + t.Fatalf("AddFile: %v", err) + } + backupPath := target + ".bak-" + fakeTime.Current.Format("20060102-150405.000000000") + fs.OpenFileErr[filepath.Clean(backupPath)] = errors.New("disk full") + + ui := &mockAgeSetupUI{ + AbortErr: ErrAgeRecipientSetupAborted, + Overwrite: true, + Drafts: []*AgeRecipientDraft{ + {Kind: AgeRecipientInputExisting, PublicKey: id.Recipient().String()}, + }, + AddMore: []bool{false}, + } + orch := newEncryptionTestOrchestrator(&config.Config{EncryptArchive: true, AgeRecipientFile: target}) + orch.SetForceNewAgeRecipient(true) + orch.fs = fs + orch.clock = fakeTime + + _, _, err = orch.runAgeSetupWorkflow(context.Background(), target, ui) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "backup existing AGE recipients at "+target) { + t.Fatalf("err=%v; want backup failure context", err) + } + + content, readErr := fs.ReadFile(target) + if readErr != nil { + t.Fatalf("ReadFile(%s): %v", target, readErr) + } + if got := strings.TrimSpace(string(content)); got != "old" { + t.Fatalf("original content=%q; want %q", got, "old") + } + if _, statErr := fs.Stat(backupPath); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("backup stat err=%v; want not exist", statErr) + } +} + +func TestRunAgeSetupWorkflow_ForceNewRecipientWriteFailurePreservesOriginalAndBackup(t *testing.T) { + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + + baseFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(baseFS.Root) }) + fs := &renameFailFS{FakeFS: baseFS, err: errors.New("rename failed")} + fakeTime := &FakeTime{Current: time.Date(2026, 3, 17, 12, 30, 0, 0, time.UTC)} + target := "/identity/age/recipient.txt" + if err := fs.AddFile(target, []byte("old\n")); err != nil { + t.Fatalf("AddFile: %v", err) + } + + ui := &mockAgeSetupUI{ + AbortErr: ErrAgeRecipientSetupAborted, + Overwrite: true, + Drafts: []*AgeRecipientDraft{ + {Kind: AgeRecipientInputExisting, PublicKey: id.Recipient().String()}, + }, + AddMore: []bool{false}, + } + orch := newEncryptionTestOrchestrator(&config.Config{EncryptArchive: true, AgeRecipientFile: target}) + orch.SetForceNewAgeRecipient(true) + orch.fs = fs + orch.clock = fakeTime + + _, _, err = orch.runAgeSetupWorkflow(context.Background(), target, ui) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "write recipient file") { + t.Fatalf("err=%v; want write recipient file failure", err) + } + + backupPath := target + ".bak-" + fakeTime.Current.Format("20060102-150405.000000000") + backup, readErr := fs.ReadFile(backupPath) + if readErr != nil { + t.Fatalf("ReadFile(%s): %v", backupPath, readErr) + } + if got := strings.TrimSpace(string(backup)); got != "old" { + t.Fatalf("backup content=%q; want %q", got, "old") + } + + content, readErr := fs.ReadFile(target) + if readErr != nil { + t.Fatalf("ReadFile(%s): %v", target, readErr) + } + if got := strings.TrimSpace(string(content)); got != "old" { + t.Fatalf("original content=%q; want %q", got, "old") + } +} diff --git a/internal/orchestrator/backup_sources.go b/internal/orchestrator/backup_sources.go index c5def503..eb858720 100644 --- a/internal/orchestrator/backup_sources.go +++ b/internal/orchestrator/backup_sources.go @@ -48,16 +48,18 @@ func buildDecryptPathOptions(cfg *config.Config, logger *logging.Logger) (option logging.DebugStep(logger, "build backup source options", "skip local (empty)") } - if clean := strings.TrimSpace(cfg.SecondaryPath); clean != "" { - logging.DebugStep(logger, "build backup source options", "add secondary path=%q", clean) - options = append(options, decryptPathOption{ - Label: "Secondary backups", - Path: clean, - }) - } else if cfg.SecondaryEnabled { - logging.DebugStep(logger, "build backup source options", "skip secondary (enabled but path empty)") + if cfg.SecondaryEnabled { + if clean := strings.TrimSpace(cfg.SecondaryPath); clean != "" { + logging.DebugStep(logger, "build backup source options", "add secondary path=%q", clean) + options = append(options, decryptPathOption{ + Label: "Secondary backups", + Path: clean, + }) + } else { + logging.DebugStep(logger, "build backup source options", "skip secondary (enabled but path empty)") + } } else { - logging.DebugStep(logger, "build backup source options", "skip secondary (path empty)") + logging.DebugStep(logger, "build backup source options", "skip secondary (disabled)") } if strings.TrimSpace(cfg.CloudRemote) != "" || strings.TrimSpace(cfg.CloudRemotePath) != "" { diff --git a/internal/orchestrator/backup_sources_test.go b/internal/orchestrator/backup_sources_test.go index ad27869f..d51148bb 100644 --- a/internal/orchestrator/backup_sources_test.go +++ b/internal/orchestrator/backup_sources_test.go @@ -16,6 +16,11 @@ import ( "github.com/tis24dev/proxsave/internal/types" ) +func prependPathEnv(t *testing.T, dir string) { + t.Helper() + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) +} + func TestIsRcloneRemote(t *testing.T) { tests := []struct { name string @@ -204,8 +209,7 @@ func TestDiscoverRcloneBackups_ListsAndParsesBundles(t *testing.T) { ctx := context.Background() logger := logging.New(types.LogLevelDebug, false) - manifest, cleanup := setupFakeRcloneListAndCat(t) - defer cleanup() + manifest := setupFakeRcloneListAndCat(t) candidates, err := discoverRcloneBackups(ctx, nil, "gdrive:pbs-backups/server1", logger, nil) if err != nil { @@ -267,16 +271,8 @@ esac t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) - - if err := os.Setenv("METADATA_PATH", metadataPath); err != nil { - t.Fatalf("set METADATA_PATH: %v", err) - } - defer os.Unsetenv("METADATA_PATH") + prependPathEnv(t, tmpDir) + t.Setenv("METADATA_PATH", metadataPath) ctx := context.Background() candidates, err := discoverRcloneBackups(ctx, nil, "gdrive:pbs-backups/server1", nil, nil) @@ -394,17 +390,10 @@ esac t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) - _ = os.Setenv("BUNDLE_PATH", bundlePath) - _ = os.Setenv("RAW_NEWEST_META", rawNewestMeta) - _ = os.Setenv("RAW_OLD_META", rawOldMeta) - defer os.Unsetenv("BUNDLE_PATH") - defer os.Unsetenv("RAW_NEWEST_META") - defer os.Unsetenv("RAW_OLD_META") + prependPathEnv(t, tmpDir) + t.Setenv("BUNDLE_PATH", bundlePath) + t.Setenv("RAW_NEWEST_META", rawNewestMeta) + t.Setenv("RAW_OLD_META", rawOldMeta) // Ensure archives appear in lsf snapshot; their content is not fetched. _ = os.WriteFile(rawNewestArchive, []byte("x"), 0o600) @@ -433,8 +422,7 @@ esac func TestDiscoverRcloneBackups_AllowsNilLogger(t *testing.T) { ctx := context.Background() - manifest, cleanup := setupFakeRcloneListAndCat(t) - defer cleanup() + manifest := setupFakeRcloneListAndCat(t) candidates, err := discoverRcloneBackups(ctx, nil, "gdrive:pbs-backups/server1", nil, nil) if err != nil { @@ -519,16 +507,8 @@ func TestInspectRcloneBundleManifest_UsesRcloneCat(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) - - if err := os.Setenv("BUNDLE_PATH", bundlePath); err != nil { - t.Fatalf("set BUNDLE_PATH: %v", err) - } - defer os.Unsetenv("BUNDLE_PATH") + prependPathEnv(t, tmpDir) + t.Setenv("BUNDLE_PATH", bundlePath) ctx := context.Background() logger := logging.New(types.LogLevelInfo, false) @@ -553,9 +533,8 @@ func TestInspectRcloneBundleManifest_UsesRcloneCat(t *testing.T) { // setupFakeRcloneListAndCat creates a temporary bundle and installs a fake // rclone binary that supports `lsf` and `cat`, emulating cloud discovery. -// It returns the manifest embedded in the bundle and a cleanup function that -// restores PATH and auxiliary env vars. -func setupFakeRcloneListAndCat(t *testing.T) (backup.Manifest, func()) { +// It returns the manifest embedded in the bundle. +func setupFakeRcloneListAndCat(t *testing.T) backup.Manifest { t.Helper() tmpDir := t.TempDir() @@ -616,20 +595,10 @@ esac t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - if err := os.Setenv("BUNDLE_PATH", bundlePath); err != nil { - t.Fatalf("set BUNDLE_PATH: %v", err) - } + prependPathEnv(t, tmpDir) + t.Setenv("BUNDLE_PATH", bundlePath) - cleanup := func() { - _ = os.Setenv("PATH", oldPath) - _ = os.Unsetenv("BUNDLE_PATH") - } - - return manifest, cleanup + return manifest } func TestDiscoverBackupCandidates_NoLoggerSkipsRawArtifactsWithoutChecksumVerification(t *testing.T) { @@ -872,19 +841,9 @@ esac t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) - if err := os.Setenv("METADATA_PATH", metadataPath); err != nil { - t.Fatalf("set METADATA_PATH: %v", err) - } - defer os.Unsetenv("METADATA_PATH") - if err := os.Setenv("CHECKSUM_PATH", checksumPath); err != nil { - t.Fatalf("set CHECKSUM_PATH: %v", err) - } - defer os.Unsetenv("CHECKSUM_PATH") + prependPathEnv(t, tmpDir) + t.Setenv("METADATA_PATH", metadataPath) + t.Setenv("CHECKSUM_PATH", checksumPath) candidates, err := discoverRcloneBackups(context.Background(), nil, "gdrive:pbs-backups/server1", nil, nil) if err != nil { @@ -973,19 +932,9 @@ esac t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) - if err := os.Setenv("METADATA_PATH", metadataPath); err != nil { - t.Fatalf("set METADATA_PATH: %v", err) - } - defer os.Unsetenv("METADATA_PATH") - if err := os.Setenv("CHECKSUM_PATH", checksumPath); err != nil { - t.Fatalf("set CHECKSUM_PATH: %v", err) - } - defer os.Unsetenv("CHECKSUM_PATH") + prependPathEnv(t, tmpDir) + t.Setenv("METADATA_PATH", metadataPath) + t.Setenv("CHECKSUM_PATH", checksumPath) cfg := &config.Config{RcloneTimeoutConnection: 3} candidates, err := discoverRcloneBackups(context.Background(), cfg, "gdrive:pbs-backups/server1", nil, nil) @@ -1077,19 +1026,9 @@ esac t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) - if err := os.Setenv("METADATA_PATH", metadataPath); err != nil { - t.Fatalf("set METADATA_PATH: %v", err) - } - defer os.Unsetenv("METADATA_PATH") - if err := os.Setenv("CHECKSUM_PATH", checksumPath); err != nil { - t.Fatalf("set CHECKSUM_PATH: %v", err) - } - defer os.Unsetenv("CHECKSUM_PATH") + prependPathEnv(t, tmpDir) + t.Setenv("METADATA_PATH", metadataPath) + t.Setenv("CHECKSUM_PATH", checksumPath) candidates, err := discoverRcloneBackups(context.Background(), nil, "gdrive:pbs-backups/server1", nil, nil) if err != nil { diff --git a/internal/orchestrator/bundle_test.go b/internal/orchestrator/bundle_test.go index 280060e3..be0ba164 100644 --- a/internal/orchestrator/bundle_test.go +++ b/internal/orchestrator/bundle_test.go @@ -3,6 +3,7 @@ package orchestrator import ( "archive/tar" "context" + "errors" "io" "os" "path/filepath" @@ -11,6 +12,94 @@ import ( "github.com/tis24dev/proxsave/internal/logging" ) +type trackingBundleFS struct { + FS + createdFiles []*os.File + createdPaths []string + openErr map[string]error + renameErr error +} + +func (f *trackingBundleFS) recordCreatedFile(file *os.File) { + if file == nil { + return + } + f.createdFiles = append(f.createdFiles, file) + f.createdPaths = append(f.createdPaths, filepath.Clean(file.Name())) +} + +func (f *trackingBundleFS) Create(name string) (*os.File, error) { + file, err := f.FS.Create(name) + if err == nil { + f.recordCreatedFile(file) + } + return file, err +} + +func (f *trackingBundleFS) CreateTemp(dir, pattern string) (*os.File, error) { + file, err := f.FS.CreateTemp(dir, pattern) + if err == nil { + f.recordCreatedFile(file) + } + return file, err +} + +func (f *trackingBundleFS) Open(path string) (*os.File, error) { + if err, ok := f.openErr[filepath.Clean(path)]; ok { + return nil, err + } + return f.FS.Open(path) +} + +func (f *trackingBundleFS) Rename(oldpath, newpath string) error { + if f.renameErr != nil { + return f.renameErr + } + return f.FS.Rename(oldpath, newpath) +} + +func assertPathAbsent(t *testing.T, path string) { + t.Helper() + + if _, err := os.Stat(path); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected %s to be absent, got %v", path, err) + } +} + +func assertTrackedFilesClosed(t *testing.T, files []*os.File) { + t.Helper() + + if len(files) == 0 { + t.Fatalf("expected tracked bundle file") + } + for _, file := range files { + if err := file.Close(); !errors.Is(err, os.ErrClosed) { + t.Fatalf("bundle file %s close after createBundle = %v, want ErrClosed", file.Name(), err) + } + } +} + +func assertTrackedPathsAbsent(t *testing.T, paths []string) { + t.Helper() + + if len(paths) == 0 { + t.Fatalf("expected tracked bundle file path") + } + for _, path := range paths { + assertPathAbsent(t, path) + } +} + +func writeBundleFixtures(t *testing.T, archive string, data map[string]string) { + t.Helper() + + for suffix, content := range data { + if err := os.WriteFile(archive+suffix, []byte(content), 0o640); err != nil { + t.Fatalf("write %s: %v", suffix, err) + } + } +} + func TestCreateBundle_CreatesValidTarArchive(t *testing.T) { logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) tempDir := t.TempDir() @@ -23,16 +112,12 @@ func TestCreateBundle_CreatesValidTarArchive(t *testing.T) { ".metadata": "metadata-json", ".metadata.sha256": "checksum2", } + writeBundleFixtures(t, archive, testData) - for suffix, content := range testData { - if err := os.WriteFile(archive+suffix, []byte(content), 0o640); err != nil { - t.Fatalf("write %s: %v", suffix, err) - } - } - + bundleFS := &trackingBundleFS{FS: osFS{}} o := &Orchestrator{ logger: logger, - fs: osFS{}, + fs: bundleFS, } bundlePath, err := o.createBundle(context.Background(), archive) @@ -44,6 +129,16 @@ func TestCreateBundle_CreatesValidTarArchive(t *testing.T) { if bundlePath != expectedPath { t.Fatalf("bundle path = %s, want %s", bundlePath, expectedPath) } + if len(bundleFS.createdPaths) == 0 { + t.Fatalf("expected tracked bundle file path") + } + for _, path := range bundleFS.createdPaths { + if path == expectedPath { + t.Fatalf("expected bundle to be written via temp file, got %s", path) + } + } + assertTrackedPathsAbsent(t, bundleFS.createdPaths) + assertTrackedFilesClosed(t, bundleFS.createdFiles) // Verify bundle file exists bundleInfo, err := os.Stat(bundlePath) @@ -121,6 +216,103 @@ func TestCreateBundle_CreatesValidTarArchive(t *testing.T) { } } +func TestCreateBundle_ClosesBundleFileOnInputOpenError(t *testing.T) { + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + tempDir := t.TempDir() + archive := filepath.Join(tempDir, "backup.tar") + + testData := map[string]string{ + "": "archive-content", + ".sha256": "checksum1", + ".metadata": "metadata-json", + } + writeBundleFixtures(t, archive, testData) + + forcedErr := errors.New("forced open failure") + bundleFS := &trackingBundleFS{ + FS: osFS{}, + openErr: map[string]error{ + filepath.Clean(archive + ".sha256"): forcedErr, + }, + } + o := &Orchestrator{ + logger: logger, + fs: bundleFS, + } + + _, err := o.createBundle(context.Background(), archive) + if !errors.Is(err, forcedErr) { + t.Fatalf("createBundle error = %v, want wrapped %v", err, forcedErr) + } + assertTrackedFilesClosed(t, bundleFS.createdFiles) + assertPathAbsent(t, archive+".bundle.tar") + assertTrackedPathsAbsent(t, bundleFS.createdPaths) +} + +func TestCreateBundle_RemovesTempFileOnRenameError(t *testing.T) { + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + tempDir := t.TempDir() + archive := filepath.Join(tempDir, "backup.tar") + + testData := map[string]string{ + "": "archive-content", + ".sha256": "checksum1", + ".metadata": "metadata-json", + } + writeBundleFixtures(t, archive, testData) + + forcedErr := errors.New("forced rename failure") + bundleFS := &trackingBundleFS{ + FS: osFS{}, + renameErr: forcedErr, + } + o := &Orchestrator{ + logger: logger, + fs: bundleFS, + } + + _, err := o.createBundle(context.Background(), archive) + if !errors.Is(err, forcedErr) { + t.Fatalf("createBundle error = %v, want wrapped %v", err, forcedErr) + } + assertTrackedFilesClosed(t, bundleFS.createdFiles) + assertPathAbsent(t, archive+".bundle.tar") + assertTrackedPathsAbsent(t, bundleFS.createdPaths) +} + +func TestCreateBundle_RemovesFinalBundleOnDirectoryOpenErrorDuringSync(t *testing.T) { + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + tempDir := t.TempDir() + archive := filepath.Join(tempDir, "backup.tar") + + testData := map[string]string{ + "": "archive-content", + ".sha256": "checksum1", + ".metadata": "metadata-json", + } + writeBundleFixtures(t, archive, testData) + + forcedErr := errors.New("forced directory open failure during sync") + bundleFS := &trackingBundleFS{ + FS: osFS{}, + openErr: map[string]error{ + filepath.Clean(tempDir): forcedErr, + }, + } + o := &Orchestrator{ + logger: logger, + fs: bundleFS, + } + + _, err := o.createBundle(context.Background(), archive) + if !errors.Is(err, forcedErr) { + t.Fatalf("createBundle error = %v, want wrapped %v", err, forcedErr) + } + assertTrackedFilesClosed(t, bundleFS.createdFiles) + assertPathAbsent(t, archive+".bundle.tar") + assertTrackedPathsAbsent(t, bundleFS.createdPaths) +} + func TestRemoveAssociatedFiles_RemovesAll(t *testing.T) { logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) tempDir := t.TempDir() diff --git a/internal/orchestrator/decrypt_additional_test.go b/internal/orchestrator/decrypt_additional_test.go index a92a2fc5..cd0bc48f 100644 --- a/internal/orchestrator/decrypt_additional_test.go +++ b/internal/orchestrator/decrypt_additional_test.go @@ -81,11 +81,7 @@ func TestDownloadRcloneBackup(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", binDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", binDir+string(os.PathListSeparator)+os.Getenv("PATH")) logger := logging.New(types.LogLevelDebug, false) logger.SetOutput(io.Discard) diff --git a/internal/orchestrator/decrypt_test.go b/internal/orchestrator/decrypt_test.go index 4c795df2..e6efd30b 100644 --- a/internal/orchestrator/decrypt_test.go +++ b/internal/orchestrator/decrypt_test.go @@ -124,6 +124,17 @@ func TestBuildDecryptPathOptions(t *testing.T) { wantPaths: []string{"/backup/local"}, wantLabel: []string{"Local backups"}, }, + { + name: "secondary disabled ignores stale path", + cfg: &config.Config{ + BackupPath: "/backup/local", + SecondaryEnabled: false, + SecondaryPath: "remote:path", + }, + wantCount: 1, + wantPaths: []string{"/backup/local"}, + wantLabel: []string{"Local backups"}, + }, { name: "cloud enabled but empty remote", cfg: &config.Config{ @@ -228,16 +239,8 @@ func TestInspectRcloneMetadataManifest_JSONArchivePathEmptyUsesRemoteArchivePath t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) - - if err := os.Setenv("METADATA_PATH", metadataPath); err != nil { - t.Fatalf("set METADATA_PATH: %v", err) - } - defer os.Unsetenv("METADATA_PATH") + prependPathEnv(t, tmpDir) + t.Setenv("METADATA_PATH", metadataPath) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -275,16 +278,8 @@ func TestInspectRcloneMetadataManifest_LegacyInfersAgeFromArchiveExt(t *testing. t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) - - if err := os.Setenv("METADATA_PATH", metadataPath); err != nil { - t.Fatalf("set METADATA_PATH: %v", err) - } - defer os.Unsetenv("METADATA_PATH") + prependPathEnv(t, tmpDir) + t.Setenv("METADATA_PATH", metadataPath) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -330,16 +325,8 @@ func TestInspectRcloneBundleManifest_ReturnsErrorWhenManifestMissing(t *testing. t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) - - if err := os.Setenv("BUNDLE_PATH", bundlePath); err != nil { - t.Fatalf("set BUNDLE_PATH: %v", err) - } - defer os.Unsetenv("BUNDLE_PATH") + prependPathEnv(t, tmpDir) + t.Setenv("BUNDLE_PATH", bundlePath) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -2093,24 +2080,10 @@ esac t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", binDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) - - if err := os.Setenv("ARCHIVE_SRC", archiveSrc); err != nil { - t.Fatalf("set ARCHIVE_SRC: %v", err) - } - if err := os.Setenv("METADATA_SRC", metadataSrc); err != nil { - t.Fatalf("set METADATA_SRC: %v", err) - } - if err := os.Setenv("CHECKSUM_SRC", checksumSrc); err != nil { - t.Fatalf("set CHECKSUM_SRC: %v", err) - } - defer os.Unsetenv("ARCHIVE_SRC") - defer os.Unsetenv("METADATA_SRC") - defer os.Unsetenv("CHECKSUM_SRC") + prependPathEnv(t, binDir) + t.Setenv("ARCHIVE_SRC", archiveSrc) + t.Setenv("METADATA_SRC", metadataSrc) + t.Setenv("CHECKSUM_SRC", checksumSrc) cand := &decryptCandidate{ IsRclone: true, @@ -2298,11 +2271,7 @@ func TestInspectRcloneBundleManifest_TarReadErrorInLoop(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) + prependPathEnv(t, tmpDir) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -2345,11 +2314,7 @@ func TestInspectRcloneBundleManifest_UnmarshalError(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) + prependPathEnv(t, tmpDir) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -2400,11 +2365,7 @@ func TestInspectRcloneBundleManifest_ValidManifest(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) + prependPathEnv(t, tmpDir) logger := logging.New(types.LogLevelDebug, false) logger.SetOutput(io.Discard) @@ -2441,11 +2402,7 @@ func TestInspectRcloneMetadataManifest_EmptyData(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) + prependPathEnv(t, tmpDir) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -2482,11 +2439,7 @@ func TestInspectRcloneMetadataManifest_LegacyPlainEncryption(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) + prependPathEnv(t, tmpDir) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -2533,11 +2486,7 @@ func TestInspectRcloneMetadataManifest_LegacyWithComments(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) + prependPathEnv(t, tmpDir) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -2567,11 +2516,7 @@ func TestInspectRcloneMetadataManifest_RcloneFails(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) + prependPathEnv(t, tmpDir) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -2743,11 +2688,7 @@ func TestDownloadRcloneBackup_RcloneRunError(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) + prependPathEnv(t, tmpDir) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -2909,9 +2850,7 @@ esac t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - os.Setenv("PATH", binDir+string(os.PathListSeparator)+oldPath) - defer os.Setenv("PATH", oldPath) + prependPathEnv(t, binDir) // Mock password input to return the correct key readPassword = func(fd int) ([]byte, error) { @@ -3354,11 +3293,7 @@ func TestInspectRcloneBundleManifest_StartError(t *testing.T) { t.Fatalf("write fake rclone: %v", err) } - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { - t.Fatalf("set PATH: %v", err) - } - defer os.Setenv("PATH", oldPath) + prependPathEnv(t, tmpDir) logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) @@ -3652,9 +3587,7 @@ exit 1 t.Fatalf("write rclone: %v", err) } - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) ctx := context.Background() logger := logging.New(types.LogLevelDebug, false) @@ -3687,9 +3620,7 @@ exit 0 t.Fatalf("write rclone: %v", err) } - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) workDir := filepath.Join(tmp, "work") if err := os.MkdirAll(workDir, 0o755); err != nil { @@ -3745,9 +3676,7 @@ fi t.Fatalf("write rclone: %v", err) } - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) workDir := filepath.Join(tmp, "work") if err := os.MkdirAll(workDir, 0o755); err != nil { @@ -3845,9 +3774,7 @@ exit 0 t.Fatalf("write rclone: %v", err) } - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) cfg := &config.Config{ BackupPath: "", @@ -3887,9 +3814,7 @@ exit 1 t.Fatalf("write rclone: %v", err) } - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) // Create local backup directory with valid backup backupDir := filepath.Join(tmp, "backups") @@ -4024,9 +3949,7 @@ exit 1 t.Fatalf("write rclone: %v", err) } - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) cand := &decryptCandidate{ Source: sourceBundle, @@ -4082,9 +4005,7 @@ exit 1 t.Fatalf("write rclone: %v", err) } - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) // First allow the rclone download to work by using real FS initially orig := restoreFS @@ -4140,9 +4061,7 @@ cat "%s" t.Fatalf("write rclone: %v", err) } - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) ctx := context.Background() logger := logging.New(types.LogLevelError, false) @@ -4178,9 +4097,7 @@ exit 1 t.Fatalf("write rclone: %v", err) } - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) ctx := context.Background() logger := logging.New(types.LogLevelError, false) @@ -4222,9 +4139,7 @@ cat "%s" t.Fatalf("write rclone: %v", err) } - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) ctx := context.Background() logger := logging.New(types.LogLevelError, false) @@ -4438,9 +4353,7 @@ exit 0 } // Prepend fake rclone to PATH - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) orig := restoreFS // Use regular osFS - the download will work, then MkdirAll for /tmp/proxsave should succeed @@ -4585,9 +4498,7 @@ exit 0 `, sourceBundlePath) os.WriteFile(fakeRclone, []byte(script), 0o755) - origPath := os.Getenv("PATH") - os.Setenv("PATH", tmp+":"+origPath) - defer os.Setenv("PATH", origPath) + prependPathEnv(t, tmp) // Use FS that fails MkdirAll after the first call (download uses MkdirAll too) fake := &fakeMkdirAllFailAfterDownloadFS{failAfterCall: 1} diff --git a/internal/orchestrator/decrypt_tui.go b/internal/orchestrator/decrypt_tui.go index 7602d8a2..04ea8aad 100644 --- a/internal/orchestrator/decrypt_tui.go +++ b/internal/orchestrator/decrypt_tui.go @@ -4,11 +4,8 @@ import ( "context" "errors" "fmt" - "os" - "path/filepath" "strings" - "filippo.io/age" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" @@ -16,21 +13,11 @@ import ( "github.com/tis24dev/proxsave/internal/config" "github.com/tis24dev/proxsave/internal/logging" "github.com/tis24dev/proxsave/internal/tui" - "github.com/tis24dev/proxsave/internal/tui/components" ) const ( decryptWizardSubtitle = "Decrypt Backup Workflow" decryptNavText = "[yellow]Navigation:[white] TAB/↑↓ to move | ENTER to select | ESC to exit screens | Mouse clicks enabled" - - pathActionOverwrite = "overwrite" - pathActionNew = "new" - pathActionCancel = "cancel" -) - -var ( - promptOverwriteActionFunc = promptOverwriteAction - promptNewPathInputFunc = promptNewPathInput ) // RunDecryptWorkflowTUI runs the decrypt workflow using a TUI flow. @@ -104,324 +91,15 @@ func filterEncryptedCandidates(candidates []*decryptCandidate) []*decryptCandida return filtered } -func ensureWritablePathTUI(path, description, configPath, buildSig string) (string, error) { - current := filepath.Clean(path) - if description == "" { - description = "file" - } - var failureMessage string - - for { - if _, err := restoreFS.Stat(current); errors.Is(err, os.ErrNotExist) { - return current, nil - } else if err != nil && !errors.Is(err, os.ErrExist) { - return "", fmt.Errorf("stat %s: %w", current, err) - } - - action, err := promptOverwriteActionFunc(current, description, failureMessage, configPath, buildSig) - if err != nil { - return "", err - } - failureMessage = "" - - switch action { - case pathActionOverwrite: - if err := restoreFS.Remove(current); err != nil && !errors.Is(err, os.ErrNotExist) { - failureMessage = fmt.Sprintf("Failed to remove existing %s: %v", description, err) - continue - } - return current, nil - case pathActionNew: - newPath, err := promptNewPathInputFunc(current, configPath, buildSig) - if err != nil { - if errors.Is(err, ErrDecryptAborted) { - return "", ErrDecryptAborted - } - failureMessage = err.Error() - continue - } - current = filepath.Clean(newPath) - default: - return "", ErrDecryptAborted - } - } -} - -func promptOverwriteAction(path, description, failureMessage, configPath, buildSig string) (string, error) { - app := newTUIApp() - var choice string - - message := fmt.Sprintf("The %s [yellow]%s[white] already exists.\nSelect how you want to proceed.", description, path) - if strings.TrimSpace(failureMessage) != "" { - message = fmt.Sprintf("%s\n\n[red]%s[white]", message, failureMessage) - } - message += "\n\n[yellow]Use ←→ or TAB to switch buttons | ENTER to confirm[white]" - - modal := tview.NewModal(). - SetText(message). - AddButtons([]string{"Overwrite", "Use different path", "Cancel"}). - SetDoneFunc(func(buttonIndex int, buttonLabel string) { - switch buttonLabel { - case "Overwrite": - choice = pathActionOverwrite - case "Use different path": - choice = pathActionNew - default: - choice = pathActionCancel - } - app.Stop() - }) - - modal.SetBorder(true). - SetTitle(" Existing file "). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.WarningYellow). - SetBorderColor(tui.WarningYellow). - SetBackgroundColor(tcell.ColorBlack) - - wrapped := buildWizardPage("Destination path", configPath, buildSig, modal) - if err := app.SetRoot(wrapped, true).SetFocus(modal).Run(); err != nil { - return "", err - } - return choice, nil -} - -func promptNewPathInput(defaultPath, configPath, buildSig string) (string, error) { - app := newTUIApp() - var newPath string - var cancelled bool - - form := components.NewForm(app) - label := "New path" - form.AddInputFieldWithValidation(label, defaultPath, 64, func(value string) error { - if strings.TrimSpace(value) == "" { - return fmt.Errorf("path cannot be empty") - } - return nil - }) - form.SetOnSubmit(func(values map[string]string) error { - newPath = strings.TrimSpace(values[label]) - return nil - }) - form.SetOnCancel(func() { - cancelled = true - }) - form.AddSubmitButton("Continue") - form.AddCancelButton("Cancel") - - helper := tview.NewTextView(). - SetText("Provide a writable filesystem path for the decrypted files."). - SetWrap(true). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true) - - content := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(helper, 3, 0, false). - AddItem(form.Form, 0, 1, true) - - page := buildWizardPage("Choose destination path", configPath, buildSig, content) - form.SetParentView(page) - - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { - return "", err - } - if cancelled { - return "", ErrDecryptAborted - } - return filepath.Clean(newPath), nil -} - -func preparePlainBundleTUI(ctx context.Context, cand *decryptCandidate, version string, logger *logging.Logger, configPath, buildSig string) (*preparedBundle, error) { - return preparePlainBundleCommon(ctx, cand, version, logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error { - return decryptArchiveWithTUIPrompts(ctx, encryptedPath, outputPath, displayName, configPath, buildSig, logger) - }) -} - -func decryptArchiveWithTUIPrompts(ctx context.Context, encryptedPath, outputPath, displayName, configPath, buildSig string, logger *logging.Logger) error { - var promptError string - if ctx == nil { - ctx = context.Background() - } - for { - if err := ctx.Err(); err != nil { - return err - } - identities, err := promptDecryptIdentity(displayName, configPath, buildSig, promptError) - if err != nil { - return err - } - - if err := ctx.Err(); err != nil { - return err - } - if err := decryptWithIdentity(encryptedPath, outputPath, identities...); err != nil { - var noMatch *age.NoIdentityMatchError - if errors.Is(err, age.ErrIncorrectIdentity) || errors.As(err, &noMatch) { - promptError = "Provided key or passphrase does not match this archive." - logger.Warning("Incorrect key or passphrase for %s", filepath.Base(encryptedPath)) - continue - } - return err - } - return nil - } -} - -func promptDecryptIdentity(displayName, configPath, buildSig, errorMessage string) ([]age.Identity, error) { - app := newTUIApp() - var ( - chosenIdentity []age.Identity - cancelled bool - ) - - name := displayName - if strings.TrimSpace(name) == "" { - name = "selected backup" - } - infoMessage := fmt.Sprintf("Provide the AGE secret key or passphrase used for [yellow]%s[white].", name) - if strings.TrimSpace(errorMessage) != "" { - infoMessage = fmt.Sprintf("%s\n\n[red]%s[white]", infoMessage, errorMessage) - } - infoText := tview.NewTextView(). - SetText(infoMessage). - SetWrap(true). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true) - - form := components.NewForm(app) - label := "Key or passphrase:" - form.AddPasswordField(label, 64) - form.SetOnSubmit(func(values map[string]string) error { - raw := strings.TrimSpace(values[label]) - if raw == "" { - return fmt.Errorf("key or passphrase cannot be empty") - } - identity, err := parseIdentityInput(raw) - resetString(&raw) - if err != nil { - return fmt.Errorf("invalid key or passphrase: %w", err) - } - chosenIdentity = identity - return nil - }) - form.SetOnCancel(func() { - cancelled = true - }) - // Buttons: Continue, Cancel - form.AddSubmitButton("Continue") - form.AddCancelButton("Cancel") - - content := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(infoText, 3, 0, false). - AddItem(form.Form, 0, 1, true) - - page := buildWizardPage("Enter decryption secret", configPath, buildSig, content) - form.SetParentView(page) - - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { - return nil, err - } - if cancelled { - return nil, ErrDecryptAborted - } - if len(chosenIdentity) == 0 { - return nil, fmt.Errorf("missing identity") - } - return chosenIdentity, nil -} - -func enableFormNavigation(form *components.Form, dropdownOpen *bool) { - if form == nil || form.Form == nil { - return - } - form.Form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { - if event == nil { - return event - } - if dropdownOpen != nil && *dropdownOpen { - return event - } - - formItemIndex, buttonIndex := form.Form.GetFocusedItemIndex() - isOnButton := formItemIndex < 0 && buttonIndex >= 0 - isOnField := formItemIndex >= 0 - - if isOnButton { - switch event.Key() { - case tcell.KeyLeft, tcell.KeyUp: - return tcell.NewEventKey(tcell.KeyBacktab, 0, tcell.ModNone) - case tcell.KeyRight, tcell.KeyDown: - return tcell.NewEventKey(tcell.KeyTab, 0, tcell.ModNone) - } - } else if isOnField { - // If focused item is a ListFormItem, let it handle navigation internally - if formItemIndex >= 0 { - if _, ok := form.Form.GetFormItem(formItemIndex).(*components.ListFormItem); ok { - return event - } - } - // For other form fields, convert arrows to tab navigation - switch event.Key() { - case tcell.KeyUp: - return tcell.NewEventKey(tcell.KeyBacktab, 0, tcell.ModNone) - case tcell.KeyDown: - return tcell.NewEventKey(tcell.KeyTab, 0, tcell.ModNone) - } - } - return event - }) -} - func buildWizardPage(title, configPath, buildSig string, content tview.Primitive) tview.Primitive { - welcomeText := tview.NewTextView(). - SetText(fmt.Sprintf("ProxSave - By TIS24DEV\n%s\n", decryptWizardSubtitle)). - SetTextColor(tui.ProxmoxLight). - SetDynamicColors(true) - welcomeText.SetBorder(false) - - navInstructions := tview.NewTextView(). - SetText("\n" + decryptNavText). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - navInstructions.SetBorder(false) - - separator := tview.NewTextView(). - SetText(strings.Repeat("─", 80)). - SetTextColor(tui.ProxmoxOrange) - separator.SetBorder(false) - - configPathText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", configPath)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - configPathText.SetBorder(false) - - buildSigText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - buildSigText.SetBorder(false) - - flex := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(welcomeText, 5, 0, false). - AddItem(navInstructions, 2, 0, false). - AddItem(separator, 1, 0, false). - AddItem(content, 0, 1, true). - AddItem(configPathText, 1, 0, false). - AddItem(buildSigText, 1, 0, false) - - flex.SetBorder(true). - SetTitle(fmt.Sprintf(" %s ", title)). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.ProxmoxOrange). - SetBorderColor(tui.ProxmoxOrange). - SetBackgroundColor(tcell.ColorBlack) - - return flex + return tui.BuildScreen(tui.ScreenSpec{ + Title: title, + HeaderText: fmt.Sprintf("ProxSave - By TIS24DEV\n%s\n", decryptWizardSubtitle), + NavText: decryptNavText, + ConfigPath: configPath, + BuildSig: buildSig, + TitleColor: tui.ProxmoxOrange, + BorderColor: tui.ProxmoxOrange, + BackgroundColor: tcell.ColorBlack, + }, content) } diff --git a/internal/orchestrator/decrypt_tui_e2e_helpers_test.go b/internal/orchestrator/decrypt_tui_e2e_helpers_test.go new file mode 100644 index 00000000..e7b78c18 --- /dev/null +++ b/internal/orchestrator/decrypt_tui_e2e_helpers_test.go @@ -0,0 +1,301 @@ +package orchestrator + +import ( + "archive/tar" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "filippo.io/age" + "github.com/gdamore/tcell/v2" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/tui" + "github.com/tis24dev/proxsave/internal/types" +) + +var decryptTUIE2EMu sync.Mutex + +type timedSimKey struct { + Key tcell.Key + R rune + Mod tcell.ModMask + Wait time.Duration +} + +type decryptTUIFixture struct { + Config *config.Config + ConfigPath string + BackupDir string + BaseDir string + DestinationDir string + ArchivePlaintext []byte + Secret string + EncryptedArchive string + ExpectedBundlePath string + ExpectedArchiveName string + ExpectedChecksum string +} + +func withTimedSimAppSequence(t *testing.T, keys []timedSimKey) { + t.Helper() + + decryptTUIE2EMu.Lock() + orig := newTUIApp + done := make(chan struct{}) + var injectWG sync.WaitGroup + t.Cleanup(func() { + close(done) + injectWG.Wait() + newTUIApp = orig + decryptTUIE2EMu.Unlock() + }) + + screen := tcell.NewSimulationScreen("UTF-8") + if err := screen.Init(); err != nil { + t.Fatalf("screen.Init: %v", err) + } + screen.SetSize(120, 40) + + var once sync.Once + newTUIApp = func() *tui.App { + app := tui.NewApp() + app.SetScreen(screen) + + once.Do(func() { + injectWG.Add(1) + go func() { + defer injectWG.Done() + + for _, k := range keys { + if k.Wait > 0 { + timer := time.NewTimer(k.Wait) + select { + case <-done: + if !timer.Stop() { + <-timer.C + } + return + case <-timer.C: + } + } + mod := k.Mod + if mod == 0 { + mod = tcell.ModNone + } + select { + case <-done: + return + default: + } + screen.InjectKey(k.Key, k.R, mod) + } + }() + }) + + return app + } +} + +func createDecryptTUIEncryptedFixture(t *testing.T) *decryptTUIFixture { + t.Helper() + + backupDir := t.TempDir() + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, "backup.env") + if err := os.WriteFile(configPath, []byte("BACKUP_PATH="+backupDir+"\nBASE_DIR="+baseDir+"\n"), 0o600); err != nil { + t.Fatalf("write config placeholder: %v", err) + } + + passphrase := "Decrypt123!" + recipientStr, err := deriveDeterministicRecipientFromPassphrase(passphrase) + if err != nil { + t.Fatalf("deriveDeterministicRecipientFromPassphrase: %v", err) + } + recipient, err := age.ParseX25519Recipient(recipientStr) + if err != nil { + t.Fatalf("age.ParseX25519Recipient: %v", err) + } + + plaintext := []byte("proxsave decrypt tui e2e plaintext\n") + archivePath := filepath.Join(backupDir, "backup.tar.xz.age") + archiveFile, err := os.Create(archivePath) + if err != nil { + t.Fatalf("create encrypted archive: %v", err) + } + + encWriter, err := age.Encrypt(archiveFile, recipient) + if err != nil { + _ = archiveFile.Close() + t.Fatalf("age.Encrypt: %v", err) + } + if _, err := encWriter.Write(plaintext); err != nil { + _ = encWriter.Close() + _ = archiveFile.Close() + t.Fatalf("write plaintext to age writer: %v", err) + } + if err := encWriter.Close(); err != nil { + _ = archiveFile.Close() + t.Fatalf("close age writer: %v", err) + } + if err := archiveFile.Close(); err != nil { + t.Fatalf("close encrypted archive: %v", err) + } + + encryptedBytes, err := os.ReadFile(archivePath) + if err != nil { + t.Fatalf("read encrypted archive: %v", err) + } + + createdAt := time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC) + manifest := &backup.Manifest{ + ArchivePath: archivePath, + CreatedAt: createdAt, + Hostname: "node1", + EncryptionMode: "age", + ProxmoxType: "pve", + } + manifestData, err := json.Marshal(manifest) + if err != nil { + t.Fatalf("marshal manifest: %v", err) + } + if err := os.WriteFile(archivePath+".metadata", manifestData, 0o640); err != nil { + t.Fatalf("write manifest sidecar: %v", err) + } + if err := os.WriteFile(archivePath+".sha256", checksumLineForBytes(filepath.Base(archivePath), encryptedBytes), 0o640); err != nil { + t.Fatalf("write checksum sidecar: %v", err) + } + + checksum := sha256.Sum256(plaintext) + expectedArchiveName := "backup.tar.xz" + destinationDir := filepath.Join(baseDir, "decrypt") + + return &decryptTUIFixture{ + Config: &config.Config{ + BackupPath: backupDir, + BaseDir: baseDir, + SecondaryEnabled: false, + CloudEnabled: false, + }, + ConfigPath: configPath, + BackupDir: backupDir, + BaseDir: baseDir, + DestinationDir: destinationDir, + ArchivePlaintext: plaintext, + Secret: passphrase, + EncryptedArchive: archivePath, + ExpectedBundlePath: filepath.Join(destinationDir, expectedArchiveName+".decrypted.bundle.tar"), + ExpectedArchiveName: expectedArchiveName, + ExpectedChecksum: hex.EncodeToString(checksum[:]), + } +} + +func successDecryptTUISequence(secret string) []timedSimKey { + keys := []timedSimKey{ + {Key: tcell.KeyEnter, Wait: 1 * time.Second}, + {Key: tcell.KeyEnter, Wait: 750 * time.Millisecond}, + } + + for _, r := range secret { + keys = append(keys, timedSimKey{ + Key: tcell.KeyRune, + R: r, + Wait: 35 * time.Millisecond, + }) + } + + keys = append(keys, + timedSimKey{Key: tcell.KeyTab, Wait: 150 * time.Millisecond}, + timedSimKey{Key: tcell.KeyEnter, Wait: 100 * time.Millisecond}, + timedSimKey{Key: tcell.KeyTab, Wait: 500 * time.Millisecond}, + timedSimKey{Key: tcell.KeyEnter, Wait: 100 * time.Millisecond}, + ) + + return keys +} + +func abortDecryptTUISequence() []timedSimKey { + return []timedSimKey{ + {Key: tcell.KeyEnter, Wait: 1 * time.Second}, + {Key: tcell.KeyEnter, Wait: 750 * time.Millisecond}, + {Key: tcell.KeyRune, R: '0', Wait: 500 * time.Millisecond}, + {Key: tcell.KeyTab, Wait: 150 * time.Millisecond}, + {Key: tcell.KeyEnter, Wait: 100 * time.Millisecond}, + } +} + +func runDecryptWorkflowTUIForTest(t *testing.T, ctx context.Context, cfg *config.Config, configPath string) error { + t.Helper() + + logger := logging.New(types.LogLevelError, false) + logger.SetOutput(io.Discard) + + errCh := make(chan error, 1) + go func() { + errCh <- RunDecryptWorkflowTUI(ctx, cfg, logger, "1.0.0", configPath, "test-build") + }() + + waitTimeout := 30 * time.Second + if deadline, ok := ctx.Deadline(); ok { + waitTimeout = time.Until(deadline) + 2*time.Second + if waitTimeout < 2*time.Second { + waitTimeout = 2 * time.Second + } + } + timer := time.NewTimer(waitTimeout) + defer timer.Stop() + + select { + case err := <-errCh: + return err + case <-timer.C: + if err := ctx.Err(); err != nil { + t.Fatalf("RunDecryptWorkflowTUI did not return within %s (context state: %v)", waitTimeout, err) + return nil + } + t.Fatalf("RunDecryptWorkflowTUI did not return within %s", waitTimeout) + return nil + } +} + +func readTarEntries(t *testing.T, tarPath string) map[string][]byte { + t.Helper() + + file, err := os.Open(tarPath) + if err != nil { + t.Fatalf("open tar %s: %v", tarPath, err) + } + defer file.Close() + + tr := tar.NewReader(file) + entries := make(map[string][]byte) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("read tar header from %s: %v", tarPath, err) + } + data, err := io.ReadAll(tr) + if err != nil { + t.Fatalf("read tar entry %s: %v", hdr.Name, err) + } + entries[hdr.Name] = data + } + return entries +} + +func checksumLineForArchiveHex(filename, checksumHex string) string { + return fmt.Sprintf("%s %s\n", checksumHex, filename) +} diff --git a/internal/orchestrator/decrypt_tui_e2e_test.go b/internal/orchestrator/decrypt_tui_e2e_test.go new file mode 100644 index 00000000..925b81d0 --- /dev/null +++ b/internal/orchestrator/decrypt_tui_e2e_test.go @@ -0,0 +1,95 @@ +package orchestrator + +import ( + "context" + "encoding/json" + "errors" + "os" + "path/filepath" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/backup" +) + +func TestRunDecryptWorkflowTUI_SuccessLocalEncrypted(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + fixture := createDecryptTUIEncryptedFixture(t) + withTimedSimAppSequence(t, successDecryptTUISequence(fixture.Secret)) + + ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second) + defer cancel() + + if err := runDecryptWorkflowTUIForTest(t, ctx, fixture.Config, fixture.ConfigPath); err != nil { + t.Fatalf("RunDecryptWorkflowTUI error: %v", err) + } + + if _, err := os.Stat(fixture.ExpectedBundlePath); err != nil { + t.Fatalf("expected decrypted bundle at %s: %v", fixture.ExpectedBundlePath, err) + } + + entries := readTarEntries(t, fixture.ExpectedBundlePath) + + archiveData, ok := entries[fixture.ExpectedArchiveName] + if !ok { + t.Fatalf("bundle missing archive entry %s", fixture.ExpectedArchiveName) + } + if string(archiveData) != string(fixture.ArchivePlaintext) { + t.Fatalf("archive entry content mismatch: got %q want %q", string(archiveData), string(fixture.ArchivePlaintext)) + } + + metadataName := fixture.ExpectedArchiveName + ".metadata" + metadataData, ok := entries[metadataName] + if !ok { + t.Fatalf("bundle missing metadata entry %s", metadataName) + } + + var manifest backup.Manifest + if err := json.Unmarshal(metadataData, &manifest); err != nil { + t.Fatalf("unmarshal metadata entry %s: %v", metadataName, err) + } + if manifest.EncryptionMode != "none" { + t.Fatalf("metadata EncryptionMode=%q; want %q", manifest.EncryptionMode, "none") + } + expectedArchivePath := filepath.Join(fixture.DestinationDir, fixture.ExpectedArchiveName) + if manifest.ArchivePath != expectedArchivePath { + t.Fatalf("metadata ArchivePath=%q; want %q", manifest.ArchivePath, expectedArchivePath) + } + if manifest.SHA256 != fixture.ExpectedChecksum { + t.Fatalf("metadata SHA256=%q; want %q", manifest.SHA256, fixture.ExpectedChecksum) + } + + checksumName := fixture.ExpectedArchiveName + ".sha256" + checksumData, ok := entries[checksumName] + if !ok { + t.Fatalf("bundle missing checksum entry %s", checksumName) + } + expectedChecksumLine := checksumLineForArchiveHex(fixture.ExpectedArchiveName, fixture.ExpectedChecksum) + if string(checksumData) != expectedChecksumLine { + t.Fatalf("checksum entry=%q; want %q", string(checksumData), expectedChecksumLine) + } +} + +func TestRunDecryptWorkflowTUI_AbortAtSecretPrompt(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + fixture := createDecryptTUIEncryptedFixture(t) + withTimedSimAppSequence(t, abortDecryptTUISequence()) + + ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second) + defer cancel() + + err := runDecryptWorkflowTUIForTest(t, ctx, fixture.Config, fixture.ConfigPath) + if !errors.Is(err, ErrDecryptAborted) { + t.Fatalf("RunDecryptWorkflowTUI error=%v; want %v", err, ErrDecryptAborted) + } + + if _, err := os.Stat(fixture.ExpectedBundlePath); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected no decrypted bundle at %s, stat err=%v", fixture.ExpectedBundlePath, err) + } +} diff --git a/internal/orchestrator/decrypt_tui_simulation_test.go b/internal/orchestrator/decrypt_tui_simulation_test.go index 9a65f1c8..b278f4e1 100644 --- a/internal/orchestrator/decrypt_tui_simulation_test.go +++ b/internal/orchestrator/decrypt_tui_simulation_test.go @@ -1,22 +1,23 @@ package orchestrator import ( + "context" "testing" "github.com/gdamore/tcell/v2" ) -func TestPromptDecryptIdentity_CancelReturnsAborted(t *testing.T) { - // Focus starts on the password field; tab to Cancel and submit. +func TestTUIWorkflowUIPromptDecryptSecret_CancelReturnsAborted(t *testing.T) { withSimApp(t, []tcell.Key{tcell.KeyTab, tcell.KeyTab, tcell.KeyEnter}) - _, err := promptDecryptIdentity("backup", "/tmp/config.env", "sig", "") + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + _, err := ui.PromptDecryptSecret(context.Background(), "backup", "") if err != ErrDecryptAborted { t.Fatalf("err=%v; want %v", err, ErrDecryptAborted) } } -func TestPromptDecryptIdentity_PassphraseReturnsIdentity(t *testing.T) { +func TestTUIWorkflowUIPromptDecryptSecret_PassphraseReturnsSecret(t *testing.T) { passphrase := "test passphrase" var seq []simKey @@ -26,11 +27,46 @@ func TestPromptDecryptIdentity_PassphraseReturnsIdentity(t *testing.T) { seq = append(seq, simKey{Key: tcell.KeyTab}, simKey{Key: tcell.KeyEnter}) withSimAppSequence(t, seq) - ids, err := promptDecryptIdentity("backup", "/tmp/config.env", "sig", "") + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + secret, err := ui.PromptDecryptSecret(context.Background(), "backup", "") if err != nil { - t.Fatalf("promptDecryptIdentity error: %v", err) + t.Fatalf("PromptDecryptSecret error: %v", err) } - if len(ids) == 0 { - t.Fatalf("expected at least one identity") + if secret != passphrase { + t.Fatalf("secret=%q; want %q", secret, passphrase) + } +} + +func TestTUIWorkflowUIPromptDecryptSecret_PreservesSurroundingSpaces(t *testing.T) { + passphrase := " test passphrase " + + var seq []simKey + for _, r := range passphrase { + seq = append(seq, simKey{Key: tcell.KeyRune, R: r}) + } + seq = append(seq, simKey{Key: tcell.KeyTab}, simKey{Key: tcell.KeyEnter}) + withSimAppSequence(t, seq) + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + secret, err := ui.PromptDecryptSecret(context.Background(), "backup", "") + if err != nil { + t.Fatalf("PromptDecryptSecret error: %v", err) + } + if secret != passphrase { + t.Fatalf("secret=%q; want %q", secret, passphrase) + } +} + +func TestTUIWorkflowUIPromptDecryptSecret_ZeroInputAborts(t *testing.T) { + withSimAppSequence(t, []simKey{ + {Key: tcell.KeyRune, R: '0'}, + {Key: tcell.KeyTab}, + {Key: tcell.KeyEnter}, + }) + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + _, err := ui.PromptDecryptSecret(context.Background(), "backup", "") + if err != ErrDecryptAborted { + t.Fatalf("err=%v; want %v", err, ErrDecryptAborted) } } diff --git a/internal/orchestrator/decrypt_tui_test.go b/internal/orchestrator/decrypt_tui_test.go index f08f9a04..d94b78ea 100644 --- a/internal/orchestrator/decrypt_tui_test.go +++ b/internal/orchestrator/decrypt_tui_test.go @@ -1,18 +1,12 @@ package orchestrator import ( - "context" - "errors" - "os" - "path/filepath" "testing" "time" "github.com/rivo/tview" "github.com/tis24dev/proxsave/internal/backup" - "github.com/tis24dev/proxsave/internal/logging" - "github.com/tis24dev/proxsave/internal/types" ) func TestNormalizeProxmoxVersion(t *testing.T) { @@ -66,194 +60,6 @@ func TestFilterEncryptedCandidates(t *testing.T) { } } -func TestEnsureWritablePathTUI_ReturnsCleanMissingPath(t *testing.T) { - originalFS := restoreFS - restoreFS = osFS{} - defer func() { restoreFS = originalFS }() - - tmp := t.TempDir() - target := filepath.Join(tmp, "subdir", "file.txt") - dirty := target + string(filepath.Separator) + ".." + string(filepath.Separator) + "file.txt" - - path, err := ensureWritablePathTUI(dirty, "test file", "cfg", "sig") - if err != nil { - t.Fatalf("ensureWritablePathTUI returned error: %v", err) - } - if path != target { - t.Fatalf("ensureWritablePathTUI path=%q, want %q", path, target) - } -} - -func TestEnsureWritablePathTUIOverwriteExisting(t *testing.T) { - tmp := t.TempDir() - target := filepath.Join(tmp, "existing.tar") - if err := os.WriteFile(target, []byte("payload"), 0o640); err != nil { - t.Fatalf("write existing file: %v", err) - } - - restore := stubPromptOverwriteAction(func(path, desc, failure, configPath, buildSig string) (string, error) { - if failure != "" { - t.Fatalf("unexpected failure message: %s", failure) - } - return pathActionOverwrite, nil - }) - defer restore() - - got, err := ensureWritablePathTUI(target, "archive", "cfg", "sig") - if err != nil { - t.Fatalf("ensureWritablePathTUI error: %v", err) - } - if got != target { - t.Fatalf("path = %q, want %q", got, target) - } - if _, err := os.Stat(target); !errors.Is(err, os.ErrNotExist) { - t.Fatalf("existing file should be removed, stat err=%v", err) - } -} - -func TestEnsureWritablePathTUINewPath(t *testing.T) { - tmp := t.TempDir() - existing := filepath.Join(tmp, "current.tar") - if err := os.WriteFile(existing, []byte("payload"), 0o640); err != nil { - t.Fatalf("write existing file: %v", err) - } - nextPath := filepath.Join(tmp, "new.tar") - - var promptCalls int - restorePrompt := stubPromptOverwriteAction(func(path, desc, failure, configPath, buildSig string) (string, error) { - promptCalls++ - if failure != "" { - t.Fatalf("unexpected failure message: %s", failure) - } - return pathActionNew, nil - }) - defer restorePrompt() - - restoreNew := stubPromptNewPath(func(current, configPath, buildSig string) (string, error) { - if filepath.Clean(current) != filepath.Clean(existing) { - t.Fatalf("promptNewPath received %q, want %q", current, existing) - } - return nextPath, nil - }) - defer restoreNew() - - got, err := ensureWritablePathTUI(existing, "bundle", "cfg", "sig") - if err != nil { - t.Fatalf("ensureWritablePathTUI error: %v", err) - } - if got != filepath.Clean(nextPath) { - t.Fatalf("path=%q, want %q", got, nextPath) - } - if promptCalls != 1 { - t.Fatalf("expected 1 prompt call, got %d", promptCalls) - } -} - -func TestEnsureWritablePathTUIAbortOnCancel(t *testing.T) { - path := mustCreateExistingFile(t) - restore := stubPromptOverwriteAction(func(path, desc, failure, configPath, buildSig string) (string, error) { - return pathActionCancel, nil - }) - defer restore() - - if _, err := ensureWritablePathTUI(path, "bundle", "cfg", "sig"); !errors.Is(err, ErrDecryptAborted) { - t.Fatalf("expected ErrDecryptAborted, got %v", err) - } -} - -func TestEnsureWritablePathTUIPropagatesPromptErrors(t *testing.T) { - path := mustCreateExistingFile(t) - wantErr := errors.New("boom") - restore := stubPromptOverwriteAction(func(path, desc, failure, configPath, buildSig string) (string, error) { - return "", wantErr - }) - defer restore() - - if _, err := ensureWritablePathTUI(path, "bundle", "cfg", "sig"); !errors.Is(err, wantErr) { - t.Fatalf("expected %v, got %v", wantErr, err) - } -} - -func TestEnsureWritablePathTUINewPathAbort(t *testing.T) { - path := mustCreateExistingFile(t) - restorePrompt := stubPromptOverwriteAction(func(path, desc, failure, configPath, buildSig string) (string, error) { - return pathActionNew, nil - }) - defer restorePrompt() - - restoreNew := stubPromptNewPath(func(current, configPath, buildSig string) (string, error) { - return "", ErrDecryptAborted - }) - defer restoreNew() - - if _, err := ensureWritablePathTUI(path, "bundle", "cfg", "sig"); !errors.Is(err, ErrDecryptAborted) { - t.Fatalf("expected ErrDecryptAborted, got %v", err) - } -} - -func TestPreparePlainBundleTUICopiesRawArtifacts(t *testing.T) { - logger := logging.New(types.LogLevelError, false) - tmp := t.TempDir() - rawArchive := filepath.Join(tmp, "backup.tar") - rawMetadata := rawArchive + ".metadata" - rawChecksum := rawArchive + ".sha256" - - if err := os.WriteFile(rawArchive, []byte("payload-data"), 0o640); err != nil { - t.Fatalf("write archive: %v", err) - } - if err := os.WriteFile(rawMetadata, []byte(`{"manifest":true}`), 0o640); err != nil { - t.Fatalf("write metadata: %v", err) - } - if err := os.WriteFile(rawChecksum, checksumLineForBytes("backup.tar", []byte("payload-data")), 0o640); err != nil { - t.Fatalf("write checksum: %v", err) - } - - cand := &decryptCandidate{ - Manifest: &backup.Manifest{ - ArchivePath: rawArchive, - EncryptionMode: "none", - CreatedAt: time.Now(), - Hostname: "node1", - }, - Source: sourceRaw, - RawArchivePath: rawArchive, - RawMetadataPath: rawMetadata, - RawChecksumPath: rawChecksum, - DisplayBase: "test-backup", - } - - ctx := context.Background() - prepared, err := preparePlainBundleTUI(ctx, cand, "1.0.0", logger, "cfg", "sig") - if err != nil { - t.Fatalf("preparePlainBundleTUI error: %v", err) - } - defer prepared.Cleanup() - - if prepared.ArchivePath == "" { - t.Fatalf("expected archive path to be set") - } - if prepared.Manifest.EncryptionMode != "none" { - t.Fatalf("expected manifest encryption mode none, got %s", prepared.Manifest.EncryptionMode) - } - if prepared.Manifest.ScriptVersion != "1.0.0" { - t.Fatalf("expected script version to propagate, got %s", prepared.Manifest.ScriptVersion) - } - if _, err := os.Stat(prepared.ArchivePath); err != nil { - t.Fatalf("expected staged archive to exist: %v", err) - } - if prepared.Checksum == "" { - t.Fatalf("expected checksum to be computed") - } -} - -func TestPreparePlainBundleTUIRejectsInvalidCandidate(t *testing.T) { - logger := logging.New(types.LogLevelError, false) - ctx := context.Background() - if _, err := preparePlainBundleTUI(ctx, nil, "", logger, "cfg", "sig"); err == nil { - t.Fatalf("expected error for nil candidate") - } -} - func TestBuildWizardPageReturnsFlex(t *testing.T) { content := tview.NewBox() page := buildWizardPage("Title", "/etc/proxsave/backup.env", "sig", content) @@ -264,25 +70,3 @@ func TestBuildWizardPageReturnsFlex(t *testing.T) { t.Fatalf("expected *tview.Flex, got %T", page) } } - -func stubPromptOverwriteAction(fn func(path, description, failureMessage, configPath, buildSig string) (string, error)) func() { - orig := promptOverwriteActionFunc - promptOverwriteActionFunc = fn - return func() { promptOverwriteActionFunc = orig } -} - -func stubPromptNewPath(fn func(current, configPath, buildSig string) (string, error)) func() { - orig := promptNewPathInputFunc - promptNewPathInputFunc = fn - return func() { promptNewPathInputFunc = orig } -} - -func mustCreateExistingFile(t *testing.T) string { - t.Helper() - tmp := t.TempDir() - path := filepath.Join(tmp, "existing.dat") - if err := os.WriteFile(path, []byte("data"), 0o640); err != nil { - t.Fatalf("write %s: %v", path, err) - } - return path -} diff --git a/internal/orchestrator/decrypt_workflow_ui.go b/internal/orchestrator/decrypt_workflow_ui.go index 7de3f697..7adc905f 100644 --- a/internal/orchestrator/decrypt_workflow_ui.go +++ b/internal/orchestrator/decrypt_workflow_ui.go @@ -19,6 +19,10 @@ func selectBackupCandidateWithUI(ctx context.Context, ui BackupSelectionUI, cfg done := logging.DebugStart(logger, "select backup candidate (ui)", "requireEncrypted=%v", requireEncrypted) defer func() { done(err) }() + if ui == nil { + return nil, fmt.Errorf("backup selection UI not available") + } + pathOptions := buildDecryptPathOptions(cfg, logger) if len(pathOptions) == 0 { return nil, fmt.Errorf("no backup paths configured in backup.env") @@ -177,6 +181,13 @@ func decryptArchiveWithSecretPrompt(ctx context.Context, encryptedPath, outputPa func preparePlainBundleWithUI(ctx context.Context, cand *decryptCandidate, version string, logger *logging.Logger, ui interface { PromptDecryptSecret(ctx context.Context, displayName, previousError string) (string, error) }) (bundle *preparedBundle, err error) { + if cand == nil || cand.Manifest == nil { + return nil, fmt.Errorf("invalid backup candidate") + } + if ui == nil { + return nil, fmt.Errorf("decrypt workflow UI not available") + } + done := logging.DebugStart(logger, "prepare plain bundle (ui)", "source=%v rclone=%v", cand.Source, cand.IsRclone) defer func() { done(err) }() return preparePlainBundleCommon(ctx, cand, version, logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error { @@ -191,6 +202,9 @@ func runDecryptWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l if logger == nil { logger = logging.GetDefaultLogger() } + if ui == nil { + return fmt.Errorf("decrypt workflow UI not available") + } done := logging.DebugStart(logger, "decrypt workflow (ui)", "version=%s", version) defer func() { done(err) }() defer func() { diff --git a/internal/orchestrator/decrypt_workflow_ui_test.go b/internal/orchestrator/decrypt_workflow_ui_test.go new file mode 100644 index 00000000..27a07c74 --- /dev/null +++ b/internal/orchestrator/decrypt_workflow_ui_test.go @@ -0,0 +1,298 @@ +package orchestrator + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +type fakeDecryptWorkflowUI struct { + resolveExistingPathFn func(ctx context.Context, path, description, failure string) (ExistingPathDecision, string, error) +} + +func (f *fakeDecryptWorkflowUI) RunTask(ctx context.Context, title, initialMessage string, run func(ctx context.Context, report ProgressReporter) error) error { + panic("unexpected RunTask call") +} + +func (f *fakeDecryptWorkflowUI) ShowMessage(ctx context.Context, title, message string) error { + panic("unexpected ShowMessage call") +} + +func (f *fakeDecryptWorkflowUI) ShowError(ctx context.Context, title, message string) error { + panic("unexpected ShowError call") +} + +func (f *fakeDecryptWorkflowUI) SelectBackupSource(ctx context.Context, options []decryptPathOption) (decryptPathOption, error) { + panic("unexpected SelectBackupSource call") +} + +func (f *fakeDecryptWorkflowUI) SelectBackupCandidate(ctx context.Context, candidates []*decryptCandidate) (*decryptCandidate, error) { + panic("unexpected SelectBackupCandidate call") +} + +func (f *fakeDecryptWorkflowUI) PromptDestinationDir(ctx context.Context, defaultDir string) (string, error) { + panic("unexpected PromptDestinationDir call") +} + +func (f *fakeDecryptWorkflowUI) ResolveExistingPath(ctx context.Context, path, description, failure string) (ExistingPathDecision, string, error) { + if f.resolveExistingPathFn == nil { + panic("unexpected ResolveExistingPath call") + } + return f.resolveExistingPathFn(ctx, path, description, failure) +} + +func (f *fakeDecryptWorkflowUI) PromptDecryptSecret(ctx context.Context, displayName, previousError string) (string, error) { + panic("unexpected PromptDecryptSecret call") +} + +type countingSecretPrompter struct { + calls int +} + +func (c *countingSecretPrompter) PromptDecryptSecret(ctx context.Context, displayName, previousError string) (string, error) { + c.calls++ + return "unused", nil +} + +func TestEnsureWritablePathWithUI_ReturnsCleanMissingPath(t *testing.T) { + originalFS := restoreFS + restoreFS = osFS{} + defer func() { restoreFS = originalFS }() + + tmp := t.TempDir() + target := filepath.Join(tmp, "subdir", "file.txt") + dirty := target + string(filepath.Separator) + ".." + string(filepath.Separator) + "file.txt" + + got, err := ensureWritablePathWithUI(context.Background(), &fakeDecryptWorkflowUI{}, dirty, "test file") + if err != nil { + t.Fatalf("ensureWritablePathWithUI error: %v", err) + } + if got != target { + t.Fatalf("ensureWritablePathWithUI path=%q, want %q", got, target) + } +} + +func TestEnsureWritablePathWithUI_OverwriteExisting(t *testing.T) { + tmp := t.TempDir() + target := filepath.Join(tmp, "existing.tar") + if err := os.WriteFile(target, []byte("payload"), 0o640); err != nil { + t.Fatalf("write existing file: %v", err) + } + + ui := &fakeDecryptWorkflowUI{ + resolveExistingPathFn: func(ctx context.Context, path, description, failure string) (ExistingPathDecision, string, error) { + if path != target { + t.Fatalf("path=%q, want %q", path, target) + } + if failure != "" { + t.Fatalf("unexpected failure message: %s", failure) + } + return PathDecisionOverwrite, "", nil + }, + } + + got, err := ensureWritablePathWithUI(context.Background(), ui, target, "archive") + if err != nil { + t.Fatalf("ensureWritablePathWithUI error: %v", err) + } + if got != target { + t.Fatalf("path=%q, want %q", got, target) + } + if _, err := os.Stat(target); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("existing file should be removed, stat err=%v", err) + } +} + +func TestEnsureWritablePathWithUI_NewPath(t *testing.T) { + tmp := t.TempDir() + existing := filepath.Join(tmp, "current.tar") + if err := os.WriteFile(existing, []byte("payload"), 0o640); err != nil { + t.Fatalf("write existing file: %v", err) + } + nextPath := filepath.Join(tmp, "next.tar") + + var calls int + ui := &fakeDecryptWorkflowUI{ + resolveExistingPathFn: func(ctx context.Context, path, description, failure string) (ExistingPathDecision, string, error) { + calls++ + if path != existing { + t.Fatalf("path=%q, want %q", path, existing) + } + return PathDecisionNewPath, nextPath, nil + }, + } + + got, err := ensureWritablePathWithUI(context.Background(), ui, existing, "bundle") + if err != nil { + t.Fatalf("ensureWritablePathWithUI error: %v", err) + } + if got != filepath.Clean(nextPath) { + t.Fatalf("path=%q, want %q", got, filepath.Clean(nextPath)) + } + if calls != 1 { + t.Fatalf("expected 1 ResolveExistingPath call, got %d", calls) + } +} + +func TestEnsureWritablePathWithUI_AbortOnCancelDecision(t *testing.T) { + path := mustCreateExistingFile(t) + ui := &fakeDecryptWorkflowUI{ + resolveExistingPathFn: func(ctx context.Context, path, description, failure string) (ExistingPathDecision, string, error) { + return PathDecisionCancel, "", nil + }, + } + + if _, err := ensureWritablePathWithUI(context.Background(), ui, path, "bundle"); !errors.Is(err, ErrDecryptAborted) { + t.Fatalf("expected ErrDecryptAborted, got %v", err) + } +} + +func TestEnsureWritablePathWithUI_PropagatesPromptErrors(t *testing.T) { + path := mustCreateExistingFile(t) + wantErr := errors.New("boom") + ui := &fakeDecryptWorkflowUI{ + resolveExistingPathFn: func(ctx context.Context, path, description, failure string) (ExistingPathDecision, string, error) { + return PathDecisionCancel, "", wantErr + }, + } + + if _, err := ensureWritablePathWithUI(context.Background(), ui, path, "bundle"); !errors.Is(err, wantErr) { + t.Fatalf("expected %v, got %v", wantErr, err) + } +} + +func TestPreparePlainBundleWithUICopiesRawArtifacts(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + tmp := t.TempDir() + rawArchive := filepath.Join(tmp, "backup.tar") + rawMetadata := rawArchive + ".metadata" + rawChecksum := rawArchive + ".sha256" + + if err := os.WriteFile(rawArchive, []byte("payload-data"), 0o640); err != nil { + t.Fatalf("write archive: %v", err) + } + if err := os.WriteFile(rawMetadata, []byte(`{"manifest":true}`), 0o640); err != nil { + t.Fatalf("write metadata: %v", err) + } + if err := os.WriteFile(rawChecksum, checksumLineForBytes("backup.tar", []byte("payload-data")), 0o640); err != nil { + t.Fatalf("write checksum: %v", err) + } + + cand := &decryptCandidate{ + Manifest: &backup.Manifest{ + ArchivePath: rawArchive, + EncryptionMode: "none", + CreatedAt: time.Now(), + Hostname: "node1", + }, + Source: sourceRaw, + RawArchivePath: rawArchive, + RawMetadataPath: rawMetadata, + RawChecksumPath: rawChecksum, + DisplayBase: "test-backup", + } + + ctx := context.Background() + prompter := &countingSecretPrompter{} + prepared, err := preparePlainBundleWithUI(ctx, cand, "1.0.0", logger, prompter) + if err != nil { + t.Fatalf("preparePlainBundleWithUI error: %v", err) + } + defer prepared.Cleanup() + + if prepared.ArchivePath == "" { + t.Fatalf("expected archive path to be set") + } + if prepared.Manifest.EncryptionMode != "none" { + t.Fatalf("expected manifest encryption mode none, got %s", prepared.Manifest.EncryptionMode) + } + if prepared.Manifest.ScriptVersion != "1.0.0" { + t.Fatalf("expected script version to propagate, got %s", prepared.Manifest.ScriptVersion) + } + if _, err := os.Stat(prepared.ArchivePath); err != nil { + t.Fatalf("expected staged archive to exist: %v", err) + } + if prepared.Checksum == "" { + t.Fatalf("expected checksum to be computed") + } + if prompter.calls != 0 { + t.Fatalf("PromptDecryptSecret should not be called for plain backups, got %d calls", prompter.calls) + } +} + +func TestPreparePlainBundleWithUIRejectsInvalidCandidate(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + ctx := context.Background() + prompter := &countingSecretPrompter{} + if _, err := preparePlainBundleWithUI(ctx, nil, "", logger, prompter); err == nil { + t.Fatalf("expected error for nil candidate") + } +} + +func TestPreparePlainBundleWithUIRejectsMissingUI(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + tmp := t.TempDir() + rawArchive := filepath.Join(tmp, "backup.tar") + rawMetadata := rawArchive + ".metadata" + rawChecksum := rawArchive + ".sha256" + + if err := os.WriteFile(rawArchive, []byte("payload-data"), 0o640); err != nil { + t.Fatalf("write archive: %v", err) + } + if err := os.WriteFile(rawMetadata, []byte(`{"manifest":true}`), 0o640); err != nil { + t.Fatalf("write metadata: %v", err) + } + if err := os.WriteFile(rawChecksum, checksumLineForBytes("backup.tar", []byte("payload-data")), 0o640); err != nil { + t.Fatalf("write checksum: %v", err) + } + + cand := &decryptCandidate{ + Manifest: &backup.Manifest{ + ArchivePath: rawArchive, + EncryptionMode: "none", + CreatedAt: time.Now(), + Hostname: "node1", + }, + Source: sourceRaw, + RawArchivePath: rawArchive, + RawMetadataPath: rawMetadata, + RawChecksumPath: rawChecksum, + DisplayBase: "test-backup", + } + + if _, err := preparePlainBundleWithUI(context.Background(), cand, "1.0.0", logger, nil); err == nil { + t.Fatalf("expected error for missing UI") + } +} + +func TestRunDecryptWorkflowWithUIRejectsMissingUI(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{} + + err := runDecryptWorkflowWithUI(context.Background(), cfg, logger, "1.0.0", nil) + if err == nil { + t.Fatal("expected error for missing UI") + } + if got, want := err.Error(), "decrypt workflow UI not available"; got != want { + t.Fatalf("error=%q, want %q", got, want) + } +} + +func mustCreateExistingFile(t *testing.T) string { + t.Helper() + + tmp := t.TempDir() + path := filepath.Join(tmp, "existing.dat") + if err := os.WriteFile(path, []byte("data"), 0o640); err != nil { + t.Fatalf("write %s: %v", path, err) + } + return path +} diff --git a/internal/orchestrator/deps_test.go b/internal/orchestrator/deps_test.go index 986051fa..b2fbfb3e 100644 --- a/internal/orchestrator/deps_test.go +++ b/internal/orchestrator/deps_test.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "testing" "time" "github.com/tis24dev/proxsave/internal/logging" @@ -36,8 +37,34 @@ func NewFakeFS() *FakeFS { func (f *FakeFS) onDisk(path string) string { clean := filepath.Clean(path) - clean = strings.TrimPrefix(clean, string(filepath.Separator)) - return filepath.Join(f.Root, clean) + root := filepath.Clean(f.Root) + if clean == root || strings.HasPrefix(clean, root+string(filepath.Separator)) { + return clean + } + + mapped := clean + if filepath.IsAbs(mapped) { + fsRoot := filepath.VolumeName(mapped) + string(filepath.Separator) + rel, err := filepath.Rel(fsRoot, mapped) + if err != nil { + return root + } + mapped = rel + } + + candidate := filepath.Join(root, mapped) + rel, err := filepath.Rel(root, candidate) + if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return root + } + return filepath.Join(root, rel) +} + +func (f *FakeFS) Cleanup() error { + if f == nil { + return nil + } + return os.RemoveAll(f.Root) } // AddFile creates a file with content. @@ -250,3 +277,31 @@ func (f *FakeCommandRunner) RunStream(ctx context.Context, name string, stdin io } return io.NopCloser(strings.NewReader(string(out))), nil } + +func TestFakeFSOnDiskMapsAbsolutePathsUnderRoot(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = fakeFS.Cleanup() }) + + got := fakeFS.onDisk("/etc/pve/corosync.conf") + want := filepath.Join(fakeFS.Root, "etc", "pve", "corosync.conf") + if got != want { + t.Fatalf("onDisk=%q, want %q", got, want) + } +} + +func TestFakeFSOnDiskBlocksUpwardTraversal(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = fakeFS.Cleanup() }) + + got := fakeFS.onDisk("../x") + rel, err := filepath.Rel(fakeFS.Root, got) + if err != nil { + t.Fatalf("filepath.Rel error: %v", err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + t.Fatalf("onDisk escaped root: got %q (rel %q)", got, rel) + } + if got != fakeFS.Root { + t.Fatalf("onDisk=%q, want sandbox root %q", got, fakeFS.Root) + } +} diff --git a/internal/orchestrator/encryption.go b/internal/orchestrator/encryption.go index aacfbb46..7a642e88 100644 --- a/internal/orchestrator/encryption.go +++ b/internal/orchestrator/encryption.go @@ -6,9 +6,11 @@ import ( "context" "errors" "fmt" + "io" "os" "path/filepath" "strings" + "syscall" "time" "unicode" @@ -58,47 +60,8 @@ func (o *Orchestrator) EnsureAgeRecipientsReady(ctx context.Context) error { } func (o *Orchestrator) prepareAgeRecipients(ctx context.Context) ([]age.Recipient, error) { - if o.cfg == nil || !o.cfg.EncryptArchive { - return nil, nil - } - - if o.ageRecipientCache != nil && !o.forceNewAgeRecipient { - return cloneRecipients(o.ageRecipientCache), nil - } - - recipients, candidatePath, err := o.collectRecipientStrings() - if err != nil { - return nil, err - } - - if len(recipients) == 0 { - if !o.isInteractiveShell() { - o.logger.Error("Encryption setup requires interaction. Run the script interactively to complete the AGE recipient setup, then re-run in automated mode.") - o.logger.Debug("HINT Set AGE_RECIPIENT or AGE_RECIPIENT_FILE to bypass the interactive setup and re-run.") - return nil, fmt.Errorf("age recipients not configured") - } - - wizardRecipients, savedPath, err := o.runAgeSetupWizard(ctx, candidatePath) - if err != nil { - return nil, err - } - recipients = append(recipients, wizardRecipients...) - if o.cfg.AgeRecipientFile == "" { - o.cfg.AgeRecipientFile = savedPath - } - } - - if len(recipients) == 0 { - return nil, fmt.Errorf("no AGE recipients configured after setup") - } - - parsed, err := parseRecipientStrings(recipients) - if err != nil { - return nil, err - } - o.ageRecipientCache = cloneRecipients(parsed) - o.forceNewAgeRecipient = false - return cloneRecipients(parsed), nil + recipients, _, err := o.prepareAgeRecipientsWithUI(ctx, nil) + return recipients, err } func (o *Orchestrator) collectRecipientStrings() ([]string, string, error) { @@ -129,94 +92,22 @@ func (o *Orchestrator) collectRecipientStrings() ([]string, string, error) { // runAgeSetupWizard collects AGE recipients interactively. // Returns (fileRecipients, savedPath, error) func (o *Orchestrator) runAgeSetupWizard(ctx context.Context, candidatePath string) ([]string, string, error) { - reader := bufio.NewReader(os.Stdin) - targetPath := candidatePath - if targetPath == "" { - targetPath = o.defaultAgeRecipientFile() + if o == nil { + return nil, "", fmt.Errorf("orchestrator is required") } - o.logger.Info("Encryption setup: no AGE recipients found, starting interactive wizard") - if targetPath == "" { - return nil, "", fmt.Errorf("unable to determine default path for AGE recipients") - } - - // Create a child context for the wizard to handle Ctrl+C locally wizardCtx, wizardCancel := context.WithCancel(ctx) defer wizardCancel() - recipientPath := targetPath - if o.forceNewAgeRecipient && recipientPath != "" { - if _, err := os.Stat(recipientPath); err == nil { - fmt.Printf("WARNING: this will remove the existing AGE recipients stored at %s. Existing backups remain decryptable with your old private key.\n", recipientPath) - confirm, errPrompt := promptYesNoAge(wizardCtx, reader, fmt.Sprintf("Delete %s and enter a new recipient? [y/N]: ", recipientPath)) - if errPrompt != nil { - return nil, "", errPrompt - } - if !confirm { - return nil, "", fmt.Errorf("operation aborted by user") - } - if err := backupExistingRecipientFile(recipientPath); err != nil { - fmt.Printf("NOTE: %v\n", err) - } - } else if !errors.Is(err, os.ErrNotExist) { - return nil, "", fmt.Errorf("failed to inspect existing AGE recipients at %s: %w", recipientPath, err) - } - } - - recipients := make([]string, 0) - for { - fmt.Println("\n[1] Use an existing AGE public key") - fmt.Println("[2] Generate an AGE public key using a personal passphrase/password — not stored on the server") - fmt.Println("[3] Generate an AGE public key from an existing personal private key — not stored on the server") - fmt.Println("[4] Exit setup") - option, err := promptOptionAge(wizardCtx, reader, "Select an option [1-4]: ") - if err != nil { - return nil, "", err - } - if option == "4" { - return nil, "", ErrAgeRecipientSetupAborted - } - - var value string - switch option { - case "1": - value, err = promptPublicRecipientAge(wizardCtx, reader) - case "2": - value, err = promptPassphraseRecipientAge(wizardCtx) - if err == nil { - o.logger.Info("Derived deterministic AGE public key from passphrase (no secrets stored)") - } - case "3": - value, err = promptPrivateKeyRecipientAge(wizardCtx) - } - if err != nil { - o.logger.Warning("Encryption setup: %v", err) - continue - } - if value != "" { - recipients = append(recipients, value) - } - - more, err := promptYesNoAge(wizardCtx, reader, "Add another recipient? [y/N]: ") - if err != nil { - return nil, "", err - } - if !more { - break - } - } - - if len(recipients) == 0 { - return nil, "", fmt.Errorf("no recipients provided") - } - - if err := writeRecipientFile(targetPath, dedupeRecipientStrings(recipients)); err != nil { + recipients, result, err := o.runAgeSetupWorkflow(wizardCtx, candidatePath, newCLIAgeSetupUI(bufio.NewReader(os.Stdin), o.logger)) + if err != nil { return nil, "", err } - - o.logger.Info("Saved AGE recipient to %s", targetPath) - o.logger.Info("Reminder: keep the AGE private key offline; the server stores only recipients.") - return recipients, targetPath, nil + savedPath := "" + if result != nil { + savedPath = result.RecipientPath + } + return recipients, savedPath, nil } func (o *Orchestrator) defaultAgeRecipientFile() string { @@ -262,6 +153,16 @@ func promptPublicRecipientAge(ctx context.Context, reader *bufio.Reader) (string } func promptPrivateKeyRecipientAge(ctx context.Context) (string, error) { + secret, err := promptPrivateKeyValueAge(ctx) + if err != nil { + return "", err + } + defer resetString(&secret) + + return ParseAgePrivateKeyRecipient(secret) +} + +func promptPrivateKeyValueAge(ctx context.Context) (string, error) { fmt.Print("Paste your AGE private key (not stored; input is not echoed). Press Enter when done: ") secretBytes, err := input.ReadPasswordWithContext(ctx, readPassword, int(os.Stdin.Fd())) fmt.Println() @@ -271,15 +172,14 @@ func promptPrivateKeyRecipientAge(ctx context.Context) (string, error) { defer zeroBytes(secretBytes) secret := strings.TrimSpace(string(secretBytes)) - defer resetString(&secret) if secret == "" { return "", fmt.Errorf("private key cannot be empty") } - identity, err := age.ParseX25519Identity(secret) - if err != nil { - return "", fmt.Errorf("invalid AGE private key: %w", err) + if err := ValidateAgePrivateKeyString(secret); err != nil { + resetString(&secret) + return "", err } - return identity.Recipient().String(), nil + return secret, nil } // promptPassphraseRecipient derives a deterministic AGE public key from a passphrase @@ -420,19 +320,23 @@ func readRecipientFile(path string) ([]string, error) { } func writeRecipientFile(path string, recipients []string) error { + return writeRecipientFileWithDeps(osFS{}, realTimeProvider{}, path, recipients) +} + +func writeRecipientFileWithDeps(fs FS, tp TimeProvider, path string, recipients []string) error { + if fs == nil { + fs = osFS{} + } if len(recipients) == 0 { return fmt.Errorf("no recipients to write") } - if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + if err := fs.MkdirAll(filepath.Dir(path), 0o700); err != nil { return fmt.Errorf("create recipient directory: %w", err) } content := strings.Join(recipients, "\n") + "\n" - if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + if err := writeFileAtomicWithDeps(fs, tp, path, []byte(content), 0o600); err != nil { return fmt.Errorf("write recipient file: %w", err) } - if err := os.Chmod(path, 0o600); err != nil { - return fmt.Errorf("chmod recipient file: %w", err) - } return nil } @@ -456,26 +360,153 @@ func mapInputAbortToAgeAbort(err error) error { } func backupExistingRecipientFile(path string) error { + _, err := backupExistingRecipientFileWithDeps(osFS{}, realTimeProvider{}, path) + return err +} + +func backupExistingRecipientFileWithDeps(fs FS, tp TimeProvider, path string) (string, error) { + if fs == nil { + fs = osFS{} + } if path == "" { - return nil + return "", nil } - if _, err := os.Stat(path); err != nil { - if os.IsNotExist(err) { - return nil + info, err := fs.Stat(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return "", nil } + return "", err + } + if info.IsDir() { + return "", fmt.Errorf("recipient path is a directory: %s", path) + } + perm := info.Mode().Perm() + if perm == 0 { + perm = 0o600 + } + ts := recipientTime(tp) + backupPath := fmt.Sprintf("%s.bak-%s", path, ts.Format("20060102-150405.000000000")) + if err := copyRecipientFileWithDeps(fs, path, backupPath, perm); err != nil { + return "", fmt.Errorf("backup recipient file: %w", err) + } + return backupPath, nil +} + +func writeFileAtomicWithDeps(fs FS, tp TimeProvider, path string, data []byte, perm os.FileMode) error { + if fs == nil { + fs = osFS{} + } + perm &= 0o7777 + if perm == 0 { + perm = 0o600 + } + + tmpPath := fmt.Sprintf("%s.proxsave.tmp.%d", path, recipientTime(tp).UnixNano()) + tmpFile, err := fs.OpenFile(tmpPath, os.O_CREATE|os.O_WRONLY|os.O_EXCL|os.O_TRUNC, perm) + if err != nil { return err } - backupPath := fmt.Sprintf("%s.bak-%s", path, time.Now().Format("20060102-150405")) - if err := os.Rename(path, backupPath); err != nil { - if removeErr := os.Remove(path); removeErr != nil { - return fmt.Errorf("failed to backup recipient file: %w (also failed to remove: %v)", err, removeErr) + + writeErr := func() error { + if len(data) != 0 { + if _, err := tmpFile.Write(data); err != nil { + return err + } + } + if err := tmpFile.Chmod(perm); err != nil { + return err } - return fmt.Errorf("renamed recipient file failed, removed original: %w", err) + return tmpFile.Sync() + }() + + closeErr := tmpFile.Close() + if writeErr != nil { + _ = fs.Remove(tmpPath) + return writeErr } - return nil + if closeErr != nil { + _ = fs.Remove(tmpPath) + return closeErr + } + + if err := fs.Rename(tmpPath, path); err != nil { + _ = fs.Remove(tmpPath) + return err + } + + return syncDirectoryWithDeps(fs, filepath.Dir(path)) +} + +func copyRecipientFileWithDeps(fs FS, src, dest string, perm os.FileMode) error { + if fs == nil { + fs = osFS{} + } + + in, err := fs.Open(src) + if err != nil { + return err + } + defer in.Close() + + out, err := fs.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_EXCL, perm) + if err != nil { + return err + } + + copyErr := func() error { + if _, err := io.Copy(out, in); err != nil { + return err + } + if err := out.Chmod(perm); err != nil { + return err + } + return out.Sync() + }() + + closeErr := out.Close() + if copyErr != nil { + _ = fs.Remove(dest) + return copyErr + } + if closeErr != nil { + _ = fs.Remove(dest) + return closeErr + } + + return syncDirectoryWithDeps(fs, filepath.Dir(dest)) } -// BackupAgeRecipientFile backs up an existing AGE recipient file (if present). +func syncDirectoryWithDeps(fs FS, dir string) error { + if fs == nil { + fs = osFS{} + } + + df, err := fs.Open(dir) + if err != nil { + return fmt.Errorf("open dir %s: %w", dir, err) + } + + syncErr := df.Sync() + closeErr := df.Close() + if syncErr != nil { + if errors.Is(syncErr, syscall.EINVAL) || errors.Is(syncErr, syscall.ENOTSUP) { + return closeErr + } + return fmt.Errorf("sync dir %s: %w", dir, syncErr) + } + return closeErr +} + +func recipientTime(tp TimeProvider) time.Time { + if tp != nil { + return tp.Now() + } + return time.Now() +} + +// BackupAgeRecipientFile backs up an existing AGE recipient file (if present) +// without removing the active file. func BackupAgeRecipientFile(path string) error { return backupExistingRecipientFile(path) } @@ -495,6 +526,25 @@ func ValidateRecipientString(value string) error { return err } +// ValidateAgePrivateKeyString checks whether a private AGE identity is valid. +func ValidateAgePrivateKeyString(value string) error { + _, err := ParseAgePrivateKeyRecipient(value) + return err +} + +// ParseAgePrivateKeyRecipient validates a private AGE identity and returns its public recipient. +func ParseAgePrivateKeyRecipient(value string) (string, error) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "", fmt.Errorf("private key cannot be empty") + } + identity, err := age.ParseX25519Identity(trimmed) + if err != nil { + return "", fmt.Errorf("invalid AGE private key: %w", err) + } + return identity.Recipient().String(), nil +} + // DedupeRecipientStrings removes empty values and duplicates from recipient strings. func DedupeRecipientStrings(values []string) []string { return dedupeRecipientStrings(values) diff --git a/internal/orchestrator/encryption_exported_test.go b/internal/orchestrator/encryption_exported_test.go index 912f991f..980634d9 100644 --- a/internal/orchestrator/encryption_exported_test.go +++ b/internal/orchestrator/encryption_exported_test.go @@ -118,8 +118,19 @@ func TestBackupAgeRecipientFileExported(t *testing.T) { if err != nil || len(matches) != 1 { t.Fatalf("expected backup file, got %v err=%v", matches, err) } - if _, err := os.Stat(path); !os.IsNotExist(err) { - t.Fatalf("original path should have been moved, stat err=%v", err) + original, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile(%s): %v", path, err) + } + if got := string(original); got != "old" { + t.Fatalf("original content=%q; want %q", got, "old") + } + backup, err := os.ReadFile(matches[0]) + if err != nil { + t.Fatalf("ReadFile(%s): %v", matches[0], err) + } + if got := string(backup); got != "old" { + t.Fatalf("backup content=%q; want %q", got, "old") } } @@ -249,9 +260,13 @@ func TestRunAgeSetupWizard_ExitReturnsAborted(t *testing.T) { func TestRunAgeSetupWizard_Option1WritesFile(t *testing.T) { tmp := t.TempDir() + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } inputFile := filepath.Join(tmp, "stdin.txt") // Option 1 -> recipient -> no more recipients. - if err := os.WriteFile(inputFile, []byte("1\nage1alpha\nn\n"), 0o600); err != nil { + if err := os.WriteFile(inputFile, []byte("1\n"+id.Recipient().String()+"\nn\n"), 0o600); err != nil { t.Fatalf("write stdin: %v", err) } f, err := os.Open(inputFile) @@ -272,14 +287,14 @@ func TestRunAgeSetupWizard_Option1WritesFile(t *testing.T) { if savedPath == "" { t.Fatalf("expected saved path") } - if len(out) != 1 || out[0] != "age1alpha" { - t.Fatalf("out=%v; want %v", out, []string{"age1alpha"}) + if len(out) != 1 || out[0] != id.Recipient().String() { + t.Fatalf("out=%v; want %v", out, []string{id.Recipient().String()}) } data, err := os.ReadFile(savedPath) if err != nil { t.Fatalf("read saved: %v", err) } - if string(data) != "age1alpha\n" { - t.Fatalf("saved content=%q; want %q", string(data), "age1alpha\n") + if string(data) != id.Recipient().String()+"\n" { + t.Fatalf("saved content=%q; want %q", string(data), id.Recipient().String()+"\n") } } diff --git a/internal/orchestrator/encryption_more_test.go b/internal/orchestrator/encryption_more_test.go index 415c0360..e1dcfa0b 100644 --- a/internal/orchestrator/encryption_more_test.go +++ b/internal/orchestrator/encryption_more_test.go @@ -123,7 +123,7 @@ func TestPrepareAgeRecipients_InteractiveWizardSetsRecipientFile(t *testing.T) { } } -func TestRunAgeSetupWizard_ForceNewRecipientBacksUpExistingFile(t *testing.T) { +func TestRunAgeSetupWizard_ForceNewRecipientAbortKeepsExistingFile(t *testing.T) { tmp := t.TempDir() target := filepath.Join(tmp, "identity", "age", "recipient.txt") if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { @@ -175,21 +175,18 @@ func TestRunAgeSetupWizard_ForceNewRecipientBacksUpExistingFile(t *testing.T) { } matches, err := filepath.Glob(target + ".bak-*") - if err != nil || len(matches) != 1 { - t.Fatalf("expected backup file, got %v err=%v", matches, err) + if err != nil { + t.Fatalf("Glob(%s): %v", target+".bak-*", err) } - - // Ensure original was moved away. - if _, err := os.Stat(target); !os.IsNotExist(err) { - t.Fatalf("expected original to be moved, stat err=%v", err) + if len(matches) != 0 { + t.Fatalf("expected no backup file on abort, got %v", matches) } - // Ensure the old recipient didn't get replaced during abort. - data, err := os.ReadFile(matches[0]) + data, err := os.ReadFile(target) if err != nil { - t.Fatalf("ReadFile backup: %v", err) + t.Fatalf("ReadFile(%s): %v", target, err) } if strings.TrimSpace(string(data)) != "old" { - t.Fatalf("backup content=%q want=%q", strings.TrimSpace(string(data)), "old") + t.Fatalf("original content=%q want=%q", strings.TrimSpace(string(data)), "old") } } diff --git a/internal/orchestrator/encryption_test.go b/internal/orchestrator/encryption_test.go index 2e5a8144..e46a75b3 100644 --- a/internal/orchestrator/encryption_test.go +++ b/internal/orchestrator/encryption_test.go @@ -173,7 +173,18 @@ func TestBackupExistingRecipientFileCreatesBackup(t *testing.T) { if err != nil || len(matches) != 1 { t.Fatalf("expected backup file, got %v err=%v", matches, err) } - if _, err := os.Stat(path); !os.IsNotExist(err) { - t.Fatalf("original path should have been moved, stat err=%v", err) + original, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile(%s): %v", path, err) + } + if got := string(original); got != "old" { + t.Fatalf("original content=%q; want %q", got, "old") + } + backup, err := os.ReadFile(matches[0]) + if err != nil { + t.Fatalf("ReadFile(%s): %v", matches[0], err) + } + if got := string(backup); got != "old" { + t.Fatalf("backup content=%q; want %q", got, "old") } } diff --git a/internal/orchestrator/helpers_test.go b/internal/orchestrator/helpers_test.go index a01050c9..ea01bff9 100644 --- a/internal/orchestrator/helpers_test.go +++ b/internal/orchestrator/helpers_test.go @@ -794,48 +794,24 @@ func TestFormatBytesHR(t *testing.T) { func TestCalculateUsagePercent(t *testing.T) { tests := []struct { name string - freeBytes uint64 + usedBytes uint64 total uint64 want float64 }{ {"zero total", 0, 0, 0.0}, {"50% used", 500, 1000, 50.0}, - {"100% full", 0, 1000, 100.0}, - {"empty disk", 1000, 1000, 0.0}, - {"25% used", 750, 1000, 25.0}, + {"100% full", 1000, 1000, 100.0}, + {"empty disk", 0, 1000, 0.0}, + {"25% used", 250, 1000, 25.0}, + {"used exceeds total", 1500, 1000, 100.0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := calculateUsagePercent(tt.freeBytes, tt.total) + got := calculateUsagePercent(tt.usedBytes, tt.total) if got != tt.want { t.Errorf("calculateUsagePercent(%d, %d) = %f; want %f", - tt.freeBytes, tt.total, got, tt.want) - } - }) - } -} - -func TestCalculateUsedBytes(t *testing.T) { - tests := []struct { - name string - freeBytes uint64 - total uint64 - want uint64 - }{ - {"zero total", 0, 0, 0}, - {"normal usage", 300, 1000, 700}, - {"full disk", 0, 1000, 1000}, - {"empty disk", 1000, 1000, 0}, - {"free > total (invalid)", 1500, 1000, 0}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := calculateUsedBytes(tt.freeBytes, tt.total) - if got != tt.want { - t.Errorf("calculateUsedBytes(%d, %d) = %d; want %d", - tt.freeBytes, tt.total, got, tt.want) + tt.usedBytes, tt.total, got, tt.want) } }) } diff --git a/internal/orchestrator/network_apply_workflow_ui.go b/internal/orchestrator/network_apply_workflow_ui.go index b150c897..716c4021 100644 --- a/internal/orchestrator/network_apply_workflow_ui.go +++ b/internal/orchestrator/network_apply_workflow_ui.go @@ -553,7 +553,7 @@ func (u *tuiWorkflowUI) ConfirmAction(ctx context.Context, title, message, yesLa if timeout > 0 { return promptYesNoTUIWithCountdown(ctx, u.logger, title, u.configPath, u.buildSig, message, yesLabel, noLabel, timeout) } - return promptYesNoTUIFunc(title, u.configPath, u.buildSig, message, yesLabel, noLabel) + return promptYesNoTUIFunc(ctx, title, u.configPath, u.buildSig, message, yesLabel, noLabel) } func (u *tuiWorkflowUI) RepairNICNames(ctx context.Context, archivePath string) (*nicRepairResult, error) { @@ -564,7 +564,7 @@ func (u *tuiWorkflowUI) PromptNetworkCommit(ctx context.Context, remaining time. if err := ctx.Err(); err != nil { return false, err } - committed, err := promptNetworkCommitTUI(remaining, health, nicRepair, diagnosticsDir, u.configPath, u.buildSig) + committed, err := promptNetworkCommitTUI(ctx, remaining, health, nicRepair, diagnosticsDir, u.configPath, u.buildSig) if err != nil && errors.Is(err, input.ErrInputAborted) { return false, err } diff --git a/internal/orchestrator/notification_adapter.go b/internal/orchestrator/notification_adapter.go index 2b4605ae..d68fcbb8 100644 --- a/internal/orchestrator/notification_adapter.go +++ b/internal/orchestrator/notification_adapter.go @@ -160,17 +160,19 @@ func (n *NotificationAdapter) convertBackupStatsToNotificationData(stats *Backup compressionRatio = (1.0 - float64(stats.CompressedSize)/float64(stats.UncompressedSize)) * 100.0 } + n.warnOnInconsistentUsageStats("local", stats.LocalUsedSpace, stats.LocalTotalSpace) localFree := formatBytesHR(stats.LocalFreeSpace) - localUsed := formatBytesHR(calculateUsedBytes(stats.LocalFreeSpace, stats.LocalTotalSpace)) - localPercent := formatPercentString(calculateUsagePercent(stats.LocalFreeSpace, stats.LocalTotalSpace)) + localUsed := formatBytesHR(stats.LocalUsedSpace) + localPercent := formatPercentString(calculateUsagePercent(stats.LocalUsedSpace, stats.LocalTotalSpace)) secondaryFree := "" secondaryUsed := "" secondaryPercent := "" if stats.SecondaryEnabled { + n.warnOnInconsistentUsageStats("secondary", stats.SecondaryUsedSpace, stats.SecondaryTotalSpace) secondaryFree = formatBytesHR(stats.SecondaryFreeSpace) - secondaryUsed = formatBytesHR(calculateUsedBytes(stats.SecondaryFreeSpace, stats.SecondaryTotalSpace)) - secondaryPercent = formatPercentString(calculateUsagePercent(stats.SecondaryFreeSpace, stats.SecondaryTotalSpace)) + secondaryUsed = formatBytesHR(stats.SecondaryUsedSpace) + secondaryPercent = formatPercentString(calculateUsagePercent(stats.SecondaryUsedSpace, stats.SecondaryTotalSpace)) } // Parse log file for categories - use ParseLogCounts as primary source @@ -224,7 +226,7 @@ func (n *NotificationAdapter) convertBackupStatsToNotificationData(stats *Backup LocalUsed: localUsed, LocalPercent: localPercent, LocalSpaceBytes: stats.LocalFreeSpace, - LocalUsagePercent: calculateUsagePercent(stats.LocalFreeSpace, stats.LocalTotalSpace), + LocalUsagePercent: calculateUsagePercent(stats.LocalUsedSpace, stats.LocalTotalSpace), // Local retention info LocalRetentionPolicy: stats.LocalRetentionPolicy, @@ -247,7 +249,7 @@ func (n *NotificationAdapter) convertBackupStatsToNotificationData(stats *Backup SecondaryUsed: secondaryUsed, SecondaryPercent: secondaryPercent, SecondarySpaceBytes: stats.SecondaryFreeSpace, - SecondaryUsagePercent: calculateUsagePercent(stats.SecondaryFreeSpace, stats.SecondaryTotalSpace), + SecondaryUsagePercent: calculateUsagePercent(stats.SecondaryUsedSpace, stats.SecondaryTotalSpace), // Secondary retention info SecondaryRetentionPolicy: stats.SecondaryRetentionPolicy, @@ -318,20 +320,31 @@ func formatBytesHR(bytes uint64) string { return fmt.Sprintf("%.2f %s", val, units[exp]) } -// calculateUsagePercent calculates the usage percentage -func calculateUsagePercent(freeBytes, totalBytes uint64) float64 { +// calculateUsagePercent calculates the usage percentage from used and total bytes. +func calculateUsagePercent(usedBytes, totalBytes uint64) float64 { if totalBytes == 0 { return 0.0 } - usedBytes := totalBytes - freeBytes + if usedBytes >= totalBytes { + return 100.0 + } return (float64(usedBytes) / float64(totalBytes)) * 100.0 } -func calculateUsedBytes(freeBytes, totalBytes uint64) uint64 { - if totalBytes == 0 || totalBytes <= freeBytes { - return 0 +func (n *NotificationAdapter) warnOnInconsistentUsageStats(location string, usedBytes, totalBytes uint64) { + if n == nil || n.logger == nil { + return + } + if totalBytes == 0 { + if usedBytes > 0 { + n.logger.Warning("%s storage usage stats inconsistent: used=%d total=%d; reporting 0%% usage for display because total capacity is unknown", location, usedBytes, totalBytes) + } + return + } + if usedBytes <= totalBytes { + return } - return totalBytes - freeBytes + n.logger.Warning("%s storage usage stats inconsistent: used=%d total=%d; clamping percentage to 100%% for display", location, usedBytes, totalBytes) } func formatPercentString(percent float64) string { diff --git a/internal/orchestrator/notification_adapter_test.go b/internal/orchestrator/notification_adapter_test.go index a4c84c37..e7b53ee1 100644 --- a/internal/orchestrator/notification_adapter_test.go +++ b/internal/orchestrator/notification_adapter_test.go @@ -210,12 +210,14 @@ func TestConvertBackupStatsToNotificationData(t *testing.T) { CompressedSize: 4000, UncompressedSize: 8000, LocalBackups: 2, - LocalFreeSpace: 1024, - LocalTotalSpace: 2048, + LocalFreeSpace: 500, + LocalUsedSpace: 500, + LocalTotalSpace: 1000, SecondaryEnabled: true, SecondaryBackups: 1, - SecondaryFreeSpace: 2048, - SecondaryTotalSpace: 4096, + SecondaryFreeSpace: 1500, + SecondaryUsedSpace: 2500, + SecondaryTotalSpace: 4000, CloudEnabled: true, CloudBackups: 3, MaxLocalBackups: 10, @@ -267,6 +269,12 @@ func TestConvertBackupStatsToNotificationData(t *testing.T) { if data.EmailStatus != "disabled" || data.TelegramStatus != "N/A" { t.Fatalf("Email/Telegram status unexpected: %q / %q", data.EmailStatus, data.TelegramStatus) } + if data.LocalUsed != "500 B" || data.LocalPercent != "50.0%" { + t.Fatalf("local usage should use provided used-space stats, got used=%q percent=%q", data.LocalUsed, data.LocalPercent) + } + if data.SecondaryUsed != "2.44 KB" || data.SecondaryPercent != "62.5%" { + t.Fatalf("secondary usage should use provided used-space stats, got used=%q percent=%q", data.SecondaryUsed, data.SecondaryPercent) + } } func TestFormatHelpers(t *testing.T) { @@ -276,11 +284,11 @@ func TestFormatHelpers(t *testing.T) { if got := formatBytesHR(1024); got != "1.00 KB" { t.Fatalf("formatBytesHR(1024) = %q; want 1.00 KB", got) } - if got := calculateUsagePercent(25, 100); got != 75 { - t.Fatalf("calculateUsagePercent = %f; want 75", got) + if got := calculateUsagePercent(25, 100); got != 25 { + t.Fatalf("calculateUsagePercent = %f; want 25", got) } - if got := calculateUsedBytes(25, 100); got != 75 { - t.Fatalf("calculateUsedBytes = %d; want 75", got) + if got := calculateUsagePercent(125, 100); got != 100 { + t.Fatalf("calculateUsagePercent should clamp at 100, got %f", got) } if got := formatPercentString(12.345); got != "12.3%" { t.Fatalf("formatPercentString = %q; want 12.3%%", got) @@ -357,6 +365,59 @@ func TestConvertBackupStatsUsesLogCountsAndCompressionFallback(t *testing.T) { } } +func TestConvertBackupStatsToNotificationDataWarnsOnInconsistentUsageStats(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + var buf bytes.Buffer + logger.SetOutput(&buf) + + adapter := NewNotificationAdapter(&stubNotifier{name: "Email", enabled: true}, logger) + stats := sampleBackupStats() + stats.LocalUsedSpace = 1500 + stats.LocalTotalSpace = 1000 + stats.SecondaryUsedSpace = 4500 + stats.SecondaryTotalSpace = 4000 + + data := adapter.convertBackupStatsToNotificationData(stats) + + if data.LocalUsagePercent != 100 || data.SecondaryUsagePercent != 100 { + t.Fatalf("usage percent should still clamp for display, got local=%f secondary=%f", data.LocalUsagePercent, data.SecondaryUsagePercent) + } + logOutput := buf.String() + if !strings.Contains(logOutput, "local storage usage stats inconsistent") { + t.Fatalf("expected local inconsistency warning, got %q", logOutput) + } + if !strings.Contains(logOutput, "secondary storage usage stats inconsistent") { + t.Fatalf("expected secondary inconsistency warning, got %q", logOutput) + } +} + +func TestConvertBackupStatsToNotificationDataWarnsWhenUsedIsSetButTotalIsZero(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + var buf bytes.Buffer + logger.SetOutput(&buf) + + adapter := NewNotificationAdapter(&stubNotifier{name: "Email", enabled: true}, logger) + stats := sampleBackupStats() + stats.LocalUsedSpace = 1500 + stats.LocalTotalSpace = 0 + stats.SecondaryEnabled = true + stats.SecondaryUsedSpace = 4500 + stats.SecondaryTotalSpace = 0 + + data := adapter.convertBackupStatsToNotificationData(stats) + + if data.LocalUsagePercent != 0 || data.SecondaryUsagePercent != 0 { + t.Fatalf("usage percent should remain 0 when total capacity is unknown, got local=%f secondary=%f", data.LocalUsagePercent, data.SecondaryUsagePercent) + } + logOutput := buf.String() + if !strings.Contains(logOutput, "local storage usage stats inconsistent: used=1500 total=0") { + t.Fatalf("expected local zero-total warning, got %q", logOutput) + } + if !strings.Contains(logOutput, "secondary storage usage stats inconsistent: used=4500 total=0") { + t.Fatalf("expected secondary zero-total warning, got %q", logOutput) + } +} + func sampleBackupStats() *BackupStats { return &BackupStats{ ExitCode: 0, @@ -366,12 +427,14 @@ func sampleBackupStats() *BackupStats { ArchivePath: "/var/tmp/backup.tar", CompressedSize: 12345, LocalBackups: 1, - LocalFreeSpace: 1024, - LocalTotalSpace: 2048, + LocalFreeSpace: 500, + LocalUsedSpace: 500, + LocalTotalSpace: 1000, SecondaryEnabled: true, SecondaryBackups: 1, - SecondaryFreeSpace: 2048, - SecondaryTotalSpace: 4096, + SecondaryFreeSpace: 1500, + SecondaryUsedSpace: 2500, + SecondaryTotalSpace: 4000, CloudEnabled: true, CloudBackups: 1, Timestamp: time.Now(), diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index 271e88ca..08d714a6 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -106,9 +106,11 @@ type BackupStats struct { SecondaryEnabled bool LocalBackups int LocalFreeSpace uint64 + LocalUsedSpace uint64 LocalTotalSpace uint64 SecondaryBackups int SecondaryFreeSpace uint64 + SecondaryUsedSpace uint64 SecondaryTotalSpace uint64 CloudEnabled bool CloudBackups int @@ -1119,15 +1121,29 @@ func (o *Orchestrator) createBundle(ctx context.Context, archivePath string) (bu bundlePath = archivePath + ".bundle.tar" logger.Debug("Creating bundle with native Go tar: %s (files: %v)", bundlePath, associated) - // Create tar archive using native Go archive/tar - outFile, err := fs.Create(bundlePath) + // Write to a temporary file in the target directory and rename on success. + outFile, err := fs.CreateTemp(dir, fmt.Sprintf("%s.tmp-*", filepath.Base(bundlePath))) if err != nil { - return "", fmt.Errorf("failed to create bundle file: %w", err) + return "", fmt.Errorf("failed to create temp bundle file: %w", err) } - defer outFile.Close() + tempBundle := outFile.Name() + var tw *tar.Writer + removeTemp := true + defer func() { + if tw != nil { + _ = tw.Close() + tw = nil + } + if outFile != nil { + _ = outFile.Close() + outFile = nil + } + if removeTemp { + _ = fs.Remove(tempBundle) + } + }() - tw := tar.NewWriter(outFile) - defer tw.Close() + tw = tar.NewWriter(outFile) // Add each associated file to the tar archive for _, filename := range associated { @@ -1169,15 +1185,33 @@ func (o *Orchestrator) createBundle(ctx context.Context, archivePath string) (bu // Close tar writer to flush if err := tw.Close(); err != nil { + tw = nil return "", fmt.Errorf("failed to finalize tar archive: %w", err) } + tw = nil + + if err := outFile.Sync(); err != nil { + return "", fmt.Errorf("failed to sync bundle file: %w", err) + } if err := outFile.Close(); err != nil { + outFile = nil return "", fmt.Errorf("failed to close bundle file: %w", err) } + outFile = nil + + if err := fs.Rename(tempBundle, bundlePath); err != nil { + return "", fmt.Errorf("failed to rename temp bundle file: %w", err) + } + removeTemp = false + if err := syncDirectoryWithDeps(fs, dir); err != nil { + _ = fs.Remove(bundlePath) + return "", fmt.Errorf("failed to sync bundle directory: %w", err) + } // Verify bundle was created if _, err := fs.Stat(bundlePath); err != nil { + _ = fs.Remove(bundlePath) return "", fmt.Errorf("bundle file not created: %w", err) } diff --git a/internal/orchestrator/restore.go b/internal/orchestrator/restore.go index a3903602..c774821b 100644 --- a/internal/orchestrator/restore.go +++ b/internal/orchestrator/restore.go @@ -37,6 +37,8 @@ var ( restoreGlob = filepath.Glob ) +const restoreTempPattern = ".proxsave-tmp-*" + // RestoreAbortInfo contains information about an aborted restore with network rollback. type RestoreAbortInfo struct { NetworkRollbackArmed bool @@ -1574,18 +1576,34 @@ func extractTarEntry(tarReader *tar.Reader, header *tar.Header, destRoot string, } // extractDirectory creates a directory with proper permissions and timestamps -func extractDirectory(target string, header *tar.Header, logger *logging.Logger) error { - if err := restoreFS.MkdirAll(target, os.FileMode(header.Mode)); err != nil { +func extractDirectory(target string, header *tar.Header, logger *logging.Logger) (retErr error) { + // Create with an owner-accessible mode first so the directory can be opened + // before applying restrictive archive permissions. + if err := restoreFS.MkdirAll(target, 0o700); err != nil { return fmt.Errorf("create directory: %w", err) } - // Set ownership - if err := os.Chown(target, header.Uid, header.Gid); err != nil { - logger.Debug("Failed to chown directory %s: %v", target, err) + dirFile, err := restoreFS.Open(target) + if err != nil { + return fmt.Errorf("open directory: %w", err) } + defer func() { + if dirFile == nil { + return + } + if err := dirFile.Close(); err != nil && retErr == nil { + retErr = fmt.Errorf("close directory: %w", err) + } + }() - // Set permissions explicitly - if err := os.Chmod(target, os.FileMode(header.Mode)); err != nil { + // Apply metadata on the opened directory handle so logical FS paths + // (e.g. FakeFS-backed test roots) do not leak through to host paths. + // Ownership remains best-effort to match the previous restore behavior on + // unprivileged runs and filesystems that do not support chown. + if err := atomicFileChown(dirFile, header.Uid, header.Gid); err != nil { + logger.Debug("Failed to chown directory %s: %v", target, err) + } + if err := atomicFileChmod(dirFile, os.FileMode(header.Mode)); err != nil { return fmt.Errorf("chmod directory: %w", err) } @@ -1598,36 +1616,70 @@ func extractDirectory(target string, header *tar.Header, logger *logging.Logger) } // extractRegularFile extracts a regular file with content and timestamps -func extractRegularFile(tarReader *tar.Reader, target string, header *tar.Header, logger *logging.Logger) error { - // Remove existing file if it exists - _ = restoreFS.Remove(target) +func extractRegularFile(tarReader *tar.Reader, target string, header *tar.Header, logger *logging.Logger) (retErr error) { + tmpPath := "" + var outFile *os.File + appendDeferredErr := func(prefix string, err error) { + if err == nil { + return + } + wrapped := fmt.Errorf("%s: %w", prefix, err) + if retErr == nil { + retErr = wrapped + return + } + retErr = errors.Join(retErr, wrapped) + } + closeOutFile := func() error { + if outFile == nil { + return nil + } + err := outFile.Close() + outFile = nil + return err + } - // Create the file - outFile, err := restoreFS.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode)) + // Write to a sibling temp file first so a truncated archive entry cannot clobber + // an existing target before the content is fully copied and closed. + outFile, err := restoreFS.CreateTemp(filepath.Dir(target), restoreTempPattern) if err != nil { return fmt.Errorf("create file: %w", err) } - defer outFile.Close() + tmpPath = outFile.Name() + defer func() { + appendDeferredErr("close file", closeOutFile()) + if tmpPath != "" { + if err := restoreFS.Remove(tmpPath); err != nil && logger != nil { + logger.Debug("Failed to remove temp file %s: %v", tmpPath, err) + } + } + }() // Copy content if _, err := io.Copy(outFile, tarReader); err != nil { return fmt.Errorf("write file content: %w", err) } - // Close before setting attributes - if err := outFile.Close(); err != nil { - return fmt.Errorf("close file: %w", err) + // Set metadata on the temp file before replacing the target so failures do not + // leave the final path in a partially restored state. + // Ownership remains best-effort to match the previous restore behavior on + // unprivileged runs and filesystems that do not support chown. + if err := atomicFileChown(outFile, header.Uid, header.Gid); err != nil { + logger.Debug("Failed to chown file %s: %v", target, err) + } + if err := atomicFileChmod(outFile, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("chmod file: %w", err) } - // Set ownership - if err := os.Chown(target, header.Uid, header.Gid); err != nil { - logger.Debug("Failed to chown file %s: %v", target, err) + // Close before renaming into place. + if err := closeOutFile(); err != nil { + return fmt.Errorf("close file: %w", err) } - // Set permissions explicitly - if err := os.Chmod(target, os.FileMode(header.Mode)); err != nil { - return fmt.Errorf("chmod file: %w", err) + if err := restoreFS.Rename(tmpPath, target); err != nil { + return fmt.Errorf("replace file: %w", err) } + tmpPath = "" // Set timestamps (mtime, atime, ctime via syscall) if err := setTimestamps(target, header); err != nil { diff --git a/internal/orchestrator/restore_access_control_ui_additional_test.go b/internal/orchestrator/restore_access_control_ui_additional_test.go index 079a35ef..e8427514 100644 --- a/internal/orchestrator/restore_access_control_ui_additional_test.go +++ b/internal/orchestrator/restore_access_control_ui_additional_test.go @@ -177,11 +177,7 @@ func TestArmAccessControlRollback_SystemdAndBackgroundPaths(t *testing.T) { t.Run("uses systemd-run when available", func(t *testing.T) { binDir := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", binDir); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", binDir) writeExecutable(t, binDir, "systemd-run") fakeCmd := &FakeCommandRunner{} @@ -204,11 +200,7 @@ func TestArmAccessControlRollback_SystemdAndBackgroundPaths(t *testing.T) { t.Run("falls back to background timer on systemd-run failure", func(t *testing.T) { binDir := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", binDir); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", binDir) writeExecutable(t, binDir, "systemd-run") fakeCmd := &FakeCommandRunner{ @@ -242,11 +234,7 @@ func TestArmAccessControlRollback_SystemdAndBackgroundPaths(t *testing.T) { t.Run("background timer failure returns error", func(t *testing.T) { emptyBin := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", emptyBin); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", emptyBin) fakeCmd := &FakeCommandRunner{ Errors: map[string]error{}, @@ -283,11 +271,7 @@ func TestArmAccessControlRollback_DefaultWorkDirAndMinTimeout(t *testing.T) { restoreTime = fakeTime emptyBin := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", emptyBin); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", emptyBin) fakeCmd := &FakeCommandRunner{} restoreCmd = fakeCmd @@ -375,11 +359,7 @@ func TestDisarmAccessControlRollback_RemovesMarkerScriptAndStopsTimer(t *testing restoreFS = fakeFS binDir := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", binDir); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", binDir) writeExecutable(t, binDir, "systemctl") fakeCmd := &FakeCommandRunner{} @@ -844,4 +824,3 @@ func TestMaybeApplyAccessControlWithUI_BranchCoverage(t *testing.T) { } }) } - diff --git a/internal/orchestrator/restore_errors_test.go b/internal/orchestrator/restore_errors_test.go index 0faf0241..f8d7fb1c 100644 --- a/internal/orchestrator/restore_errors_test.go +++ b/internal/orchestrator/restore_errors_test.go @@ -671,6 +671,45 @@ func TestExtractDirectory_WithTimestamps(t *testing.T) { } } +func TestExtractDirectory_AppliesRestrictiveModeAfterOpen(t *testing.T) { + orig := restoreFS + t.Cleanup(func() { restoreFS = orig }) + restoreFS = osFS{} + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + destRoot := t.TempDir() + target := filepath.Join(destRoot, "locked") + t.Cleanup(func() { + if err := os.Chmod(target, 0o700); err != nil && !errors.Is(err, os.ErrNotExist) { + t.Fatalf("restore directory permissions for cleanup: %v", err) + } + }) + + header := &tar.Header{ + Name: "locked", + Mode: 0, + Uid: os.Getuid(), + Gid: os.Getgid(), + ModTime: time.Date(2023, 6, 15, 12, 0, 0, 0, time.UTC), + AccessTime: time.Date(2023, 6, 15, 12, 0, 0, 0, time.UTC), + } + + if err := extractDirectory(target, header, logger); err != nil { + t.Fatalf("extractDirectory failed with restrictive mode: %v", err) + } + + info, err := os.Stat(target) + if err != nil { + t.Fatalf("stat failed: %v", err) + } + if !info.IsDir() { + t.Fatalf("expected directory") + } + if info.Mode().Perm() != 0 { + t.Fatalf("directory mode = %o, want %o", info.Mode().Perm(), 0) + } +} + // -------------------------------------------------------------------------- // resetFailedService test // -------------------------------------------------------------------------- @@ -785,17 +824,39 @@ func (a *alwaysFailCommandRunner) RunStream(ctx context.Context, name string, st return nil, a.err } +type trackingOpenFileFS struct { + FS + lastOpened *os.File +} + +func (f *trackingOpenFileFS) OpenFile(path string, flag int, perm os.FileMode) (*os.File, error) { + file, err := f.FS.OpenFile(path, flag, perm) + if err == nil { + f.lastOpened = file + } + return file, err +} + +func (f *trackingOpenFileFS) CreateTemp(dir, pattern string) (*os.File, error) { + file, err := f.FS.CreateTemp(dir, pattern) + if err == nil { + f.lastOpened = file + } + return file, err +} + // -------------------------------------------------------------------------- // ErrorInjectingFS - FS wrapper that can inject errors // -------------------------------------------------------------------------- type ErrorInjectingFS struct { - base FS - mkdirAllErr error - openFileErr error - symlinkErr error - readlinkErr error - linkErr error + base FS + mkdirAllErr error + openFileErr error + createTempErr error + symlinkErr error + readlinkErr error + linkErr error } func (f *ErrorInjectingFS) Stat(path string) (os.FileInfo, error) { return f.base.Stat(path) } @@ -810,6 +871,9 @@ func (f *ErrorInjectingFS) Remove(path string) error { return func (f *ErrorInjectingFS) RemoveAll(path string) error { return f.base.RemoveAll(path) } func (f *ErrorInjectingFS) ReadDir(path string) ([]os.DirEntry, error) { return f.base.ReadDir(path) } func (f *ErrorInjectingFS) CreateTemp(dir, pattern string) (*os.File, error) { + if f.createTempErr != nil { + return nil, f.createTempErr + } return f.base.CreateTemp(dir, pattern) } func (f *ErrorInjectingFS) MkdirTemp(dir, pattern string) (string, error) { @@ -887,7 +951,7 @@ func TestExtractDirectory_MkdirAllFails(t *testing.T) { // extractRegularFile error tests // -------------------------------------------------------------------------- -func TestExtractRegularFile_OpenFileFails(t *testing.T) { +func TestExtractRegularFile_CreateTempFails(t *testing.T) { origFS := restoreFS t.Cleanup(func() { restoreFS = origFS }) @@ -895,8 +959,8 @@ func TestExtractRegularFile_OpenFileFails(t *testing.T) { t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) restoreFS = &ErrorInjectingFS{ - base: fakeFS, - openFileErr: fmt.Errorf("permission denied"), + base: fakeFS, + createTempErr: fmt.Errorf("permission denied"), } header := &tar.Header{ @@ -916,7 +980,8 @@ func TestExtractRegularFile_CopyFails(t *testing.T) { origFS := restoreFS t.Cleanup(func() { restoreFS = origFS }) - restoreFS = osFS{} + trackingFS := &trackingOpenFileFS{FS: osFS{}} + restoreFS = trackingFS dir := t.TempDir() target := filepath.Join(dir, "testfile.txt") @@ -942,6 +1007,79 @@ func TestExtractRegularFile_CopyFails(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "write file content") { t.Fatalf("expected io.Copy error, got: %v", err) } + if trackingFS.lastOpened == nil { + t.Fatalf("expected tracked output file") + } + if closeErr := trackingFS.lastOpened.Close(); !errors.Is(closeErr, os.ErrClosed) { + t.Fatalf("output file close after copy failure = %v, want ErrClosed", closeErr) + } + tempMatches, err := filepath.Glob(filepath.Join(filepath.Dir(target), restoreTempPattern)) + if err != nil { + t.Fatalf("glob temp files: %v", err) + } + if len(tempMatches) != 0 { + t.Fatalf("temporary files should be removed after copy failure, found %v", tempMatches) + } + if _, err := os.Stat(target); !os.IsNotExist(err) { + t.Fatalf("target %s should not exist after copy failure, stat err=%v", target, err) + } +} + +func TestExtractRegularFile_CopyFailsPreservesExistingTarget(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + restoreFS = osFS{} + + dir := t.TempDir() + target := filepath.Join(dir, "testfile.txt") + if err := os.WriteFile(target, []byte("keep me"), 0o600); err != nil { + t.Fatalf("seed target: %v", err) + } + + header := &tar.Header{ + Name: "testfile.txt", + Mode: 0o644, + Size: 100, + } + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + _ = tw.WriteHeader(header) + _, _ = tw.Write([]byte("short")) + tw.Close() + + tr := tar.NewReader(&buf) + _, _ = tr.Next() + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + err := extractRegularFile(tr, target, header, logger) + if err == nil || !strings.Contains(err.Error(), "write file content") { + t.Fatalf("expected io.Copy error, got: %v", err) + } + + data, err := os.ReadFile(target) + if err != nil { + t.Fatalf("read preserved target: %v", err) + } + if string(data) != "keep me" { + t.Fatalf("target content = %q, want preserved original", string(data)) + } + + info, err := os.Stat(target) + if err != nil { + t.Fatalf("stat preserved target: %v", err) + } + if info.Mode().Perm() != 0o600 { + t.Fatalf("preserved target mode = %o, want %o", info.Mode().Perm(), 0o600) + } + + tempMatches, err := filepath.Glob(filepath.Join(filepath.Dir(target), restoreTempPattern)) + if err != nil { + t.Fatalf("glob temp files: %v", err) + } + if len(tempMatches) != 0 { + t.Fatalf("temporary files should be removed after copy failure, found %v", tempMatches) + } } // -------------------------------------------------------------------------- @@ -1609,7 +1747,8 @@ func TestExtractDirectory_SuccessWithTimestamps(t *testing.T) { func TestExtractRegularFile_Success(t *testing.T) { origFS := restoreFS t.Cleanup(func() { restoreFS = origFS }) - restoreFS = osFS{} + trackingFS := &trackingOpenFileFS{FS: osFS{}} + restoreFS = trackingFS dir := t.TempDir() target := filepath.Join(dir, "file.txt") @@ -1651,6 +1790,26 @@ func TestExtractRegularFile_Success(t *testing.T) { if string(data) != "hello world" { t.Fatalf("expected 'hello world', got: %q", string(data)) } + info, err := os.Stat(target) + if err != nil { + t.Fatalf("stat file: %v", err) + } + if info.Mode().Perm() != 0o644 { + t.Fatalf("file mode = %o, want %o", info.Mode().Perm(), 0o644) + } + if trackingFS.lastOpened == nil { + t.Fatalf("expected tracked output file") + } + if closeErr := trackingFS.lastOpened.Close(); !errors.Is(closeErr, os.ErrClosed) { + t.Fatalf("output file close after success = %v, want ErrClosed", closeErr) + } + tempMatches, err := filepath.Glob(filepath.Join(filepath.Dir(target), restoreTempPattern)) + if err != nil { + t.Fatalf("glob temp files: %v", err) + } + if len(tempMatches) != 0 { + t.Fatalf("temporary files should be removed after success, found %v", tempMatches) + } } // -------------------------------------------------------------------------- diff --git a/internal/orchestrator/restore_firewall_additional_test.go b/internal/orchestrator/restore_firewall_additional_test.go index 97e007bb..11f47ecb 100644 --- a/internal/orchestrator/restore_firewall_additional_test.go +++ b/internal/orchestrator/restore_firewall_additional_test.go @@ -560,11 +560,7 @@ func TestRestartPVEFirewallService_CommandFallbacks(t *testing.T) { t.Cleanup(func() { restoreCmd = origCmd }) binDir := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", binDir); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", binDir) writeExecutable(t, binDir, "systemctl") writeExecutable(t, binDir, "pve-firewall") @@ -623,9 +619,7 @@ func TestRestartPVEFirewallService_CommandFallbacks(t *testing.T) { restoreCmd = fake emptyBin := t.TempDir() - if err := os.Setenv("PATH", emptyBin); err != nil { - t.Fatalf("set PATH: %v", err) - } + t.Setenv("PATH", emptyBin) if err := restartPVEFirewallService(context.Background()); err == nil { t.Fatalf("expected error") @@ -663,11 +657,7 @@ func TestArmFirewallRollback_SystemdAndBackgroundPaths(t *testing.T) { t.Run("uses systemd-run when available", func(t *testing.T) { binDir := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", binDir); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", binDir) writeExecutable(t, binDir, "systemd-run") fakeCmd := &FakeCommandRunner{} @@ -690,11 +680,7 @@ func TestArmFirewallRollback_SystemdAndBackgroundPaths(t *testing.T) { t.Run("falls back to background timer on systemd-run failure", func(t *testing.T) { binDir := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", binDir); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", binDir) writeExecutable(t, binDir, "systemd-run") fakeCmd := &FakeCommandRunner{ @@ -728,11 +714,7 @@ func TestArmFirewallRollback_SystemdAndBackgroundPaths(t *testing.T) { t.Run("background timer failure returns error", func(t *testing.T) { emptyBin := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", emptyBin); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", emptyBin) fakeCmd := &FakeCommandRunner{ Errors: map[string]error{}, @@ -764,11 +746,7 @@ func TestDisarmFirewallRollback_RemovesMarkerAndStopsTimer(t *testing.T) { restoreFS = fakeFS binDir := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", binDir); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", binDir) writeExecutable(t, binDir, "systemctl") fakeCmd := &FakeCommandRunner{} @@ -1566,11 +1544,7 @@ func TestArmFirewallRollback_DefaultWorkDirAndMinTimeout(t *testing.T) { restoreTime = fakeTime emptyBin := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", emptyBin); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", emptyBin) fakeCmd := &FakeCommandRunner{} restoreCmd = fakeCmd @@ -1658,11 +1632,7 @@ func TestDisarmFirewallRollback_MissingMarkerAndNoSystemctl(t *testing.T) { restoreFS = fakeFS emptyBin := t.TempDir() - oldPath := os.Getenv("PATH") - if err := os.Setenv("PATH", emptyBin); err != nil { - t.Fatalf("set PATH: %v", err) - } - t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + t.Setenv("PATH", emptyBin) fakeCmd := &FakeCommandRunner{} restoreCmd = fakeCmd diff --git a/internal/orchestrator/restore_ha.go b/internal/orchestrator/restore_ha.go index d69db668..286bf1fc 100644 --- a/internal/orchestrator/restore_ha.go +++ b/internal/orchestrator/restore_ha.go @@ -17,6 +17,16 @@ const defaultHARollbackTimeout = 180 * time.Second var ErrHAApplyNotCommitted = errors.New("HA configuration not committed") +var ( + haApplyGeteuid = os.Geteuid + haIsMounted = isMounted + haIsRealRestoreFS = isRealRestoreFS + + haArmRollback = armHARollback + haDisarmRollback = disarmHARollback + haApplyFromStage = applyPVEHAFromStage +) + type HAApplyNotCommittedError struct { RollbackLog string RollbackMarker string @@ -99,7 +109,7 @@ func maybeApplyPVEHAWithUI( if ui == nil { return fmt.Errorf("restore UI not available") } - if !isRealRestoreFS(restoreFS) { + if !haIsRealRestoreFS(restoreFS) { logger.Debug("Skipping PVE HA restore: non-system filesystem in use") return nil } @@ -107,7 +117,7 @@ func maybeApplyPVEHAWithUI( logger.Info("Dry run enabled: skipping PVE HA restore") return nil } - if os.Geteuid() != 0 { + if haApplyGeteuid() != 0 { logger.Warning("Skipping PVE HA restore: requires root privileges") return nil } @@ -132,7 +142,7 @@ func maybeApplyPVEHAWithUI( } etcPVE := "/etc/pve" - mounted, mountErr := isMounted(etcPVE) + mounted, mountErr := haIsMounted(etcPVE) if mountErr != nil { logger.Warning("PVE HA restore: unable to check pmxcfs mount (%s): %v", etcPVE, mountErr) } @@ -214,21 +224,21 @@ func maybeApplyPVEHAWithUI( if rollbackPath != "" { logger.Info("") logger.Info("Arming HA rollback timer (%ds)...", int(defaultHARollbackTimeout.Seconds())) - rollbackHandle, err = armHARollback(ctx, logger, rollbackPath, defaultHARollbackTimeout, "/tmp/proxsave") + rollbackHandle, err = haArmRollback(ctx, logger, rollbackPath, defaultHARollbackTimeout, "/tmp/proxsave") if err != nil { return fmt.Errorf("arm HA rollback: %w", err) } logger.Info("HA rollback log: %s", rollbackHandle.logPath) } - applied, err := applyPVEHAFromStage(logger, stageRoot) + applied, err := haApplyFromStage(logger, stageRoot) if err != nil { return err } if len(applied) == 0 { logger.Info("PVE HA restore: no changes applied (stage contained no HA config entries)") if rollbackHandle != nil { - disarmHARollback(ctx, logger, rollbackHandle) + haDisarmRollback(ctx, logger, rollbackHandle) } return nil } @@ -238,7 +248,7 @@ func maybeApplyPVEHAWithUI( return nil } - remaining := rollbackHandle.remaining(time.Now()) + remaining := rollbackHandle.remaining(nowRestore()) if remaining <= 0 { return buildHAApplyNotCommittedError(rollbackHandle) } @@ -260,7 +270,7 @@ func maybeApplyPVEHAWithUI( } if commit { - disarmHARollback(ctx, logger, rollbackHandle) + haDisarmRollback(ctx, logger, rollbackHandle) logger.Info("HA changes committed.") return nil } @@ -360,7 +370,7 @@ func armHARollback(ctx context.Context, logger *logging.Logger, backupPath strin markerPath: filepath.Join(baseDir, fmt.Sprintf("ha_rollback_pending_%s", timestamp)), scriptPath: filepath.Join(baseDir, fmt.Sprintf("ha_rollback_%s.sh", timestamp)), logPath: filepath.Join(baseDir, fmt.Sprintf("ha_rollback_%s.log", timestamp)), - armedAt: time.Now(), + armedAt: nowRestore(), timeout: timeout, } diff --git a/internal/orchestrator/restore_ha_additional_test.go b/internal/orchestrator/restore_ha_additional_test.go new file mode 100644 index 00000000..bca75e95 --- /dev/null +++ b/internal/orchestrator/restore_ha_additional_test.go @@ -0,0 +1,830 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/input" + "github.com/tis24dev/proxsave/internal/logging" +) + +type haTestEnv struct { + fs *FakeFS + cmd *FakeCommandRunner + fakeTime *FakeTime + plan *RestorePlan + stageRoot string + logger *logging.Logger +} + +func setupHATestEnv(t *testing.T) *haTestEnv { + t.Helper() + + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + origGeteuid := haApplyGeteuid + origMounted := haIsMounted + origRealFS := haIsRealRestoreFS + origArm := haArmRollback + origDisarm := haDisarmRollback + origApply := haApplyFromStage + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + haApplyGeteuid = origGeteuid + haIsMounted = origMounted + haIsRealRestoreFS = origRealFS + haArmRollback = origArm + haDisarmRollback = origDisarm + haApplyFromStage = origApply + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + fakeTime := &FakeTime{Current: time.Unix(100, 0)} + restoreTime = fakeTime + + haApplyGeteuid = func() int { return 0 } + haIsMounted = func(path string) (bool, error) { return true, nil } + haIsRealRestoreFS = func(fs FS) bool { return true } + haArmRollback = armHARollback + haDisarmRollback = disarmHARollback + haApplyFromStage = applyPVEHAFromStage + + return &haTestEnv{ + fs: fakeFS, + cmd: fakeCmd, + fakeTime: fakeTime, + plan: &RestorePlan{ + SystemType: SystemTypePVE, + NormalCategories: []Category{{ID: "pve_ha"}}, + }, + stageRoot: "/stage", + logger: newTestLogger(), + } +} + +func TestHAApplyNotCommittedErrorHelpers(t *testing.T) { + env := setupHATestEnv(t) + + var nilErr *HAApplyNotCommittedError + if nilErr.Error() != ErrHAApplyNotCommitted.Error() { + t.Fatalf("nil error string = %q, want %q", nilErr.Error(), ErrHAApplyNotCommitted.Error()) + } + if got := (&HAApplyNotCommittedError{}).Error(); got != ErrHAApplyNotCommitted.Error() { + t.Fatalf("error string = %q, want %q", got, ErrHAApplyNotCommitted.Error()) + } + + errValue := (&HAApplyNotCommittedError{}).Unwrap() + if errValue != ErrHAApplyNotCommitted { + t.Fatalf("unwrap = %v, want %v", errValue, ErrHAApplyNotCommitted) + } + if !errors.Is(&HAApplyNotCommittedError{}, ErrHAApplyNotCommitted) { + t.Fatalf("expected HAApplyNotCommittedError to match ErrHAApplyNotCommitted") + } + + now := env.fakeTime.Current + if got := (*haRollbackHandle)(nil).remaining(now); got != 0 { + t.Fatalf("nil remaining = %s, want 0", got) + } + + handle := &haRollbackHandle{armedAt: now.Add(-time.Second), timeout: 3 * time.Second} + if got := handle.remaining(now); got != 2*time.Second { + t.Fatalf("remaining = %s, want %s", got, 2*time.Second) + } + + handle.armedAt = now.Add(-5 * time.Second) + if got := handle.remaining(now); got != 0 { + t.Fatalf("expired remaining = %s, want 0", got) + } +} + +func TestBuildHAApplyNotCommittedError_ReflectsMarkerState(t *testing.T) { + env := setupHATestEnv(t) + + empty := buildHAApplyNotCommittedError(nil) + if empty.RollbackArmed || empty.RollbackMarker != "" || empty.RollbackLog != "" || !empty.RollbackDeadline.IsZero() { + t.Fatalf("unexpected empty error fields: %#v", empty) + } + + handle := &haRollbackHandle{ + markerPath: " /tmp/proxsave/ha.marker ", + logPath: " /tmp/proxsave/ha.log ", + armedAt: env.fakeTime.Current, + timeout: 3 * time.Second, + } + + built := buildHAApplyNotCommittedError(handle) + if built.RollbackArmed { + t.Fatalf("expected rollback to be unarmed when marker is absent") + } + if built.RollbackMarker != "/tmp/proxsave/ha.marker" || built.RollbackLog != "/tmp/proxsave/ha.log" { + t.Fatalf("unexpected trimmed fields: %#v", built) + } + if !built.RollbackDeadline.Equal(env.fakeTime.Current.Add(3 * time.Second)) { + t.Fatalf("RollbackDeadline=%s want %s", built.RollbackDeadline, env.fakeTime.Current.Add(3*time.Second)) + } + + if err := env.fs.AddFile("/tmp/proxsave/ha.marker", []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + built = buildHAApplyNotCommittedError(handle) + if !built.RollbackArmed { + t.Fatalf("expected rollback to be armed when marker exists") + } +} + +func TestStageHasPVEHAConfig_DetectsFilesAndErrors(t *testing.T) { + env := setupHATestEnv(t) + + ok, err := stageHasPVEHAConfig(env.stageRoot) + if err != nil { + t.Fatalf("stageHasPVEHAConfig error: %v", err) + } + if ok { + t.Fatalf("expected ok=false when stage is empty") + } + + if err := env.fs.AddFile(env.stageRoot+"/etc/pve/ha/groups.cfg", []byte("grp\n")); err != nil { + t.Fatalf("add groups.cfg: %v", err) + } + ok, err = stageHasPVEHAConfig(env.stageRoot) + if err != nil { + t.Fatalf("stageHasPVEHAConfig error: %v", err) + } + if !ok { + t.Fatalf("expected ok=true when staged HA config exists") + } + + restoreFS = statFailFS{ + FS: env.fs, + failPath: env.stageRoot + "/etc/pve/ha/resources.cfg", + err: fmt.Errorf("boom"), + } + if _, err := stageHasPVEHAConfig(env.stageRoot); err == nil || !strings.Contains(err.Error(), "stat") { + t.Fatalf("expected wrapped stat error, got %v", err) + } +} + +func TestBuildHARollbackScript_QuotesPaths(t *testing.T) { + script := buildHARollbackScript("/tmp/marker path", "/tmp/backup's.tar.gz", "/tmp/log path") + if !strings.Contains(script, "MARKER='/tmp/marker path'") { + t.Fatalf("expected MARKER to be quoted, got script:\n%s", script) + } + if !strings.Contains(script, "LOG='/tmp/log path'") { + t.Fatalf("expected LOG to be quoted, got script:\n%s", script) + } + if !strings.Contains(script, "BACKUP='/tmp/backup'\\''s.tar.gz'") { + t.Fatalf("expected BACKUP to escape single quotes, got script:\n%s", script) + } + if !strings.HasSuffix(script, "\n") { + t.Fatalf("expected script to end with newline") + } +} + +func TestArmHARollback_CoversSchedulingPaths(t *testing.T) { + t.Run("rejects invalid input", func(t *testing.T) { + env := setupHATestEnv(t) + if _, err := armHARollback(context.Background(), env.logger, " ", time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error for empty backup path") + } + if _, err := armHARollback(context.Background(), env.logger, "/backup.tgz", 0, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error for invalid timeout") + } + }) + + t.Run("fails when rollback directory cannot be created", func(t *testing.T) { + env := setupHATestEnv(t) + restoreFS = mkdirAllFailFS{ + FS: env.fs, + failPath: "/tmp/proxsave", + err: fmt.Errorf("boom"), + } + if _, err := armHARollback(context.Background(), env.logger, "/backup.tgz", time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "create rollback directory") { + t.Fatalf("expected mkdir failure, got %v", err) + } + }) + + t.Run("fails when marker or script cannot be written", func(t *testing.T) { + env := setupHATestEnv(t) + t.Setenv("PATH", t.TempDir()) + + timestamp := env.fakeTime.Current.Format("20060102_150405") + markerPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("ha_rollback_pending_%s", timestamp)) + restoreFS = writeFileFailFS{ + FS: env.fs, + failPath: markerPath, + err: fmt.Errorf("disk full"), + } + if _, err := armHARollback(context.Background(), env.logger, "/backup.tgz", time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "write rollback marker") { + t.Fatalf("expected marker write failure, got %v", err) + } + + env = setupHATestEnv(t) + t.Setenv("PATH", t.TempDir()) + timestamp = env.fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("ha_rollback_%s.sh", timestamp)) + restoreFS = writeFileFailFS{ + FS: env.fs, + failPath: scriptPath, + err: fmt.Errorf("disk full"), + } + if _, err := armHARollback(context.Background(), env.logger, "/backup.tgz", time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "write rollback script") { + t.Fatalf("expected script write failure, got %v", err) + } + }) + + t.Run("background timer writes marker and script when systemd-run unavailable", func(t *testing.T) { + env := setupHATestEnv(t) + t.Setenv("PATH", t.TempDir()) + + handle, err := armHARollback(context.Background(), env.logger, "/backup.tgz", 2*time.Second, "") + if err != nil { + t.Fatalf("armHARollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + if handle.workDir != "/tmp/proxsave" { + t.Fatalf("workDir=%q, want %q", handle.workDir, "/tmp/proxsave") + } + if !handle.armedAt.Equal(env.fakeTime.Current) { + t.Fatalf("armedAt=%s, want %s", handle.armedAt, env.fakeTime.Current) + } + if _, err := env.fs.Stat(handle.markerPath); err != nil { + t.Fatalf("expected marker file, stat err=%v", err) + } + script, err := env.fs.ReadFile(handle.scriptPath) + if err != nil { + t.Fatalf("read rollback script: %v", err) + } + if !strings.Contains(string(script), "BACKUP=/backup.tgz") { + t.Fatalf("expected backup path in script, got:\n%s", string(script)) + } + + wantBackground := "sh -c nohup sh -c 'sleep 2; /bin/sh " + handle.scriptPath + "' >/dev/null 2>&1 &" + calls := env.cmd.CallsList() + if len(calls) != 1 || calls[0] != wantBackground { + t.Fatalf("unexpected calls: %#v", calls) + } + }) + + t.Run("sub-second timeout rounds up to one second", func(t *testing.T) { + env := setupHATestEnv(t) + t.Setenv("PATH", t.TempDir()) + + handle, err := armHARollback(context.Background(), env.logger, "/backup.tgz", 100*time.Millisecond, "/tmp/proxsave") + if err != nil { + t.Fatalf("armHARollback error: %v", err) + } + wantBackground := "sh -c nohup sh -c 'sleep 1; /bin/sh " + handle.scriptPath + "' >/dev/null 2>&1 &" + calls := env.cmd.CallsList() + if len(calls) != 1 || calls[0] != wantBackground { + t.Fatalf("unexpected calls: %#v", calls) + } + }) + + t.Run("systemd-run failure falls back to background timer", func(t *testing.T) { + env := setupHATestEnv(t) + binDir := t.TempDir() + t.Setenv("PATH", binDir) + writeExecutable(t, binDir, "systemd-run") + + timestamp := env.fakeTime.Current.Format("20060102_150405") + unitName := "proxsave-ha-rollback-" + timestamp + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("ha_rollback_%s.sh", timestamp)) + systemdKey := "systemd-run --unit=" + unitName + " --on-active=2s /bin/sh " + scriptPath + env.cmd.Errors = map[string]error{ + systemdKey: fmt.Errorf("boom"), + } + + handle, err := armHARollback(context.Background(), env.logger, "/backup.tgz", 2*time.Second, "/tmp/proxsave") + if err != nil { + t.Fatalf("armHARollback error: %v", err) + } + if handle.unitName != "" { + t.Fatalf("expected unitName to be cleared after systemd-run failure, got %q", handle.unitName) + } + + wantBackground := "sh -c nohup sh -c 'sleep 2; /bin/sh " + scriptPath + "' >/dev/null 2>&1 &" + calls := env.cmd.CallsList() + if len(calls) != 2 || calls[0] != systemdKey || calls[1] != wantBackground { + t.Fatalf("unexpected calls: %#v", calls) + } + }) + + t.Run("background timer failure returns error", func(t *testing.T) { + env := setupHATestEnv(t) + t.Setenv("PATH", t.TempDir()) + + timestamp := env.fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("ha_rollback_%s.sh", timestamp)) + backgroundKey := "sh -c nohup sh -c 'sleep 1; /bin/sh " + scriptPath + "' >/dev/null 2>&1 &" + env.cmd.Errors = map[string]error{ + backgroundKey: fmt.Errorf("boom"), + } + + if _, err := armHARollback(context.Background(), env.logger, "/backup.tgz", time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error") + } + }) +} + +func TestDisarmHARollback_RemovesMarkerAndStopsTimer(t *testing.T) { + env := setupHATestEnv(t) + + binDir := t.TempDir() + t.Setenv("PATH", binDir) + writeExecutable(t, binDir, "systemctl") + + handle := &haRollbackHandle{ + markerPath: "/tmp/proxsave/ha.marker", + unitName: "proxsave-ha-rollback-test", + scriptPath: "/tmp/proxsave/ha.sh", + logPath: "/tmp/proxsave/ha.log", + } + if err := env.fs.AddFile(handle.markerPath, []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + if err := env.fs.AddFile(handle.scriptPath, []byte("#!/bin/sh\n")); err != nil { + t.Fatalf("add script: %v", err) + } + + disarmHARollback(context.Background(), env.logger, handle) + + if _, err := env.fs.Stat(handle.markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected marker removed; stat err=%v", err) + } + if _, err := env.fs.Stat(handle.scriptPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected script removed; stat err=%v", err) + } + + timerUnit := handle.unitName + ".timer" + want1 := "systemctl stop " + timerUnit + want2 := "systemctl reset-failed " + handle.unitName + ".service " + timerUnit + calls := env.cmd.CallsList() + if len(calls) != 2 || calls[0] != want1 || calls[1] != want2 { + t.Fatalf("unexpected calls: %#v", calls) + } + + disarmHARollback(context.Background(), env.logger, nil) +} + +func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { + env := setupHATestEnv(t) + stageWithHA := env.stageRoot + "/etc/pve/ha/resources.cfg" + if err := env.fs.AddFile(stageWithHA, []byte("res\n")); err != nil { + t.Fatalf("add staged HA config: %v", err) + } + + t.Run("nil plan returns nil", func(t *testing.T) { + if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, nil, nil, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("errors when ui missing", func(t *testing.T) { + if err := maybeApplyPVEHAWithUI(context.Background(), nil, env.logger, env.plan, nil, nil, env.stageRoot, false); err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("skips on non-system restore fs", func(t *testing.T) { + haIsRealRestoreFS = func(fs FS) bool { return false } + t.Cleanup(func() { haIsRealRestoreFS = func(fs FS) bool { return true } }) + + if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, env.plan, nil, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("dry run, non-root, empty stage and cluster restore all skip", func(t *testing.T) { + if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, env.plan, nil, nil, env.stageRoot, true); err != nil { + t.Fatalf("expected nil on dry run, got %v", err) + } + + haApplyGeteuid = func() int { return 1000 } + if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, env.plan, nil, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil for non-root, got %v", err) + } + haApplyGeteuid = func() int { return 0 } + + if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, env.plan, nil, nil, " ", false); err != nil { + t.Fatalf("expected nil for empty stageRoot, got %v", err) + } + + plan := *env.plan + plan.NeedsClusterRestore = true + if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, &plan, nil, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil for cluster restore, got %v", err) + } + }) + + t.Run("skips when stage has no HA config or mount unavailable", func(t *testing.T) { + if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, env.plan, nil, nil, "/empty", false); err != nil { + t.Fatalf("expected nil when stage has no HA config, got %v", err) + } + + haIsMounted = func(path string) (bool, error) { return false, nil } + t.Cleanup(func() { haIsMounted = func(path string) (bool, error) { return true, nil } }) + if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, env.plan, nil, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil when /etc/pve is not mounted, got %v", err) + } + }) + + t.Run("stage detection and initial prompt errors are propagated", func(t *testing.T) { + restoreFS = statFailFS{ + FS: env.fs, + failPath: env.stageRoot + "/etc/pve/ha/resources.cfg", + err: fmt.Errorf("boom"), + } + if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, env.plan, nil, nil, env.stageRoot, false); err == nil { + t.Fatalf("expected staged stat error") + } + + restoreFS = env.fs + haIsMounted = func(path string) (bool, error) { return false, fmt.Errorf("boom") } + if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, env.plan, nil, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil when mount check warns then skips, got %v", err) + } + haIsMounted = func(path string) (bool, error) { return true, nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{err: input.ErrInputAborted}}, + } + if err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, nil, env.stageRoot, false); err == nil || !errors.Is(err, input.ErrInputAborted) { + t.Fatalf("expected apply prompt error, got %v", err) + } + }) + + t.Run("user skips apply", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: false}}, + } + if err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("proceed without rollback applies and returns", func(t *testing.T) { + haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { + t.Fatalf("unexpected rollback arm") + return nil, nil + } + appliedCalled := false + haApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + appliedCalled = true + return []string{"/etc/pve/ha/resources.cfg"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, + {ok: true}, + }, + } + if err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + if !appliedCalled { + t.Fatalf("expected HA apply to be called") + } + }) + + t.Run("full rollback and no rollback prompts can be declined or fail", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, + {ok: false}, + }, + } + safetyBackup := &SafetyBackupResult{BackupPath: "/backups/full.tgz"} + if err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, safetyBackup, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + + ui = &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, + {err: fmt.Errorf("boom")}, + }, + } + if err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, safetyBackup, nil, env.stageRoot, false); err == nil { + t.Fatalf("expected full rollback prompt error") + } + + ui = &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, + {ok: false}, + }, + } + if err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + + ui = &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, + {err: fmt.Errorf("boom")}, + }, + } + if err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, nil, env.stageRoot, false); err == nil { + t.Fatalf("expected no-rollback prompt error") + } + }) + + t.Run("full rollback backup is used when HA rollback backup missing", func(t *testing.T) { + markerPath := "/tmp/proxsave/ha-full.marker" + disarmed := false + haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { + if backupPath != "/backups/full.tgz" { + t.Fatalf("backupPath=%q, want %q", backupPath, "/backups/full.tgz") + } + handle := &haRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/ha-full.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + haDisarmRollback = func(ctx context.Context, logger *logging.Logger, handle *haRollbackHandle) { + disarmed = true + disarmHARollback(ctx, logger, handle) + } + haApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/ha/resources.cfg"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, + {ok: true}, + {ok: true}, + }, + } + safetyBackup := &SafetyBackupResult{BackupPath: "/backups/full.tgz"} + if err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, safetyBackup, nil, env.stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + if !disarmed { + t.Fatalf("expected rollback to be disarmed on commit") + } + }) + + t.Run("no changes applied disarms rollback", func(t *testing.T) { + markerPath := "/tmp/proxsave/ha-empty.marker" + disarmed := false + haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { + handle := &haRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/ha-empty.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + haDisarmRollback = func(ctx context.Context, logger *logging.Logger, handle *haRollbackHandle) { + disarmed = true + disarmHARollback(ctx, logger, handle) + } + haApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return nil, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: true}}, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/ha.tgz"} + if err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, rollback, env.stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + if !disarmed { + t.Fatalf("expected rollback to be disarmed when nothing was applied") + } + }) + + t.Run("apply errors are propagated", func(t *testing.T) { + haApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return nil, fmt.Errorf("boom") + } + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, + {ok: true}, + }, + } + if err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, nil, env.stageRoot, false); err == nil { + t.Fatalf("expected apply error") + } + }) +} + +func TestMaybeApplyPVEHAWithUI_CommitOutcomes(t *testing.T) { + env := setupHATestEnv(t) + if err := env.fs.AddFile(env.stageRoot+"/etc/pve/ha/resources.cfg", []byte("res\n")); err != nil { + t.Fatalf("add staged HA config: %v", err) + } + + baseRollback := &SafetyBackupResult{BackupPath: "/backups/ha.tgz"} + + makeHandle := func(markerPath string, armedAt time.Time) *haRollbackHandle { + handle := &haRollbackHandle{ + markerPath: markerPath, + scriptPath: markerPath + ".sh", + logPath: markerPath + ".log", + armedAt: armedAt, + timeout: defaultHARollbackTimeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle + } + + t.Run("rollback choice returns typed error", func(t *testing.T) { + markerPath := "/tmp/proxsave/ha-rollback.marker" + haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { + return makeHandle(markerPath, nowRestore()), nil + } + haApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/ha/resources.cfg"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, + {ok: false}, + }, + } + err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, baseRollback, env.stageRoot, false) + if err == nil || !errors.Is(err, ErrHAApplyNotCommitted) { + t.Fatalf("expected ErrHAApplyNotCommitted, got %v", err) + } + var typed *HAApplyNotCommittedError + if !errors.As(err, &typed) || typed == nil { + t.Fatalf("expected typed HAApplyNotCommittedError, got %T", err) + } + if !typed.RollbackArmed || typed.RollbackMarker != markerPath { + t.Fatalf("unexpected typed error fields: %#v", typed) + } + }) + + t.Run("commit prompt abort returns abort error", func(t *testing.T) { + haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { + return makeHandle("/tmp/proxsave/ha-abort.marker", nowRestore()), nil + } + haApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/ha/resources.cfg"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, + {err: input.ErrInputAborted}, + }, + } + err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, baseRollback, env.stageRoot, false) + if err == nil || !errors.Is(err, input.ErrInputAborted) { + t.Fatalf("expected input abort, got %v", err) + } + }) + + t.Run("commit prompt failure returns typed error", func(t *testing.T) { + haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { + return makeHandle("/tmp/proxsave/ha-fail.marker", nowRestore()), nil + } + haApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/ha/resources.cfg"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, + {err: fmt.Errorf("boom")}, + }, + } + err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, baseRollback, env.stageRoot, false) + if err == nil || !errors.Is(err, ErrHAApplyNotCommitted) { + t.Fatalf("expected ErrHAApplyNotCommitted, got %v", err) + } + }) + + t.Run("expired rollback handle returns typed error without commit prompt", func(t *testing.T) { + haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { + return makeHandle("/tmp/proxsave/ha-expired.marker", nowRestore().Add(-defaultHARollbackTimeout-time.Second)), nil + } + haApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/ha/resources.cfg"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: true}}, + } + err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, baseRollback, env.stageRoot, false) + if err == nil || !errors.Is(err, ErrHAApplyNotCommitted) { + t.Fatalf("expected ErrHAApplyNotCommitted, got %v", err) + } + if ui.calls != 1 { + t.Fatalf("expected only initial apply prompt, got %d calls", ui.calls) + } + }) + + t.Run("arm rollback failure is wrapped", func(t *testing.T) { + haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { + return nil, fmt.Errorf("boom") + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: true}}, + } + err := maybeApplyPVEHAWithUI(context.Background(), ui, env.logger, env.plan, nil, baseRollback, env.stageRoot, false) + if err == nil || !strings.Contains(err.Error(), "arm HA rollback") { + t.Fatalf("expected wrapped arm error, got %v", err) + } + }) +} + +func TestApplyPVEHAFromStage_BranchCoverage(t *testing.T) { + t.Run("blank stage root returns nil", func(t *testing.T) { + if applied, err := applyPVEHAFromStage(newTestLogger(), " "); err != nil || len(applied) != 0 { + t.Fatalf("applied=%#v err=%v; want nil,nil", applied, err) + } + }) + + t.Run("ensure dir, staged stat, copy and remove failures are propagated", func(t *testing.T) { + env := setupHATestEnv(t) + stageRoot := env.stageRoot + if err := env.fs.AddFile(stageRoot+"/etc/pve/ha/resources.cfg", []byte("res\n")); err != nil { + t.Fatalf("add resources.cfg: %v", err) + } + + restoreFS = mkdirAllFailFS{ + FS: env.fs, + failPath: "/etc/pve/ha", + err: fmt.Errorf("boom"), + } + if _, err := applyPVEHAFromStage(env.logger, stageRoot); err == nil || !strings.Contains(err.Error(), "ensure /etc/pve/ha") { + t.Fatalf("expected ensure error, got %v", err) + } + + restoreFS = statFailFS{ + FS: env.fs, + failPath: stageRoot + "/etc/pve/ha/resources.cfg", + err: fmt.Errorf("boom"), + } + if _, err := applyPVEHAFromStage(env.logger, stageRoot); err == nil || !strings.Contains(err.Error(), "stat") { + t.Fatalf("expected stage stat error, got %v", err) + } + + restoreFS = readFileFailFS{ + FS: env.fs, + failPath: stageRoot + "/etc/pve/ha/resources.cfg", + err: fmt.Errorf("boom"), + } + if _, err := applyPVEHAFromStage(env.logger, stageRoot); err == nil { + t.Fatalf("expected copy error") + } + + if err := env.fs.AddFile("/etc/pve/ha/groups.cfg", []byte("grp\n")); err != nil { + t.Fatalf("add existing groups.cfg: %v", err) + } + restoreFS = removeFailFS{ + FS: env.fs, + failPath: "/etc/pve/ha/groups.cfg", + err: fmt.Errorf("boom"), + } + if _, err := applyPVEHAFromStage(env.logger, stageRoot); err == nil || !strings.Contains(err.Error(), "remove /etc/pve/ha/groups.cfg") { + t.Fatalf("expected remove error, got %v", err) + } + }) +} diff --git a/internal/orchestrator/restore_test.go b/internal/orchestrator/restore_test.go index d65e4499..2a9c37a5 100644 --- a/internal/orchestrator/restore_test.go +++ b/internal/orchestrator/restore_test.go @@ -63,6 +63,94 @@ func TestExtractTarEntry_BlocksPathTraversal(t *testing.T) { } } +func TestExtractPlainArchive_WithFakeFS_RestoresFiles(t *testing.T) { + origRestoreFS := restoreFS + fakeFS := NewFakeFS() + restoreFS = fakeFS + t.Cleanup(func() { + restoreFS = origRestoreFS + _ = fakeFS.Cleanup() + }) + + tmpTar := filepath.Join(t.TempDir(), "bundle.tar") + if err := writeTarFile(tmpTar, map[string]string{ + "etc/hosts": "127.0.0.1 localhost\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + tarBytes, err := os.ReadFile(tmpTar) + if err != nil { + t.Fatalf("ReadFile tar: %v", err) + } + if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil { + t.Fatalf("fakeFS.WriteFile: %v", err) + } + + logger := logging.New(types.LogLevelDebug, false) + if err := extractPlainArchive(context.Background(), "/bundle.tar", "/", logger, nil); err != nil { + t.Fatalf("extractPlainArchive: %v", err) + } + + data, err := fakeFS.ReadFile("/etc/hosts") + if err != nil { + t.Fatalf("expected restored /etc/hosts: %v", err) + } + if string(data) != "127.0.0.1 localhost\n" { + t.Fatalf("hosts=%q; want %q", string(data), "127.0.0.1 localhost\n") + } +} + +func TestExtractSelectiveArchive_WithFakeFSMkdirTemp_RestoresIntoTempDir(t *testing.T) { + origRestoreFS := restoreFS + fakeFS := NewFakeFS() + restoreFS = fakeFS + t.Cleanup(func() { + restoreFS = origRestoreFS + _ = fakeFS.Cleanup() + }) + + tmpTar := filepath.Join(t.TempDir(), "bundle.tar") + if err := writeTarFile(tmpTar, map[string]string{ + "etc/fstab": "UUID=root / ext4 defaults 0 1\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + tarBytes, err := os.ReadFile(tmpTar) + if err != nil { + t.Fatalf("ReadFile tar: %v", err) + } + if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil { + t.Fatalf("fakeFS.WriteFile: %v", err) + } + + fsTempDir, err := restoreFS.MkdirTemp("", "proxsave-fstab-") + if err != nil { + t.Fatalf("MkdirTemp: %v", err) + } + t.Cleanup(func() { _ = restoreFS.RemoveAll(fsTempDir) }) + + logger := logging.New(types.LogLevelDebug, false) + categories := []Category{{ + ID: "filesystem", + Name: "Filesystem Configuration", + Paths: []string{ + "./etc/fstab", + }, + }} + if _, err := extractSelectiveArchive(context.Background(), "/bundle.tar", fsTempDir, categories, RestoreModeCustom, logger); err != nil { + t.Fatalf("extractSelectiveArchive: %v", err) + } + + backupFstab := filepath.Join(fsTempDir, "etc", "fstab") + data, err := fakeFS.ReadFile(backupFstab) + if err != nil { + t.Fatalf("expected extracted backup fstab: %v", err) + } + if string(data) != "UUID=root / ext4 defaults 0 1\n" { + t.Fatalf("fstab=%q; want %q", string(data), "UUID=root / ext4 defaults 0 1\n") + } +} + func TestParsePoolNameFromUnit(t *testing.T) { tests := []struct { name string diff --git a/internal/orchestrator/restore_tui.go b/internal/orchestrator/restore_tui.go index 8b41afdc..48618c28 100644 --- a/internal/orchestrator/restore_tui.go +++ b/internal/orchestrator/restore_tui.go @@ -50,7 +50,7 @@ func RunRestoreWorkflowTUI(ctx context.Context, cfg *config.Config, logger *logg return nil } -func selectRestoreModeTUI(systemType SystemType, configPath, buildSig, backupSummary string) (RestoreMode, error) { +func selectRestoreModeTUI(ctx context.Context, systemType SystemType, configPath, buildSig, backupSummary string) (RestoreMode, error) { app := newTUIApp() var selected RestoreMode var aborted bool @@ -134,7 +134,7 @@ func selectRestoreModeTUI(systemType SystemType, configPath, buildSig, backupSum page := buildRestoreWizardPage("Select restore mode", configPath, buildSig, content) app.SetRoot(page, true).SetFocus(form.Form) - if err := app.Run(); err != nil { + if err := app.RunWithContext(ctx); err != nil { return "", err } if aborted || selected == "" { @@ -143,7 +143,7 @@ func selectRestoreModeTUI(systemType SystemType, configPath, buildSig, backupSum return selected, nil } -func selectPBSRestoreBehaviorTUI(configPath, buildSig, backupSummary string) (PBSRestoreBehavior, error) { +func selectPBSRestoreBehaviorTUI(ctx context.Context, configPath, buildSig, backupSummary string) (PBSRestoreBehavior, error) { app := newTUIApp() var selected PBSRestoreBehavior var aborted bool @@ -218,7 +218,7 @@ func selectPBSRestoreBehaviorTUI(configPath, buildSig, backupSummary string) (PB page := buildRestoreWizardPage("PBS restore behavior", configPath, buildSig, content) app.SetRoot(page, true).SetFocus(form.Form) - if err := app.Run(); err != nil { + if err := app.RunWithContext(ctx); err != nil { return PBSRestoreBehaviorUnspecified, err } if aborted || selected == PBSRestoreBehaviorUnspecified { @@ -253,7 +253,7 @@ func filterAndSortCategoriesForSystem(available []Category, systemType SystemTyp return relevant } -func selectCategoriesTUI(available []Category, systemType SystemType, configPath, buildSig string) ([]Category, error) { +func selectCategoriesTUI(ctx context.Context, available []Category, systemType SystemType, configPath, buildSig string) ([]Category, error) { relevant := filterAndSortCategoriesForSystem(available, systemType) if len(relevant) == 0 { @@ -345,7 +345,8 @@ func selectCategoriesTUI(available []Category, systemType SystemType, configPath page := buildRestoreWizardPage("Select restore categories", configPath, buildSig, content) form.SetParentView(page) - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { return nil, err } if goBack { @@ -360,9 +361,13 @@ func selectCategoriesTUI(available []Category, systemType SystemType, configPath return chosen, nil } -func promptCompatibilityTUI(configPath, buildSig string, compatErr error) (bool, error) { - message := fmt.Sprintf("Compatibility check reported:\n\n[red]%v[white]\n\nContinuing may cause system instability.\n\nDo you want to continue anyway?", compatErr) +func promptCompatibilityTUI(ctx context.Context, configPath, buildSig string, compatErr error) (bool, error) { + message := fmt.Sprintf( + "Compatibility check reported:\n\n[red]%s[white]\n\nContinuing may cause system instability.\n\nDo you want to continue anyway?", + tview.Escape(fmt.Sprint(compatErr)), + ) return promptYesNoTUIFunc( + ctx, "Compatibility warning", configPath, buildSig, @@ -372,9 +377,13 @@ func promptCompatibilityTUI(configPath, buildSig string, compatErr error) (bool, ) } -func promptContinueWithoutSafetyBackupTUI(configPath, buildSig string, cause error) (bool, error) { - message := fmt.Sprintf("Failed to create safety backup:\n\n[red]%v[white]\n\nWithout a safety backup, it will be harder to rollback changes.\n\nContinue without safety backup?", cause) +func promptContinueWithoutSafetyBackupTUI(ctx context.Context, configPath, buildSig string, cause error) (bool, error) { + message := fmt.Sprintf( + "Failed to create safety backup:\n\n[red]%s[white]\n\nWithout a safety backup, it will be harder to rollback changes.\n\nContinue without safety backup?", + tview.Escape(fmt.Sprint(cause)), + ) return promptYesNoTUIFunc( + ctx, "Safety backup failed", configPath, buildSig, @@ -384,9 +393,10 @@ func promptContinueWithoutSafetyBackupTUI(configPath, buildSig string, cause err ) } -func promptContinueWithPBSServicesTUI(configPath, buildSig string) (bool, error) { +func promptContinueWithPBSServicesTUI(ctx context.Context, configPath, buildSig string) (bool, error) { message := "Unable to stop Proxmox Backup Server services automatically.\n\nContinuing the restore while services are running may lead to inconsistent state.\n\nContinue restore with PBS services still running?" return promptYesNoTUIFunc( + ctx, "PBS services running", configPath, buildSig, @@ -432,6 +442,7 @@ func maybeRepairNICNamesTUI(ctx context.Context, logger *logging.Logger, archive b.WriteString("Skip NIC name repair and keep restored interface names?") skip, err := promptYesNoTUIFunc( + ctx, "NIC naming overrides", configPath, buildSig, @@ -471,6 +482,7 @@ func maybeRepairNICNamesTUI(ctx context.Context, logger *logging.Logger, archive b.WriteString("\nApply NIC rename mapping even for conflicts?") ok, err := promptYesNoTUIFunc( + ctx, "NIC name conflicts", configPath, buildSig, @@ -498,7 +510,7 @@ func maybeRepairNICNamesTUI(ctx context.Context, logger *logging.Logger, archive return result } -func promptClusterRestoreModeTUI(configPath, buildSig string) (int, error) { +func promptClusterRestoreModeTUI(ctx context.Context, configPath, buildSig string) (int, error) { app := newTUIApp() var choice int var aborted bool @@ -545,7 +557,7 @@ func promptClusterRestoreModeTUI(configPath, buildSig string) (int, error) { page := buildRestoreWizardPage("Cluster restore mode", configPath, buildSig, form.Form) app.SetRoot(page, true).SetFocus(form.Form) - if err := app.Run(); err != nil { + if err := app.RunWithContext(ctx); err != nil { return 0, err } if aborted { @@ -612,7 +624,7 @@ func buildRestorePlanText(config *SelectiveRestoreConfig) string { return b.String() } -func showRestorePlanTUI(config *SelectiveRestoreConfig, configPath, buildSig string) error { +func showRestorePlanTUI(ctx context.Context, config *SelectiveRestoreConfig, configPath, buildSig string) error { if config == nil { return fmt.Errorf("restore configuration not available") } @@ -648,7 +660,8 @@ func showRestorePlanTUI(config *SelectiveRestoreConfig, configPath, buildSig str page := buildRestoreWizardPage("Restore plan", configPath, buildSig, content) form.SetParentView(page) - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { return err } if aborted || !proceed { @@ -657,7 +670,7 @@ func showRestorePlanTUI(config *SelectiveRestoreConfig, configPath, buildSig str return nil } -func confirmRestoreTUI(configPath, buildSig string) (bool, error) { +func confirmRestoreTUI(ctx context.Context, configPath, buildSig string) (bool, error) { app := newTUIApp() var confirmed bool var aborted bool @@ -689,7 +702,8 @@ func confirmRestoreTUI(configPath, buildSig string) (bool, error) { page := buildRestoreWizardPage("Confirm restore", configPath, buildSig, content) form.SetParentView(page) - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { return false, err } if aborted { @@ -699,7 +713,7 @@ func confirmRestoreTUI(configPath, buildSig string) (bool, error) { return false, ErrRestoreAborted } // Second-stage explicit overwrite confirmation - ok, err := confirmOverwriteTUI(configPath, buildSig) + ok, err := confirmOverwriteTUI(ctx, configPath, buildSig) if err != nil { return false, err } @@ -709,7 +723,7 @@ func confirmRestoreTUI(configPath, buildSig string) (bool, error) { return true, nil } -func promptYesNoTUI(title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { +func promptYesNoTUI(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { app := newTUIApp() var result bool var cancelled bool @@ -740,7 +754,8 @@ func promptYesNoTUI(title, configPath, buildSig, message, yesLabel, noLabel stri page := buildRestoreWizardPage(title, configPath, buildSig, content) form.SetParentView(page) - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { return false, err } if cancelled { @@ -809,10 +824,6 @@ func promptYesNoTUIWithCountdown(ctx context.Context, logger *logging.Logger, ti select { case <-stopCh: return - case <-ctx.Done(): - cancelled = true - app.Stop() - return case <-ticker.C: left := time.Until(deadline) if left <= 0 { @@ -827,7 +838,8 @@ func promptYesNoTUIWithCountdown(ctx context.Context, logger *logging.Logger, ti }() } - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { return false, err } if timedOut { @@ -840,7 +852,7 @@ func promptYesNoTUIWithCountdown(ctx context.Context, logger *logging.Logger, ti return result, nil } -func promptNetworkCommitTUI(timeout time.Duration, health networkHealthReport, nicRepair *nicRepairResult, diagnosticsDir, configPath, buildSig string) (bool, error) { +func promptNetworkCommitTUI(ctx context.Context, timeout time.Duration, health networkHealthReport, nicRepair *nicRepairResult, diagnosticsDir, configPath, buildSig string) (bool, error) { app := newTUIApp() var committed bool var cancelled bool @@ -873,7 +885,13 @@ func promptNetworkCommitTUI(timeout time.Duration, health networkHealthReport, n var b strings.Builder for _, check := range report.Checks { color := healthColor(check.Severity) - b.WriteString(fmt.Sprintf("- [%s]%s[white] %s: %s\n", color, check.Severity.String(), check.Name, check.Message)) + b.WriteString(fmt.Sprintf( + "- [%s]%s[white] %s: %s\n", + color, + check.Severity.String(), + tview.Escape(check.Name), + tview.Escape(check.Message), + )) } return strings.TrimRight(b.String(), "\n") } @@ -886,7 +904,7 @@ func promptNetworkCommitTUI(timeout time.Duration, health networkHealthReport, n return fmt.Sprintf("NIC repair: [green]APPLIED[white] (%d file(s))", len(r.ChangedFiles)) } if r.SkippedReason != "" { - return fmt.Sprintf("NIC repair: [yellow]SKIPPED[white] (%s)", r.SkippedReason) + return fmt.Sprintf("NIC repair: [yellow]SKIPPED[white] (%s)", tview.Escape(r.SkippedReason)) } return "" } @@ -897,7 +915,7 @@ func promptNetworkCommitTUI(timeout time.Duration, health networkHealthReport, n } var b strings.Builder for _, m := range r.AppliedNICMap { - b.WriteString(fmt.Sprintf("- %s -> %s\n", m.OldName, m.NewName)) + b.WriteString(fmt.Sprintf("- %s -> %s\n", tview.Escape(m.OldName), tview.Escape(m.NewName))) } return strings.TrimRight(b.String(), "\n") } @@ -918,7 +936,7 @@ func promptNetworkCommitTUI(timeout time.Duration, health networkHealthReport, n diagInfo := "" if strings.TrimSpace(diagnosticsDir) != "" { - diagInfo = fmt.Sprintf("\n\nDiagnostics saved under:\n%s", diagnosticsDir) + diagInfo = fmt.Sprintf("\n\nDiagnostics saved under:\n%s", tview.Escape(diagnosticsDir)) } infoText.SetText(fmt.Sprintf("Rollback in [yellow]%ds[white] (deadline: [yellow]%s[white]).\n\n%sNetwork health: [%s]%s[white]\n%s%s\n\nType COMMIT or press the button to keep the new network configuration.\nIf you do nothing, rollback will be automatic.", @@ -977,9 +995,11 @@ func promptNetworkCommitTUI(timeout time.Duration, health networkHealthReport, n } }() - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { close(stopCh) ticker.Stop() + <-done return false, err } close(stopCh) @@ -992,9 +1012,10 @@ func promptNetworkCommitTUI(timeout time.Duration, health networkHealthReport, n return committed, nil } -func confirmOverwriteTUI(configPath, buildSig string) (bool, error) { +func confirmOverwriteTUI(ctx context.Context, configPath, buildSig string) (bool, error) { message := "This operation will overwrite existing configuration files on this system.\n\nAre you sure you want to proceed with the restore?" return promptYesNoTUIFunc( + ctx, "Confirm overwrite", configPath, buildSig, @@ -1005,53 +1026,14 @@ func confirmOverwriteTUI(configPath, buildSig string) (bool, error) { } func buildRestoreWizardPage(title, configPath, buildSig string, content tview.Primitive) tview.Primitive { - welcomeText := tview.NewTextView(). - SetText(fmt.Sprintf("ProxSave - By TIS24DEV\n%s\n", restoreWizardSubtitle)). - SetTextColor(tui.ProxmoxLight). - SetDynamicColors(true) - welcomeText.SetBorder(false) - - navInstructions := tview.NewTextView(). - SetText("\n" + restoreNavText). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - navInstructions.SetBorder(false) - - separator := tview.NewTextView(). - SetText(strings.Repeat("─", 80)). - SetTextColor(tui.ProxmoxOrange) - separator.SetBorder(false) - - configPathText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", configPath)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - configPathText.SetBorder(false) - - buildSigText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - buildSigText.SetBorder(false) - - flex := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(welcomeText, 5, 0, false). - AddItem(navInstructions, 2, 0, false). - AddItem(separator, 1, 0, false). - AddItem(content, 0, 1, true). - AddItem(configPathText, 1, 0, false). - AddItem(buildSigText, 1, 0, false) - - flex.SetBorder(true). - SetTitle(fmt.Sprintf(" %s ", title)). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.ProxmoxOrange). - SetBorderColor(tui.ProxmoxOrange). - SetBackgroundColor(tcell.ColorBlack) - - return flex + return tui.BuildScreen(tui.ScreenSpec{ + Title: title, + HeaderText: fmt.Sprintf("ProxSave - By TIS24DEV\n%s\n", restoreWizardSubtitle), + NavText: restoreNavText, + ConfigPath: configPath, + BuildSig: buildSig, + TitleColor: tui.ProxmoxOrange, + BorderColor: tui.ProxmoxOrange, + BackgroundColor: tcell.ColorBlack, + }, content) } diff --git a/internal/orchestrator/restore_tui_simulation_test.go b/internal/orchestrator/restore_tui_simulation_test.go index 2ec843c3..519dd649 100644 --- a/internal/orchestrator/restore_tui_simulation_test.go +++ b/internal/orchestrator/restore_tui_simulation_test.go @@ -1,6 +1,7 @@ package orchestrator import ( + "context" "testing" "github.com/gdamore/tcell/v2" @@ -9,7 +10,7 @@ import ( func TestPromptYesNoTUI_YesReturnsTrue(t *testing.T) { withSimApp(t, []tcell.Key{tcell.KeyEnter}) - ok, err := promptYesNoTUI("Title", "/tmp/config.env", "sig", "Message", "Yes", "No") + ok, err := promptYesNoTUI(context.Background(), "Title", "/tmp/config.env", "sig", "Message", "Yes", "No") if err != nil { t.Fatalf("promptYesNoTUI error: %v", err) } @@ -21,7 +22,7 @@ func TestPromptYesNoTUI_YesReturnsTrue(t *testing.T) { func TestPromptYesNoTUI_NoReturnsFalse(t *testing.T) { withSimApp(t, []tcell.Key{tcell.KeyTab, tcell.KeyEnter}) - ok, err := promptYesNoTUI("Title", "/tmp/config.env", "sig", "Message", "Yes", "No") + ok, err := promptYesNoTUI(context.Background(), "Title", "/tmp/config.env", "sig", "Message", "Yes", "No") if err != nil { t.Fatalf("promptYesNoTUI error: %v", err) } @@ -40,7 +41,7 @@ func TestShowRestorePlanTUI_ContinueReturnsNil(t *testing.T) { {Name: "Alpha", Type: CategoryTypePVE, Description: "First", Paths: []string{"./etc/alpha"}}, }, } - if err := showRestorePlanTUI(cfg, "/tmp/config.env", "sig"); err != nil { + if err := showRestorePlanTUI(context.Background(), cfg, "/tmp/config.env", "sig"); err != nil { t.Fatalf("showRestorePlanTUI error: %v", err) } } @@ -55,21 +56,21 @@ func TestShowRestorePlanTUI_CancelReturnsAborted(t *testing.T) { {Name: "Alpha", Type: CategoryTypePVE, Description: "First", Paths: []string{"./etc/alpha"}}, }, } - err := showRestorePlanTUI(cfg, "/tmp/config.env", "sig") + err := showRestorePlanTUI(context.Background(), cfg, "/tmp/config.env", "sig") if err != ErrRestoreAborted { t.Fatalf("err=%v; want %v", err, ErrRestoreAborted) } } func TestConfirmRestoreTUI_ConfirmedAndOverwriteReturnsTrue(t *testing.T) { - restore := stubPromptYesNo(func(title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { + restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { return true, nil }) defer restore() withSimApp(t, []tcell.Key{tcell.KeyEnter}) - ok, err := confirmRestoreTUI("/tmp/config.env", "sig") + ok, err := confirmRestoreTUI(context.Background(), "/tmp/config.env", "sig") if err != nil { t.Fatalf("confirmRestoreTUI error: %v", err) } @@ -79,14 +80,14 @@ func TestConfirmRestoreTUI_ConfirmedAndOverwriteReturnsTrue(t *testing.T) { } func TestConfirmRestoreTUI_OverwriteDeclinedReturnsFalse(t *testing.T) { - restore := stubPromptYesNo(func(title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { + restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { return false, nil }) defer restore() withSimApp(t, []tcell.Key{tcell.KeyEnter}) - ok, err := confirmRestoreTUI("/tmp/config.env", "sig") + ok, err := confirmRestoreTUI(context.Background(), "/tmp/config.env", "sig") if err != nil { t.Fatalf("confirmRestoreTUI error: %v", err) } @@ -108,7 +109,7 @@ func TestSelectCategoriesTUI_SelectsAtLeastOne(t *testing.T) { tcell.KeyEnter, // submit }) - got, err := selectCategoriesTUI(available, SystemTypePVE, "/tmp/config.env", "sig") + got, err := selectCategoriesTUI(context.Background(), available, SystemTypePVE, "/tmp/config.env", "sig") if err != nil { t.Fatalf("selectCategoriesTUI error: %v", err) } @@ -123,7 +124,7 @@ func TestSelectCategoriesTUI_BackReturnsErrRestoreBackToMode(t *testing.T) { } withSimApp(t, []tcell.Key{tcell.KeyTab, tcell.KeyEnter}) - _, err := selectCategoriesTUI(available, SystemTypePVE, "/tmp/config.env", "sig") + _, err := selectCategoriesTUI(context.Background(), available, SystemTypePVE, "/tmp/config.env", "sig") if err != errRestoreBackToMode { t.Fatalf("err=%v; want %v", err, errRestoreBackToMode) } @@ -140,7 +141,7 @@ func TestSelectCategoriesTUI_CancelReturnsAborted(t *testing.T) { tcell.KeyEnter, }) - _, err := selectCategoriesTUI(available, SystemTypePVE, "/tmp/config.env", "sig") + _, err := selectCategoriesTUI(context.Background(), available, SystemTypePVE, "/tmp/config.env", "sig") if err != ErrRestoreAborted { t.Fatalf("err=%v; want %v", err, ErrRestoreAborted) } diff --git a/internal/orchestrator/restore_tui_test.go b/internal/orchestrator/restore_tui_test.go index d7f1523a..d93b321a 100644 --- a/internal/orchestrator/restore_tui_test.go +++ b/internal/orchestrator/restore_tui_test.go @@ -1,6 +1,7 @@ package orchestrator import ( + "context" "errors" "strings" "testing" @@ -84,7 +85,7 @@ func TestBuildRestoreWizardPageReturnsFlex(t *testing.T) { } func TestPromptCompatibilityTUIUsesWarningText(t *testing.T) { - restore := stubPromptYesNo(func(title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { + restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { if title != "Compatibility warning" { t.Fatalf("unexpected title %q", title) } @@ -98,14 +99,29 @@ func TestPromptCompatibilityTUIUsesWarningText(t *testing.T) { }) defer restore() - ok, err := promptCompatibilityTUI("cfg", "sig", errors.New("boom")) + ok, err := promptCompatibilityTUI(context.Background(), "cfg", "sig", errors.New("boom")) + if err != nil || !ok { + t.Fatalf("promptCompatibilityTUI returned %v, %v", ok, err) + } +} + +func TestPromptCompatibilityTUIEscapesBracketedWarningText(t *testing.T) { + restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { + if !strings.Contains(message, tview.Escape("bad [warning]")) { + t.Fatalf("expected escaped bracketed warning, got %q", message) + } + return true, nil + }) + defer restore() + + ok, err := promptCompatibilityTUI(context.Background(), "cfg", "sig", errors.New("bad [warning]")) if err != nil || !ok { t.Fatalf("promptCompatibilityTUI returned %v, %v", ok, err) } } func TestPromptContinueWithoutSafetyBackupTUI(t *testing.T) { - restore := stubPromptYesNo(func(title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { + restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { if title != "Safety backup failed" { t.Fatalf("unexpected title %q", title) } @@ -116,7 +132,25 @@ func TestPromptContinueWithoutSafetyBackupTUI(t *testing.T) { }) defer restore() - ok, err := promptContinueWithoutSafetyBackupTUI("cfg", "sig", errors.New("failure")) + ok, err := promptContinueWithoutSafetyBackupTUI(context.Background(), "cfg", "sig", errors.New("failure")) + if err != nil { + t.Fatalf("promptContinueWithoutSafetyBackupTUI error: %v", err) + } + if ok { + t.Fatalf("expected false decision") + } +} + +func TestPromptContinueWithoutSafetyBackupTUIEscapesBracketedCause(t *testing.T) { + restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { + if !strings.Contains(message, tview.Escape("bad [cause]")) { + t.Fatalf("expected escaped bracketed cause, got %q", message) + } + return false, nil + }) + defer restore() + + ok, err := promptContinueWithoutSafetyBackupTUI(context.Background(), "cfg", "sig", errors.New("bad [cause]")) if err != nil { t.Fatalf("promptContinueWithoutSafetyBackupTUI error: %v", err) } @@ -126,7 +160,7 @@ func TestPromptContinueWithoutSafetyBackupTUI(t *testing.T) { } func TestPromptContinueWithPBSServicesTUI(t *testing.T) { - restore := stubPromptYesNo(func(title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { + restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { if title != "PBS services running" { t.Fatalf("unexpected title %q", title) } @@ -137,14 +171,14 @@ func TestPromptContinueWithPBSServicesTUI(t *testing.T) { }) defer restore() - ok, err := promptContinueWithPBSServicesTUI("cfg", "sig") + ok, err := promptContinueWithPBSServicesTUI(context.Background(), "cfg", "sig") if err != nil || !ok { t.Fatalf("promptContinueWithPBSServicesTUI returned %v, %v", ok, err) } } func TestConfirmOverwriteTUI(t *testing.T) { - restore := stubPromptYesNo(func(title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { + restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { if title != "Confirm overwrite" { t.Fatalf("unexpected title %q", title) } @@ -158,13 +192,13 @@ func TestConfirmOverwriteTUI(t *testing.T) { }) defer restore() - ok, err := confirmOverwriteTUI("cfg", "sig") + ok, err := confirmOverwriteTUI(context.Background(), "cfg", "sig") if err != nil || !ok { t.Fatalf("confirmOverwriteTUI returned %v, %v", ok, err) } } -func stubPromptYesNo(fn func(title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error)) func() { +func stubPromptYesNo(fn func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error)) func() { orig := promptYesNoTUIFunc promptYesNoTUIFunc = fn return func() { promptYesNoTUIFunc = orig } diff --git a/internal/orchestrator/restore_workflow_integration_test.go b/internal/orchestrator/restore_workflow_integration_test.go index 294c8d5c..0835aa2a 100644 --- a/internal/orchestrator/restore_workflow_integration_test.go +++ b/internal/orchestrator/restore_workflow_integration_test.go @@ -58,9 +58,7 @@ func TestRunSafeClusterApply_PveshNotFound(t *testing.T) { reader := bufio.NewReader(strings.NewReader("0\n")) // Force PATH empty so LookPath fails - origPath := os.Getenv("PATH") t.Setenv("PATH", "") - defer os.Setenv("PATH", origPath) if err := runSafeClusterApply(context.Background(), reader, t.TempDir(), logger); err != nil { t.Fatalf("expected nil when pvesh missing, got %v", err) diff --git a/internal/orchestrator/storage_adapter.go b/internal/orchestrator/storage_adapter.go index 6e473bf1..344d4ac7 100644 --- a/internal/orchestrator/storage_adapter.go +++ b/internal/orchestrator/storage_adapter.go @@ -252,6 +252,7 @@ func (s *StorageAdapter) applyStorageStats(storageStats *storage.StorageStats, r case storage.LocationPrimary: stats.LocalBackups = storageStats.TotalBackups stats.LocalFreeSpace = clampInt64ToUint64(storageStats.AvailableSpace) + stats.LocalUsedSpace = clampInt64ToUint64(storageStats.UsedSpace) stats.LocalTotalSpace = clampInt64ToUint64(storageStats.TotalSpace) // Populate retention info stats.LocalRetentionPolicy = retentionConfig.Policy @@ -273,6 +274,7 @@ func (s *StorageAdapter) applyStorageStats(storageStats *storage.StorageStats, r } stats.SecondaryBackups = storageStats.TotalBackups stats.SecondaryFreeSpace = clampInt64ToUint64(storageStats.AvailableSpace) + stats.SecondaryUsedSpace = clampInt64ToUint64(storageStats.UsedSpace) stats.SecondaryTotalSpace = clampInt64ToUint64(storageStats.TotalSpace) // Populate retention info stats.SecondaryRetentionPolicy = retentionConfig.Policy diff --git a/internal/orchestrator/storage_adapter_test.go b/internal/orchestrator/storage_adapter_test.go index 9f6a6deb..0b3360bb 100644 --- a/internal/orchestrator/storage_adapter_test.go +++ b/internal/orchestrator/storage_adapter_test.go @@ -275,6 +275,7 @@ func TestStorageAdapterSync_NonCriticalStoreErrorFinalizesErrorAndContinues(t *t return &storage.StorageStats{ TotalBackups: 3, AvailableSpace: 10, + UsedSpace: 7, TotalSpace: 20, }, nil }, @@ -298,6 +299,9 @@ func TestStorageAdapterSync_NonCriticalStoreErrorFinalizesErrorAndContinues(t *t if stats.SecondaryBackups != 3 { t.Fatalf("SecondaryBackups = %d; want 3", stats.SecondaryBackups) } + if stats.SecondaryUsedSpace != 7 { + t.Fatalf("SecondaryUsedSpace = %d; want 7", stats.SecondaryUsedSpace) + } if stats.SecondaryRetentionPolicy != "simple" { t.Fatalf("SecondaryRetentionPolicy = %q; want simple", stats.SecondaryRetentionPolicy) } @@ -325,6 +329,7 @@ func TestStorageAdapterSync_NonCriticalRetentionErrorFinalizesWarning(t *testing return &storage.StorageStats{ TotalBackups: 1, AvailableSpace: 5, + UsedSpace: 4, TotalSpace: 10, }, nil }, @@ -340,6 +345,9 @@ func TestStorageAdapterSync_NonCriticalRetentionErrorFinalizesWarning(t *testing if got := stats.LocalStatus; got != "warning" { t.Fatalf("LocalStatus = %q; want warning", got) } + if stats.LocalUsedSpace != 4 { + t.Fatalf("LocalUsedSpace = %d; want 4", stats.LocalUsedSpace) + } if stats.LocalRetentionPolicy != "simple" { t.Fatalf("LocalRetentionPolicy = %q; want simple", stats.LocalRetentionPolicy) } diff --git a/internal/orchestrator/telegram_setup_bootstrap.go b/internal/orchestrator/telegram_setup_bootstrap.go new file mode 100644 index 00000000..1e5429cf --- /dev/null +++ b/internal/orchestrator/telegram_setup_bootstrap.go @@ -0,0 +1,104 @@ +package orchestrator + +import ( + "os" + "strings" + + "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/identity" +) + +const defaultTelegramServerAPIHost = "https://bot.tis24.it:1443" + +type TelegramSetupEligibility int + +const ( + TelegramSetupEligibilityUnknown TelegramSetupEligibility = iota + TelegramSetupEligibleCentralized + TelegramSetupSkipDisabled + TelegramSetupSkipConfigError + TelegramSetupSkipPersonalMode + TelegramSetupSkipIdentityUnavailable +) + +type TelegramSetupBootstrap struct { + Eligibility TelegramSetupEligibility + + ConfigLoaded bool + ConfigError string + + TelegramEnabled bool + TelegramMode string + ServerAPIHost string + + ServerID string + IdentityFile string + IdentityPersisted bool + IdentityDetectError string +} + +var ( + telegramSetupBootstrapLoadConfig = config.LoadConfig + telegramSetupBootstrapIdentityDetect = identity.Detect + telegramSetupBootstrapStat = os.Stat +) + +func BuildTelegramSetupBootstrap(configPath, baseDir string) (TelegramSetupBootstrap, error) { + state := TelegramSetupBootstrap{} + + cfg, err := telegramSetupBootstrapLoadConfig(configPath) + if err != nil { + state.Eligibility = TelegramSetupSkipConfigError + state.ConfigError = err.Error() + return state, nil + } + + state.ConfigLoaded = true + if cfg != nil { + state.TelegramEnabled = cfg.TelegramEnabled + state.TelegramMode = strings.ToLower(strings.TrimSpace(cfg.TelegramBotType)) + state.ServerAPIHost = strings.TrimSpace(cfg.TelegramServerAPIHost) + } + + if !state.TelegramEnabled { + state.Eligibility = TelegramSetupSkipDisabled + return state, nil + } + + if state.TelegramMode == "" { + state.TelegramMode = "centralized" + } + if state.ServerAPIHost == "" { + state.ServerAPIHost = defaultTelegramServerAPIHost + } + + if state.TelegramMode == "personal" { + state.Eligibility = TelegramSetupSkipPersonalMode + return state, nil + } + + info, err := telegramSetupBootstrapIdentityDetect(baseDir, nil) + if err != nil { + state.Eligibility = TelegramSetupSkipIdentityUnavailable + state.IdentityDetectError = err.Error() + return state, nil + } + + if info != nil { + state.ServerID = strings.TrimSpace(info.ServerID) + state.IdentityFile = strings.TrimSpace(info.IdentityFile) + if state.IdentityFile != "" { + if _, statErr := telegramSetupBootstrapStat(state.IdentityFile); statErr == nil { + state.IdentityPersisted = true + } + } + } + + if state.ServerID == "" { + state.Eligibility = TelegramSetupSkipIdentityUnavailable + return state, nil + } + + state.Eligibility = TelegramSetupEligibleCentralized + return state, nil +} diff --git a/internal/orchestrator/telegram_setup_bootstrap_test.go b/internal/orchestrator/telegram_setup_bootstrap_test.go new file mode 100644 index 00000000..ad655665 --- /dev/null +++ b/internal/orchestrator/telegram_setup_bootstrap_test.go @@ -0,0 +1,212 @@ +package orchestrator + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/identity" + "github.com/tis24dev/proxsave/internal/logging" +) + +func stubTelegramSetupBootstrapDeps(t *testing.T) { + t.Helper() + + origLoadConfig := telegramSetupBootstrapLoadConfig + origIdentityDetect := telegramSetupBootstrapIdentityDetect + origStat := telegramSetupBootstrapStat + + t.Cleanup(func() { + telegramSetupBootstrapLoadConfig = origLoadConfig + telegramSetupBootstrapIdentityDetect = origIdentityDetect + telegramSetupBootstrapStat = origStat + }) +} + +func TestBuildTelegramSetupBootstrap_ConfigLoadFailureSkips(t *testing.T) { + stubTelegramSetupBootstrapDeps(t) + + telegramSetupBootstrapLoadConfig = func(path string) (*config.Config, error) { + return nil, errors.New("parse failed") + } + telegramSetupBootstrapIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { + t.Fatalf("identity detect should not run on config failure") + return nil, nil + } + + state, err := BuildTelegramSetupBootstrap("/fake/backup.env", t.TempDir()) + if err != nil { + t.Fatalf("BuildTelegramSetupBootstrap error: %v", err) + } + if state.Eligibility != TelegramSetupSkipConfigError { + t.Fatalf("Eligibility=%v, want %v", state.Eligibility, TelegramSetupSkipConfigError) + } + if state.ConfigError == "" { + t.Fatalf("expected ConfigError to be set") + } + if state.ConfigLoaded { + t.Fatalf("expected ConfigLoaded=false") + } +} + +func TestBuildTelegramSetupBootstrap_DisabledSkips(t *testing.T) { + stubTelegramSetupBootstrapDeps(t) + + telegramSetupBootstrapLoadConfig = func(path string) (*config.Config, error) { + return &config.Config{TelegramEnabled: false}, nil + } + telegramSetupBootstrapIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { + t.Fatalf("identity detect should not run when telegram is disabled") + return nil, nil + } + + state, err := BuildTelegramSetupBootstrap("/fake/backup.env", t.TempDir()) + if err != nil { + t.Fatalf("BuildTelegramSetupBootstrap error: %v", err) + } + if state.Eligibility != TelegramSetupSkipDisabled { + t.Fatalf("Eligibility=%v, want %v", state.Eligibility, TelegramSetupSkipDisabled) + } + if state.TelegramEnabled { + t.Fatalf("expected TelegramEnabled=false") + } +} + +func TestBuildTelegramSetupBootstrap_PersonalModeSkips(t *testing.T) { + stubTelegramSetupBootstrapDeps(t) + + telegramSetupBootstrapLoadConfig = func(path string) (*config.Config, error) { + return &config.Config{ + TelegramEnabled: true, + TelegramBotType: " Personal ", + TelegramServerAPIHost: "", + }, nil + } + telegramSetupBootstrapIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { + t.Fatalf("identity detect should not run in personal mode") + return nil, nil + } + + state, err := BuildTelegramSetupBootstrap("/fake/backup.env", t.TempDir()) + if err != nil { + t.Fatalf("BuildTelegramSetupBootstrap error: %v", err) + } + if state.Eligibility != TelegramSetupSkipPersonalMode { + t.Fatalf("Eligibility=%v, want %v", state.Eligibility, TelegramSetupSkipPersonalMode) + } + if state.TelegramMode != "personal" { + t.Fatalf("TelegramMode=%q, want personal", state.TelegramMode) + } + if state.ServerAPIHost != defaultTelegramServerAPIHost { + t.Fatalf("ServerAPIHost=%q, want %q", state.ServerAPIHost, defaultTelegramServerAPIHost) + } +} + +func TestBuildTelegramSetupBootstrap_IdentityErrorSkips(t *testing.T) { + stubTelegramSetupBootstrapDeps(t) + + telegramSetupBootstrapLoadConfig = func(path string) (*config.Config, error) { + return &config.Config{ + TelegramEnabled: true, + TelegramBotType: "centralized", + }, nil + } + telegramSetupBootstrapIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { + return nil, errors.New("detect failed") + } + + state, err := BuildTelegramSetupBootstrap("/fake/backup.env", t.TempDir()) + if err != nil { + t.Fatalf("BuildTelegramSetupBootstrap error: %v", err) + } + if state.Eligibility != TelegramSetupSkipIdentityUnavailable { + t.Fatalf("Eligibility=%v, want %v", state.Eligibility, TelegramSetupSkipIdentityUnavailable) + } + if state.IdentityDetectError == "" { + t.Fatalf("expected IdentityDetectError to be set") + } + if state.ServerAPIHost != defaultTelegramServerAPIHost { + t.Fatalf("ServerAPIHost=%q, want %q", state.ServerAPIHost, defaultTelegramServerAPIHost) + } +} + +func TestBuildTelegramSetupBootstrap_EmptyServerIDSkips(t *testing.T) { + stubTelegramSetupBootstrapDeps(t) + + telegramSetupBootstrapLoadConfig = func(path string) (*config.Config, error) { + return &config.Config{ + TelegramEnabled: true, + TelegramBotType: "centralized", + TelegramServerAPIHost: "https://api.example.test", + }, nil + } + telegramSetupBootstrapIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { + return &identity.Info{ServerID: " ", IdentityFile: " /tmp/id "}, nil + } + + state, err := BuildTelegramSetupBootstrap("/fake/backup.env", t.TempDir()) + if err != nil { + t.Fatalf("BuildTelegramSetupBootstrap error: %v", err) + } + if state.Eligibility != TelegramSetupSkipIdentityUnavailable { + t.Fatalf("Eligibility=%v, want %v", state.Eligibility, TelegramSetupSkipIdentityUnavailable) + } + if state.ServerID != "" { + t.Fatalf("ServerID=%q, want empty", state.ServerID) + } + if state.IdentityFile != "/tmp/id" { + t.Fatalf("IdentityFile=%q, want /tmp/id", state.IdentityFile) + } +} + +func TestBuildTelegramSetupBootstrap_EligibleCentralized(t *testing.T) { + stubTelegramSetupBootstrapDeps(t) + + identityFile := filepath.Join(t.TempDir(), ".server_identity") + if err := os.WriteFile(identityFile, []byte("id"), 0o600); err != nil { + t.Fatalf("write identity file: %v", err) + } + + telegramSetupBootstrapLoadConfig = func(path string) (*config.Config, error) { + return &config.Config{ + TelegramEnabled: true, + TelegramBotType: " ", + TelegramServerAPIHost: " https://api.example.test ", + }, nil + } + telegramSetupBootstrapIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { + return &identity.Info{ + ServerID: " 123456789 ", + IdentityFile: " " + identityFile + " ", + }, nil + } + telegramSetupBootstrapStat = os.Stat + + state, err := BuildTelegramSetupBootstrap("/fake/backup.env", t.TempDir()) + if err != nil { + t.Fatalf("BuildTelegramSetupBootstrap error: %v", err) + } + if state.Eligibility != TelegramSetupEligibleCentralized { + t.Fatalf("Eligibility=%v, want %v", state.Eligibility, TelegramSetupEligibleCentralized) + } + if !state.ConfigLoaded { + t.Fatalf("expected ConfigLoaded=true") + } + if state.TelegramMode != "centralized" { + t.Fatalf("TelegramMode=%q, want centralized", state.TelegramMode) + } + if state.ServerAPIHost != "https://api.example.test" { + t.Fatalf("ServerAPIHost=%q, want https://api.example.test", state.ServerAPIHost) + } + if state.ServerID != "123456789" { + t.Fatalf("ServerID=%q, want 123456789", state.ServerID) + } + if state.IdentityFile != identityFile { + t.Fatalf("IdentityFile=%q, want %q", state.IdentityFile, identityFile) + } + if !state.IdentityPersisted { + t.Fatalf("expected IdentityPersisted=true") + } +} diff --git a/internal/orchestrator/tui_screen_env.go b/internal/orchestrator/tui_screen_env.go new file mode 100644 index 00000000..879a5083 --- /dev/null +++ b/internal/orchestrator/tui_screen_env.go @@ -0,0 +1,24 @@ +package orchestrator + +import ( + "github.com/rivo/tview" + + "github.com/tis24dev/proxsave/internal/logging" +) + +type tuiPageBuilder func(title, configPath, buildSig string, content tview.Primitive) tview.Primitive + +type tuiScreenEnv struct { + configPath string + buildSig string + logger *logging.Logger + buildPage tuiPageBuilder +} + +func (e tuiScreenEnv) page(title string, content tview.Primitive) tview.Primitive { + buildPage := e.buildPage + if buildPage == nil { + buildPage = buildWizardPage + } + return buildPage(title, e.configPath, e.buildSig, content) +} diff --git a/internal/orchestrator/tui_simulation_test.go b/internal/orchestrator/tui_simulation_test.go index 27dd3d0a..b846286d 100644 --- a/internal/orchestrator/tui_simulation_test.go +++ b/internal/orchestrator/tui_simulation_test.go @@ -1,21 +1,27 @@ package orchestrator import ( + "context" + "errors" + "sync" "testing" "time" "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" "github.com/tis24dev/proxsave/internal/tui" ) +const simAppInitialDrawTimeout = 2 * time.Second + type simKey struct { Key tcell.Key R rune Mod tcell.ModMask } -func withSimAppSequence(t *testing.T, keys []simKey) { +func withSimAppSequence(t *testing.T, keys []simKey) <-chan struct{} { t.Helper() orig := newTUIApp @@ -25,68 +31,300 @@ func withSimAppSequence(t *testing.T, keys []simKey) { } screen.SetSize(120, 40) + drawCh := make(chan struct{}, 8) + done := make(chan struct{}) + var injectOnce sync.Once + var injectWG sync.WaitGroup + newTUIApp = func() *tui.App { app := tui.NewApp() app.SetScreen(screen) + readyCh := make(chan struct{}) + var readyOnce sync.Once + app.SetAfterDrawFunc(func(screen tcell.Screen) { + readyOnce.Do(func() { + close(readyCh) + drawCh <- struct{}{} + }) + }) + + injectOnce.Do(func() { + injectWG.Add(1) + go func() { + defer injectWG.Done() + + timer := time.NewTimer(simAppInitialDrawTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + }() - go func() { - // Wait for app.Run() to start event processing. - time.Sleep(50 * time.Millisecond) - for _, k := range keys { - mod := k.Mod - if mod == 0 { - mod = tcell.ModNone + select { + case <-readyCh: + case <-done: + return + case <-timer.C: + return } - screen.InjectKey(k.Key, k.R, mod) - time.Sleep(10 * time.Millisecond) - } - }() + + for _, k := range keys { + mod := k.Mod + if mod == 0 { + mod = tcell.ModNone + } + select { + case <-done: + return + default: + } + screen.InjectKey(k.Key, k.R, mod) + } + }() + }) return app } t.Cleanup(func() { + close(done) + injectWG.Wait() newTUIApp = orig }) + + return drawCh } -func withSimApp(t *testing.T, keys []tcell.Key) { +func withSimApp(t *testing.T, keys []tcell.Key) <-chan struct{} { t.Helper() seq := make([]simKey, 0, len(keys)) for _, k := range keys { seq = append(seq, simKey{Key: k}) } - withSimAppSequence(t, seq) + return withSimAppSequence(t, seq) } func TestPromptOverwriteAction_SelectsOverwrite(t *testing.T) { withSimApp(t, []tcell.Key{tcell.KeyEnter}) - got, err := promptOverwriteAction("/tmp/existing", "file", "", "/tmp/config.env", "sig") + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + decision, newPath, err := promptExistingPathDecisionTUI(context.Background(), ui.screenEnv(), "/tmp/existing", "file", "") + if err != nil { + t.Fatalf("promptExistingPathDecisionTUI error: %v", err) + } + if decision != PathDecisionOverwrite { + t.Fatalf("decision=%v; want %v", decision, PathDecisionOverwrite) + } + if newPath != "" { + t.Fatalf("newPath=%q; want empty", newPath) + } +} + +func TestPromptNewPathInput_ContinueReturnsEditedPath(t *testing.T) { + withSimAppSequence(t, []simKey{ + {Key: tcell.KeyRune, R: '/'}, + {Key: tcell.KeyRune, R: 'a'}, + {Key: tcell.KeyRune, R: 'l'}, + {Key: tcell.KeyRune, R: 't'}, + {Key: tcell.KeyTab}, + {Key: tcell.KeyEnter}, + }) + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + got, err := promptNewPathInputTUI(context.Background(), ui.screenEnv(), "/tmp/newpath") if err != nil { - t.Fatalf("promptOverwriteAction error: %v", err) + t.Fatalf("promptNewPathInputTUI error: %v", err) } - if got != pathActionOverwrite { - t.Fatalf("choice=%q; want %q", got, pathActionOverwrite) + if got != "/tmp/newpath/alt" { + t.Fatalf("path=%q; want %q", got, "/tmp/newpath/alt") } } -func TestPromptNewPathInput_ContinueReturnsDefault(t *testing.T) { - // Move focus to Continue button then submit. - withSimApp(t, []tcell.Key{tcell.KeyTab, tcell.KeyEnter}) +func TestPromptNewPathInputTUI_UsesProvidedBuilder(t *testing.T) { + withSimAppSequence(t, []simKey{ + {Key: tcell.KeyRune, R: '/'}, + {Key: tcell.KeyRune, R: 'a'}, + {Key: tcell.KeyRune, R: 'l'}, + {Key: tcell.KeyRune, R: 't'}, + {Key: tcell.KeyTab}, + {Key: tcell.KeyEnter}, + }) + + ui := newTUIRestoreWorkflowUI("/tmp/config.env", "sig", nil) + builderCalls := 0 + var gotTitle, gotConfigPath, gotBuildSig string + ui.buildPage = func(title, configPath, buildSig string, content tview.Primitive) tview.Primitive { + builderCalls++ + gotTitle = title + gotConfigPath = configPath + gotBuildSig = buildSig + return buildRestoreWizardPage(title, configPath, buildSig, content) + } + + got, err := promptNewPathInputTUI(context.Background(), ui.screenEnv(), "/tmp/newpath") + if err != nil { + t.Fatalf("promptNewPathInputTUI error: %v", err) + } + if got != "/tmp/newpath/alt" { + t.Fatalf("path=%q; want %q", got, "/tmp/newpath/alt") + } + if builderCalls != 1 { + t.Fatalf("builderCalls=%d; want 1", builderCalls) + } + if gotTitle != "Choose destination path" { + t.Fatalf("title=%q; want %q", gotTitle, "Choose destination path") + } + if gotConfigPath != "/tmp/config.env" { + t.Fatalf("configPath=%q; want %q", gotConfigPath, "/tmp/config.env") + } + if gotBuildSig != "sig" { + t.Fatalf("buildSig=%q; want %q", gotBuildSig, "sig") + } +} + +func TestPromptExistingPathDecisionTUI_ContextCanceledWhileRunning(t *testing.T) { + drawCh := withSimAppSequence(t, nil) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-drawCh + cancel() + }() + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + _, _, err := promptExistingPathDecisionTUI(ctx, ui.screenEnv(), "/tmp/existing", "file", "") + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want %v", err, context.Canceled) + } +} + +func TestPromptExistingPathDecisionTUI_NewPathContextCanceledWhileRunning(t *testing.T) { + drawCh := withSimApp(t, []tcell.Key{tcell.KeyRight, tcell.KeyEnter}) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + // This flow produces two draw events: the first read waits for the initial + // promptExistingPathDecisionTUI dialog render, and the second waits for the + // secondary "new path" dialog opened after selecting that option. Cancel + // only after both drawCh reads so the test simulates context cancellation + // while the second dialog is already running. + <-drawCh + <-drawCh + cancel() + }() + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + _, _, err := promptExistingPathDecisionTUI(ctx, ui.screenEnv(), "/tmp/existing", "file", "") + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want %v", err, context.Canceled) + } +} + +func TestPromptDecryptSecretTUI_ContextCanceledWhileRunning(t *testing.T) { + drawCh := withSimAppSequence(t, nil) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-drawCh + cancel() + }() + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + _, err := promptDecryptSecretTUI(ctx, ui.screenEnv(), "backup", "") + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want %v", err, context.Canceled) + } +} + +func TestPromptExistingPathDecisionTUI_PassesBuilderToNestedPrompt(t *testing.T) { + withSimApp(t, []tcell.Key{tcell.KeyRight, tcell.KeyEnter}) + + ui := newTUIRestoreWorkflowUI("/tmp/config.env", "sig", nil) + builderCalls := 0 + var gotTitles []string + var gotConfigPaths []string + var gotBuildSigs []string + ui.buildPage = func(title, configPath, buildSig string, content tview.Primitive) tview.Primitive { + builderCalls++ + gotTitles = append(gotTitles, title) + gotConfigPaths = append(gotConfigPaths, configPath) + gotBuildSigs = append(gotBuildSigs, buildSig) + return buildRestoreWizardPage(title, configPath, buildSig, content) + } + restore := stubTUINewPathInputPrompt(func(ctx context.Context, env tuiScreenEnv, defaultPath string) (string, error) { + if page := env.page("Spy", tview.NewBox()); page == nil { + t.Fatalf("expected non-nil page") + } + return "/tmp/existing/alt", nil + }) + defer restore() - got, err := promptNewPathInput("/tmp/newpath", "/tmp/config.env", "sig") + decision, newPath, err := promptExistingPathDecisionTUI(context.Background(), ui.screenEnv(), "/tmp/existing", "file", "") if err != nil { - t.Fatalf("promptNewPathInput error: %v", err) + t.Fatalf("promptExistingPathDecisionTUI error: %v", err) + } + if decision != PathDecisionNewPath { + t.Fatalf("decision=%v; want %v", decision, PathDecisionNewPath) + } + if newPath != "/tmp/existing/alt" { + t.Fatalf("newPath=%q; want %q", newPath, "/tmp/existing/alt") + } + if builderCalls != 2 { + t.Fatalf("builderCalls=%d; want 2", builderCalls) + } + if gotTitles[0] != "Destination path" || gotTitles[1] != "Spy" { + t.Fatalf("titles=%v; want %v", gotTitles, []string{"Destination path", "Spy"}) + } + for i, configPath := range gotConfigPaths { + if configPath != "/tmp/config.env" { + t.Fatalf("configPath[%d]=%q; want %q", i, configPath, "/tmp/config.env") + } + } + for i, buildSig := range gotBuildSigs { + if buildSig != "sig" { + t.Fatalf("buildSig[%d]=%q; want %q", i, buildSig, "sig") + } + } +} + +func TestPromptDecryptSecretTUI_UsesProvidedBuilder(t *testing.T) { + withSimApp(t, []tcell.Key{tcell.KeyTab, tcell.KeyTab, tcell.KeyEnter}) + + ui := newTUIRestoreWorkflowUI("/tmp/config.env", "sig", nil) + builderCalls := 0 + var gotTitle, gotConfigPath, gotBuildSig string + ui.buildPage = func(title, configPath, buildSig string, content tview.Primitive) tview.Primitive { + builderCalls++ + gotTitle = title + gotConfigPath = configPath + gotBuildSig = buildSig + return buildRestoreWizardPage(title, configPath, buildSig, content) + } + + _, err := promptDecryptSecretTUI(context.Background(), ui.screenEnv(), "backup", "") + if err != ErrDecryptAborted { + t.Fatalf("err=%v; want %v", err, ErrDecryptAborted) + } + if builderCalls != 1 { + t.Fatalf("builderCalls=%d; want 1", builderCalls) + } + if gotTitle != "Decrypt key" { + t.Fatalf("title=%q; want %q", gotTitle, "Decrypt key") + } + if gotConfigPath != "/tmp/config.env" { + t.Fatalf("configPath=%q; want %q", gotConfigPath, "/tmp/config.env") } - if got != "/tmp/newpath" { - t.Fatalf("path=%q; want %q", got, "/tmp/newpath") + if gotBuildSig != "sig" { + t.Fatalf("buildSig=%q; want %q", gotBuildSig, "sig") } } func TestSelectRestoreModeTUI_SelectsStorage(t *testing.T) { withSimApp(t, []tcell.Key{tcell.KeyDown, tcell.KeyEnter}) - mode, err := selectRestoreModeTUI(SystemTypePVE, "/tmp/config.env", "sig", "backup") + mode, err := selectRestoreModeTUI(context.Background(), SystemTypePVE, "/tmp/config.env", "sig", "backup") if err != nil { t.Fatalf("selectRestoreModeTUI error: %v", err) } @@ -98,7 +336,7 @@ func TestSelectRestoreModeTUI_SelectsStorage(t *testing.T) { func TestPromptClusterRestoreModeTUI_SelectsRecovery(t *testing.T) { withSimApp(t, []tcell.Key{tcell.KeyDown, tcell.KeyEnter}) - choice, err := promptClusterRestoreModeTUI("/tmp/config.env", "sig") + choice, err := promptClusterRestoreModeTUI(context.Background(), "/tmp/config.env", "sig") if err != nil { t.Fatalf("promptClusterRestoreModeTUI error: %v", err) } @@ -111,7 +349,7 @@ func TestPromptClusterRestoreModeTUI_CancelAborts(t *testing.T) { // Switch focus to the Cancel button then submit. withSimApp(t, []tcell.Key{tcell.KeyTab, tcell.KeyEnter}) - _, err := promptClusterRestoreModeTUI("/tmp/config.env", "sig") + _, err := promptClusterRestoreModeTUI(context.Background(), "/tmp/config.env", "sig") if err == nil { t.Fatalf("expected abort error") } diff --git a/internal/orchestrator/workflow_ui_cli.go b/internal/orchestrator/workflow_ui_cli.go index 1d303c76..ec734558 100644 --- a/internal/orchestrator/workflow_ui_cli.go +++ b/internal/orchestrator/workflow_ui_cli.go @@ -133,8 +133,9 @@ func (u *cliWorkflowUI) ResolveExistingPath(ctx context.Context, path, descripti if err != nil { return PathDecisionCancel, "", err } - trimmed := strings.TrimSpace(newPath) - if trimmed == "" { + trimmed, err := validateDistinctNewPathInput(newPath, current) + if err != nil { + fmt.Println(err.Error()) continue } return PathDecisionNewPath, filepath.Clean(trimmed), nil diff --git a/internal/orchestrator/workflow_ui_cli_test.go b/internal/orchestrator/workflow_ui_cli_test.go new file mode 100644 index 00000000..bad0856e --- /dev/null +++ b/internal/orchestrator/workflow_ui_cli_test.go @@ -0,0 +1,84 @@ +package orchestrator + +import ( + "bufio" + "bytes" + "context" + "io" + "os" + "strings" + "testing" +) + +func captureCLIStdout(t *testing.T, fn func()) (captured string) { + t.Helper() + + oldStdout := os.Stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("os.Pipe: %v", err) + } + os.Stdout = w + t.Cleanup(func() { + os.Stdout = oldStdout + }) + + var buf bytes.Buffer + done := make(chan struct{}) + go func() { + _, _ = io.Copy(&buf, r) + close(done) + }() + defer func() { + os.Stdout = oldStdout + _ = w.Close() + <-done + _ = r.Close() + captured = buf.String() + }() + + fn() + return +} + +func TestCLIWorkflowUIResolveExistingPath_RejectsEquivalentNormalizedPath(t *testing.T) { + reader := bufio.NewReader(strings.NewReader("2\n/tmp/out/\n2\n /tmp/out/../alt \n")) + ui := newCLIWorkflowUI(reader, nil) + + var ( + decision ExistingPathDecision + newPath string + err error + ) + output := captureCLIStdout(t, func() { + decision, newPath, err = ui.ResolveExistingPath(context.Background(), "/tmp/out", "archive", "") + }) + if err != nil { + t.Fatalf("ResolveExistingPath error: %v", err) + } + if decision != PathDecisionNewPath { + t.Fatalf("decision=%v, want %v", decision, PathDecisionNewPath) + } + if newPath != "/tmp/alt" { + t.Fatalf("newPath=%q, want %q", newPath, "/tmp/alt") + } + if !strings.Contains(output, "path must be different from existing path") { + t.Fatalf("expected validation message in output, got %q", output) + } +} + +func TestCLIWorkflowUIResolveExistingPath_EmptyPathRetriesUntilValid(t *testing.T) { + reader := bufio.NewReader(strings.NewReader("2\n \n2\n/tmp/next\n")) + ui := newCLIWorkflowUI(reader, nil) + + decision, newPath, err := ui.ResolveExistingPath(context.Background(), "/tmp/out", "archive", "") + if err != nil { + t.Fatalf("ResolveExistingPath error: %v", err) + } + if decision != PathDecisionNewPath { + t.Fatalf("decision=%v, want %v", decision, PathDecisionNewPath) + } + if newPath != "/tmp/next" { + t.Fatalf("newPath=%q, want %q", newPath, "/tmp/next") + } +} diff --git a/internal/orchestrator/workflow_ui_tui_decrypt.go b/internal/orchestrator/workflow_ui_tui_decrypt.go index 88c83b3f..cae0c3f4 100644 --- a/internal/orchestrator/workflow_ui_tui_decrypt.go +++ b/internal/orchestrator/workflow_ui_tui_decrypt.go @@ -18,7 +18,7 @@ type tuiWorkflowUI struct { configPath string buildSig string logger *logging.Logger - buildPage func(title, configPath, buildSig string, content tview.Primitive) tview.Primitive + buildPage tuiPageBuilder selectedBackupSummary string } @@ -44,6 +44,15 @@ func newTUIRestoreWorkflowUI(configPath, buildSig string, logger *logging.Logger return ui } +func (u *tuiWorkflowUI) screenEnv() tuiScreenEnv { + return tuiScreenEnv{ + configPath: u.configPath, + buildSig: u.buildSig, + logger: u.logger, + buildPage: u.buildPage, + } +} + func (u *tuiWorkflowUI) RunTask(ctx context.Context, title, initialMessage string, run func(ctx context.Context, report ProgressReporter) error) error { if ctx == nil { ctx = context.Background() @@ -55,7 +64,7 @@ func (u *tuiWorkflowUI) RunTask(ctx context.Context, title, initialMessage strin app := newTUIApp() messageView := tview.NewTextView(). - SetText(strings.TrimSpace(initialMessage)). + SetText(tview.Escape(strings.TrimSpace(initialMessage))). SetTextAlign(tview.AlignCenter). SetTextColor(tcell.ColorWhite). SetDynamicColors(true) @@ -85,7 +94,7 @@ func (u *tuiWorkflowUI) RunTask(ctx context.Context, title, initialMessage strin return } app.QueueUpdateDraw(func() { - messageView.SetText(message) + messageView.SetText(tview.Escape(message)) }) } @@ -97,7 +106,8 @@ func (u *tuiWorkflowUI) RunTask(ctx context.Context, title, initialMessage strin }) }() - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(taskCtx); err != nil { cancel() <-done return err @@ -109,18 +119,18 @@ func (u *tuiWorkflowUI) RunTask(ctx context.Context, title, initialMessage strin } func (u *tuiWorkflowUI) ShowMessage(ctx context.Context, title, message string) error { - return u.showOKModal(title, message, tui.ProxmoxOrange) + return u.showOKModal(ctx, title, message, tui.ProxmoxOrange) } func (u *tuiWorkflowUI) ShowError(ctx context.Context, title, message string) error { - return u.showOKModal(title, fmt.Sprintf("%s %s", tui.SymbolError, message), tui.ErrorRed) + return u.showOKModal(ctx, title, fmt.Sprintf("%s %s", tui.SymbolError, message), tui.ErrorRed) } -func (u *tuiWorkflowUI) showOKModal(title, message string, borderColor tcell.Color) error { +func (u *tuiWorkflowUI) showOKModal(ctx context.Context, title, message string, borderColor tcell.Color) error { app := newTUIApp() modal := tview.NewModal(). - SetText(fmt.Sprintf("%s\n\n[yellow]Press ENTER to continue[white]", strings.TrimSpace(message))). + SetText(fmt.Sprintf("%s\n\n[yellow]Press ENTER to continue[white]", tview.Escape(strings.TrimSpace(message)))). AddButtons([]string{"OK"}). SetDoneFunc(func(buttonIndex int, buttonLabel string) { app.Stop() @@ -134,7 +144,8 @@ func (u *tuiWorkflowUI) showOKModal(title, message string, borderColor tcell.Col SetBackgroundColor(tcell.ColorBlack) page := u.buildPage(title, u.configPath, u.buildSig, modal) - return app.SetRoot(page, true).SetFocus(modal).Run() + app.SetRoot(page, true).SetFocus(modal) + return app.RunWithContext(ctx) } func (u *tuiWorkflowUI) SelectBackupSource(ctx context.Context, options []decryptPathOption) (decryptPathOption, error) { @@ -188,7 +199,8 @@ func (u *tuiWorkflowUI) SelectBackupSource(ctx context.Context, options []decryp page := u.buildPage("Select backup source", u.configPath, u.buildSig, form.Form) form.SetParentView(page) - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { return decryptPathOption{}, err } if aborted || strings.TrimSpace(selected.Path) == "" { @@ -323,7 +335,8 @@ func (u *tuiWorkflowUI) SelectBackupCandidate(ctx context.Context, candidates [] page := u.buildPage("Select backup", u.configPath, u.buildSig, form.Form) form.SetParentView(page) - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { return nil, err } if aborted || selected == nil { @@ -365,7 +378,8 @@ func (u *tuiWorkflowUI) PromptDestinationDir(ctx context.Context, defaultDir str page := u.buildPage("Destination directory", u.configPath, u.buildSig, form.Form) form.SetParentView(page) - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { return "", err } if cancelled { @@ -375,83 +389,22 @@ func (u *tuiWorkflowUI) PromptDestinationDir(ctx context.Context, defaultDir str } func (u *tuiWorkflowUI) ResolveExistingPath(ctx context.Context, path, description, failure string) (ExistingPathDecision, string, error) { - action, err := promptOverwriteActionFunc(path, description, failure, u.configPath, u.buildSig) + decision, newPath, err := tuiPromptExistingPathDecision(ctx, u.screenEnv(), path, description, failure) if err != nil { return PathDecisionCancel, "", err } - switch action { - case pathActionOverwrite: - return PathDecisionOverwrite, "", nil - case pathActionNew: - newPath, err := promptNewPathInputFunc(path, u.configPath, u.buildSig) - if err != nil { - return PathDecisionCancel, "", err - } - return PathDecisionNewPath, filepath.Clean(newPath), nil - default: - return PathDecisionCancel, "", ErrDecryptAborted + if decision != PathDecisionNewPath { + return decision, "", nil + } + trimmed := strings.TrimSpace(newPath) + if trimmed == "" { + return decision, "", nil } + return decision, filepath.Clean(trimmed), nil } func (u *tuiWorkflowUI) PromptDecryptSecret(ctx context.Context, displayName, previousError string) (string, error) { - app := newTUIApp() - var ( - secret string - cancelled bool - ) - - name := strings.TrimSpace(displayName) - if name == "" { - name = "selected backup" - } - - infoMessage := fmt.Sprintf("Provide the AGE secret key or passphrase used for [yellow]%s[white].", name) - if strings.TrimSpace(previousError) != "" { - infoMessage = fmt.Sprintf("%s\n\n[red]%s[white]", infoMessage, strings.TrimSpace(previousError)) - } - - infoText := tview.NewTextView(). - SetText(infoMessage). - SetWrap(true). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true) - - form := components.NewForm(app) - label := "Key or passphrase:" - form.AddPasswordField(label, 64) - form.SetOnSubmit(func(values map[string]string) error { - raw := strings.TrimSpace(values[label]) - if raw == "" { - return fmt.Errorf("key or passphrase cannot be empty") - } - if raw == "0" { - cancelled = true - return nil - } - secret = raw - return nil - }) - form.SetOnCancel(func() { - cancelled = true - }) - form.AddSubmitButton("Continue") - form.AddCancelButton("Cancel") - enableFormNavigation(form, nil) - - content := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(infoText, 0, 2, false). - AddItem(form.Form, 0, 1, true) - - page := u.buildPage("Decrypt key", u.configPath, u.buildSig, content) - form.SetParentView(page) - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { - return "", err - } - if cancelled { - return "", ErrDecryptAborted - } - return secret, nil + return tuiPromptDecryptSecret(ctx, u.screenEnv(), displayName, previousError) } func backupSummaryForUI(cand *decryptCandidate) string { diff --git a/internal/orchestrator/workflow_ui_tui_decrypt_prompts.go b/internal/orchestrator/workflow_ui_tui_decrypt_prompts.go new file mode 100644 index 00000000..792ef227 --- /dev/null +++ b/internal/orchestrator/workflow_ui_tui_decrypt_prompts.go @@ -0,0 +1,206 @@ +package orchestrator + +import ( + "context" + "fmt" + "path/filepath" + "strings" + + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" + + "github.com/tis24dev/proxsave/internal/tui" + "github.com/tis24dev/proxsave/internal/tui/components" +) + +var ( + tuiPromptExistingPathDecision = promptExistingPathDecisionTUI + tuiPromptNewPathInput = promptNewPathInputTUI + tuiPromptDecryptSecret = promptDecryptSecretTUI +) + +func promptExistingPathDecisionTUI(ctx context.Context, env tuiScreenEnv, path, description, failureMessage string) (ExistingPathDecision, string, error) { + app := newTUIApp() + decision := PathDecisionCancel + + message := fmt.Sprintf( + "The %s [yellow]%s[white] already exists.\nSelect how you want to proceed.", + tview.Escape(description), + tview.Escape(path), + ) + if strings.TrimSpace(failureMessage) != "" { + message = fmt.Sprintf("%s\n\n[red]%s[white]", message, tview.Escape(strings.TrimSpace(failureMessage))) + } + message += "\n\n[yellow]Use ←→ or TAB to switch buttons | ENTER to confirm[white]" + + modal := tview.NewModal(). + SetText(message). + AddButtons([]string{"Overwrite", "Use different path", "Cancel"}). + SetDoneFunc(func(buttonIndex int, buttonLabel string) { + switch buttonLabel { + case "Overwrite": + decision = PathDecisionOverwrite + case "Use different path": + decision = PathDecisionNewPath + default: + decision = PathDecisionCancel + } + app.Stop() + }) + + modal.SetBorder(true). + SetTitle(" Existing file "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.WarningYellow). + SetBorderColor(tui.WarningYellow). + SetBackgroundColor(tcell.ColorBlack) + + page := env.page("Destination path", modal) + app.SetRoot(page, true).SetFocus(modal) + if err := app.RunWithContext(ctx); err != nil { + return PathDecisionCancel, "", err + } + if decision != PathDecisionNewPath { + return decision, "", nil + } + + newPath, err := tuiPromptNewPathInput(ctx, env, path) + if err != nil { + if err == ErrDecryptAborted { + return PathDecisionCancel, "", nil + } + return PathDecisionCancel, "", err + } + return PathDecisionNewPath, filepath.Clean(newPath), nil +} + +func promptNewPathInputTUI(ctx context.Context, env tuiScreenEnv, defaultPath string) (string, error) { + app := newTUIApp() + var newPath string + var cancelled bool + + form := components.NewForm(app) + label := "New path" + form.AddInputFieldWithValidation(label, defaultPath, 64, func(value string) error { + _, err := validateDistinctNewPathInput(value, defaultPath) + return err + }) + form.SetOnSubmit(func(values map[string]string) error { + trimmed, err := validateDistinctNewPathInput(values[label], defaultPath) + if err != nil { + return err + } + newPath = trimmed + return nil + }) + form.SetOnCancel(func() { + cancelled = true + }) + form.AddSubmitButton("Continue") + form.AddCancelButton("Cancel") + enableFormNavigation(form, nil) + + helper := tview.NewTextView(). + SetText("Provide a writable filesystem path for the decrypted files."). + SetWrap(true). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true) + + content := tview.NewFlex(). + SetDirection(tview.FlexRow). + AddItem(helper, 3, 0, false). + AddItem(form.Form, 0, 1, true) + + page := env.page("Choose destination path", content) + form.SetParentView(page) + + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { + return "", err + } + if cancelled { + return "", ErrDecryptAborted + } + return newPath, nil +} + +func validateDistinctNewPathInput(value, defaultPath string) (string, error) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "", fmt.Errorf("path cannot be empty") + } + + trimmedDefault := strings.TrimSpace(defaultPath) + if trimmedDefault != "" && filepath.Clean(trimmed) == filepath.Clean(trimmedDefault) { + return "", fmt.Errorf("path must be different from existing path") + } + + return trimmed, nil +} + +func promptDecryptSecretTUI(ctx context.Context, env tuiScreenEnv, displayName, previousError string) (string, error) { + app := newTUIApp() + var ( + secret string + cancelled bool + ) + + name := strings.TrimSpace(displayName) + if name == "" { + name = "selected backup" + } + + infoMessage := fmt.Sprintf( + "Provide the AGE secret key or passphrase used for [yellow]%s[white].\n\n"+ + "Enter [yellow]0[white] to exit or use [yellow]Cancel[white].", + tview.Escape(name), + ) + if strings.TrimSpace(previousError) != "" { + infoMessage = fmt.Sprintf("%s\n\n[red]%s[white]", infoMessage, tview.Escape(strings.TrimSpace(previousError))) + } + + infoText := tview.NewTextView(). + SetText(infoMessage). + SetWrap(true). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true) + + form := components.NewForm(app) + label := "Key or passphrase:" + form.AddPasswordField(label, 64) + form.SetOnSubmit(func(values map[string]string) error { + raw := values[label] + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return fmt.Errorf("key or passphrase cannot be empty") + } + if trimmed == "0" { + cancelled = true + return nil + } + secret = raw + return nil + }) + form.SetOnCancel(func() { + cancelled = true + }) + form.AddSubmitButton("Continue") + form.AddCancelButton("Cancel") + enableFormNavigation(form, nil) + + content := tview.NewFlex(). + SetDirection(tview.FlexRow). + AddItem(infoText, 0, 2, false). + AddItem(form.Form, 0, 1, true) + + page := env.page("Decrypt key", content) + form.SetParentView(page) + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { + return "", err + } + if cancelled { + return "", ErrDecryptAborted + } + return secret, nil +} diff --git a/internal/orchestrator/workflow_ui_tui_decrypt_test.go b/internal/orchestrator/workflow_ui_tui_decrypt_test.go new file mode 100644 index 00000000..6517594e --- /dev/null +++ b/internal/orchestrator/workflow_ui_tui_decrypt_test.go @@ -0,0 +1,254 @@ +package orchestrator + +import ( + "context" + "errors" + "path/filepath" + "testing" + + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" +) + +func stubTUIExistingPathDecisionPrompt(fn func(ctx context.Context, env tuiScreenEnv, path, description, failure string) (ExistingPathDecision, string, error)) func() { + orig := tuiPromptExistingPathDecision + tuiPromptExistingPathDecision = fn + return func() { tuiPromptExistingPathDecision = orig } +} + +func TestTUIWorkflowUIResolveExistingPath_Overwrite(t *testing.T) { + restore := stubTUIExistingPathDecisionPrompt(func(ctx context.Context, env tuiScreenEnv, path, description, failure string) (ExistingPathDecision, string, error) { + if path != "/tmp/archive.tar" { + t.Fatalf("path=%q, want /tmp/archive.tar", path) + } + if description != "archive" { + t.Fatalf("description=%q, want archive", description) + } + if env.configPath != "/tmp/config.env" { + t.Fatalf("configPath=%q, want /tmp/config.env", env.configPath) + } + if env.buildSig != "sig" { + t.Fatalf("buildSig=%q, want sig", env.buildSig) + } + return PathDecisionOverwrite, "", nil + }) + defer restore() + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + decision, newPath, err := ui.ResolveExistingPath(context.Background(), "/tmp/archive.tar", "archive", "") + if err != nil { + t.Fatalf("ResolveExistingPath error: %v", err) + } + if decision != PathDecisionOverwrite { + t.Fatalf("decision=%v, want %v", decision, PathDecisionOverwrite) + } + if newPath != "" { + t.Fatalf("newPath=%q, want empty", newPath) + } +} + +func TestTUIWorkflowUIResolveExistingPath_NewPathIsCleaned(t *testing.T) { + restore := stubTUIExistingPathDecisionPrompt(func(ctx context.Context, env tuiScreenEnv, path, description, failure string) (ExistingPathDecision, string, error) { + return PathDecisionNewPath, "/tmp/out/../out/final.tar", nil + }) + defer restore() + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + decision, newPath, err := ui.ResolveExistingPath(context.Background(), "/tmp/archive.tar", "archive", "") + if err != nil { + t.Fatalf("ResolveExistingPath error: %v", err) + } + if decision != PathDecisionNewPath { + t.Fatalf("decision=%v, want %v", decision, PathDecisionNewPath) + } + if newPath != filepath.Clean("/tmp/out/../out/final.tar") { + t.Fatalf("newPath=%q, want %q", newPath, filepath.Clean("/tmp/out/../out/final.tar")) + } +} + +func TestTUIWorkflowUIResolveExistingPath_WhitespaceNewPathStaysEmpty(t *testing.T) { + restore := stubTUIExistingPathDecisionPrompt(func(ctx context.Context, env tuiScreenEnv, path, description, failure string) (ExistingPathDecision, string, error) { + return PathDecisionNewPath, " \t ", nil + }) + defer restore() + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + decision, newPath, err := ui.ResolveExistingPath(context.Background(), "/tmp/archive.tar", "archive", "") + if err != nil { + t.Fatalf("ResolveExistingPath error: %v", err) + } + if decision != PathDecisionNewPath { + t.Fatalf("decision=%v, want %v", decision, PathDecisionNewPath) + } + if newPath != "" { + t.Fatalf("newPath=%q, want empty", newPath) + } +} + +func TestTUIWorkflowUIResolveExistingPath_PropagatesError(t *testing.T) { + wantErr := errors.New("boom") + restore := stubTUIExistingPathDecisionPrompt(func(ctx context.Context, env tuiScreenEnv, path, description, failure string) (ExistingPathDecision, string, error) { + return PathDecisionCancel, "", wantErr + }) + defer restore() + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + if _, _, err := ui.ResolveExistingPath(context.Background(), "/tmp/archive.tar", "archive", ""); !errors.Is(err, wantErr) { + t.Fatalf("expected %v, got %v", wantErr, err) + } +} + +func TestTUIWorkflowUIResolveExistingPath_PassesContext(t *testing.T) { + called := false + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + restore := stubTUIExistingPathDecisionPrompt(func(gotCtx context.Context, env tuiScreenEnv, path, description, failure string) (ExistingPathDecision, string, error) { + called = true + if gotCtx != ctx { + t.Fatalf("got context %p, want %p", gotCtx, ctx) + } + return PathDecisionOverwrite, "", nil + }) + defer restore() + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + if _, _, err := ui.ResolveExistingPath(ctx, "/tmp/archive.tar", "archive", ""); err != nil { + t.Fatalf("ResolveExistingPath error: %v", err) + } + if !called { + t.Fatalf("expected prompt to be called") + } +} + +func TestTUIRestoreWorkflowUIResolveExistingPath_PassesBuilder(t *testing.T) { + builderCalls := 0 + restore := stubTUIExistingPathDecisionPrompt(func(ctx context.Context, env tuiScreenEnv, path, description, failure string) (ExistingPathDecision, string, error) { + if page := env.page("Spy", tview.NewBox()); page == nil { + t.Fatalf("expected non-nil page") + } + return PathDecisionOverwrite, "", nil + }) + defer restore() + + ui := newTUIRestoreWorkflowUI("/tmp/config.env", "sig", nil) + ui.buildPage = func(title, configPath, buildSig string, content tview.Primitive) tview.Primitive { + builderCalls++ + return tview.NewBox() + } + + if _, _, err := ui.ResolveExistingPath(context.Background(), "/tmp/archive.tar", "archive", ""); err != nil { + t.Fatalf("ResolveExistingPath error: %v", err) + } + if builderCalls != 1 { + t.Fatalf("builderCalls=%d, want 1", builderCalls) + } +} + +func stubTUIDecryptSecretPrompt(fn func(ctx context.Context, env tuiScreenEnv, displayName, previousError string) (string, error)) func() { + orig := tuiPromptDecryptSecret + tuiPromptDecryptSecret = fn + return func() { tuiPromptDecryptSecret = orig } +} + +func stubTUINewPathInputPrompt(fn func(ctx context.Context, env tuiScreenEnv, defaultPath string) (string, error)) func() { + orig := tuiPromptNewPathInput + tuiPromptNewPathInput = fn + return func() { tuiPromptNewPathInput = orig } +} + +func TestTUIWorkflowUIPromptDecryptSecret_PassesContext(t *testing.T) { + called := false + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + restore := stubTUIDecryptSecretPrompt(func(gotCtx context.Context, env tuiScreenEnv, displayName, previousError string) (string, error) { + called = true + if gotCtx != ctx { + t.Fatalf("got context %p, want %p", gotCtx, ctx) + } + return "secret", nil + }) + defer restore() + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + got, err := ui.PromptDecryptSecret(ctx, "archive", "") + if err != nil { + t.Fatalf("PromptDecryptSecret error: %v", err) + } + if got != "secret" { + t.Fatalf("secret=%q, want %q", got, "secret") + } + if !called { + t.Fatalf("expected prompt to be called") + } +} + +func TestTUIRestoreWorkflowUIPromptDecryptSecret_PassesBuilder(t *testing.T) { + builderCalls := 0 + restore := stubTUIDecryptSecretPrompt(func(ctx context.Context, env tuiScreenEnv, displayName, previousError string) (string, error) { + if page := env.page("Spy", tview.NewBox()); page == nil { + t.Fatalf("expected non-nil page") + } + return "secret", nil + }) + defer restore() + + ui := newTUIRestoreWorkflowUI("/tmp/config.env", "sig", nil) + ui.buildPage = func(title, configPath, buildSig string, content tview.Primitive) tview.Primitive { + builderCalls++ + return tview.NewBox() + } + + got, err := ui.PromptDecryptSecret(context.Background(), "archive", "") + if err != nil { + t.Fatalf("PromptDecryptSecret error: %v", err) + } + if got != "secret" { + t.Fatalf("secret=%q, want %q", got, "secret") + } + if builderCalls != 1 { + t.Fatalf("builderCalls=%d, want 1", builderCalls) + } +} + +func TestTUIWorkflowUIPromptDestinationDir_ContinueReturnsCleanPath(t *testing.T) { + withSimApp(t, []tcell.Key{tcell.KeyTab, tcell.KeyEnter}) + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + got, err := ui.PromptDestinationDir(context.Background(), "/tmp/out/../out") + if err != nil { + t.Fatalf("PromptDestinationDir error: %v", err) + } + if got != "/tmp/out" { + t.Fatalf("destination=%q, want %q", got, "/tmp/out") + } +} + +func TestTUIWorkflowUIPromptDestinationDir_CancelReturnsAborted(t *testing.T) { + withSimApp(t, []tcell.Key{tcell.KeyTab, tcell.KeyTab, tcell.KeyEnter}) + + ui := newTUIWorkflowUI("/tmp/config.env", "sig", nil) + _, err := ui.PromptDestinationDir(context.Background(), "/tmp/out") + if !errors.Is(err, ErrDecryptAborted) { + t.Fatalf("err=%v, want %v", err, ErrDecryptAborted) + } +} + +func TestValidateDistinctNewPathInputRejectsEquivalentNormalizedPath(t *testing.T) { + _, err := validateDistinctNewPathInput("/tmp/out/", "/tmp/out") + if err == nil { + t.Fatalf("expected validation error") + } + if err.Error() != "path must be different from existing path" { + t.Fatalf("err=%q, want %q", err.Error(), "path must be different from existing path") + } +} + +func TestValidateDistinctNewPathInputAcceptsDifferentPath(t *testing.T) { + got, err := validateDistinctNewPathInput(" /tmp/out/alt ", "/tmp/out") + if err != nil { + t.Fatalf("validateDistinctNewPathInput error: %v", err) + } + if got != "/tmp/out/alt" { + t.Fatalf("path=%q, want %q", got, "/tmp/out/alt") + } +} diff --git a/internal/orchestrator/workflow_ui_tui_restore.go b/internal/orchestrator/workflow_ui_tui_restore.go index 350bd829..b9a18d47 100644 --- a/internal/orchestrator/workflow_ui_tui_restore.go +++ b/internal/orchestrator/workflow_ui_tui_restore.go @@ -14,31 +14,31 @@ import ( ) func (u *tuiWorkflowUI) SelectRestoreMode(ctx context.Context, systemType SystemType) (RestoreMode, error) { - return selectRestoreModeTUI(systemType, u.configPath, u.buildSig, strings.TrimSpace(u.selectedBackupSummary)) + return selectRestoreModeTUI(ctx, systemType, u.configPath, u.buildSig, strings.TrimSpace(u.selectedBackupSummary)) } func (u *tuiWorkflowUI) SelectCategories(ctx context.Context, available []Category, systemType SystemType) ([]Category, error) { - return selectCategoriesTUI(available, systemType, u.configPath, u.buildSig) + return selectCategoriesTUI(ctx, available, systemType, u.configPath, u.buildSig) } func (u *tuiWorkflowUI) SelectPBSRestoreBehavior(ctx context.Context) (PBSRestoreBehavior, error) { - return selectPBSRestoreBehaviorTUI(u.configPath, u.buildSig, strings.TrimSpace(u.selectedBackupSummary)) + return selectPBSRestoreBehaviorTUI(ctx, u.configPath, u.buildSig, strings.TrimSpace(u.selectedBackupSummary)) } func (u *tuiWorkflowUI) ShowRestorePlan(ctx context.Context, config *SelectiveRestoreConfig) error { - return showRestorePlanTUI(config, u.configPath, u.buildSig) + return showRestorePlanTUI(ctx, config, u.configPath, u.buildSig) } func (u *tuiWorkflowUI) ConfirmRestore(ctx context.Context) (bool, error) { - return confirmRestoreTUI(u.configPath, u.buildSig) + return confirmRestoreTUI(ctx, u.configPath, u.buildSig) } func (u *tuiWorkflowUI) ConfirmCompatibility(ctx context.Context, warning error) (bool, error) { - return promptCompatibilityTUI(u.configPath, u.buildSig, warning) + return promptCompatibilityTUI(ctx, u.configPath, u.buildSig, warning) } func (u *tuiWorkflowUI) SelectClusterRestoreMode(ctx context.Context) (ClusterRestoreMode, error) { - choice, err := promptClusterRestoreModeTUI(u.configPath, u.buildSig) + choice, err := promptClusterRestoreModeTUI(ctx, u.configPath, u.buildSig) if err != nil { return ClusterRestoreAbort, err } @@ -53,11 +53,11 @@ func (u *tuiWorkflowUI) SelectClusterRestoreMode(ctx context.Context) (ClusterRe } func (u *tuiWorkflowUI) ConfirmContinueWithoutSafetyBackup(ctx context.Context, cause error) (bool, error) { - return promptContinueWithoutSafetyBackupTUI(u.configPath, u.buildSig, cause) + return promptContinueWithoutSafetyBackupTUI(ctx, u.configPath, u.buildSig, cause) } func (u *tuiWorkflowUI) ConfirmContinueWithPBSServicesRunning(ctx context.Context) (bool, error) { - return promptContinueWithPBSServicesTUI(u.configPath, u.buildSig) + return promptContinueWithPBSServicesTUI(ctx, u.configPath, u.buildSig) } func (u *tuiWorkflowUI) ConfirmFstabMerge(ctx context.Context, title, message string, timeout time.Duration, defaultYes bool) (bool, error) { @@ -118,10 +118,11 @@ func (u *tuiWorkflowUI) SelectExportNode(ctx context.Context, exportRoot, curren form.AddCancelButton("Cancel") enableFormNavigation(form, nil) - page := buildRestoreWizardPage("Select export node", u.configPath, u.buildSig, form.Form) + page := u.buildPage("Select export node", u.configPath, u.buildSig, form.Form) form.SetParentView(page) - if err := app.SetRoot(page, true).SetFocus(form.Form).Run(); err != nil { + app.SetRoot(page, true).SetFocus(form.Form) + if err := app.RunWithContext(ctx); err != nil { return "", err } if cancelled { @@ -139,15 +140,15 @@ func (u *tuiWorkflowUI) ConfirmApplyVMConfigs(ctx context.Context, sourceNode, c } else { message = fmt.Sprintf("Found %d VM/CT configs for exported node %s.\nThey will be applied to current node %s.\n\nApply them via pvesh now?", count, sourceNode, currentNode) } - return promptYesNoTUIFunc("Apply VM/CT configs", u.configPath, u.buildSig, message, "Apply via API", "Skip") + return promptYesNoTUIFunc(ctx, "Apply VM/CT configs", u.configPath, u.buildSig, message, "Apply via API", "Skip") } func (u *tuiWorkflowUI) ConfirmApplyStorageCfg(ctx context.Context, storageCfgPath string) (bool, error) { message := fmt.Sprintf("Storage configuration found:\n\n%s\n\nApply storage.cfg via pvesh now?", strings.TrimSpace(storageCfgPath)) - return promptYesNoTUIFunc("Apply storage.cfg", u.configPath, u.buildSig, message, "Apply via API", "Skip") + return promptYesNoTUIFunc(ctx, "Apply storage.cfg", u.configPath, u.buildSig, message, "Apply via API", "Skip") } func (u *tuiWorkflowUI) ConfirmApplyDatacenterCfg(ctx context.Context, datacenterCfgPath string) (bool, error) { message := fmt.Sprintf("Datacenter configuration found:\n\n%s\n\nApply datacenter.cfg via pvesh now?", strings.TrimSpace(datacenterCfgPath)) - return promptYesNoTUIFunc("Apply datacenter.cfg", u.configPath, u.buildSig, message, "Apply via API", "Skip") + return promptYesNoTUIFunc(ctx, "Apply datacenter.cfg", u.configPath, u.buildSig, message, "Apply via API", "Skip") } diff --git a/internal/orchestrator/workflow_ui_tui_restore_test.go b/internal/orchestrator/workflow_ui_tui_restore_test.go new file mode 100644 index 00000000..c82d320a --- /dev/null +++ b/internal/orchestrator/workflow_ui_tui_restore_test.go @@ -0,0 +1,31 @@ +package orchestrator + +import ( + "context" + "testing" + + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" +) + +func TestTUIRestoreWorkflowUISelectExportNode_UsesBuildPage(t *testing.T) { + withSimApp(t, []tcell.Key{tcell.KeyEnter}) + + ui := newTUIRestoreWorkflowUI("/tmp/config.env", "sig", nil) + builderCalls := 0 + ui.buildPage = func(title, configPath, buildSig string, content tview.Primitive) tview.Primitive { + builderCalls++ + return buildRestoreWizardPage(title, configPath, buildSig, content) + } + + got, err := ui.SelectExportNode(context.Background(), t.TempDir(), "node0", []string{"node1"}) + if err != nil { + t.Fatalf("SelectExportNode error: %v", err) + } + if got != "node1" { + t.Fatalf("node=%q, want %q", got, "node1") + } + if builderCalls != 1 { + t.Fatalf("builderCalls=%d, want 1", builderCalls) + } +} diff --git a/internal/orchestrator/workflow_ui_tui_shared.go b/internal/orchestrator/workflow_ui_tui_shared.go new file mode 100644 index 00000000..2f8cd448 --- /dev/null +++ b/internal/orchestrator/workflow_ui_tui_shared.go @@ -0,0 +1,47 @@ +package orchestrator + +import ( + "github.com/gdamore/tcell/v2" + + "github.com/tis24dev/proxsave/internal/tui/components" +) + +func enableFormNavigation(form *components.Form, dropdownOpen *bool) { + if form == nil || form.Form == nil { + return + } + form.Form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event == nil { + return event + } + if dropdownOpen != nil && *dropdownOpen { + return event + } + + formItemIndex, buttonIndex := form.Form.GetFocusedItemIndex() + isOnButton := formItemIndex < 0 && buttonIndex >= 0 + isOnField := formItemIndex >= 0 + + if isOnButton { + switch event.Key() { + case tcell.KeyLeft, tcell.KeyUp: + return tcell.NewEventKey(tcell.KeyBacktab, 0, tcell.ModNone) + case tcell.KeyRight, tcell.KeyDown: + return tcell.NewEventKey(tcell.KeyTab, 0, tcell.ModNone) + } + } else if isOnField { + // If focused item is a ListFormItem, let it handle navigation internally. + if _, ok := form.Form.GetFormItem(formItemIndex).(*components.ListFormItem); ok { + return event + } + // For other form fields, convert arrows to tab navigation. + switch event.Key() { + case tcell.KeyUp: + return tcell.NewEventKey(tcell.KeyBacktab, 0, tcell.ModNone) + case tcell.KeyDown: + return tcell.NewEventKey(tcell.KeyTab, 0, tcell.ModNone) + } + } + return event + }) +} diff --git a/internal/safefs/safefs.go b/internal/safefs/safefs.go index 1ca2bed5..61c8c806 100644 --- a/internal/safefs/safefs.go +++ b/internal/safefs/safefs.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io/fs" + "math" "os" "syscall" "time" @@ -172,3 +173,35 @@ func Statfs(ctx context.Context, path string, timeout time.Duration) (syscall.St return stat, err }) } + +// SpaceUsageFromStatfs converts statfs counters into total, user-available, and +// actually-used byte counts. "Available" tracks Bavail (space a non-root user can +// allocate), while "used" tracks Blocks-Bfree (space already consumed). +func SpaceUsageFromStatfs(stat syscall.Statfs_t) (totalBytes, availableBytes, usedBytes int64) { + totalBytes = statfsBlocksToBytes(stat.Blocks, stat.Bsize) + availableBytes = statfsBlocksToBytes(stat.Bavail, stat.Bsize) + + if stat.Blocks > stat.Bfree { + usedBytes = statfsBlocksToBytes(stat.Blocks-stat.Bfree, stat.Bsize) + } + if availableBytes > totalBytes { + availableBytes = totalBytes + } + if usedBytes > totalBytes { + usedBytes = totalBytes + } + + return totalBytes, availableBytes, usedBytes +} + +func statfsBlocksToBytes(blocks uint64, blockSize int64) int64 { + if blocks == 0 || blockSize <= 0 { + return 0 + } + + size := uint64(blockSize) + if blocks > uint64(math.MaxInt64)/size { + return math.MaxInt64 + } + return int64(blocks * size) +} diff --git a/internal/safefs/safefs_test.go b/internal/safefs/safefs_test.go index 27d87b82..97c0224b 100644 --- a/internal/safefs/safefs_test.go +++ b/internal/safefs/safefs_test.go @@ -3,6 +3,7 @@ package safefs import ( "context" "errors" + "math" "os" "sync/atomic" "syscall" @@ -153,6 +154,87 @@ func TestStatfs_ReturnsTimeoutError(t *testing.T) { } } +func TestSpaceUsageFromStatfsUsesBfreeForUsedSpace(t *testing.T) { + stat := syscall.Statfs_t{ + Blocks: 100, + Bfree: 20, + Bavail: 15, + Bsize: 4096, + } + + total, available, used := SpaceUsageFromStatfs(stat) + + if total != 100*4096 { + t.Fatalf("total = %d; want %d", total, 100*4096) + } + if available != 15*4096 { + t.Fatalf("available = %d; want %d", available, 15*4096) + } + if used != 80*4096 { + t.Fatalf("used = %d; want %d", used, 80*4096) + } + if used == total-available { + t.Fatalf("used should not be derived from Bavail when reserved blocks exist") + } +} + +func TestSpaceUsageFromStatfsClampsInconsistentCounters(t *testing.T) { + stat := syscall.Statfs_t{ + Blocks: 100, + Bfree: 150, + Bavail: 125, + Bsize: 1024, + } + + total, available, used := SpaceUsageFromStatfs(stat) + + if total != 100*1024 { + t.Fatalf("total = %d; want %d", total, 100*1024) + } + if available != total { + t.Fatalf("available = %d; want clamp to total %d", available, total) + } + if used != 0 { + t.Fatalf("used = %d; want 0", used) + } +} + +func TestSpaceUsageFromStatfsClampsNegativeByteCounts(t *testing.T) { + stat := syscall.Statfs_t{ + Blocks: 10, + Bfree: 2, + Bavail: 1, + Bsize: -4096, + } + + total, available, used := SpaceUsageFromStatfs(stat) + + if total != 0 || available != 0 || used != 0 { + t.Fatalf("negative byte counts should clamp to zero, got total=%d available=%d used=%d", total, available, used) + } +} + +func TestSpaceUsageFromStatfsSaturatesOverflowingProducts(t *testing.T) { + stat := syscall.Statfs_t{ + Blocks: 1<<63 - 1, + Bfree: 0, + Bavail: 1<<63 - 1, + Bsize: 4096, + } + + total, available, used := SpaceUsageFromStatfs(stat) + + if total != math.MaxInt64 { + t.Fatalf("total = %d; want %d", total, math.MaxInt64) + } + if available != math.MaxInt64 { + t.Fatalf("available = %d; want %d", available, math.MaxInt64) + } + if used != math.MaxInt64 { + t.Fatalf("used = %d; want %d", used, math.MaxInt64) + } +} + func TestStat_PropagatesContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() diff --git a/internal/security/security.go b/internal/security/security.go index 9edd76ad..caa80bb7 100644 --- a/internal/security/security.go +++ b/internal/security/security.go @@ -330,6 +330,18 @@ func (c *Checker) verifyBinaryIntegrity() { return } + info, err := os.Lstat(c.execPath) + if err != nil { + c.addError("Cannot stat executable %s: %v", c.execPath, err) + return + } + + if info.Mode()&os.ModeSymlink != 0 { + c.addError("Executable %s is a symlink", c.execPath) + return + } + + hashFile := c.execPath + ".md5" f, err := os.Open(c.execPath) if err != nil { c.addError("Cannot open executable %s: %v", c.execPath, err) @@ -337,31 +349,23 @@ func (c *Checker) verifyBinaryIntegrity() { } defer f.Close() - info, err := f.Stat() + openedInfo, err := f.Stat() if err != nil { - c.addError("Cannot stat executable %s: %v", c.execPath, err) + c.addError("Cannot stat opened executable %s: %v", c.execPath, err) return } - - if info.Mode()&os.ModeSymlink != 0 { - c.addError("Executable %s is a symlink", c.execPath) + if !os.SameFile(info, openedInfo) { + c.addError("Executable %s changed during integrity check; aborting", c.execPath) return } + openedInfo = c.ensureOwnershipAndPermFromFD(f, openedInfo, 0o700, fmt.Sprintf("Executable %s", c.execPath)) - c.ensureOwnershipAndPerm(c.execPath, info, 0o700, fmt.Sprintf("Executable %s", c.execPath)) - - hashFile := c.execPath + ".md5" currentHash, err := checksumReader(f) if err != nil { c.addWarning("Unable to calculate hash for %s: %v", c.execPath, err) return } - if _, err := f.Seek(0, io.SeekStart); err != nil { - c.addWarning("Unable to rewind file for %s: %v", c.execPath, err) - return - } - if _, err := os.Stat(hashFile); errors.Is(err, os.ErrNotExist) { if c.cfg.AutoUpdateHashes { if err := os.WriteFile(hashFile, []byte(currentHash), 0o600); err != nil { @@ -1038,6 +1042,58 @@ func (c *Checker) ensureOwnershipAndPerm(path string, info os.FileInfo, expected return info } +func (c *Checker) ensureOwnershipAndPermFromFD(f *os.File, info os.FileInfo, expectedPerm os.FileMode, description string) os.FileInfo { + if f == nil { + return nil + } + + path := f.Name() + if path == "" { + path = "" + } + + var err error + if info == nil { + info, err = f.Stat() + if err != nil { + c.addWarning("Cannot stat %s: %v", path, err) + return nil + } + } + + if expectedPerm != 0 { + if perm := info.Mode().Perm(); perm != expectedPerm { + c.bannerWarning(fmt.Sprintf("incorrect permissions on %s (current %o, expected %o)", path, perm, expectedPerm)) + if c.cfg.AutoFixPermissions { + if err := syscall.Fchmod(int(f.Fd()), uint32(expectedPerm)); err != nil { + c.addWarning("Failed to adjust permissions on %s: %v", path, err) + } else { + c.logger.Info("Adjusted permissions on %s to %o", path, expectedPerm) + info, _ = f.Stat() + } + } else { + c.addWarning("%s should have permissions %o (current %o)", description, expectedPerm, perm) + } + } + } + + if info != nil && !isOwnedByRoot(info) { + c.bannerWarning(fmt.Sprintf("incorrect ownership on %s (required root:root)", path)) + if c.cfg.AutoFixPermissions { + if err := syscall.Fchown(int(f.Fd()), 0, 0); err != nil { + c.addWarning("Failed to set ownership root:root on %s: %v", path, err) + } else { + c.logger.Info("Adjusted ownership on %s to root:root", path) + info, _ = f.Stat() + } + } else { + c.addWarning("%s should be owned by root:root", description) + } + } + + return info +} + var kernelProcessPrefixes = []string{ "kworker", "kthreadd", "kswapd", "rcu_", "migration", "watchdog", "ksoftirqd", "khugepaged", "kcompactd", "khubd", "kdevtmpfs", "netns", "writeback", "crypto", "bioset", "kblockd", diff --git a/internal/security/security_test.go b/internal/security/security_test.go index 09998d0e..8876062f 100644 --- a/internal/security/security_test.go +++ b/internal/security/security_test.go @@ -1251,6 +1251,41 @@ func TestEnsureOwnershipAndPermAutoFix(t *testing.T) { } } +func TestEnsureOwnershipAndPermFromFDAutoFix(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "testfile") + if err := os.WriteFile(testFile, []byte("test"), 0777); err != nil { + t.Fatal(err) + } + + f, err := os.Open(testFile) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + info, err := f.Stat() + if err != nil { + t.Fatal(err) + } + + checker := &Checker{ + logger: newSecurityTestLogger(), + cfg: &config.Config{AutoFixPermissions: true}, + result: &Result{}, + } + + checker.ensureOwnershipAndPermFromFD(f, info, 0600, "test file") + + refreshed, err := os.Stat(testFile) + if err != nil { + t.Fatal(err) + } + if refreshed.Mode().Perm() != 0600 { + t.Errorf("permissions should have been fixed to 0600, got %o", refreshed.Mode().Perm()) + } +} + func TestEnsureOwnershipAndPermSymlink(t *testing.T) { tmpDir := t.TempDir() targetFile := filepath.Join(tmpDir, "target") @@ -1531,8 +1566,6 @@ func TestVerifyBinaryIntegritySymlinkError(t *testing.T) { t.Fatal(err) } - // Note: The current implementation checks Mode()&os.ModeSymlink after os.Open - // which doesn't detect symlinks properly. This test documents the behavior. checker := &Checker{ logger: newSecurityTestLogger(), cfg: &config.Config{AutoUpdateHashes: true}, @@ -1542,8 +1575,12 @@ func TestVerifyBinaryIntegritySymlinkError(t *testing.T) { checker.verifyBinaryIntegrity() - // The function opens the file and then stats - symlink is followed by Open - // This is expected behavior given the current implementation + if !containsIssue(checker.result, "is a symlink") { + t.Fatalf("expected symlink error, issues=%+v", checker.result.Issues) + } + if _, err := os.Stat(symlinkFile + ".md5"); err == nil { + t.Fatal("hash file should not be created for symlink executable") + } } func TestVerifyBinaryIntegrityOpenError(t *testing.T) { @@ -1556,8 +1593,8 @@ func TestVerifyBinaryIntegrityOpenError(t *testing.T) { checker.verifyBinaryIntegrity() - if !containsIssue(checker.result, "Cannot open executable") { - t.Errorf("expected error about cannot open executable, got %+v", checker.result.Issues) + if !containsIssue(checker.result, "Cannot stat executable") { + t.Errorf("expected error about cannot stat executable, got %+v", checker.result.Issues) } } diff --git a/internal/storage/backup_files.go b/internal/storage/backup_files.go index 05d8ab61..2aaa0c4b 100644 --- a/internal/storage/backup_files.go +++ b/internal/storage/backup_files.go @@ -2,15 +2,24 @@ package storage import "strings" +const bundleSuffix = ".bundle.tar" + // trimBundleSuffix removes the .bundle.tar suffix from a path if present. // It returns the trimmed path and whether the suffix was removed. func trimBundleSuffix(path string) (string, bool) { - if strings.HasSuffix(path, ".bundle.tar") { - return strings.TrimSuffix(path, ".bundle.tar"), true + if strings.HasSuffix(path, bundleSuffix) { + return strings.TrimSuffix(path, bundleSuffix), true } return path, false } +// bundlePathFor returns the canonical bundle path for either a raw archive path +// or a path that already points to a bundle. +func bundlePathFor(path string) string { + base, _ := trimBundleSuffix(path) + return base + bundleSuffix +} + // buildBackupCandidatePaths returns the list of files that belong to a backup. // When includeBundle is true, both the bundle and the legacy single-file layout // are included so retention can clean up either form. @@ -29,8 +38,9 @@ func buildBackupCandidatePaths(base string, includeBundle bool) []string { files := make([]string, 0, 5) if includeBundle { - if add(base + ".bundle.tar") { - files = append(files, base+".bundle.tar") + bundlePath := bundlePathFor(base) + if add(bundlePath) { + files = append(files, bundlePath) } } candidates := []string{ diff --git a/internal/storage/cloud.go b/internal/storage/cloud.go index e7cd1898..17e07ed2 100644 --- a/internal/storage/cloud.go +++ b/internal/storage/cloud.go @@ -25,8 +25,17 @@ import ( const ( cloudUploadModeSequential = "sequential" cloudUploadModeParallel = "parallel" + cloudRetryBackoffMax = 30 * time.Second ) +var cloudRetryBackoffSchedule = [...]time.Duration{ + 2 * time.Second, + 4 * time.Second, + 8 * time.Second, + 16 * time.Second, + cloudRetryBackoffMax, +} + type CloudStorage struct { config *config.Config logger *logging.Logger @@ -37,6 +46,7 @@ type CloudStorage struct { parallelVerify bool execCommand func(ctx context.Context, name string, args ...string) ([]byte, error) lookPath func(string) (string, error) + waitForRetry func(context.Context, time.Duration) error sleep func(time.Duration) lastRet RetentionSummary remoteFilesMu sync.RWMutex @@ -111,6 +121,40 @@ func remoteBaseName(ref string) string { return path.Base(trimmed) } +// Bound exponential retry delays so large attempt counts stay safe and predictable. +func cloudRetryBackoff(attempt int) time.Duration { + if attempt <= 0 { + return 0 + } + + index := attempt - 1 + if index >= len(cloudRetryBackoffSchedule) { + return cloudRetryBackoffSchedule[len(cloudRetryBackoffSchedule)-1] + } + return cloudRetryBackoffSchedule[index] +} + +func waitForRetryContext(ctx context.Context, d time.Duration) error { + if d <= 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } + } + + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + // NewCloudStorage creates a new cloud storage instance func NewCloudStorage(cfg *config.Config, logger *logging.Logger) (*CloudStorage, error) { // Normalize CloudRemote and CloudRemotePath into: @@ -143,6 +187,7 @@ func NewCloudStorage(cfg *config.Config, logger *logging.Logger) (*CloudStorage, parallelVerify: cfg.CloudParallelVerify, execCommand: defaultExecCommand, lookPath: exec.LookPath, + waitForRetry: waitForRetryContext, sleep: time.Sleep, }, nil } @@ -314,6 +359,10 @@ func (c *CloudStorage) checkRemoteAccessible(ctx context.Context) error { lastErr = err + if err := ctx.Err(); err != nil { + return err + } + // If the context timed out, wrap as timeout error if timeoutCtx.Err() == context.DeadlineExceeded { return &remoteCheckError{ @@ -324,11 +373,23 @@ func (c *CloudStorage) checkRemoteAccessible(ctx context.Context) error { } if attempt < maxAttempts { - // Exponential backoff: 2s, 4s, 8s, ... - waitTime := time.Duration(1< %s (timeout: %ds)", filename, - utils.FormatBytes(stat.Size()), + utils.FormatBytes(primaryStat.Size()), c.remoteLabel(), c.config.RcloneTimeoutOperation) c.logger.Debug("Cloud storage: upload retries=%d threads=%d bwlimit=%s", @@ -606,7 +686,7 @@ func (c *CloudStorage) Store(ctx context.Context, backupFile string, metadata *t tasks := make([]uploadTask, 0, 4) tasks = append(tasks, uploadTask{ - local: backupFile, + local: primaryFile, remote: remoteFile, verify: true, }) @@ -628,16 +708,6 @@ func (c *CloudStorage) Store(ctx context.Context, backupFile string, metadata *t verify: c.parallelVerify, }) } - } else { - // Upload bundle file - bundleFile := backupFile + ".bundle.tar" - if _, err := os.Stat(bundleFile); err == nil { - tasks = append(tasks, uploadTask{ - local: bundleFile, - remote: c.remotePathFor(filepath.Base(bundleFile)), - verify: c.parallelVerify, - }) - } } logging.DebugStep(c.logger, "cloud store", "upload tasks=%d mode=%s", len(tasks), c.uploadMode) @@ -723,11 +793,13 @@ func (c *CloudStorage) uploadWithRetry(ctx context.Context, localFile, remoteFil break } - // Wait before retry (exponential backoff) + // Keep retry delays bounded and avoid shift/multiplication overflow. if attempt < c.config.RcloneRetries { - waitTime := time.Duration(1<= 200*time.Millisecond { + t.Fatalf("waitForRetryContext() returned after %v, want prompt cancellation", elapsed) + } + + <-done +} + func TestCloudStorageDetectFilesystem_RcloneMissingReturnsRecoverableError(t *testing.T) { cfg := &config.Config{ CloudEnabled: true, @@ -260,7 +306,7 @@ func TestCloudStorageUploadWithRetryEventuallySucceeds(t *testing.T) { }, } cs.execCommand = queue.exec - cs.sleep = func(time.Duration) {} + cs.waitForRetry = func(context.Context, time.Duration) error { return nil } if err := cs.uploadWithRetry(context.Background(), "/tmp/local.tar", "remote:local.tar"); err != nil { t.Fatalf("uploadWithRetry() error = %v", err) @@ -270,6 +316,87 @@ func TestCloudStorageUploadWithRetryEventuallySucceeds(t *testing.T) { } } +func TestCloudStorageUploadWithRetryUsesCappedBackoff(t *testing.T) { + cfg := &config.Config{ + CloudEnabled: true, + CloudRemote: "remote", + RcloneRetries: 6, + RcloneTimeoutOperation: 5, + } + cs := newCloudStorageForTest(cfg) + + queue := &commandQueue{ + t: t, + queue: []queuedResponse{ + {name: "rclone", err: errors.New("copy failed 1")}, + {name: "rclone", err: errors.New("copy failed 2")}, + {name: "rclone", err: errors.New("copy failed 3")}, + {name: "rclone", err: errors.New("copy failed 4")}, + {name: "rclone", err: errors.New("copy failed 5")}, + {name: "rclone", out: "ok"}, + }, + } + cs.execCommand = queue.exec + + var waits []time.Duration + cs.waitForRetry = func(_ context.Context, d time.Duration) error { + waits = append(waits, d) + return nil + } + + if err := cs.uploadWithRetry(context.Background(), "/tmp/local.tar", "remote:local.tar"); err != nil { + t.Fatalf("uploadWithRetry() error = %v", err) + } + + wantWaits := []time.Duration{ + 2 * time.Second, + 4 * time.Second, + 8 * time.Second, + 16 * time.Second, + 30 * time.Second, + } + if len(waits) != len(wantWaits) { + t.Fatalf("sleep calls = %d, want %d", len(waits), len(wantWaits)) + } + for i := range wantWaits { + if waits[i] != wantWaits[i] { + t.Fatalf("sleep[%d] = %v, want %v", i, waits[i], wantWaits[i]) + } + } +} + +func TestCloudStorageUploadWithRetryReturnsContextErrorDuringBackoff(t *testing.T) { + cfg := &config.Config{ + CloudEnabled: true, + CloudRemote: "remote", + RcloneRetries: 3, + RcloneTimeoutOperation: 5, + } + cs := newCloudStorageForTest(cfg) + + queue := &commandQueue{ + t: t, + queue: []queuedResponse{ + {name: "rclone", err: errors.New("copy failed")}, + }, + } + cs.execCommand = queue.exec + + ctx, cancel := context.WithCancel(context.Background()) + cs.waitForRetry = func(ctx context.Context, _ time.Duration) error { + cancel() + return ctx.Err() + } + + err := cs.uploadWithRetry(ctx, "/tmp/local.tar", "remote:local.tar") + if !errors.Is(err, context.Canceled) { + t.Fatalf("uploadWithRetry() error = %v, want context.Canceled", err) + } + if len(queue.calls) != 1 { + t.Fatalf("expected 1 upload attempt before cancellation, got %d", len(queue.calls)) + } +} + func TestCloudStorageListParsesBackups(t *testing.T) { cfg := &config.Config{ CloudEnabled: true, @@ -432,6 +559,77 @@ func TestCloudStorageStoreUploadsWithRemotePrefix(t *testing.T) { } } +func TestCloudStorageStorePrefersBundleWhenPresent(t *testing.T) { + tmpDir := t.TempDir() + backupFile := filepath.Join(tmpDir, "pbs1-backup.tar.zst") + bundleFile := bundlePathFor(backupFile) + writeTestFile(t, backupFile, "primary") + writeTestFile(t, bundleFile, "bundle") + + cfg := &config.Config{ + CloudEnabled: true, + CloudRemote: "remote", + BundleAssociatedFiles: true, + RcloneRetries: 1, + RcloneTimeoutOperation: 10, + } + + cs := newCloudStorageForTest(cfg) + cs.sleep = func(time.Duration) {} + remoteFile := cs.remotePathFor(filepath.Base(bundleFile)) + queue := &commandQueue{ + t: t, + queue: []queuedResponse{ + {name: "rclone", args: []string{"copyto", "--progress", "--stats", "10s", bundleFile, remoteFile}}, + {name: "rclone", args: []string{"lsl", remoteFile}, out: "6 2025-11-13 10:00:00 pbs1-backup.tar.zst.bundle.tar"}, + {name: "rclone", args: []string{"lsl", "remote:"}, out: "6 2025-11-13 10:00:00 pbs1-backup.tar.zst.bundle.tar"}, + }, + } + cs.execCommand = queue.exec + + if err := cs.Store(context.Background(), backupFile, nil); err != nil { + t.Fatalf("Store() error = %v", err) + } + if len(queue.calls) != 3 { + t.Fatalf("expected 3 rclone calls, got %d", len(queue.calls)) + } +} + +func TestCloudStorageStoreBundleInputSkipsDoubleBundleUpload(t *testing.T) { + tmpDir := t.TempDir() + bundleFile := filepath.Join(tmpDir, "pbs1-backup.tar.zst.bundle.tar") + writeTestFile(t, bundleFile, "bundle") + writeTestFile(t, bundleFile+".bundle.tar", "decoy") + + cfg := &config.Config{ + CloudEnabled: true, + CloudRemote: "remote", + BundleAssociatedFiles: true, + RcloneRetries: 1, + RcloneTimeoutOperation: 10, + } + + cs := newCloudStorageForTest(cfg) + cs.sleep = func(time.Duration) {} + remoteFile := cs.remotePathFor(filepath.Base(bundleFile)) + queue := &commandQueue{ + t: t, + queue: []queuedResponse{ + {name: "rclone", args: []string{"copyto", "--progress", "--stats", "10s", bundleFile, remoteFile}}, + {name: "rclone", args: []string{"lsl", remoteFile}, out: "6 2025-11-13 10:00:00 pbs1-backup.tar.zst.bundle.tar"}, + {name: "rclone", args: []string{"lsl", "remote:"}, out: "6 2025-11-13 10:00:00 pbs1-backup.tar.zst.bundle.tar"}, + }, + } + cs.execCommand = queue.exec + + if err := cs.Store(context.Background(), bundleFile, nil); err != nil { + t.Fatalf("Store() error = %v", err) + } + if len(queue.calls) != 3 { + t.Fatalf("expected 3 rclone calls, got %d", len(queue.calls)) + } +} + func TestCloudStorageStorePrimaryFailure(t *testing.T) { tmpDir := t.TempDir() backupFile := filepath.Join(tmpDir, "pbs1-backup.tar.zst") @@ -844,7 +1042,7 @@ func TestRemoteDirRef(t *testing.T) { } } -func TestCloudStorageApplyGFSRetentionDeletesMarkedBackups(t *testing.T) { +func TestCloudStorageApplyGFSRetentionKeepsMinimumDailyBackup(t *testing.T) { cfg := &config.Config{ CloudEnabled: true, CloudRemote: "remote", @@ -862,7 +1060,6 @@ func TestCloudStorageApplyGFSRetentionDeletesMarkedBackups(t *testing.T) { queue := &commandQueue{ t: t, queue: []queuedResponse{ - {name: "rclone", args: []string{"deletefile", "remote:alpha-backup.tar.zst"}}, {name: "rclone", args: []string{"deletefile", "remote:beta-backup.tar.zst"}}, }, } @@ -879,15 +1076,15 @@ func TestCloudStorageApplyGFSRetentionDeletesMarkedBackups(t *testing.T) { if err != nil { t.Fatalf("applyGFSRetention() error = %v", err) } - if deleted != 2 { - t.Fatalf("applyGFSRetention() deleted = %d, want 2", deleted) + if deleted != 1 { + t.Fatalf("applyGFSRetention() deleted = %d, want 1", deleted) } - if len(queue.calls) != 2 { - t.Fatalf("expected 2 delete commands, got %d", len(queue.calls)) + if len(queue.calls) != 1 { + t.Fatalf("expected 1 delete command, got %d", len(queue.calls)) } summary := cs.LastRetentionSummary() - if summary.BackupsDeleted != 2 || summary.BackupsRemaining != 0 { + if summary.BackupsDeleted != 1 || summary.BackupsRemaining != 1 { t.Fatalf("unexpected retention summary: %+v", summary) } } @@ -995,7 +1192,7 @@ func TestCloudStorageCheckWithTimeoutNoFallback(t *testing.T) { }, } cs.execCommand = queue.exec - cs.sleep = func(time.Duration) {} // Disable sleep for fast tests + cs.waitForRetry = func(context.Context, time.Duration) error { return nil } // Disable retry wait for fast tests err := cs.checkRemoteAccessible(context.Background()) if err == nil { @@ -1030,7 +1227,12 @@ func TestCloudStorageCheckWithNetworkErrorNoFallback(t *testing.T) { }, } cs.execCommand = queue.exec - cs.sleep = func(time.Duration) {} // Disable sleep for fast tests + + var waits []time.Duration + cs.waitForRetry = func(_ context.Context, d time.Duration) error { + waits = append(waits, d) + return nil + } err := cs.checkRemoteAccessible(context.Background()) if err == nil { @@ -1041,6 +1243,81 @@ func TestCloudStorageCheckWithNetworkErrorNoFallback(t *testing.T) { if len(queue.calls) != 3 { t.Fatalf("expected 3 rclone calls (lsf per attempt), got %d", len(queue.calls)) } + if len(waits) != 2 { + t.Fatalf("expected 2 backoff waits, got %d", len(waits)) + } + if waits[0] != 2*time.Second || waits[1] != 4*time.Second { + t.Fatalf("unexpected backoff waits: got %v", waits) + } +} + +func TestCloudStorageCheckRemoteAccessibleReturnsContextErrorDuringBackoff(t *testing.T) { + cfg := &config.Config{ + CloudEnabled: true, + CloudRemote: "remote", + CloudWriteHealthCheck: false, + RcloneTimeoutConnection: 30, + } + cs := newCloudStorageForTest(cfg) + + queue := &commandQueue{ + t: t, + queue: []queuedResponse{ + {name: "rclone", err: errors.New("exit 1"), out: "dial tcp: connection refused"}, + }, + } + cs.execCommand = queue.exec + + ctx, cancel := context.WithCancel(context.Background()) + cs.waitForRetry = func(ctx context.Context, _ time.Duration) error { + cancel() + return ctx.Err() + } + + err := cs.checkRemoteAccessible(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("checkRemoteAccessible() error = %v, want context.Canceled", err) + } + if len(queue.calls) != 1 { + t.Fatalf("expected 1 remote check attempt before cancellation, got %d", len(queue.calls)) + } +} + +func TestCloudStorageCheckRemoteAccessiblePropagatesParentDeadline(t *testing.T) { + cfg := &config.Config{ + CloudEnabled: true, + CloudRemote: "remote", + CloudWriteHealthCheck: false, + RcloneTimeoutConnection: 30, + } + cs := newCloudStorageForTest(cfg) + + queue := &commandQueue{ + t: t, + queue: []queuedResponse{ + {name: "rclone", err: errors.New("exit 1"), out: "dial tcp: connection refused"}, + }, + } + cs.execCommand = queue.exec + cs.waitForRetry = func(ctx context.Context, _ time.Duration) error { + <-ctx.Done() + return ctx.Err() + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err := cs.checkRemoteAccessible(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("checkRemoteAccessible() error = %v, want context.DeadlineExceeded", err) + } + var rcErr *remoteCheckError + if errors.As(err, &rcErr) { + t.Fatalf("checkRemoteAccessible() error = %v, want parent deadline error, not remoteCheckError", err) + } + if len(queue.calls) != 1 { + t.Fatalf("expected 1 remote check attempt before parent deadline, got %d", len(queue.calls)) + } } // Test backward compatibility: CLOUD_WRITE_HEALTHCHECK=true skips list check @@ -1109,7 +1386,7 @@ func TestCloudStorageCheckBothListAndWriteFail(t *testing.T) { }, } cs.execCommand = queue.exec - cs.sleep = func(time.Duration) {} // Disable sleep for fast tests + cs.waitForRetry = func(context.Context, time.Duration) error { return nil } // Disable retry wait for fast tests err := cs.checkRemoteAccessible(context.Background()) if err == nil { diff --git a/internal/storage/local.go b/internal/storage/local.go index c7664462..b9cc03f9 100644 --- a/internal/storage/local.go +++ b/internal/storage/local.go @@ -16,6 +16,7 @@ import ( "github.com/tis24dev/proxsave/internal/backup" "github.com/tis24dev/proxsave/internal/config" "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/safefs" "github.com/tis24dev/proxsave/internal/types" ) @@ -460,6 +461,7 @@ func (l *LocalStorage) ApplyRetention(ctx context.Context, config RetentionConfi // applyGFSRetention applies GFS (Grandfather-Father-Son) retention policy func (l *LocalStorage) applyGFSRetention(ctx context.Context, backups []*types.BackupMetadata, config RetentionConfig) (int, error) { + config = EffectiveGFSRetentionConfig(config) l.logger.Debug("Applying GFS retention policy (daily=%d, weekly=%d, monthly=%d, yearly=%d)", config.Daily, config.Weekly, config.Monthly, config.Yearly) @@ -661,20 +663,9 @@ func (l *LocalStorage) GetStats(ctx context.Context) (stats *StorageStats, err e // Get available/total space using statfs var stat syscall.Statfs_t if err := syscall.Statfs(l.basePath, &stat); err == nil { - available := int64(stat.Bavail) * int64(stat.Bsize) - total := int64(stat.Blocks) * int64(stat.Bsize) - if available < 0 { - available = 0 - } - if total < 0 { - total = 0 - } + total, available, used := safefs.SpaceUsageFromStatfs(stat) stats.AvailableSpace = available stats.TotalSpace = total - used := total - available - if used < 0 { - used = 0 - } stats.UsedSpace = used } diff --git a/internal/storage/retention.go b/internal/storage/retention.go index 2939bb85..38bbf3b1 100644 --- a/internal/storage/retention.go +++ b/internal/storage/retention.go @@ -73,6 +73,7 @@ func ClassifyBackupsGFS(backups []*types.BackupMetadata, config RetentionConfig) if len(backups) == 0 { return make(map[*types.BackupMetadata]RetentionCategory) } + config = EffectiveGFSRetentionConfig(config) // Sort by timestamp descending (newest first) sort.Slice(backups, func(i, j int) bool { @@ -87,20 +88,15 @@ func ClassifyBackupsGFS(backups []*types.BackupMetadata, config RetentionConfig) // 1. DAILY: Keep the last N backups (newest first) dailyLimit := config.Daily - if dailyLimit < 0 { - dailyLimit = 0 - } dailyCount := 0 dailyCutIndex := len(backups) - if dailyLimit > 0 { - for i, b := range backups { - if dailyCount >= dailyLimit { - dailyCutIndex = i - break - } - classification[b] = CategoryDaily - dailyCount++ + for i, b := range backups { + if dailyCount >= dailyLimit { + dailyCutIndex = i + break } + classification[b] = CategoryDaily + dailyCount++ } if dailyCount < dailyLimit { dailyCutIndex = len(backups) diff --git a/internal/storage/retention_normalize.go b/internal/storage/retention_normalize.go index fa86cf3b..517e7741 100644 --- a/internal/storage/retention_normalize.go +++ b/internal/storage/retention_normalize.go @@ -4,6 +4,18 @@ import ( "github.com/tis24dev/proxsave/internal/logging" ) +// EffectiveGFSRetentionConfig returns the effective GFS configuration without side effects. +// It applies the same value normalization used by GFS retention execution paths, but does not log. +// Callers are responsible for invoking it only for configurations that should use GFS semantics. +func EffectiveGFSRetentionConfig(cfg RetentionConfig) RetentionConfig { + effective := cfg + if effective.Daily <= 0 { + effective.Daily = 1 + } + + return effective +} + // NormalizeGFSRetentionConfig applies the required adjustments to the GFS configuration // before running retention. Currently: // - ensures the DAILY tier is at least 1 (minimum accepted value) @@ -14,12 +26,11 @@ func NormalizeGFSRetentionConfig(logger *logging.Logger, backendName string, cfg return cfg } - effective := cfg - if effective.Daily <= 0 { + effective := EffectiveGFSRetentionConfig(cfg) + if effective.Daily != cfg.Daily { if logger != nil { logger.Info("%s: RETENTION_DAILY is %d or not set, enforcing minimum of 1 daily backup", backendName, cfg.Daily) } - effective.Daily = 1 } return effective diff --git a/internal/storage/retention_test.go b/internal/storage/retention_test.go index 73337fad..d761f2fe 100644 --- a/internal/storage/retention_test.go +++ b/internal/storage/retention_test.go @@ -192,24 +192,38 @@ func TestClassifyBackupsGFS_DailyOnly(t *testing.T) { func TestClassifyBackupsGFS_ZeroDaily(t *testing.T) { now := time.Now() backups := []*types.BackupMetadata{ - {Timestamp: now.Add(-24 * time.Hour)}, - {Timestamp: now.Add(-48 * time.Hour)}, + {Timestamp: now.Add(-1 * time.Hour)}, + {Timestamp: now.Add(-8 * 24 * time.Hour)}, + {Timestamp: now.Add(-15 * 24 * time.Hour)}, } config := RetentionConfig{ Daily: 0, Weekly: 1, Monthly: 0, - Yearly: 0, + Yearly: -1, } classification := ClassifyBackupsGFS(backups, config) stats := GetRetentionStats(classification) - // With Daily=0, backups should go to weekly/monthly/yearly or delete - if stats[CategoryDaily] != 0 { - t.Errorf("Expected 0 daily backups with Daily=0, got %d", stats[CategoryDaily]) + // In GFS, Daily=0 is normalized to 1 to ensure the current period is protected. + if stats[CategoryDaily] != 1 { + t.Errorf("Expected 1 daily backup with Daily=0, got %d", stats[CategoryDaily]) + } + if stats[CategoryWeekly] != 1 { + t.Errorf("Expected 1 weekly backup with Daily=0, got %d", stats[CategoryWeekly]) + } + if stats[CategoryDelete] != 1 { + t.Errorf("Expected 1 backup marked for deletion with Daily=0, got %d", stats[CategoryDelete]) + } + + if classification[backups[0]] != CategoryDaily { + t.Errorf("Expected newest backup to be daily with Daily=0, got %v", classification[backups[0]]) + } + if classification[backups[1]] != CategoryWeekly { + t.Errorf("Expected previous-week backup to be weekly with Daily=0, got %v", classification[backups[1]]) } } @@ -439,7 +453,7 @@ func TestClassifyBackupsGFS_NegativeDaily(t *testing.T) { } config := RetentionConfig{ - Daily: -5, // Negative should be treated as 0 + Daily: -5, // Negative values are normalized to the minimum daily tier. Weekly: 0, Monthly: 0, Yearly: -1, // Disable yearly retention so older-year backups aren't implicitly kept. @@ -449,12 +463,12 @@ func TestClassifyBackupsGFS_NegativeDaily(t *testing.T) { stats := GetRetentionStats(classification) - if stats[CategoryDaily] != 0 { - t.Errorf("Expected 0 daily backups with negative daily config, got %d", stats[CategoryDaily]) + if stats[CategoryDaily] != 1 { + t.Errorf("Expected 1 daily backup with negative daily config, got %d", stats[CategoryDaily]) } - if stats[CategoryDelete] != 2 { - t.Errorf("Expected all backups marked for deletion, got %d", stats[CategoryDelete]) + if stats[CategoryDelete] != 1 { + t.Errorf("Expected 1 backup marked for deletion, got %d", stats[CategoryDelete]) } } diff --git a/internal/storage/secondary.go b/internal/storage/secondary.go index ca5a669a..b05143e2 100644 --- a/internal/storage/secondary.go +++ b/internal/storage/secondary.go @@ -13,6 +13,7 @@ import ( "github.com/tis24dev/proxsave/internal/config" "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/safefs" "github.com/tis24dev/proxsave/internal/types" "github.com/tis24dev/proxsave/pkg/utils" ) @@ -107,14 +108,20 @@ func (s *SecondaryStorage) Store(ctx context.Context, backupFile string, metadat return err } + bundleEnabled := s.config != nil && s.config.BundleAssociatedFiles + sourceFile := backupFile + if bundleEnabled { + sourceFile = bundlePathFor(sourceFile) + } + // Verify source file exists - if _, err := os.Stat(backupFile); err != nil { - s.logger.Debug("Secondary storage: source file %s not found", backupFile) - s.logger.Warning("WARNING: Secondary storage - backup file not found: %s: %v", backupFile, err) + if _, err := os.Stat(sourceFile); err != nil { + s.logger.Debug("Secondary storage: source file %s not found", sourceFile) + s.logger.Warning("WARNING: Secondary storage - backup file not found: %s: %v", sourceFile, err) return &StorageError{ Location: LocationSecondary, Operation: "store", - Path: backupFile, + Path: sourceFile, Err: fmt.Errorf("source file not found: %w", err), IsCritical: false, Recoverable: false, @@ -136,18 +143,18 @@ func (s *SecondaryStorage) Store(ctx context.Context, backupFile string, metadat } // Determine destination filename - destFile := filepath.Join(s.basePath, filepath.Base(backupFile)) + destFile := filepath.Join(s.basePath, filepath.Base(sourceFile)) s.logger.Debug("Secondary Storage: Start copy...") - s.logger.Debug("Copying backup to secondary storage: %s -> %s", filepath.Base(backupFile), s.basePath) + s.logger.Debug("Copying backup to secondary storage: %s -> %s", filepath.Base(sourceFile), s.basePath) - if err := s.copyFile(ctx, backupFile, destFile); err != nil { - s.logger.Warning("WARNING: Secondary Storage: File copy failed for %s: %v", filepath.Base(backupFile), err) + if err := s.copyFile(ctx, sourceFile, destFile); err != nil { + s.logger.Warning("WARNING: Secondary Storage: File copy failed for %s: %v", filepath.Base(sourceFile), err) s.logger.Warning("WARNING: Secondary Storage: Backup not saved to %s", s.basePath) return &StorageError{ Location: LocationSecondary, Operation: "store", - Path: backupFile, + Path: sourceFile, Err: fmt.Errorf("copy failed: %w", err), IsCritical: false, Recoverable: true, @@ -155,7 +162,7 @@ func (s *SecondaryStorage) Store(ctx context.Context, backupFile string, metadat } // Copy associated files if not bundled - if !s.config.BundleAssociatedFiles { + if !bundleEnabled { associatedFiles := []string{ backupFile + ".sha256", backupFile + ".metadata", @@ -181,16 +188,6 @@ func (s *SecondaryStorage) Store(ctx context.Context, backupFile string, metadat s.logger.Warning("WARNING: Secondary Storage: %d associated file(s) failed to copy: %v", len(failedAssoc), failedAssoc) } - } else { - // Copy bundle file - bundleFile := backupFile + ".bundle.tar" - if _, err := os.Stat(bundleFile); err == nil { - destBundle := filepath.Join(s.basePath, filepath.Base(bundleFile)) - if err := s.copyFile(ctx, bundleFile, destBundle); err != nil { - s.logger.Warning("WARNING: Secondary Storage: Failed to copy bundle %s: %v", - filepath.Base(bundleFile), err) - } - } } // Set permissions on destination (best effort) @@ -527,6 +524,7 @@ func (s *SecondaryStorage) ApplyRetention(ctx context.Context, config RetentionC // applyGFSRetention applies GFS (Grandfather-Father-Son) retention policy func (s *SecondaryStorage) applyGFSRetention(ctx context.Context, backups []*types.BackupMetadata, config RetentionConfig) (int, error) { + config = EffectiveGFSRetentionConfig(config) s.logger.Debug("Applying GFS retention policy (daily=%d, weekly=%d, monthly=%d, yearly=%d)", config.Daily, config.Weekly, config.Monthly, config.Yearly) @@ -727,20 +725,9 @@ func (s *SecondaryStorage) GetStats(ctx context.Context) (stats *StorageStats, e // Get available/total space using statfs var stat syscall.Statfs_t if err := syscall.Statfs(s.basePath, &stat); err == nil { - available := int64(stat.Bavail) * int64(stat.Bsize) - total := int64(stat.Blocks) * int64(stat.Bsize) - if available < 0 { - available = 0 - } - if total < 0 { - total = 0 - } + total, available, used := safefs.SpaceUsageFromStatfs(stat) stats.AvailableSpace = available stats.TotalSpace = total - used := total - available - if used < 0 { - used = 0 - } stats.UsedSpace = used } diff --git a/internal/storage/secondary_test.go b/internal/storage/secondary_test.go index 19b15fc8..22ce64bf 100644 --- a/internal/storage/secondary_test.go +++ b/internal/storage/secondary_test.go @@ -342,7 +342,7 @@ func TestSecondaryStorage_Store_AssociatedCopyFailuresAreNonFatal(t *testing.T) } } -func TestSecondaryStorage_Store_BundleCopyFailureIsNonFatal(t *testing.T) { +func TestSecondaryStorage_Store_BundleSourceCopyFailureReturnsRecoverableError(t *testing.T) { logger := logging.New(types.LogLevelInfo, false) srcDir := t.TempDir() destDir := t.TempDir() @@ -367,15 +367,46 @@ func TestSecondaryStorage_Store_BundleCopyFailureIsNonFatal(t *testing.T) { t.Fatalf("WriteFile: %v", err) } - if err := storage.Store(context.Background(), backupFile, &types.BackupMetadata{}); err != nil { - t.Fatalf("Store() error = %v; want nil (non-fatal bundle failure)", err) + err := storage.Store(context.Background(), backupFile, &types.BackupMetadata{}) + if err == nil { + t.Fatalf("expected error") + } + var se *StorageError + if !errors.As(err, &se) { + t.Fatalf("expected StorageError, got %T: %v", err, err) + } + if !se.Recoverable { + t.Fatalf("expected recoverable StorageError, got %+v", se) } - if _, err := os.Stat(filepath.Join(destDir, filepath.Base(backupFile))); err != nil { - t.Fatalf("expected backup to be copied: %v", err) + if _, err := os.Stat(filepath.Join(destDir, filepath.Base(backupFile))); !os.IsNotExist(err) { + t.Fatalf("raw backup should not be copied when bundling is enabled, err=%v", err) } if _, err := os.Stat(filepath.Join(destDir, filepath.Base(bundleDir))); !os.IsNotExist(err) { - t.Fatalf("expected bundle not to be copied due to forced failure, err=%v", err) + t.Fatalf("bundle should not be copied due to forced source failure, err=%v", err) + } +} + +func TestSecondaryStorage_Store_NilConfigTreatsSourceAsUnbundled(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + srcDir := t.TempDir() + destDir := t.TempDir() + storage := &SecondaryStorage{ + logger: logger, + basePath: destDir, + } + + backupFile := filepath.Join(srcDir, "node-backup-20240102-030405.tar.zst") + if err := os.WriteFile(backupFile, []byte("data"), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + if err := storage.Store(context.Background(), backupFile, &types.BackupMetadata{}); err != nil { + t.Fatalf("Store() error = %v", err) + } + + if _, err := os.Stat(filepath.Join(destDir, filepath.Base(backupFile))); err != nil { + t.Fatalf("expected raw backup to be copied, err=%v", err) } } diff --git a/internal/storage/storage_test.go b/internal/storage/storage_test.go index 439e82e2..3234d1de 100644 --- a/internal/storage/storage_test.go +++ b/internal/storage/storage_test.go @@ -48,6 +48,66 @@ func TestNormalizeGFSRetentionConfigEnforcesDailyMinimum(t *testing.T) { } } +func TestNormalizeGFSRetentionConfigLeavesNonGFSUnchanged(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + var buf bytes.Buffer + logger.SetOutput(&buf) + + cfg := RetentionConfig{ + Policy: "simple", + MaxBackups: 7, + Daily: 0, + Weekly: 4, + } + + effective := NormalizeGFSRetentionConfig(logger, "Test Storage", cfg) + + if effective != cfg { + t.Fatalf("NormalizeGFSRetentionConfig() = %+v; want %+v", effective, cfg) + } + if buf.Len() != 0 { + t.Fatalf("expected no log output for non-GFS policy, got: %s", buf.String()) + } +} + +func TestNormalizeGFSRetentionConfigDoesNotLogWhenDailyAlreadyValid(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + var buf bytes.Buffer + logger.SetOutput(&buf) + + cfg := RetentionConfig{ + Policy: "gfs", + Daily: 3, + Weekly: 4, + } + + effective := NormalizeGFSRetentionConfig(logger, "Test Storage", cfg) + + if effective != cfg { + t.Fatalf("NormalizeGFSRetentionConfig() = %+v; want %+v", effective, cfg) + } + if buf.Len() != 0 { + t.Fatalf("expected no log output when daily is already valid, got: %s", buf.String()) + } +} + +func TestEffectiveGFSRetentionConfigEnforcesDailyMinimumWithoutLogging(t *testing.T) { + cfg := RetentionConfig{ + Policy: "gfs", + Daily: 0, + Weekly: 4, + } + + effective := EffectiveGFSRetentionConfig(cfg) + + if effective.Daily != 1 { + t.Fatalf("EffectiveGFSRetentionConfig() Daily = %d; want 1", effective.Daily) + } + if effective.Weekly != cfg.Weekly { + t.Fatalf("EffectiveGFSRetentionConfig() Weekly = %d; want %d", effective.Weekly, cfg.Weekly) + } +} + func TestLocalStorageListSkipsAssociatedFilesAndSortsByTimestamp(t *testing.T) { t.Parallel() @@ -1407,20 +1467,72 @@ func TestSecondaryStorageStoreHandlesBundles(t *testing.T) { } destBackup := filepath.Join(destDir, filepath.Base(backupFile)) - if _, err := os.Stat(destBackup); err != nil { - t.Fatalf("expected backup to be copied: %v", err) + if _, err := os.Stat(destBackup); !os.IsNotExist(err) { + t.Fatalf("raw backup should not be copied when bundling is enabled, err=%v", err) } destBundle := filepath.Join(destDir, filepath.Base(backupFile)+".bundle.tar") if _, err := os.Stat(destBundle); err != nil { t.Fatalf("expected bundle to be copied: %v", err) } + if data, err := os.ReadFile(destBundle); err != nil { + t.Fatalf("read copied bundle: %v", err) + } else if string(data) != "bundle" { + t.Fatalf("copied bundle = %q, want %q", string(data), "bundle") + } if _, err := os.Stat(filepath.Join(destDir, filepath.Base(backupFile)+".metadata")); !os.IsNotExist(err) { t.Fatalf("metadata should not be copied when bundling is enabled, err=%v", err) } } +func TestSecondaryStorageStoreBundleInputSkipsDoubleBundleCopy(t *testing.T) { + t.Parallel() + + srcDir := t.TempDir() + destDir := t.TempDir() + cfg := &config.Config{ + SecondaryEnabled: true, + SecondaryPath: destDir, + BundleAssociatedFiles: true, + } + storage := newSecondaryStorageForTest(t, cfg) + + bundleFile := filepath.Join(srcDir, "node-bundle-backup-20240202-020202.tar.zst.bundle.tar") + if err := os.WriteFile(bundleFile, []byte("bundle"), 0o600); err != nil { + t.Fatalf("write bundle: %v", err) + } + doubleBundle := bundleFile + ".bundle.tar" + if err := os.WriteFile(doubleBundle, []byte("decoy"), 0o600); err != nil { + t.Fatalf("write double bundle decoy: %v", err) + } + + if err := storage.Store(context.Background(), bundleFile, &types.BackupMetadata{}); err != nil { + t.Fatalf("Store() error = %v", err) + } + + destBundle := filepath.Join(destDir, filepath.Base(bundleFile)) + if _, err := os.Stat(destBundle); err != nil { + t.Fatalf("expected bundle to be copied: %v", err) + } + originalBundleData, err := os.ReadFile(bundleFile) + if err != nil { + t.Fatalf("read original bundle: %v", err) + } + copiedBundleData, err := os.ReadFile(destBundle) + if err != nil { + t.Fatalf("read copied bundle: %v", err) + } + if string(copiedBundleData) != string(originalBundleData) { + t.Fatalf("copied bundle contents = %q, want %q", string(copiedBundleData), string(originalBundleData)) + } + + destDoubleBundle := filepath.Join(destDir, filepath.Base(doubleBundle)) + if _, err := os.Stat(destDoubleBundle); !os.IsNotExist(err) { + t.Fatalf("double bundle decoy should not be copied, err=%v", err) + } +} + func TestSecondaryStorageStoreHonorsContextCancellation(t *testing.T) { t.Parallel() @@ -1526,8 +1638,11 @@ func TestSecondaryStorageGetStatsIncludesFilesystemInfo(t *testing.T) { if stats.TotalSpace == 0 || stats.AvailableSpace == 0 { t.Fatalf("expected filesystem stats to be populated (TotalSpace=%d, AvailableSpace=%d)", stats.TotalSpace, stats.AvailableSpace) } - if stats.UsedSpace != stats.TotalSpace-stats.AvailableSpace { - t.Fatalf("UsedSpace mismatch: got %d want %d", stats.UsedSpace, stats.TotalSpace-stats.AvailableSpace) + if stats.AvailableSpace > stats.TotalSpace { + t.Fatalf("AvailableSpace = %d, should not exceed TotalSpace = %d", stats.AvailableSpace, stats.TotalSpace) + } + if stats.UsedSpace < 0 || stats.UsedSpace > stats.TotalSpace { + t.Fatalf("UsedSpace = %d, should be within [0, %d]", stats.UsedSpace, stats.TotalSpace) } } diff --git a/internal/testutil/age_setup_ui_stub.go b/internal/testutil/age_setup_ui_stub.go new file mode 100644 index 00000000..1ddf37bd --- /dev/null +++ b/internal/testutil/age_setup_ui_stub.go @@ -0,0 +1,48 @@ +package testutil + +import ( + "context" + "errors" +) + +var ErrAgeSetupUIStubAborted = errors.New("age setup ui stub aborted") + +// AgeSetupUIStub is a reusable scripted UI double for age recipient setup flows. +type AgeSetupUIStub[T any] struct { + Overwrite bool + Drafts []*T + AddMore []bool + AbortErr error + + OverwriteCalls int + CollectCalls int + AddCalls int +} + +func (u *AgeSetupUIStub[T]) ConfirmOverwriteExistingRecipient(ctx context.Context, recipientPath string) (bool, error) { + u.OverwriteCalls++ + return u.Overwrite, nil +} + +func (u *AgeSetupUIStub[T]) CollectRecipientDraft(ctx context.Context, recipientPath string) (*T, error) { + u.CollectCalls++ + if len(u.Drafts) == 0 { + if u.AbortErr != nil { + return nil, u.AbortErr + } + return nil, ErrAgeSetupUIStubAborted + } + draft := u.Drafts[0] + u.Drafts = u.Drafts[1:] + return draft, nil +} + +func (u *AgeSetupUIStub[T]) ConfirmAddAnotherRecipient(ctx context.Context, currentCount int) (bool, error) { + u.AddCalls++ + if len(u.AddMore) == 0 { + return false, nil + } + next := u.AddMore[0] + u.AddMore = u.AddMore[1:] + return next, nil +} diff --git a/internal/tui/abort_context_test.go b/internal/tui/abort_context_test.go index d0e775d4..93778c1c 100644 --- a/internal/tui/abort_context_test.go +++ b/internal/tui/abort_context_test.go @@ -2,12 +2,35 @@ package tui import ( "context" + "errors" + "sync" "testing" "time" + "github.com/gdamore/tcell/v2" "github.com/rivo/tview" ) +func newSimulationApp(t *testing.T) (*App, tcell.SimulationScreen, <-chan struct{}) { + t.Helper() + screen := tcell.NewSimulationScreen("UTF-8") + if err := screen.Init(); err != nil { + t.Fatalf("screen.Init: %v", err) + } + + app := NewApp() + started := make(chan struct{}) + var startedOnce sync.Once + app.SetAfterDrawFunc(func(screen tcell.Screen) { + startedOnce.Do(func() { + close(started) + }) + }) + app.SetScreen(screen) + app.SetRoot(tview.NewBox(), true) + return app, screen, started +} + func TestSetAbortContext_GetAbortContextRoundTrip(t *testing.T) { SetAbortContext(nil) if got := getAbortContext(); got != nil { @@ -93,6 +116,143 @@ func TestAppStop_DelegatesToEmbeddedApplication(t *testing.T) { app.Stop() } +func TestAppRunWithContext_CanceledBeforeRun(t *testing.T) { + app := &App{Application: tview.NewApplication()} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + if err := app.RunWithContext(ctx); !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v want %v", err, context.Canceled) + } +} + +func TestAppRunWithContext_NilReceiverReturnsNil(t *testing.T) { + var app *App + if err := app.RunWithContext(context.Background()); err != nil { + t.Fatalf("err=%v want nil", err) + } +} + +func TestAppRunWithContext_NilContextRunsUntilStopped(t *testing.T) { + app, _, started := newSimulationApp(t) + done := make(chan error, 1) + + go func() { + done <- app.RunWithContext(nil) + }() + + select { + case err := <-done: + t.Fatalf("RunWithContext(nil) returned before app started: %v", err) + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for app to start") + } + + select { + case err := <-done: + t.Fatalf("RunWithContext(nil) returned before Stop: %v", err) + case <-time.After(50 * time.Millisecond): + } + + app.Stop() + + select { + case err := <-done: + if err != nil { + t.Fatalf("err=%v want nil", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for RunWithContext(nil) to return after Stop") + } +} + +func TestAppRunWithContext_ReturnsNilWhenStoppedWithoutCancellation(t *testing.T) { + app, _, started := newSimulationApp(t) + done := make(chan error, 1) + + go func() { + done <- app.RunWithContext(context.Background()) + }() + + select { + case err := <-done: + t.Fatalf("RunWithContext(context.Background()) returned before app started: %v", err) + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for app to start") + } + + select { + case err := <-done: + t.Fatalf("RunWithContext(context.Background()) returned before Stop: %v", err) + case <-time.After(50 * time.Millisecond): + } + + app.Stop() + + select { + case err := <-done: + if err != nil { + t.Fatalf("err=%v want nil", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for RunWithContext(context.Background()) to return after Stop") + } +} + +func TestAppRunWithContext_StopsOnCancel(t *testing.T) { + app, _, _ := newSimulationApp(t) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + if err := app.RunWithContext(ctx); !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v want %v", err, context.Canceled) + } +} + +func TestAppRunWithContext_PropagatesRunErrorWithoutCancellation(t *testing.T) { + app, _, _ := newSimulationApp(t) + runErr := errors.New("run failed") + eventErr := tcell.NewEventError(runErr) + + go func() { + time.Sleep(50 * time.Millisecond) + app.QueueEvent(eventErr) + }() + + if err := app.RunWithContext(context.Background()); err != eventErr { + t.Fatalf("err=%v want %v", err, eventErr) + } +} + +func TestAppRunWithContext_PrefersContextErrorWhenCanceledDuringRunError(t *testing.T) { + app, _, _ := newSimulationApp(t) + runErr := errors.New("run failed") + + var stopOnce sync.Once + app.stopHook = func() { + stopOnce.Do(func() { + app.stopHook = nil + app.QueueEvent(tcell.NewEventError(runErr)) + }) + } + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + if err := app.RunWithContext(ctx); !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v want %v", err, context.Canceled) + } +} + func TestSetRootWithTitle_SetsBoxTitleAndBorderColor(t *testing.T) { app := &App{Application: tview.NewApplication()} box := tview.NewBox() diff --git a/internal/tui/app.go b/internal/tui/app.go index 91660136..e190f67f 100644 --- a/internal/tui/app.go +++ b/internal/tui/app.go @@ -1,6 +1,9 @@ package tui import ( + "context" + "sync/atomic" + "github.com/gdamore/tcell/v2" "github.com/rivo/tview" ) @@ -50,6 +53,43 @@ func (a *App) Stop() { } } +func (a *App) RunWithContext(ctx context.Context) error { + if a == nil { + return nil + } + if ctx == nil { + return a.Run() + } + if err := ctx.Err(); err != nil { + return err + } + + done := make(chan struct{}) + defer close(done) + + var state atomic.Int32 + go func() { + select { + case <-ctx.Done(): + if state.CompareAndSwap(0, 1) { + a.Stop() + } + case <-done: + } + }() + + if err := a.Run(); err != nil { + if state.CompareAndSwap(0, 2) { + return err + } + return ctx.Err() + } + if state.CompareAndSwap(0, 2) { + return nil + } + return ctx.Err() +} + // SetRootWithTitle sets the root primitive with a styled title func (a *App) SetRootWithTitle(root tview.Primitive, title string) *App { if box, ok := root.(*tview.Box); ok { diff --git a/internal/tui/screen.go b/internal/tui/screen.go new file mode 100644 index 00000000..0a030b19 --- /dev/null +++ b/internal/tui/screen.go @@ -0,0 +1,90 @@ +package tui + +import ( + "fmt" + "strings" + + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" +) + +type ScreenSpec struct { + Title string + HeaderText string + NavText string + ConfigPath string + BuildSig string + TitleColor tcell.Color + BorderColor tcell.Color + BackgroundColor tcell.Color +} + +func BuildScreen(spec ScreenSpec, content tview.Primitive) tview.Primitive { + if content == nil { + content = tview.NewBox() + } + configPath := strings.TrimSpace(spec.ConfigPath) + buildSig := strings.TrimSpace(spec.BuildSig) + escapedHeaderText := tview.Escape(spec.HeaderText) + escapedConfigPath := tview.Escape(configPath) + escapedBuildSig := tview.Escape(buildSig) + + welcomeText := tview.NewTextView(). + SetText(escapedHeaderText). + SetTextColor(ProxmoxLight). + SetDynamicColors(true) + welcomeText.SetBorder(false) + + navText := strings.TrimSpace(spec.NavText) + if navText != "" { + navText = "\n" + navText + } + separator := tview.NewTextView(). + SetText(strings.Repeat("─", 80)). + SetTextColor(ProxmoxOrange) + separator.SetBorder(false) + + flex := tview.NewFlex(). + SetDirection(tview.FlexRow) + flex.AddItem(welcomeText, 5, 0, false) + if navText != "" { + navInstructions := tview.NewTextView(). + SetText(navText). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true). + SetTextAlign(tview.AlignCenter) + navInstructions.SetBorder(false) + flex.AddItem(navInstructions, 2, 0, false) + } + flex.AddItem(separator, 1, 0, false) + flex.AddItem(content, 0, 1, true) + + if configPath != "" { + configPathText := tview.NewTextView(). + SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", escapedConfigPath)). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true). + SetTextAlign(tview.AlignCenter) + configPathText.SetBorder(false) + flex.AddItem(configPathText, 1, 0, false) + } + + if buildSig != "" { + buildSigText := tview.NewTextView(). + SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", escapedBuildSig)). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true). + SetTextAlign(tview.AlignCenter) + buildSigText.SetBorder(false) + flex.AddItem(buildSigText, 1, 0, false) + } + + flex.SetBorder(true). + SetTitle(fmt.Sprintf(" %s ", spec.Title)). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(spec.TitleColor). + SetBorderColor(spec.BorderColor). + SetBackgroundColor(spec.BackgroundColor) + + return flex +} diff --git a/internal/tui/screen_test.go b/internal/tui/screen_test.go new file mode 100644 index 00000000..8a2e4b10 --- /dev/null +++ b/internal/tui/screen_test.go @@ -0,0 +1,200 @@ +package tui + +import ( + "reflect" + "strings" + "testing" + + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" +) + +func primitiveContainsText(p tview.Primitive, want string) bool { + return primitiveContainsTextWithVisited(p, want, map[uintptr]struct{}{}) +} + +func primitiveContainsTextWithVisited(p tview.Primitive, want string, visited map[uintptr]struct{}) bool { + switch v := p.(type) { + case nil: + return false + case *tview.TextView: + return strings.Contains(v.GetTitle(), want) || strings.Contains(v.GetText(false), want) + case *tview.Box: + return strings.Contains(v.GetTitle(), want) + case *tview.Button: + return strings.Contains(v.GetTitle(), want) || strings.Contains(v.GetLabel(), want) + case *tview.Flex: + if strings.Contains(v.GetTitle(), want) { + return true + } + for i := 0; i < v.GetItemCount(); i++ { + if primitiveContainsTextWithVisited(v.GetItem(i), want, visited) { + return true + } + } + return false + case *tview.Form: + if strings.Contains(v.GetTitle(), want) { + return true + } + for i := 0; i < v.GetFormItemCount(); i++ { + if strings.Contains(v.GetFormItem(i).GetLabel(), want) { + return true + } + } + for i := 0; i < v.GetButtonCount(); i++ { + if strings.Contains(v.GetButton(i).GetLabel(), want) { + return true + } + } + case *tview.List: + if strings.Contains(v.GetTitle(), want) { + return true + } + for i := 0; i < v.GetItemCount(); i++ { + main, secondary := v.GetItemText(i) + if strings.Contains(main, want) || strings.Contains(secondary, want) { + return true + } + } + return false + case *tview.Pages: + if strings.Contains(v.GetTitle(), want) { + return true + } + for _, name := range v.GetPageNames(true) { + if primitiveContainsTextWithVisited(v.GetPage(name), want, visited) { + return true + } + } + return false + case *tview.Frame: + if strings.Contains(v.GetTitle(), want) { + return true + } + if primitiveContainsTextWithVisited(v.GetPrimitive(), want, visited) { + return true + } + case *tview.Modal: + if strings.Contains(v.GetTitle(), want) { + return true + } + default: + } + + return reflectedValueContainsText(reflect.ValueOf(p), want, visited) +} + +func reflectedValueContainsText(v reflect.Value, want string, visited map[uintptr]struct{}) bool { + if !v.IsValid() { + return false + } + for v.Kind() == reflect.Interface { + if v.IsNil() { + return false + } + v = v.Elem() + } + if v.Kind() == reflect.Pointer { + if v.IsNil() { + return false + } + ptr := v.Pointer() + if ptr != 0 { + if _, seen := visited[ptr]; seen { + return false + } + visited[ptr] = struct{}{} + } + return reflectedValueContainsText(v.Elem(), want, visited) + } + + switch v.Kind() { + case reflect.String: + return strings.Contains(v.String(), want) + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + if reflectedValueContainsText(v.Field(i), want, visited) { + return true + } + } + case reflect.Slice, reflect.Array: + for i := 0; i < v.Len(); i++ { + if reflectedValueContainsText(v.Index(i), want, visited) { + return true + } + } + } + return false +} + +func TestBuildScreenOmitsEmptyOptionalFooters(t *testing.T) { + page := BuildScreen(ScreenSpec{ + Title: "Title", + HeaderText: "Header", + NavText: "Navigation", + ConfigPath: "", + BuildSig: "sig", + TitleColor: ProxmoxOrange, + BorderColor: ProxmoxOrange, + BackgroundColor: tcell.ColorBlack, + }, tview.NewBox()) + + if primitiveContainsText(page, "Configuration file:") { + t.Fatalf("did not expect configuration footer when ConfigPath is empty") + } + if !primitiveContainsText(page, "Build Signature:") { + t.Fatalf("expected build signature footer") + } +} + +func TestBuildScreenEscapesHeaderText(t *testing.T) { + page := BuildScreen(ScreenSpec{ + Title: "Title", + HeaderText: "Header[prod]", + NavText: "", + ConfigPath: "", + BuildSig: "", + TitleColor: ProxmoxOrange, + BorderColor: ProxmoxOrange, + BackgroundColor: tcell.ColorBlack, + }, tview.NewBox()) + + if !primitiveContainsText(page, tview.Escape("Header[prod]")) { + t.Fatalf("expected escaped header text") + } +} + +func TestPrimitiveContainsTextFindsBoxTitle(t *testing.T) { + box := tview.NewBox().SetTitle("Box Title") + + if !primitiveContainsText(box, "Box Title") { + t.Fatalf("expected box title to be discovered") + } +} + +func TestPrimitiveContainsTextFindsFormLabelsAndButtons(t *testing.T) { + form := tview.NewForm(). + AddInputField("Token", "", 0, nil, nil). + AddButton("Save", nil) + + if !primitiveContainsText(form, "Token") { + t.Fatalf("expected form label to be discovered") + } + if !primitiveContainsText(form, "Save") { + t.Fatalf("expected form button label to be discovered") + } +} + +func TestPrimitiveContainsTextFallsBackToModalText(t *testing.T) { + modal := tview.NewModal(). + SetText("Danger zone"). + AddButtons([]string{"Continue"}) + + if !primitiveContainsText(modal, "Danger zone") { + t.Fatalf("expected modal text to be discovered") + } + if !primitiveContainsText(modal, "Continue") { + t.Fatalf("expected modal button label to be discovered") + } +} diff --git a/internal/tui/wizard/age.go b/internal/tui/wizard/age.go index 524fcf1c..356415fb 100644 --- a/internal/tui/wizard/age.go +++ b/internal/tui/wizard/age.go @@ -25,14 +25,22 @@ type AgeSetupData struct { RecipientKey string // The final recipient key to save } +const ( + ageSetupTypeExisting = "existing" + ageSetupTypePassphrase = "passphrase" + ageSetupTypePrivateKey = "privatekey" +) + var ( // ErrAgeSetupCancelled is returned when the user aborts the AGE setup wizard. ErrAgeSetupCancelled = errors.New("encryption setup aborted by user") ) var ( - ageWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { - return app.SetRoot(root, true).SetFocus(focus).Run() + ageWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + app.SetRoot(root, true) + app.SetFocus(focus) + return app.RunWithContext(ctx) } ageMkdirAll = os.MkdirAll ageWriteFile = os.WriteFile @@ -70,53 +78,20 @@ func validatePrivateKey(value string) (string, error) { if key == "" { return "", fmt.Errorf("private key cannot be empty") } - if !strings.HasPrefix(key, "AGE-SECRET-KEY-1") { - return "", fmt.Errorf("private key must start with 'AGE-SECRET-KEY-1'") + if err := orchestrator.ValidateAgePrivateKeyString(key); err != nil { + return "", err } return key, nil } // ConfirmRecipientOverwrite shows a TUI modal to confirm overwriting an existing AGE recipient. -func ConfirmRecipientOverwrite(recipientPath, configPath, buildSig string) (bool, error) { +func ConfirmRecipientOverwrite(ctx context.Context, recipientPath, configPath, buildSig string) (bool, error) { app := tui.NewApp() overwrite := false - - welcomeText := tview.NewTextView(). - SetText("ProxSave - By TIS24DEV\nAGE Encryption Setup\n\n" + - "Configure encryption for your backups using the AGE encryption tool.\n" + - "Choose how you want to set up your encryption key.\n"). - SetTextColor(tui.ProxmoxLight). - SetDynamicColors(true) - welcomeText.SetBorder(false) - - navInstructions := tview.NewTextView(). - SetText("\n[yellow]Navigation:[white] Use [yellow]←→[white] on buttons | Press [yellow]ENTER[white] to select | Mouse clicks enabled"). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - navInstructions.SetBorder(false) - - separator := tview.NewTextView(). - SetText(strings.Repeat("─", 80)). - SetTextColor(tui.ProxmoxOrange) - separator.SetBorder(false) - - configPathText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", configPath)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - configPathText.SetBorder(false) - - buildSigText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - buildSigText.SetBorder(false) + escapedRecipientPath := tview.Escape(recipientPath) modal := tview.NewModal(). - SetText(fmt.Sprintf("Existing recipient:\n[yellow]%s[white]\n\nOverwrite with a new one?", recipientPath)). + SetText(fmt.Sprintf("Existing recipient:\n[yellow]%s[white]\n\nOverwrite with a new one?", escapedRecipientPath)). AddButtons([]string{"Overwrite", "Cancel"}). SetDoneFunc(func(buttonIndex int, buttonLabel string) { if buttonLabel == "Overwrite" { @@ -132,23 +107,18 @@ func ConfirmRecipientOverwrite(recipientPath, configPath, buildSig string) (bool SetBorderColor(tui.WarningYellow). SetBackgroundColor(tcell.ColorBlack) - flex := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(welcomeText, 5, 0, false). - AddItem(navInstructions, 2, 0, false). - AddItem(separator, 1, 0, false). - AddItem(modal, 0, 1, true). - AddItem(configPathText, 1, 0, false). - AddItem(buildSigText, 1, 0, false) - - flex.SetBorder(true). - SetTitle(" AGE Encryption Setup "). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.ProxmoxOrange). - SetBorderColor(tui.ProxmoxOrange). - SetBackgroundColor(tcell.ColorBlack) - - if err := ageWizardRunner(app, flex, modal); err != nil { + flex := buildWizardScreen( + "AGE Encryption Setup", + "ProxSave - By TIS24DEV\nAGE Encryption Setup\n\n"+ + "Configure encryption for your backups using the AGE encryption tool.\n"+ + "Choose how you want to set up your encryption key.\n", + "[yellow]Navigation:[white] Use [yellow]←→[white] on buttons | Press [yellow]ENTER[white] to select | Mouse clicks enabled", + configPath, + buildSig, + modal, + ) + + if err := ageWizardRunner(ctx, app, flex, modal); err != nil { return false, err } @@ -156,43 +126,10 @@ func ConfirmRecipientOverwrite(recipientPath, configPath, buildSig string) (bool } // ConfirmAddRecipient asks whether to add another AGE recipient. -func ConfirmAddRecipient(configPath, buildSig string, count int) (bool, error) { +func ConfirmAddRecipient(ctx context.Context, configPath, buildSig string, count int) (bool, error) { app := tui.NewApp() addAnother := false - welcomeText := tview.NewTextView(). - SetText("ProxSave - By TIS24DEV\nAGE Encryption Setup\n\n" + - "Add one or more AGE recipients for encryption.\n"). - SetTextColor(tui.ProxmoxLight). - SetDynamicColors(true) - welcomeText.SetBorder(false) - - navInstructions := tview.NewTextView(). - SetText("\n[yellow]Navigation:[white] Use [yellow]←→[white] on buttons | Press [yellow]ENTER[white] to select | Mouse clicks enabled"). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - navInstructions.SetBorder(false) - - separator := tview.NewTextView(). - SetText(strings.Repeat("─", 80)). - SetTextColor(tui.ProxmoxOrange) - separator.SetBorder(false) - - configPathText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", configPath)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - configPathText.SetBorder(false) - - buildSigText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - buildSigText.SetBorder(false) - message := fmt.Sprintf("Recipient(s) added: %d\n\nAdd another recipient?", count) modal := tview.NewModal(). SetText(message). @@ -211,23 +148,17 @@ func ConfirmAddRecipient(configPath, buildSig string, count int) (bool, error) { SetBorderColor(tui.ProxmoxOrange). SetBackgroundColor(tcell.ColorBlack) - flex := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(welcomeText, 5, 0, false). - AddItem(navInstructions, 2, 0, false). - AddItem(separator, 1, 0, false). - AddItem(modal, 0, 1, true). - AddItem(configPathText, 1, 0, false). - AddItem(buildSigText, 1, 0, false) - - flex.SetBorder(true). - SetTitle(" AGE Encryption Setup "). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.ProxmoxOrange). - SetBorderColor(tui.ProxmoxOrange). - SetBackgroundColor(tcell.ColorBlack) - - if err := ageWizardRunner(app, flex, modal); err != nil { + flex := buildWizardScreen( + "AGE Encryption Setup", + "ProxSave - By TIS24DEV\nAGE Encryption Setup\n\n"+ + "Add one or more AGE recipients for encryption.\n", + "[yellow]Navigation:[white] Use [yellow]←→[white] on buttons | Press [yellow]ENTER[white] to select | Mouse clicks enabled", + configPath, + buildSig, + modal, + ) + + if err := ageWizardRunner(ctx, app, flex, modal); err != nil { return false, err } @@ -245,43 +176,6 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig // Build the form form := components.NewForm(app) - // Welcome text - welcomeText := tview.NewTextView(). - SetText("ProxSave - By TIS24DEV\nAGE Encryption Setup\n\n" + - "Configure encryption for your backups using the AGE encryption tool.\n" + - "Choose how you want to set up your encryption key.\n"). - SetTextColor(tui.ProxmoxLight). - SetDynamicColors(true) - welcomeText.SetBorder(false) - - // Navigation instructions - navInstructions := tview.NewTextView(). - SetText("\n[yellow]Navigation:[white] TAB/↑↓ to move | ENTER to open dropdowns | ←→ on buttons | ENTER to submit | Mouse clicks enabled"). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - navInstructions.SetBorder(false) - - // Add separator - separator := tview.NewTextView(). - SetText(strings.Repeat("─", 80)). - SetTextColor(tui.ProxmoxOrange) - separator.SetBorder(false) - - configPathText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", configPath)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - configPathText.SetBorder(false) - - buildSigText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - buildSigText.SetBorder(false) - // Setup type dropdown var setupType string var publicKeyField, passphraseField, passphraseConfirmField, privateKeyField *tview.InputField @@ -295,7 +189,7 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig }, func(option string, index int) { switch index { case 0: - setupType = "existing" + setupType = ageSetupTypeExisting if publicKeyField != nil { publicKeyField.SetDisabled(false) } @@ -309,7 +203,7 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig privateKeyField.SetDisabled(true) } case 1: - setupType = "passphrase" + setupType = ageSetupTypePassphrase if publicKeyField != nil { publicKeyField.SetDisabled(true) } @@ -323,7 +217,7 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig privateKeyField.SetDisabled(true) } case 2: - setupType = "privatekey" + setupType = ageSetupTypePrivateKey if publicKeyField != nil { publicKeyField.SetDisabled(true) } @@ -386,7 +280,7 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig form.Form.AddFormItem(privateKeyField) // Initialize with "existing" type selected - setupType = "existing" + setupType = ageSetupTypeExisting passphraseField.SetDisabled(true) passphraseConfirmField.SetDisabled(true) privateKeyField.SetDisabled(true) @@ -396,7 +290,7 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig data.SetupType = setupType switch setupType { - case "existing": + case ageSetupTypeExisting: publicKey, err := validatePublicKey(publicKeyField.GetText()) if err != nil { return err @@ -404,14 +298,14 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig data.PublicKey = publicKey data.RecipientKey = publicKey - case "passphrase": + case ageSetupTypePassphrase: passphrase, err := validatePassphrase(passphraseField.GetText(), passphraseConfirmField.GetText()) if err != nil { return err } data.Passphrase = passphrase - case "privatekey": + case ageSetupTypePrivateKey: privateKey, err := validatePrivateKey(privateKeyField.GetText()) if err != nil { return err @@ -463,31 +357,25 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig return event }) - // Create layout - flex := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(welcomeText, 5, 0, false). - AddItem(navInstructions, 2, 0, false). - AddItem(separator, 1, 0, false). - AddItem(form.Form, 0, 1, true) - // Footers - flex.AddItem(configPathText, 1, 0, false). - AddItem(buildSigText, 1, 0, false) - - flex.SetBorder(true). - SetTitle(" AGE Encryption Setup "). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.ProxmoxOrange). - SetBorderColor(tui.ProxmoxOrange). - SetBackgroundColor(tcell.ColorBlack) + flex := buildWizardScreen( + "AGE Encryption Setup", + "ProxSave - By TIS24DEV\nAGE Encryption Setup\n\n"+ + "Configure encryption for your backups using the AGE encryption tool.\n"+ + "Choose how you want to set up your encryption key.\n", + "[yellow]Navigation:[white] TAB/↑↓ to move | ENTER to open dropdowns | ←→ on buttons | ENTER to submit | Mouse clicks enabled", + configPath, + buildSig, + form.Form, + ) // Set the parent view for inline error display, then add buttons form.SetParentView(flex) form.AddSubmitButton("Continue") form.AddCancelButton("Cancel") - // Run the app - ignore errors from normal app termination - _ = ageWizardRunner(app, flex, form.Form) + if err := ageWizardRunner(ctx, app, flex, form.Form); err != nil { + return nil, err + } if data == nil { return nil, ErrAgeSetupCancelled diff --git a/internal/tui/wizard/age_test.go b/internal/tui/wizard/age_test.go index 1a64a1fc..4e5b0b50 100644 --- a/internal/tui/wizard/age_test.go +++ b/internal/tui/wizard/age_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + "filippo.io/age" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" "golang.org/x/crypto/ssh" @@ -82,15 +83,21 @@ func TestValidatePassphrase(t *testing.T) { } func TestValidatePrivateKey(t *testing.T) { + identity, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + cases := []struct { name string input string want string wantErr bool }{ - {name: "valid", input: " AGE-SECRET-KEY-1abc ", want: "AGE-SECRET-KEY-1abc"}, + {name: "valid", input: " " + identity.String() + " ", want: identity.String()}, {name: "empty", input: "", wantErr: true}, {name: "wrong prefix", input: "SECRET", wantErr: true}, + {name: "invalid body", input: "AGE-SECRET-KEY-1invalid", wantErr: true}, } for _, tc := range cases { @@ -220,13 +227,13 @@ func TestConfirmRecipientOverwriteSelection(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - ageWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + ageWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { done := extractModalDone(focus.(*tview.Modal)) done(0, tc.button) return nil } - got, err := ConfirmRecipientOverwrite("/tmp/recipient.age", "/etc/proxsave/.env", "sig-xyz") + got, err := ConfirmRecipientOverwrite(context.Background(), "/tmp/recipient.age", "/etc/proxsave/.env", "sig-xyz") if err != nil { t.Fatalf("ConfirmRecipientOverwrite returned error: %v", err) } @@ -242,12 +249,12 @@ func TestConfirmRecipientOverwriteModalIncludesRecipientPath(t *testing.T) { defer func() { ageWizardRunner = originalRunner }() var modalText string - ageWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + ageWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { modalText = extractModalText(focus.(*tview.Modal)) return nil } - _, err := ConfirmRecipientOverwrite("/var/lib/proxsave/recipient.age", "/etc/.env", "sig") + _, err := ConfirmRecipientOverwrite(context.Background(), "/var/lib/proxsave/recipient.age", "/etc/.env", "sig") if err != nil { t.Fatalf("ConfirmRecipientOverwrite returned error: %v", err) } @@ -260,11 +267,30 @@ func TestConfirmRecipientOverwriteRunnerError(t *testing.T) { originalRunner := ageWizardRunner defer func() { ageWizardRunner = originalRunner }() - ageWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + ageWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + return errors.New("boom") + } + + if _, err := ConfirmRecipientOverwrite(context.Background(), "/tmp/recipient.age", "/etc/.env", "sig"); err == nil { + t.Fatalf("expected error from runner") + } +} + +func TestConfirmRecipientOverwritePassesContextToRunner(t *testing.T) { + originalRunner := ageWizardRunner + defer func() { ageWizardRunner = originalRunner }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ageWizardRunner = func(gotCtx context.Context, app *tui.App, root, focus tview.Primitive) error { + if gotCtx != ctx { + t.Fatalf("ctx=%p; want %p", gotCtx, ctx) + } return errors.New("boom") } - if _, err := ConfirmRecipientOverwrite("/tmp/recipient.age", "/etc/.env", "sig"); err == nil { + if _, err := ConfirmRecipientOverwrite(ctx, "/tmp/recipient.age", "/etc/.env", "sig"); err == nil { t.Fatalf("expected error from runner") } } @@ -285,13 +311,13 @@ func TestConfirmAddRecipientSelection(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - ageWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + ageWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { done := extractModalDone(focus.(*tview.Modal)) done(0, tc.button) return nil } - got, err := ConfirmAddRecipient("/etc/proxsave/.env", "sig-xyz", 2) + got, err := ConfirmAddRecipient(context.Background(), "/etc/proxsave/.env", "sig-xyz", 2) if err != nil { t.Fatalf("ConfirmAddRecipient returned error: %v", err) } @@ -307,12 +333,12 @@ func TestConfirmAddRecipientModalIncludesCount(t *testing.T) { defer func() { ageWizardRunner = originalRunner }() var modalText string - ageWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + ageWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { modalText = extractModalText(focus.(*tview.Modal)) return nil } - _, err := ConfirmAddRecipient("/etc/proxsave/.env", "sig", 3) + _, err := ConfirmAddRecipient(context.Background(), "/etc/proxsave/.env", "sig", 3) if err != nil { t.Fatalf("ConfirmAddRecipient returned error: %v", err) } @@ -321,6 +347,25 @@ func TestConfirmAddRecipientModalIncludesCount(t *testing.T) { } } +func TestConfirmAddRecipientPassesContextToRunner(t *testing.T) { + originalRunner := ageWizardRunner + defer func() { ageWizardRunner = originalRunner }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ageWizardRunner = func(gotCtx context.Context, app *tui.App, root, focus tview.Primitive) error { + if gotCtx != ctx { + t.Fatalf("ctx=%p; want %p", gotCtx, ctx) + } + return errors.New("boom") + } + + if _, err := ConfirmAddRecipient(ctx, "/etc/proxsave/.env", "sig", 3); err == nil { + t.Fatalf("expected error from runner") + } +} + func TestRunAgeSetupWizardExistingKey(t *testing.T) { validAge := "age1ql3z7hjy54pw3hyww5ayyfg7zqgvc7w3j2elw8zmrj2kg5sfn9aqmcac8p" data, err := runAgeWizardTest(t, func(form *tview.Form) { @@ -333,7 +378,7 @@ func TestRunAgeSetupWizardExistingKey(t *testing.T) { if err != nil { t.Fatalf("RunAgeSetupWizard returned error: %v", err) } - if data.SetupType != "existing" { + if data.SetupType != ageSetupTypeExisting { t.Fatalf("unexpected setup type: %s", data.SetupType) } if data.RecipientKey != validAge { @@ -354,7 +399,7 @@ func TestRunAgeSetupWizardPassphrase(t *testing.T) { if err != nil { t.Fatalf("RunAgeSetupWizard returned error: %v", err) } - if data.SetupType != "passphrase" { + if data.SetupType != ageSetupTypePassphrase { t.Fatalf("unexpected setup type: %s", data.SetupType) } if data.Passphrase != "CorrectHorse1!" { @@ -366,20 +411,25 @@ func TestRunAgeSetupWizardPassphrase(t *testing.T) { } func TestRunAgeSetupWizardPrivateKey(t *testing.T) { + identity, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("GenerateX25519Identity: %v", err) + } + data, err := runAgeWizardTest(t, func(form *tview.Form) { drop := form.GetFormItem(0).(*tview.DropDown) drop.SetCurrentOption(2) privateField := form.GetFormItem(4).(*tview.InputField) - privateField.SetText("AGE-SECRET-KEY-1valid") + privateField.SetText(identity.String()) pressFormButton(t, form, "Continue") }) if err != nil { t.Fatalf("RunAgeSetupWizard returned error: %v", err) } - if data.SetupType != "privatekey" { + if data.SetupType != ageSetupTypePrivateKey { t.Fatalf("unexpected setup type: %s", data.SetupType) } - if data.PrivateKey != "AGE-SECRET-KEY-1valid" { + if data.PrivateKey != identity.String() { t.Fatalf("expected private key saved, got %q", data.PrivateKey) } if data.Passphrase != "" || data.PublicKey != "" { @@ -399,10 +449,24 @@ func TestRunAgeSetupWizardCancel(t *testing.T) { } } +func TestRunAgeSetupWizardRunnerError(t *testing.T) { + originalRunner := ageWizardRunner + defer func() { ageWizardRunner = originalRunner }() + + expected := errors.New("boom") + ageWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + return expected + } + + if _, err := RunAgeSetupWizard(context.Background(), "/tmp/recipient.age", "/etc/proxsave/config.env", "sig-test"); !errors.Is(err, expected) { + t.Fatalf("err=%v; want %v", err, expected) + } +} + func runAgeWizardTest(t *testing.T, configure func(form *tview.Form)) (*AgeSetupData, error) { t.Helper() originalRunner := ageWizardRunner - ageWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + ageWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { form, ok := focus.(*tview.Form) if !ok { t.Fatalf("expected *tview.Form focus, got %T", focus) @@ -414,6 +478,23 @@ func runAgeWizardTest(t *testing.T, configure func(form *tview.Form)) (*AgeSetup return RunAgeSetupWizard(context.Background(), "/tmp/recipient.age", "/etc/proxsave/config.env", "sig-test") } +func TestRunAgeSetupWizard_PassesContextToRunner(t *testing.T) { + originalRunner := ageWizardRunner + defer func() { ageWizardRunner = originalRunner }() + + ctx := t.Context() + ageWizardRunner = func(gotCtx context.Context, app *tui.App, root, focus tview.Primitive) error { + if gotCtx != ctx { + t.Fatalf("ctx=%p; want %p", gotCtx, ctx) + } + return ErrAgeSetupCancelled + } + + if _, err := RunAgeSetupWizard(ctx, "/tmp/recipient.age", "/etc/proxsave/config.env", "sig-test"); !errors.Is(err, ErrAgeSetupCancelled) { + t.Fatalf("err=%v; want %v", err, ErrAgeSetupCancelled) + } +} + func pressFormButton(t *testing.T, form *tview.Form, label string) { t.Helper() index := form.GetButtonIndex(label) diff --git a/internal/tui/wizard/age_ui_adapter.go b/internal/tui/wizard/age_ui_adapter.go new file mode 100644 index 00000000..50769772 --- /dev/null +++ b/internal/tui/wizard/age_ui_adapter.go @@ -0,0 +1,68 @@ +package wizard + +import ( + "context" + "errors" + "fmt" + + "github.com/tis24dev/proxsave/internal/orchestrator" +) + +type ageSetupUIAdapter struct { + configPath string + buildSig string +} + +func NewAgeSetupUI(configPath, buildSig string) orchestrator.AgeSetupUI { + return &ageSetupUIAdapter{ + configPath: configPath, + buildSig: buildSig, + } +} + +func (a *ageSetupUIAdapter) ConfirmOverwriteExistingRecipient(ctx context.Context, recipientPath string) (bool, error) { + if err := ctx.Err(); err != nil { + return false, err + } + return ConfirmRecipientOverwrite(ctx, recipientPath, a.configPath, a.buildSig) +} + +func (a *ageSetupUIAdapter) CollectRecipientDraft(ctx context.Context, recipientPath string) (*orchestrator.AgeRecipientDraft, error) { + data, err := RunAgeSetupWizard(ctx, recipientPath, a.configPath, a.buildSig) + if err != nil { + if errors.Is(err, ErrAgeSetupCancelled) { + return nil, orchestrator.ErrAgeRecipientSetupAborted + } + return nil, err + } + if data == nil { + return nil, orchestrator.ErrAgeRecipientSetupAborted + } + + switch data.SetupType { + case ageSetupTypeExisting: + return &orchestrator.AgeRecipientDraft{ + Kind: orchestrator.AgeRecipientInputExisting, + PublicKey: data.PublicKey, + }, nil + case ageSetupTypePassphrase: + return &orchestrator.AgeRecipientDraft{ + Kind: orchestrator.AgeRecipientInputPassphrase, + Passphrase: data.Passphrase, + }, nil + case ageSetupTypePrivateKey: + return &orchestrator.AgeRecipientDraft{ + Kind: orchestrator.AgeRecipientInputPrivateKey, + PrivateKey: data.PrivateKey, + }, nil + default: + return nil, fmt.Errorf("unknown AGE setup type: %s", data.SetupType) + } +} + +func (a *ageSetupUIAdapter) ConfirmAddAnotherRecipient(ctx context.Context, currentCount int) (bool, error) { + if err := ctx.Err(); err != nil { + return false, err + } + return ConfirmAddRecipient(ctx, a.configPath, a.buildSig, currentCount) +} diff --git a/internal/tui/wizard/age_ui_adapter_test.go b/internal/tui/wizard/age_ui_adapter_test.go new file mode 100644 index 00000000..f20f1d7f --- /dev/null +++ b/internal/tui/wizard/age_ui_adapter_test.go @@ -0,0 +1,92 @@ +package wizard + +import ( + "context" + "errors" + "testing" + + "github.com/rivo/tview" + + "github.com/tis24dev/proxsave/internal/orchestrator" + "github.com/tis24dev/proxsave/internal/tui" +) + +func registerAgeWizardRunner(t *testing.T, runner func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error) { + t.Helper() + + originalRunner := ageWizardRunner + ageWizardRunner = runner + t.Cleanup(func() { + ageWizardRunner = originalRunner + }) +} + +func TestAgeSetupUIAdapterCollectRecipientDraftCancelMapsAbort(t *testing.T) { + registerAgeWizardRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + form, ok := focus.(*tview.Form) + if !ok { + t.Fatalf("expected *tview.Form focus, got %T", focus) + } + pressFormButton(t, form, "Cancel") + return nil + }) + + ui := NewAgeSetupUI("/etc/proxsave/config.env", "sig-test") + draft, err := ui.CollectRecipientDraft(context.Background(), "/tmp/recipient.age") + if !errors.Is(err, orchestrator.ErrAgeRecipientSetupAborted) { + t.Fatalf("err=%v; want %v", err, orchestrator.ErrAgeRecipientSetupAborted) + } + if draft != nil { + t.Fatalf("draft=%+v; want nil", draft) + } +} + +func TestAgeSetupUIAdapterCollectRecipientDraftRunnerError(t *testing.T) { + expected := errors.New("boom") + registerAgeWizardRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + return expected + }) + + ui := NewAgeSetupUI("/etc/proxsave/config.env", "sig-test") + if _, err := ui.CollectRecipientDraft(context.Background(), "/tmp/recipient.age"); !errors.Is(err, expected) { + t.Fatalf("err=%v; want %v", err, expected) + } +} + +func TestAgeSetupUIAdapterConfirmOverwriteExistingRecipientCanceledContext(t *testing.T) { + registerAgeWizardRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + t.Fatal("ageWizardRunner should not be called when context is already canceled") + return nil + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + ui := NewAgeSetupUI("/etc/proxsave/config.env", "sig-test") + confirmed, err := ui.ConfirmOverwriteExistingRecipient(ctx, "/tmp/recipient.age") + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want %v", err, context.Canceled) + } + if confirmed { + t.Fatalf("confirmed=%t; want false", confirmed) + } +} + +func TestAgeSetupUIAdapterConfirmAddAnotherRecipientCanceledContext(t *testing.T) { + registerAgeWizardRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + t.Fatal("ageWizardRunner should not be called when context is already canceled") + return nil + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + ui := NewAgeSetupUI("/etc/proxsave/config.env", "sig-test") + confirmed, err := ui.ConfirmAddAnotherRecipient(ctx, 1) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want %v", err, context.Canceled) + } + if confirmed { + t.Fatalf("confirmed=%t; want false", confirmed) + } +} diff --git a/internal/tui/wizard/install.go b/internal/tui/wizard/install.go index 66484b99..19c661de 100644 --- a/internal/tui/wizard/install.go +++ b/internal/tui/wizard/install.go @@ -6,30 +6,29 @@ import ( "errors" "fmt" "os" - "path/filepath" - "strconv" "strings" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" "github.com/tis24dev/proxsave/internal/config" + cronutil "github.com/tis24dev/proxsave/internal/cron" "github.com/tis24dev/proxsave/internal/tui" "github.com/tis24dev/proxsave/internal/tui/components" "github.com/tis24dev/proxsave/pkg/utils" ) type installWizardPrefill struct { - SecondaryEnabled bool - SecondaryPath string - SecondaryLogPath string - CloudEnabled bool - CloudRemote string - CloudLogPath string - FirewallEnabled bool - TelegramEnabled bool - EmailEnabled bool - EncryptionEnabled bool + SecondaryEnabled bool + SecondaryPath string + SecondaryLogPath string + CloudEnabled bool + CloudRemote string + CloudLogPath string + FirewallEnabled bool + TelegramEnabled bool + EmailEnabled bool + EncryptionEnabled bool } // InstallWizardData holds the collected installation data @@ -52,16 +51,26 @@ type InstallWizardData struct { type ExistingConfigAction int const ( - ExistingConfigOverwrite ExistingConfigAction = iota // Start from embedded template (overwrite) - ExistingConfigEdit // Keep existing file as base and edit - ExistingConfigSkip // Leave the file untouched and skip wizard + ExistingConfigOverwrite ExistingConfigAction = iota // Start from embedded template (overwrite) + ExistingConfigEdit // Keep existing file as base and edit + ExistingConfigKeepContinue // Leave file untouched and continue installation + ExistingConfigCancel // Abort installation ) var ( // ErrInstallCancelled is returned when the user aborts the install wizard. - ErrInstallCancelled = errors.New("installation aborted by user") - checkExistingConfigRunner = func(app *tui.App, root, focus tview.Primitive) error { - return app.SetRoot(root, true).SetFocus(focus).Run() + ErrInstallCancelled = errors.New("installation aborted by user") + // ErrNilInstallData is returned when ApplyInstallData or its validators receive a nil payload. + ErrNilInstallData = errors.New("install wizard data cannot be nil") + runInstallWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + app.SetRoot(root, true) + app.SetFocus(focus) + return app.RunWithContext(ctx) + } + checkExistingConfigRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + app.SetRoot(root, true) + app.SetFocus(focus) + return app.RunWithContext(ctx) } ) @@ -71,7 +80,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu data := &InstallWizardData{ BaseDir: baseDir, ConfigPath: configPath, - CronTime: "02:00", + CronTime: cronutil.DefaultTime, EnableEncryption: false, // Default to disabled BackupFirewallRules: &defaultFirewallRules, } @@ -83,29 +92,6 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu // Build the form form := components.NewForm(app) - // Welcome text - welcomeText := tview.NewTextView(). - SetText("Welcome to ProxSave Installation Wizard - By TIS24DEV\n\n" + - "This wizard will guide you through configuring your backup system for Proxmox.\n" + - "All settings can be changed later by editing the configuration file."). - SetTextColor(tui.ProxmoxLight). - SetDynamicColors(true) - welcomeText.SetBorder(false) - - // Navigation instructions - navInstructions := tview.NewTextView(). - SetText("[yellow]Navigation:[white] TAB/↑↓ to move | ENTER to open dropdowns | ←→ on buttons | ENTER to submit | Mouse clicks enabled"). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - navInstructions.SetBorder(false) - - // Add separator - separator := tview.NewTextView(). - SetText(strings.Repeat("─", 80)). - SetTextColor(tui.ProxmoxOrange) - separator.SetBorder(false) - // Track if any dropdown is currently open var dropdownOpen bool @@ -327,15 +313,10 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu // Collect data data.EnableSecondaryStorage = secondaryEnabled if secondaryEnabled { - data.SecondaryPath = secondaryPathField.GetText() - data.SecondaryLogPath = secondaryLogField.GetText() - - // Validate paths - if !filepath.IsAbs(data.SecondaryPath) { - return fmt.Errorf("secondary backup path must be absolute") - } - if !filepath.IsAbs(data.SecondaryLogPath) { - return fmt.Errorf("secondary log path must be absolute") + data.SecondaryPath = strings.TrimSpace(secondaryPathField.GetText()) + data.SecondaryLogPath = strings.TrimSpace(secondaryLogField.GetText()) + if err := validateSecondaryInstallData(data); err != nil { + return err } } @@ -370,24 +351,11 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu // Get encryption setting data.EnableEncryption = values["Enable Backup Encryption (AGE)"] == "Yes" - // Cron time validation (HH:MM) - cron := strings.TrimSpace(cronField.GetText()) - if cron == "" { - cron = "02:00" - } - parts := strings.Split(cron, ":") - if len(parts) != 2 { - return fmt.Errorf("cron time must be in HH:MM format") - } - hour, err := strconv.Atoi(parts[0]) - if err != nil || hour < 0 || hour > 23 { - return fmt.Errorf("cron hour must be between 00 and 23") - } - minute, err := strconv.Atoi(parts[1]) - if err != nil || minute < 0 || minute > 59 { - return fmt.Errorf("cron minute must be between 00 and 59") + normalizedCron, err := cronutil.NormalizeTime(cronField.GetText(), cronutil.DefaultTime) + if err != nil { + return err } - data.CronTime = fmt.Sprintf("%02d:%02d", hour, minute) + data.CronTime = normalizedCron return nil }) @@ -437,40 +405,20 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu return event }) - // Config path footer - configPathText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", configPath)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - configPathText.SetBorder(false) - - buildSigText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - buildSigText.SetBorder(false) - - // Create layout - flex := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(welcomeText, 5, 0, false). - AddItem(navInstructions, 2, 0, false). - AddItem(separator, 1, 0, false). - AddItem(form.Form, 0, 1, true). - AddItem(configPathText, 1, 0, false). - AddItem(buildSigText, 1, 0, false) - - flex.SetBorder(true). - SetTitle(" ProxSave Installation "). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.ProxmoxOrange). - SetBorderColor(tui.ProxmoxOrange). - SetBackgroundColor(tcell.ColorBlack) - - // Run the app - ignore errors from normal app termination - _ = app.SetRoot(flex, true).SetFocus(form.Form).Run() + flex := buildWizardScreen( + "ProxSave Installation", + "Welcome to ProxSave Installation Wizard - By TIS24DEV\n\n"+ + "This wizard will guide you through configuring your backup system for Proxmox.\n"+ + "All settings can be changed later by editing the configuration file.", + "[yellow]Navigation:[white] TAB/↑↓ to move | ENTER to open dropdowns | ←→ on buttons | ENTER to submit | Mouse clicks enabled", + configPath, + buildSig, + form.Form, + ) + + if err := runInstallWizardRunner(ctx, app, flex, form.Form); err != nil { + return nil, err + } if data == nil { return nil, ErrInstallCancelled @@ -482,6 +430,10 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu // ApplyInstallData applies the collected data to the config template. // If baseTemplate is empty, the embedded default template is used. func ApplyInstallData(baseTemplate string, data *InstallWizardData) (string, error) { + if data == nil { + return "", ErrNilInstallData + } + template := baseTemplate editingExisting := strings.TrimSpace(baseTemplate) != "" existingValues := map[string]string{} @@ -491,6 +443,9 @@ func ApplyInstallData(baseTemplate string, data *InstallWizardData) (string, err if strings.TrimSpace(template) == "" { template = config.DefaultEnvTemplate() } + if err := validateSecondaryInstallData(data); err != nil { + return "", err + } // BASE_DIR is auto-detected at runtime from the executable/config location. // Keep it out of backup.env to avoid pinning the installation to a specific path. @@ -500,13 +455,12 @@ func ApplyInstallData(baseTemplate string, data *InstallWizardData) (string, err template = unsetEnvValue(template, "CRON_MINUTE") // Apply secondary storage - if data.EnableSecondaryStorage { - template = setEnvValue(template, "SECONDARY_ENABLED", "true") - template = setEnvValue(template, "SECONDARY_PATH", data.SecondaryPath) - template = setEnvValue(template, "SECONDARY_LOG_PATH", data.SecondaryLogPath) - } else { - template = setEnvValue(template, "SECONDARY_ENABLED", "false") - } + template = config.ApplySecondaryStorageSettings( + template, + data.EnableSecondaryStorage, + data.SecondaryPath, + data.SecondaryLogPath, + ) // Apply cloud storage if data.EnableCloudStorage { @@ -562,6 +516,22 @@ func ApplyInstallData(baseTemplate string, data *InstallWizardData) (string, err return template, nil } +func validateSecondaryInstallData(data *InstallWizardData) error { + if data == nil { + return ErrNilInstallData + } + if !data.EnableSecondaryStorage { + return nil + } + if err := config.ValidateRequiredSecondaryPath(data.SecondaryPath); err != nil { + return err + } + if err := config.ValidateOptionalSecondaryLogPath(data.SecondaryLogPath); err != nil { + return err + } + return nil +} + // setEnvValue sets or updates an environment variable in the template func setEnvValue(template, key, value string) string { return utils.SetEnvValue(template, key, value) @@ -684,42 +654,16 @@ func readTemplateBool(values map[string]string, keys ...string) bool { } // CheckExistingConfig checks if config file exists and asks how to proceed -func CheckExistingConfig(configPath string, buildSig string) (ExistingConfigAction, error) { - if _, err := os.Stat(configPath); err == nil { +func CheckExistingConfig(ctx context.Context, configPath string, buildSig string) (ExistingConfigAction, error) { + if info, err := os.Stat(configPath); err == nil { + if !info.Mode().IsRegular() { + return ExistingConfigCancel, fmt.Errorf("configuration file path is not a regular file: %s", configPath) + } + // File exists, ask how to proceed app := tui.NewApp() - action := ExistingConfigSkip - - // Welcome text (same as main wizard) - welcomeText := tview.NewTextView(). - SetText("Welcome to ProxSave Installation Wizard - By TIS24DEV\n\n" + - "This wizard will guide you through configuring your backup system for Proxmox.\n" + - "All settings can be changed later by editing the configuration file."). - SetTextColor(tui.ProxmoxLight). - SetDynamicColors(true) - welcomeText.SetBorder(false) - - // Navigation instructions (no dropdowns in this view) - navInstructions := tview.NewTextView(). - SetText("[yellow]Navigation:[white] Press [yellow]TAB[white] or [yellow]↑↓[white] to move between fields | " + - "Use [yellow]←→[white] on buttons | Press [yellow]ENTER[white] to submit | Mouse clicks enabled"). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - navInstructions.SetBorder(false) - - buildSigText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - buildSigText.SetBorder(false) - - // Separator - separator := tview.NewTextView(). - SetText(strings.Repeat("─", 80)). - SetTextColor(tui.ProxmoxOrange) - separator.SetBorder(false) + action := ExistingConfigCancel + escapedConfigPath := tview.Escape(configPath) // Confirmation modal modal := tview.NewModal(). @@ -727,16 +671,19 @@ func CheckExistingConfig(configPath string, buildSig string) (ExistingConfigActi "Choose how to proceed:\n"+ "[yellow]Overwrite[white] - Start from embedded template\n"+ "[yellow]Edit existing[white] - Keep current file as base\n"+ - "[yellow]Keep & exit[white] - Leave file untouched, exit wizard", configPath)). - AddButtons([]string{"Overwrite", "Edit existing", "Keep & exit"}). + "[yellow]Keep & continue[white] - Leave file untouched, continue install\n"+ + "[yellow]Cancel[white] - Exit installation", escapedConfigPath)). + AddButtons([]string{"Overwrite", "Edit existing", "Keep & continue", "Cancel"}). SetDoneFunc(func(buttonIndex int, buttonLabel string) { switch buttonLabel { case "Overwrite": action = ExistingConfigOverwrite case "Edit existing": action = ExistingConfigEdit + case "Keep & continue": + action = ExistingConfigKeepContinue default: - action = ExistingConfigSkip + action = ExistingConfigCancel } app.Stop() }) @@ -747,29 +694,26 @@ func CheckExistingConfig(configPath string, buildSig string) (ExistingConfigActi SetTitleColor(tui.WarningYellow). SetBorderColor(tui.WarningYellow). SetBackgroundColor(tcell.ColorBlack) - - // Create layout with welcome text at top - flex := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(welcomeText, 5, 0, false). - AddItem(navInstructions, 2, 0, false). - AddItem(separator, 1, 0, false). - AddItem(modal, 0, 1, true). - AddItem(buildSigText, 1, 0, false) - - flex.SetBorder(true). - SetTitle(" ProxSave Installation "). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.ProxmoxOrange). - SetBorderColor(tui.ProxmoxOrange). - SetBackgroundColor(tcell.ColorBlack) - - // Run the modal - ignore errors from normal app termination - _ = checkExistingConfigRunner(app, flex, modal) + modal.SetFocus(2) + + flex := buildWizardScreen( + "ProxSave Installation", + "Welcome to ProxSave Installation Wizard - By TIS24DEV\n\n"+ + "This wizard will guide you through configuring your backup system for Proxmox.\n"+ + "All settings can be changed later by editing the configuration file.", + "[yellow]Navigation:[white] Press [yellow]TAB[white] or [yellow]↑↓[white] to move between fields | Use [yellow]←→[white] on buttons | Press [yellow]ENTER[white] to submit | Mouse clicks enabled", + "", + buildSig, + modal, + ) + + if err := checkExistingConfigRunner(ctx, app, flex, modal); err != nil { + return ExistingConfigCancel, err + } return action, nil } else if !os.IsNotExist(err) { - return ExistingConfigSkip, err + return ExistingConfigCancel, err } return ExistingConfigOverwrite, nil // File doesn't exist, proceed diff --git a/internal/tui/wizard/install_test.go b/internal/tui/wizard/install_test.go index 4a7c9602..1e22a4dc 100644 --- a/internal/tui/wizard/install_test.go +++ b/internal/tui/wizard/install_test.go @@ -1,13 +1,17 @@ package wizard import ( + "context" + "errors" "os" "path/filepath" "strings" "testing" + "github.com/gdamore/tcell/v2" "github.com/rivo/tview" + cronutil "github.com/tis24dev/proxsave/internal/cron" "github.com/tis24dev/proxsave/internal/tui" ) @@ -102,6 +106,110 @@ func TestApplyInstallDataDefaultsBaseTemplate(t *testing.T) { } } +func TestApplyInstallDataRejectsNilData(t *testing.T) { + _, err := ApplyInstallData("", nil) + if !errors.Is(err, ErrNilInstallData) { + t.Fatalf("ApplyInstallData error = %v, want %v", err, ErrNilInstallData) + } +} + +func TestApplyInstallDataAllowsEmptySecondaryLogPath(t *testing.T) { + data := &InstallWizardData{ + BaseDir: "/tmp/base", + EnableSecondaryStorage: true, + SecondaryPath: "/mnt/sec", + SecondaryLogPath: "", + } + + result, err := ApplyInstallData("", data) + if err != nil { + t.Fatalf("ApplyInstallData returned error: %v", err) + } + if !strings.Contains(result, "SECONDARY_ENABLED=true") { + t.Fatalf("expected secondary enabled in result:\n%s", result) + } + if !strings.Contains(result, "SECONDARY_PATH=/mnt/sec") { + t.Fatalf("expected secondary path in result:\n%s", result) + } + if !strings.Contains(result, "SECONDARY_LOG_PATH=") { + t.Fatalf("expected empty secondary log path in result:\n%s", result) + } +} + +func TestApplyInstallDataDisabledSecondaryClearsExistingValues(t *testing.T) { + baseTemplate := strings.Join([]string{ + "SECONDARY_ENABLED=true", + "SECONDARY_PATH=/mnt/old-secondary", + "SECONDARY_LOG_PATH=/mnt/old-secondary/logs", + "TELEGRAM_ENABLED=false", + "EMAIL_ENABLED=false", + "ENCRYPT_ARCHIVE=false", + "", + }, "\n") + data := &InstallWizardData{ + BaseDir: "/tmp/base", + EnableSecondaryStorage: false, + } + + result, err := ApplyInstallData(baseTemplate, data) + if err != nil { + t.Fatalf("ApplyInstallData returned error: %v", err) + } + + for _, needle := range []string{ + "SECONDARY_ENABLED=false", + "SECONDARY_PATH=", + "SECONDARY_LOG_PATH=", + } { + if !strings.Contains(result, needle) { + t.Fatalf("expected %q in result:\n%s", needle, result) + } + } + if strings.Contains(result, "/mnt/old-secondary") { + t.Fatalf("expected old secondary values to be cleared:\n%s", result) + } +} + +func TestApplyInstallDataRejectsInvalidSecondaryPath(t *testing.T) { + data := &InstallWizardData{ + BaseDir: "/tmp/base", + EnableSecondaryStorage: true, + SecondaryPath: "relative/path", + } + + _, err := ApplyInstallData("", data) + if err == nil { + t.Fatal("expected ApplyInstallData to fail") + } + if got, want := err.Error(), "SECONDARY_PATH must be an absolute local filesystem path"; got != want { + t.Fatalf("ApplyInstallData error = %q, want %q", got, want) + } +} + +func TestApplyInstallDataRejectsInvalidSecondaryLogPath(t *testing.T) { + data := &InstallWizardData{ + BaseDir: "/tmp/base", + EnableSecondaryStorage: true, + SecondaryPath: "/mnt/sec", + SecondaryLogPath: "remote:/logs", + } + + _, err := ApplyInstallData("", data) + if err == nil { + t.Fatal("expected ApplyInstallData to fail") + } + if got, want := err.Error(), "SECONDARY_LOG_PATH must be an absolute local filesystem path"; got != want { + t.Fatalf("ApplyInstallData error = %q, want %q", got, want) + } +} + +func TestValidateSecondaryInstallDataRejectsNilData(t *testing.T) { + err := validateSecondaryInstallData(nil) + if !errors.Is(err, ErrNilInstallData) { + t.Fatalf("validateSecondaryInstallData error = %v, want %v", err, ErrNilInstallData) + } +} + func TestApplyInstallDataCronAndNotifications(t *testing.T) { baseTemplate := "CRON_SCHEDULE=\nCRON_HOUR=\nCRON_MINUTE=\nTELEGRAM_ENABLED=true\nEMAIL_ENABLED=false\nENCRYPT_ARCHIVE=true\n" data := &InstallWizardData{ @@ -132,6 +240,40 @@ func TestApplyInstallDataCronAndNotifications(t *testing.T) { assertContains("ENCRYPT_ARCHIVE", "false") } +func TestRunInstallWizardBlankCronIgnoresEnvOverride(t *testing.T) { + t.Setenv("CRON_SCHEDULE", "5 1 * * *") + + originalRunner := runInstallWizardRunner + t.Cleanup(func() { runInstallWizardRunner = originalRunner }) + + runInstallWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + if ctx != t.Context() { + t.Fatalf("ctx=%p; want %p", ctx, t.Context()) + } + form, ok := focus.(*tview.Form) + if !ok { + t.Fatalf("focus primitive = %T, want *tview.Form", focus) + } + button := form.GetButton(0) + if button == nil { + t.Fatal("expected install button") + } + button.InputHandler()(tcell.NewEventKey(tcell.KeyEnter, 0, tcell.ModNone), nil) + return nil + } + + data, err := RunInstallWizard(t.Context(), "/tmp/proxsave/backup.env", "/opt/proxsave", "sig", "") + if err != nil { + t.Fatalf("RunInstallWizard returned error: %v", err) + } + if data == nil { + t.Fatal("expected wizard data") + } + if data.CronTime != cronutil.DefaultTime { + t.Fatalf("CronTime = %q, want %q", data.CronTime, cronutil.DefaultTime) + } +} + func TestCheckExistingConfigActions(t *testing.T) { tmp := t.TempDir() configPath := filepath.Join(tmp, "prox.env") @@ -149,19 +291,20 @@ func TestCheckExistingConfigActions(t *testing.T) { }{ {name: "overwrite", button: "Overwrite", want: ExistingConfigOverwrite}, {name: "edit existing", button: "Edit existing", want: ExistingConfigEdit}, - {name: "keep", button: "Keep & exit", want: ExistingConfigSkip}, + {name: "keep continue", button: "Keep & continue", want: ExistingConfigKeepContinue}, + {name: "cancel", button: "Cancel", want: ExistingConfigCancel}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - checkExistingConfigRunner = func(app *tui.App, root, focus tview.Primitive) error { + checkExistingConfigRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { done := extractModalDone(focus.(*tview.Modal)) done(0, tc.button) return nil } - action, err := CheckExistingConfig(configPath, "sig-abc") + action, err := CheckExistingConfig(context.Background(), configPath, "sig-abc") if err != nil { t.Fatalf("CheckExistingConfig returned error: %v", err) } @@ -174,7 +317,7 @@ func TestCheckExistingConfigActions(t *testing.T) { func TestCheckExistingConfigMissingFileDefaultsToOverwrite(t *testing.T) { configPath := filepath.Join(t.TempDir(), "absent.env") - action, err := CheckExistingConfig(configPath, "sig") + action, err := CheckExistingConfig(context.Background(), configPath, "sig") if err != nil { t.Fatalf("CheckExistingConfig returned error: %v", err) } @@ -185,11 +328,127 @@ func TestCheckExistingConfigMissingFileDefaultsToOverwrite(t *testing.T) { func TestCheckExistingConfigPropagatesStatErrors(t *testing.T) { pathWithNul := string([]byte{0}) - action, err := CheckExistingConfig(pathWithNul, "sig") + action, err := CheckExistingConfig(context.Background(), pathWithNul, "sig") if err == nil { t.Fatalf("expected error for invalid path") } - if action != ExistingConfigSkip { - t.Fatalf("expected skip action on stat error, got %v", action) + if action != ExistingConfigCancel { + t.Fatalf("expected cancel action on stat error, got %v", action) + } +} + +func TestCheckExistingConfigRejectsNonRegularPath(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config-dir") + if err := os.Mkdir(configPath, 0o755); err != nil { + t.Fatalf("failed to create directory: %v", err) + } + + action, err := CheckExistingConfig(context.Background(), configPath, "sig") + if err == nil { + t.Fatal("expected error for non-regular config path") + } + if err.Error() != "configuration file path is not a regular file: "+configPath { + t.Fatalf("unexpected error: %v", err) + } + if action != ExistingConfigCancel { + t.Fatalf("expected cancel action on non-regular path, got %v", action) + } +} + +func TestCheckExistingConfigDefaultsFocusToKeepContinue(t *testing.T) { + tmp := t.TempDir() + configPath := filepath.Join(tmp, "prox.env") + if err := os.WriteFile(configPath, []byte("base"), 0o600); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + originalRunner := checkExistingConfigRunner + t.Cleanup(func() { checkExistingConfigRunner = originalRunner }) + + checkExistingConfigRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + modal, ok := focus.(*tview.Modal) + if !ok { + t.Fatalf("focus=%T; want *tview.Modal", focus) + } + + var form *tview.Form + modal.Focus(func(p tview.Primitive) { + var ok bool + form, ok = p.(*tview.Form) + if !ok { + t.Fatalf("delegate focus=%T; want *tview.Form", p) + } + }) + + formItem, button := form.GetFocusedItemIndex() + if formItem != -1 || button != 2 { + t.Fatalf("focused item=(%d,%d); want (-1,2)", formItem, button) + } + + done := extractModalDone(modal) + done(0, "Keep & continue") + return nil + } + + action, err := CheckExistingConfig(context.Background(), configPath, "sig-abc") + if err != nil { + t.Fatalf("CheckExistingConfig returned error: %v", err) + } + if action != ExistingConfigKeepContinue { + t.Fatalf("action=%v; want %v", action, ExistingConfigKeepContinue) + } +} + +func TestCheckExistingConfigPropagatesRunnerErrors(t *testing.T) { + tmp := t.TempDir() + configPath := filepath.Join(tmp, "prox.env") + if err := os.WriteFile(configPath, []byte("base"), 0o600); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + originalRunner := checkExistingConfigRunner + t.Cleanup(func() { checkExistingConfigRunner = originalRunner }) + + expectedErr := errors.New("ui runner failure") + checkExistingConfigRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + return expectedErr + } + + action, err := CheckExistingConfig(context.Background(), configPath, "sig") + if !errors.Is(err, expectedErr) { + t.Fatalf("expected runner error %v, got %v", expectedErr, err) + } + if action != ExistingConfigCancel { + t.Fatalf("expected cancel action on runner error, got %v", action) + } +} + +func TestCheckExistingConfigPassesContextToRunner(t *testing.T) { + tmp := t.TempDir() + configPath := filepath.Join(tmp, "prox.env") + if err := os.WriteFile(configPath, []byte("base"), 0o600); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + originalRunner := checkExistingConfigRunner + t.Cleanup(func() { checkExistingConfigRunner = originalRunner }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + expectedErr := errors.New("ui runner failure") + checkExistingConfigRunner = func(gotCtx context.Context, app *tui.App, root, focus tview.Primitive) error { + if gotCtx != ctx { + t.Fatalf("ctx=%p; want %p", gotCtx, ctx) + } + return expectedErr + } + + action, err := CheckExistingConfig(ctx, configPath, "sig") + if !errors.Is(err, expectedErr) { + t.Fatalf("expected runner error %v, got %v", expectedErr, err) + } + if action != ExistingConfigCancel { + t.Fatalf("expected cancel action on runner error, got %v", action) } } diff --git a/internal/tui/wizard/new_install.go b/internal/tui/wizard/new_install.go index e799db91..db7dab10 100644 --- a/internal/tui/wizard/new_install.go +++ b/internal/tui/wizard/new_install.go @@ -1,7 +1,10 @@ package wizard import ( + "context" "fmt" + "os" + "path/filepath" "strings" "github.com/gdamore/tcell/v2" @@ -10,48 +13,44 @@ import ( "github.com/tis24dev/proxsave/internal/tui" ) -var confirmNewInstallRunner = func(app *tui.App, root, focus tview.Primitive) error { - return app.SetRoot(root, true).SetFocus(focus).Run() +var confirmNewInstallRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + app.SetRoot(root, true) + app.SetFocus(focus) + return app.RunWithContext(ctx) +} + +func formatPreservedEntries(baseDir string, entries []string) string { + formatted := make([]string, 0, len(entries)) + for _, entry := range entries { + trimmed := strings.TrimSpace(entry) + if trimmed == "" { + continue + } + if !strings.HasSuffix(trimmed, "/") { + resolved := filepath.Join(baseDir, trimmed) + if fi, err := os.Stat(resolved); err == nil && fi.IsDir() { + trimmed += "/" + } + } + formatted = append(formatted, trimmed) + } + if len(formatted) == 0 { + return "(none)" + } + return strings.Join(formatted, " ") } // ConfirmNewInstall shows a TUI confirmation before wiping baseDir for --new-install. -func ConfirmNewInstall(baseDir string, buildSig string) (bool, error) { +func ConfirmNewInstall(ctx context.Context, baseDir string, buildSig string, preservedEntries []string) (bool, error) { app := tui.NewApp() proceed := false - - // Header text (align with main install wizard) - welcomeText := tview.NewTextView(). - SetText("Welcome to ProxSave Installation Wizard - By TIS24DEV\n\n" + - "This wizard will guide you through configuring your backup system for Proxmox.\n" + - "All settings can be changed later by editing the configuration file."). - SetTextColor(tui.ProxmoxLight). - SetDynamicColors(true) - welcomeText.SetBorder(false) - - // Build signature line - buildSigText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true) - buildSigText.SetBorder(false) - - // Navigation instructions - navInstructions := tview.NewTextView(). - SetText("[yellow]Navigation:[white] TAB/↑↓ to move | ENTER to open dropdowns | ←→ on buttons | ENTER to submit | Mouse clicks enabled"). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - navInstructions.SetBorder(false) - - // Separator - separator := tview.NewTextView(). - SetText(strings.Repeat("─", 80)). - SetTextColor(tui.ProxmoxOrange) - separator.SetBorder(false) + preservedText := formatPreservedEntries(baseDir, preservedEntries) + escapedBaseDir := tview.Escape(baseDir) + escapedPreservedText := tview.Escape(preservedText) // Confirmation modal modal := tview.NewModal(). - SetText(fmt.Sprintf("Base directory to reset:\n[yellow]%s[white]\n\nThis keeps [yellow]build/ env/ identity/[white]\nbut deletes everything else.\n\nContinue?", baseDir)). + SetText(fmt.Sprintf("Base directory to reset:\n[yellow]%s[white]\n\nThis keeps [yellow]%s[white]\nbut deletes everything else.\n\nContinue?", escapedBaseDir, escapedPreservedText)). AddButtons([]string{"Continue", "Cancel"}). SetDoneFunc(func(buttonIndex int, buttonLabel string) { if buttonLabel == "Continue" { @@ -67,24 +66,20 @@ func ConfirmNewInstall(baseDir string, buildSig string) (bool, error) { SetBorderColor(tui.WarningYellow). SetBackgroundColor(tcell.ColorBlack) - // Layout - flex := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(welcomeText, 5, 0, false). - AddItem(navInstructions, 2, 0, false). - AddItem(separator, 1, 0, false). - AddItem(modal, 0, 1, true). - AddItem(buildSigText, 1, 0, false) - - flex.SetBorder(true). - SetTitle(" ProxSave New Install "). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.ProxmoxOrange). - SetBorderColor(tui.ProxmoxOrange). - SetBackgroundColor(tcell.ColorBlack) + flex := buildWizardScreen( + "ProxSave New Install", + "Welcome to ProxSave Installation Wizard - By TIS24DEV\n\n"+ + "This wizard will guide you through configuring your backup system for Proxmox.\n"+ + "All settings can be changed later by editing the configuration file.", + "[yellow]Navigation:[white] TAB/↑↓ to move | ENTER to open dropdowns | ←→ on buttons | ENTER to submit | Mouse clicks enabled", + "", + buildSig, + modal, + ) - // Run the app - ignore errors from normal app termination - _ = confirmNewInstallRunner(app, flex, modal) + if err := confirmNewInstallRunner(ctx, app, flex, modal); err != nil { + return false, err + } return proceed, nil } diff --git a/internal/tui/wizard/new_install_test.go b/internal/tui/wizard/new_install_test.go index ebe3b18c..359234e4 100644 --- a/internal/tui/wizard/new_install_test.go +++ b/internal/tui/wizard/new_install_test.go @@ -1,6 +1,10 @@ package wizard import ( + "context" + "errors" + "os" + "path/filepath" "strings" "testing" @@ -9,17 +13,120 @@ import ( "github.com/tis24dev/proxsave/internal/tui" ) -func TestConfirmNewInstallContinue(t *testing.T) { +func testPreservedEntries() []string { + return []string{"build", "env", "identity"} +} + +func registerConfirmNewInstallRunner(t *testing.T, runner func(context.Context, *tui.App, tview.Primitive, tview.Primitive) error) { + t.Helper() originalRunner := confirmNewInstallRunner - defer func() { confirmNewInstallRunner = originalRunner }() + confirmNewInstallRunner = runner + t.Cleanup(func() { + confirmNewInstallRunner = originalRunner + }) +} + +func wizardPrimitiveContainsText(p tview.Primitive, want string) bool { + switch v := p.(type) { + case *tview.TextView: + return strings.Contains(v.GetText(false), want) + case *tview.Flex: + for i := 0; i < v.GetItemCount(); i++ { + if wizardPrimitiveContainsText(v.GetItem(i), want) { + return true + } + } + } + return false +} - confirmNewInstallRunner = func(app *tui.App, root, focus tview.Primitive) error { +func withWorkingDir(t *testing.T, dir string) { + t.Helper() + original, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd failed: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir(%q) failed: %v", dir, err) + } + t.Cleanup(func() { + if err := os.Chdir(original); err != nil { + t.Errorf("restoring working directory to %q failed: %v", original, err) + } + }) +} + +func TestFormatPreservedEntries(t *testing.T) { + tempDir := t.TempDir() + if err := os.Mkdir(filepath.Join(tempDir, "build"), 0o755); err != nil { + t.Fatalf("Mkdir(build) failed: %v", err) + } + if err := os.Mkdir(filepath.Join(tempDir, "identity"), 0o755); err != nil { + t.Fatalf("Mkdir(identity) failed: %v", err) + } + if err := os.WriteFile(filepath.Join(tempDir, "backup.env"), []byte("TEST=1\n"), 0o644); err != nil { + t.Fatalf("WriteFile(backup.env) failed: %v", err) + } + withWorkingDir(t, tempDir) + + tests := []struct { + name string + entries []string + want string + }{ + { + name: "adds slash only for directories", + entries: []string{" build ", "backup.env", " missing ", " identity"}, + want: "build/ backup.env missing identity/", + }, + { + name: "returns none for nil input", + entries: nil, + want: "(none)", + }, + { + name: "returns none for blank entries", + entries: []string{"", " ", "\t"}, + want: "(none)", + }, + { + name: "preserves existing trailing slash without doubling", + entries: []string{"build/", " identity/ ", "backup.env"}, + want: "build/ identity/ backup.env", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatPreservedEntries(tempDir, tt.entries); got != tt.want { + t.Fatalf("formatPreservedEntries(%v) = %q, want %q", tt.entries, got, tt.want) + } + }) + } +} + +func TestFormatPreservedEntriesResolvesAgainstBaseDirNotCWD(t *testing.T) { + baseDir := t.TempDir() + if err := os.Mkdir(filepath.Join(baseDir, "build"), 0o755); err != nil { + t.Fatalf("Mkdir(build) failed: %v", err) + } + + otherDir := t.TempDir() + withWorkingDir(t, otherDir) + + if got := formatPreservedEntries(baseDir, []string{"build"}); got != "build/" { + t.Fatalf("formatPreservedEntries should resolve against baseDir: got %q, want %q", got, "build/") + } +} + +func TestConfirmNewInstallContinue(t *testing.T) { + registerConfirmNewInstallRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { done := extractModalDone(focus.(*tview.Modal)) done(0, "Continue") return nil - } + }) - proceed, err := ConfirmNewInstall("/opt/proxmox", "sig-123") + proceed, err := ConfirmNewInstall(context.Background(), "/opt/proxmox", "sig-123", testPreservedEntries()) if err != nil { t.Fatalf("ConfirmNewInstall error: %v", err) } @@ -29,16 +136,13 @@ func TestConfirmNewInstallContinue(t *testing.T) { } func TestConfirmNewInstallCancel(t *testing.T) { - originalRunner := confirmNewInstallRunner - defer func() { confirmNewInstallRunner = originalRunner }() - - confirmNewInstallRunner = func(app *tui.App, root, focus tview.Primitive) error { + registerConfirmNewInstallRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { done := extractModalDone(focus.(*tview.Modal)) done(1, "Cancel") return nil - } + }) - proceed, err := ConfirmNewInstall("/opt/proxmox", "sig-123") + proceed, err := ConfirmNewInstall(context.Background(), "/opt/proxmox", "sig-123", testPreservedEntries()) if err != nil { t.Fatalf("ConfirmNewInstall error: %v", err) } @@ -48,16 +152,13 @@ func TestConfirmNewInstallCancel(t *testing.T) { } func TestConfirmNewInstallMessageIncludesBaseDir(t *testing.T) { - originalRunner := confirmNewInstallRunner - defer func() { confirmNewInstallRunner = originalRunner }() - var captured string - confirmNewInstallRunner = func(app *tui.App, root, focus tview.Primitive) error { + registerConfirmNewInstallRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { captured = extractModalText(focus.(*tview.Modal)) return nil - } + }) - _, err := ConfirmNewInstall("/var/lib/data", "build-sig") + _, err := ConfirmNewInstall(context.Background(), "/var/lib/data", "build-sig", testPreservedEntries()) if err != nil { t.Fatalf("ConfirmNewInstall error: %v", err) } @@ -65,3 +166,124 @@ func TestConfirmNewInstallMessageIncludesBaseDir(t *testing.T) { t.Fatalf("expected modal text to mention base dir, got %q", captured) } } + +func TestConfirmNewInstallMessageIncludesPreservedEntries(t *testing.T) { + baseDir := t.TempDir() + for _, dir := range []string{"build", "env", "identity"} { + if err := os.Mkdir(filepath.Join(baseDir, dir), 0o755); err != nil { + t.Fatalf("Mkdir(%s) failed: %v", dir, err) + } + } + + var captured string + registerConfirmNewInstallRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + captured = extractModalText(focus.(*tview.Modal)) + return nil + }) + + _, err := ConfirmNewInstall(context.Background(), baseDir, "build-sig", testPreservedEntries()) + if err != nil { + t.Fatalf("ConfirmNewInstall error: %v", err) + } + if !strings.Contains(captured, "build/ env/ identity/") { + t.Fatalf("expected modal text to mention preserved entries, got %q", captured) + } +} + +func TestConfirmNewInstallMessageUsesNoneWhenEntriesAreBlank(t *testing.T) { + var captured string + registerConfirmNewInstallRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + captured = extractModalText(focus.(*tview.Modal)) + return nil + }) + + _, err := ConfirmNewInstall(context.Background(), "/var/lib/data", "build-sig", []string{"", " ", "\t"}) + if err != nil { + t.Fatalf("ConfirmNewInstall error: %v", err) + } + if !strings.Contains(captured, "(none)") { + t.Fatalf("expected modal text to mention (none), got %q", captured) + } +} + +func TestConfirmNewInstallMessageEscapesDynamicColorMarkup(t *testing.T) { + baseDir := "/var/lib/[prod]" + preservedEntries := []string{" build[0] ", " identity] "} + + var captured string + registerConfirmNewInstallRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + captured = extractModalText(focus.(*tview.Modal)) + return nil + }) + + _, err := ConfirmNewInstall(context.Background(), baseDir, "build-sig", preservedEntries) + if err != nil { + t.Fatalf("ConfirmNewInstall error: %v", err) + } + if !strings.Contains(captured, tview.Escape(baseDir)) { + t.Fatalf("expected escaped base dir in modal text, got %q", captured) + } + + wantPreserved := tview.Escape(formatPreservedEntries(baseDir, preservedEntries)) + if !strings.Contains(captured, wantPreserved) { + t.Fatalf("expected escaped preserved entries %q in modal text, got %q", wantPreserved, captured) + } +} + +func TestConfirmNewInstallPropagatesRunnerError(t *testing.T) { + expectedErr := errors.New("runner failed") + registerConfirmNewInstallRunner(t, func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + return expectedErr + }) + + _, err := ConfirmNewInstall(context.Background(), "/opt/proxmox", "sig-123", testPreservedEntries()) + if !errors.Is(err, expectedErr) { + t.Fatalf("expected error %v, got %v", expectedErr, err) + } +} + +func TestConfirmNewInstallPassesContextToRunner(t *testing.T) { + ctx := t.Context() + registerConfirmNewInstallRunner(t, func(gotCtx context.Context, app *tui.App, root, focus tview.Primitive) error { + if gotCtx != ctx { + t.Fatalf("got context %p, want %p", gotCtx, ctx) + } + done := extractModalDone(focus.(*tview.Modal)) + done(0, "Continue") + return nil + }) + + proceed, err := ConfirmNewInstall(ctx, "/opt/proxmox", "sig-123", testPreservedEntries()) + if err != nil { + t.Fatalf("ConfirmNewInstall error: %v", err) + } + if !proceed { + t.Fatalf("expected proceed=true when Continue is selected") + } +} + +func TestConfirmNewInstallBuildsWizardScreenWithEscapedBuildSignature(t *testing.T) { + buildSig := "sig-[123]" + + var root tview.Primitive + var focus tview.Primitive + registerConfirmNewInstallRunner(t, func(ctx context.Context, app *tui.App, gotRoot, gotFocus tview.Primitive) error { + root = gotRoot + focus = gotFocus + return nil + }) + + _, err := ConfirmNewInstall(context.Background(), "/opt/proxmox", buildSig, testPreservedEntries()) + if err != nil { + t.Fatalf("ConfirmNewInstall error: %v", err) + } + if root == nil { + t.Fatalf("expected wizard root to be passed to runner") + } + if _, ok := focus.(*tview.Modal); !ok { + t.Fatalf("expected modal focus, got %T", focus) + } + if !wizardPrimitiveContainsText(root, tview.Escape(buildSig)) { + t.Fatalf("expected root screen to include escaped build signature %q", tview.Escape(buildSig)) + } +} diff --git a/internal/tui/wizard/post_install_audit_tui.go b/internal/tui/wizard/post_install_audit_tui.go index b46587d7..e81105fa 100644 --- a/internal/tui/wizard/post_install_audit_tui.go +++ b/internal/tui/wizard/post_install_audit_tui.go @@ -16,8 +16,10 @@ import ( ) var ( - postInstallAuditWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { - return app.SetRoot(root, true).SetFocus(focus).Run() + postInstallAuditWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + app.SetRoot(root, true) + app.SetFocus(focus) + return app.RunWithContext(ctx) } ) @@ -42,40 +44,6 @@ type PostInstallAuditResult struct { func RunPostInstallAuditWizard(ctx context.Context, execPath, configPath, buildSig string) (result PostInstallAuditResult, err error) { app := tui.NewApp() - titleText := tview.NewTextView(). - SetText("ProxSave - Post-install Check\n\n" + - "Detect optional components that are enabled but not configured on this node.\n" + - "This helps reduce WARNING noise and exit code 1 runs when features are unused.\n"). - SetTextColor(tui.ProxmoxLight). - SetDynamicColors(true) - titleText.SetBorder(false) - - nav := tview.NewTextView(). - SetText("[yellow]Navigation:[white] ↑↓ to move | ENTER/SPACE to toggle | ←→ on buttons | ENTER to select"). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - nav.SetBorder(false) - - separator := tview.NewTextView(). - SetText(strings.Repeat("─", 80)). - SetTextColor(tui.ProxmoxOrange) - separator.SetBorder(false) - - configPathText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", configPath)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - configPathText.SetBorder(false) - - buildSigText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - buildSigText.SetBorder(false) - pages := tview.NewPages() confirmRun := false @@ -137,23 +105,18 @@ func RunPostInstallAuditWizard(ctx context.Context, execPath, configPath, buildS pages.AddPage("confirm", confirm, true, true) pages.AddPage("running", running, true, false) - layout := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(titleText, 5, 0, false). - AddItem(nav, 2, 0, false). - AddItem(separator, 1, 0, false). - AddItem(pages, 0, 1, true). - AddItem(configPathText, 1, 0, false). - AddItem(buildSigText, 1, 0, false) - - layout.SetBorder(true). - SetTitle(" ProxSave "). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.ProxmoxOrange). - SetBorderColor(tui.ProxmoxOrange). - SetBackgroundColor(tcell.ColorBlack) - - if runErr := postInstallAuditWizardRunner(app, layout, confirm); runErr != nil { + layout := buildWizardScreen( + "ProxSave", + "ProxSave - Post-install Check\n\n"+ + "Detect optional components that are enabled but not configured on this node.\n"+ + "This helps reduce WARNING noise and exit code 1 runs when features are unused.\n", + "[yellow]Navigation:[white] ↑↓ to move | ENTER/SPACE to toggle | ←→ on buttons | ENTER to select", + configPath, + buildSig, + pages, + ) + + if runErr := postInstallAuditWizardRunner(ctx, app, layout, confirm); runErr != nil { return PostInstallAuditResult{}, runErr } @@ -228,11 +191,11 @@ func showAuditReview(app *tui.App, pages *tview.Pages, configPath string, sugges b.WriteString("[yellow]Detected warnings:[white]\n\n") for _, msg := range s.Messages { b.WriteString("- ") - b.WriteString(msg) + b.WriteString(tview.Escape(msg)) b.WriteString("\n") } b.WriteString("\n") - b.WriteString(fmt.Sprintf("If you don’t use this feature, set [yellow]%s=false[white] to disable.\n", s.Key)) + b.WriteString(fmt.Sprintf("If you don’t use this feature, set [yellow]%s=false[white] to disable.\n", tview.Escape(s.Key))) details.SetText(b.String()) } diff --git a/internal/tui/wizard/post_install_audit_tui_test.go b/internal/tui/wizard/post_install_audit_tui_test.go new file mode 100644 index 00000000..4fd992cb --- /dev/null +++ b/internal/tui/wizard/post_install_audit_tui_test.go @@ -0,0 +1,33 @@ +package wizard + +import ( + "context" + "testing" + + "github.com/rivo/tview" + + "github.com/tis24dev/proxsave/internal/tui" +) + +func TestRunPostInstallAuditWizard_PassesContextToRunner(t *testing.T) { + origRunner := postInstallAuditWizardRunner + t.Cleanup(func() { + postInstallAuditWizardRunner = origRunner + }) + + ctx := t.Context() + postInstallAuditWizardRunner = func(gotCtx context.Context, app *tui.App, root, focus tview.Primitive) error { + if gotCtx != ctx { + t.Fatalf("got context %p, want %p", gotCtx, ctx) + } + return nil + } + + result, err := RunPostInstallAuditWizard(ctx, "/tmp/proxsave", "/tmp/backup.env", "sig") + if err != nil { + t.Fatalf("RunPostInstallAuditWizard error: %v", err) + } + if result.Ran { + t.Fatalf("expected Ran=false when runner exits without selecting an action") + } +} diff --git a/internal/tui/wizard/screen.go b/internal/tui/wizard/screen.go new file mode 100644 index 00000000..59f9120b --- /dev/null +++ b/internal/tui/wizard/screen.go @@ -0,0 +1,21 @@ +package wizard + +import ( + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" + + "github.com/tis24dev/proxsave/internal/tui" +) + +func buildWizardScreen(title, headerText, navText, configPath, buildSig string, content tview.Primitive) tview.Primitive { + return tui.BuildScreen(tui.ScreenSpec{ + Title: title, + HeaderText: headerText, + NavText: navText, + ConfigPath: configPath, + BuildSig: buildSig, + TitleColor: tui.ProxmoxOrange, + BorderColor: tui.ProxmoxOrange, + BackgroundColor: tcell.ColorBlack, + }, content) +} diff --git a/internal/tui/wizard/telegram_setup_tui.go b/internal/tui/wizard/telegram_setup_tui.go index 90e20980..4e8fcc62 100644 --- a/internal/tui/wizard/telegram_setup_tui.go +++ b/internal/tui/wizard/telegram_setup_tui.go @@ -3,47 +3,34 @@ package wizard import ( "context" "fmt" - "os" "strings" "sync" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" - "github.com/tis24dev/proxsave/internal/config" - "github.com/tis24dev/proxsave/internal/identity" "github.com/tis24dev/proxsave/internal/notify" + "github.com/tis24dev/proxsave/internal/orchestrator" "github.com/tis24dev/proxsave/internal/tui" ) var ( - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { - return app.SetRoot(root, true).SetFocus(focus).Run() + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + app.SetRoot(root, true) + app.SetFocus(focus) + return app.RunWithContext(ctx) } - telegramSetupLoadConfig = config.LoadConfig - telegramSetupReadFile = os.ReadFile - telegramSetupStat = os.Stat - telegramSetupIdentityDetect = identity.Detect - telegramSetupCheckRegistration = notify.CheckTelegramRegistration - telegramSetupQueueUpdateDraw = func(app *tui.App, f func()) { app.QueueUpdateDraw(f) } - telegramSetupGo = func(fn func()) { go fn() } + telegramSetupBuildBootstrap = orchestrator.BuildTelegramSetupBootstrap + telegramSetupCheckRegistration = notify.CheckTelegramRegistration + telegramSetupQueueUpdateDraw = func(app *tui.App, f func()) { app.QueueUpdateDraw(f) } + telegramSetupGo = func(fn func()) { go fn() } ) type TelegramSetupResult struct { - Shown bool - - ConfigLoaded bool - ConfigError string - - TelegramEnabled bool - TelegramMode string - ServerAPIHost string + orchestrator.TelegramSetupBootstrap - ServerID string - IdentityFile string - IdentityPersisted bool - IdentityDetectError string + Shown bool CheckAttempts int Verified bool @@ -55,89 +42,22 @@ type TelegramSetupResult struct { } func RunTelegramSetupWizard(ctx context.Context, baseDir, configPath, buildSig string) (TelegramSetupResult, error) { - result := TelegramSetupResult{Shown: true} - - cfg, cfgErr := telegramSetupLoadConfig(configPath) - if cfgErr != nil { - result.ConfigLoaded = false - result.ConfigError = cfgErr.Error() - // Fall back to raw env parsing so the wizard can still run even when the full - // config parser fails for unrelated keys. - if configBytes, readErr := telegramSetupReadFile(configPath); readErr == nil { - values := parseEnvTemplate(string(configBytes)) - result.TelegramEnabled = readTemplateBool(values, "TELEGRAM_ENABLED") - result.TelegramMode = strings.ToLower(strings.TrimSpace(readTemplateString(values, "BOT_TELEGRAM_TYPE"))) - } - } else { - result.ConfigLoaded = true - result.TelegramEnabled = cfg.TelegramEnabled - result.TelegramMode = strings.ToLower(strings.TrimSpace(cfg.TelegramBotType)) - result.ServerAPIHost = strings.TrimSpace(cfg.TelegramServerAPIHost) + state, err := telegramSetupBuildBootstrap(configPath, baseDir) + if err != nil { + return TelegramSetupResult{}, err } - - if !result.TelegramEnabled { + result := TelegramSetupResult{ + TelegramSetupBootstrap: state, + Shown: true, + } + if result.Eligibility != orchestrator.TelegramSetupEligibleCentralized { result.Shown = false return result, nil } - if result.TelegramMode == "" { - result.TelegramMode = "centralized" - } - if result.ServerAPIHost == "" { - // Fallback (keeps behavior aligned with internal/config defaults). - result.ServerAPIHost = "https://bot.tis24.it:1443" - } - - idInfo, idErr := telegramSetupIdentityDetect(baseDir, nil) - if idErr != nil { - result.IdentityDetectError = idErr.Error() - } - if idInfo != nil { - result.ServerID = strings.TrimSpace(idInfo.ServerID) - result.IdentityFile = strings.TrimSpace(idInfo.IdentityFile) - if result.IdentityFile != "" { - if _, err := telegramSetupStat(result.IdentityFile); err == nil { - result.IdentityPersisted = true - } - } - } app := tui.NewApp() pages := tview.NewPages() - titleText := tview.NewTextView(). - SetText("ProxSave - Telegram Setup\n\n" + - "Telegram notifications are enabled.\n" + - "Complete the bot pairing now to avoid warning noise and skipped notifications.\n"). - SetTextColor(tui.ProxmoxLight). - SetDynamicColors(true) - titleText.SetBorder(false) - - nav := tview.NewTextView(). - SetText("[yellow]Navigation:[white] TAB/↑↓ to move | ENTER to select | ESC to exit"). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - nav.SetBorder(false) - - separator := tview.NewTextView(). - SetText(strings.Repeat("─", 80)). - SetTextColor(tui.ProxmoxOrange) - separator.SetBorder(false) - - configPathText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", configPath)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - configPathText.SetBorder(false) - - buildSigText := tview.NewTextView(). - SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). - SetTextColor(tcell.ColorWhite). - SetDynamicColors(true). - SetTextAlign(tview.AlignCenter) - buildSigText.SetBorder(false) - instructions := tview.NewTextView(). SetDynamicColors(true). SetWrap(true) @@ -173,43 +93,26 @@ func RunTelegramSetupWizard(ctx context.Context, baseDir, configPath, buildSig s return s[:max] + "...(truncated)" } - modeLabel := result.TelegramMode - if modeLabel == "" { - modeLabel = "centralized" - } - var b strings.Builder - b.WriteString(fmt.Sprintf("[yellow]Mode:[white] %s\n", modeLabel)) - if !result.ConfigLoaded && result.ConfigError != "" { - b.WriteString(fmt.Sprintf("[red]WARNING:[white] failed to load config: %s\n\n", truncate(result.ConfigError, 200))) - } - if result.TelegramMode == "personal" { - b.WriteString("\nPersonal mode uses your own bot.\n\n") - b.WriteString("This installer does not guide the personal bot setup.\n") - b.WriteString("Edit backup.env and set:\n") - b.WriteString(" - TELEGRAM_BOT_TOKEN\n") - b.WriteString(" - TELEGRAM_CHAT_ID\n\n") - b.WriteString("Then run ProxSave once to validate notifications.\n") - } else { - b.WriteString("\n1) Open Telegram and start [yellow]@ProxmoxAN_bot[white]\n") - b.WriteString("2) Send the [yellow]Server ID[white] below (digits only)\n") - b.WriteString("3) Press [yellow]Check[white] to verify\n\n") - b.WriteString("If the check fails, you can press Check again.\n") - b.WriteString("You can also Skip verification and complete pairing later.\n") - } + b.WriteString("[yellow]Mode:[white] centralized\n") + b.WriteString("\n1) Open Telegram and start [yellow]@ProxmoxAN_bot[white]\n") + b.WriteString("2) Send the [yellow]Server ID[white] below (digits only)\n") + b.WriteString("3) Press [yellow]Check[white] to verify\n\n") + b.WriteString("If the check fails, you can press Check again.\n") + b.WriteString("You can also Skip verification and complete pairing later.\n") instructions.SetText(b.String()) - serverIDLine := "[red]Server ID unavailable.[white]" - if result.ServerID != "" { - serverIDLine = fmt.Sprintf("[yellow]%s[white]", result.ServerID) - } + escapedServerID := tview.Escape(result.ServerID) + serverIDLine := fmt.Sprintf("[yellow]%s[white]", escapedServerID) identityLine := "" if result.IdentityFile != "" { persisted := "not persisted" if result.IdentityPersisted { persisted = "persisted" } - identityLine = fmt.Sprintf("\n[gray]Identity file:[white] %s ([yellow]%s[white])", result.IdentityFile, persisted) + escapedIdentityFile := tview.Escape(result.IdentityFile) + escapedPersisted := tview.Escape(persisted) + identityLine = fmt.Sprintf("\n[gray]Identity file:[white] %s ([yellow]%s[white])", escapedIdentityFile, escapedPersisted) } serverIDView.SetText(serverIDLine + identityLine) @@ -217,17 +120,7 @@ func RunTelegramSetupWizard(ctx context.Context, baseDir, configPath, buildSig s statusView.SetText(text) } - initialStatus := "[yellow]Not checked yet.[white]\n\nPress [yellow]Check[white] after sending the Server ID to the bot." - if result.TelegramMode == "personal" { - initialStatus = "[yellow]No centralized pairing check for personal mode.[white]" - } - if result.ServerID == "" && result.TelegramMode != "personal" { - initialStatus = "[red]Cannot check registration: Server ID missing.[white]" - if result.IdentityDetectError != "" { - initialStatus += "\n\n" + truncate(result.IdentityDetectError, 200) - } - } - setStatus(initialStatus) + setStatus("[yellow]Not checked yet.[white]\n\nPress [yellow]Check[white] after sending the Server ID to the bot.") var mu sync.Mutex checking := false @@ -256,10 +149,6 @@ func RunTelegramSetupWizard(ctx context.Context, baseDir, configPath, buildSig s var refreshButtons func() checkHandler := func() { - if result.TelegramMode == "personal" || strings.TrimSpace(result.ServerID) == "" { - return - } - mu.Lock() if checking || closing { mu.Unlock() @@ -291,7 +180,7 @@ func RunTelegramSetupWizard(ctx context.Context, baseDir, configPath, buildSig s if status.Code == 200 && status.Error == nil { result.Verified = true - setStatus(fmt.Sprintf("[green]✓ Linked successfully.[white]\n\n%s", status.Message)) + setStatus(fmt.Sprintf("[green]✓ Linked successfully.[white]\n\n%s", tview.Escape(status.Message))) if refreshButtons != nil { refreshButtons() } @@ -311,30 +200,20 @@ func RunTelegramSetupWizard(ctx context.Context, baseDir, configPath, buildSig s default: hint = "\n\nYou can press Check again, or Skip verification and complete pairing later." } - setStatus(fmt.Sprintf("[yellow]%s[white]%s", truncate(msg, 300), hint)) + setStatus(fmt.Sprintf("[yellow]%s[white]%s", tview.Escape(truncate(msg, 300)), hint)) }) }) } refreshButtons = func() { form.ClearButtons() - - // Centralized mode pairing only works when the Server ID is available. - if result.TelegramMode != "personal" && strings.TrimSpace(result.ServerID) != "" { - form.AddButton("Check", checkHandler) - } - - switch { - case result.TelegramMode == "personal": - form.AddButton("Continue", func() { doClose(false) }) - case strings.TrimSpace(result.ServerID) == "": + form.AddButton("Check", checkHandler) + if result.Verified { form.AddButton("Continue", func() { doClose(false) }) - case result.Verified: - form.AddButton("Continue", func() { doClose(false) }) - default: - // Until verification succeeds, require an explicit skip to leave without pairing. - form.AddButton("Skip", func() { doClose(true) }) + return } + // Until verification succeeds, require an explicit skip to leave without pairing. + form.AddButton("Skip", func() { doClose(true) }) } refreshButtons() @@ -354,25 +233,23 @@ func RunTelegramSetupWizard(ctx context.Context, baseDir, configPath, buildSig s pages.AddPage("main", body, true, true) - layout := tview.NewFlex(). - SetDirection(tview.FlexRow). - AddItem(titleText, 5, 0, false). - AddItem(nav, 2, 0, false). - AddItem(separator, 1, 0, false). - AddItem(pages, 0, 1, true). - AddItem(configPathText, 1, 0, false). - AddItem(buildSigText, 1, 0, false) - - layout.SetBorder(true). - SetTitle(" ProxSave "). - SetTitleAlign(tview.AlignCenter). - SetTitleColor(tui.ProxmoxOrange). - SetBorderColor(tui.ProxmoxOrange). - SetBackgroundColor(tcell.ColorBlack) + layout := buildWizardScreen( + "ProxSave", + "ProxSave - Telegram Setup\n\n"+ + "Telegram notifications are enabled.\n"+ + "Complete the bot pairing now to avoid warning noise and skipped notifications.\n", + "[yellow]Navigation:[white] TAB/↑↓ to move | ENTER to select | ESC to exit", + configPath, + buildSig, + pages, + ) app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEscape { - if result.TelegramMode != "personal" && strings.TrimSpace(result.ServerID) != "" && !result.Verified { + mu.Lock() + verified := result.Verified + mu.Unlock() + if !verified { doClose(true) } else { doClose(false) @@ -382,7 +259,7 @@ func RunTelegramSetupWizard(ctx context.Context, baseDir, configPath, buildSig s return event }) - if runErr := telegramSetupWizardRunner(app, layout, form); runErr != nil { + if runErr := telegramSetupWizardRunner(ctx, app, layout, form); runErr != nil { return TelegramSetupResult{}, runErr } diff --git a/internal/tui/wizard/telegram_setup_tui_test.go b/internal/tui/wizard/telegram_setup_tui_test.go index 19e7a8f4..399cd18f 100644 --- a/internal/tui/wizard/telegram_setup_tui_test.go +++ b/internal/tui/wizard/telegram_setup_tui_test.go @@ -3,18 +3,16 @@ package wizard import ( "context" "errors" - "os" - "path/filepath" "strings" "testing" + "time" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" - "github.com/tis24dev/proxsave/internal/config" - "github.com/tis24dev/proxsave/internal/identity" "github.com/tis24dev/proxsave/internal/logging" "github.com/tis24dev/proxsave/internal/notify" + "github.com/tis24dev/proxsave/internal/orchestrator" "github.com/tis24dev/proxsave/internal/tui" ) @@ -22,20 +20,14 @@ func stubTelegramSetupDeps(t *testing.T) { t.Helper() origRunner := telegramSetupWizardRunner - origLoadConfig := telegramSetupLoadConfig - origReadFile := telegramSetupReadFile - origStat := telegramSetupStat - origIdentityDetect := telegramSetupIdentityDetect + origBuildBootstrap := telegramSetupBuildBootstrap origCheckRegistration := telegramSetupCheckRegistration origQueueUpdateDraw := telegramSetupQueueUpdateDraw origGo := telegramSetupGo t.Cleanup(func() { telegramSetupWizardRunner = origRunner - telegramSetupLoadConfig = origLoadConfig - telegramSetupReadFile = origReadFile - telegramSetupStat = origStat - telegramSetupIdentityDetect = origIdentityDetect + telegramSetupBuildBootstrap = origBuildBootstrap telegramSetupCheckRegistration = origCheckRegistration telegramSetupQueueUpdateDraw = origQueueUpdateDraw telegramSetupGo = origGo @@ -45,17 +37,90 @@ func stubTelegramSetupDeps(t *testing.T) { telegramSetupQueueUpdateDraw = func(app *tui.App, f func()) { f() } } +func eligibleTelegramSetupBootstrap() orchestrator.TelegramSetupBootstrap { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupEligibleCentralized, + ConfigLoaded: true, + TelegramEnabled: true, + TelegramMode: "centralized", + ServerAPIHost: "https://api.example.test", + ServerID: "123456789", + IdentityFile: "/tmp/.server_identity", + IdentityPersisted: false, + } +} + +func extractTelegramSetupViews(t *testing.T, root tview.Primitive) (*tview.TextView, *tview.TextView, *tview.Form) { + t.Helper() + + layout, ok := root.(*tview.Flex) + if !ok { + t.Fatalf("expected root *tview.Flex, got %T", root) + } + + var pages *tview.Pages + for i := 0; i < layout.GetItemCount(); i++ { + candidate, ok := layout.GetItem(i).(*tview.Pages) + if !ok { + continue + } + if pages != nil { + t.Fatal("expected a single pages container in telegram setup layout") + } + pages = candidate + } + if pages == nil { + t.Fatal("expected pages container in telegram setup layout") + } + + _, bodyPrimitive := pages.GetFrontPage() + body, ok := bodyPrimitive.(*tview.Flex) + if !ok { + t.Fatalf("expected body *tview.Flex, got %T", bodyPrimitive) + } + + var serverIDView, statusView *tview.TextView + var form *tview.Form + for i := 0; i < body.GetItemCount(); i++ { + switch item := body.GetItem(i).(type) { + case *tview.TextView: + switch strings.TrimSpace(item.GetTitle()) { + case "Server ID": + serverIDView = item + case "Status": + statusView = item + } + case *tview.Form: + if strings.TrimSpace(item.GetTitle()) == "Actions" { + form = item + } + } + } + + if serverIDView == nil { + t.Fatal("expected Server ID view in telegram setup body") + } + if statusView == nil { + t.Fatal("expected Status view in telegram setup body") + } + if form == nil { + t.Fatal("expected Actions form in telegram setup body") + } + + return serverIDView, statusView, form +} + func TestRunTelegramSetupWizard_DisabledSkipsUIAndRunnerNotCalled(t *testing.T) { stubTelegramSetupDeps(t) - telegramSetupLoadConfig = func(path string) (*config.Config, error) { - return &config.Config{TelegramEnabled: false}, nil - } - telegramSetupIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { - t.Fatalf("identity detect should not be called when telegram is disabled") - return nil, nil + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupSkipDisabled, + ConfigLoaded: true, + TelegramEnabled: false, + }, nil } - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { t.Fatalf("runner should not be called when telegram is disabled") return nil } @@ -75,17 +140,17 @@ func TestRunTelegramSetupWizard_DisabledSkipsUIAndRunnerNotCalled(t *testing.T) } } -func TestRunTelegramSetupWizard_ConfigLoadAndReadFailSkipsUI(t *testing.T) { +func TestRunTelegramSetupWizard_ConfigErrorSkipsUI(t *testing.T) { stubTelegramSetupDeps(t) - telegramSetupLoadConfig = func(path string) (*config.Config, error) { - return nil, errors.New("parse failed") - } - telegramSetupReadFile = func(path string) ([]byte, error) { - return nil, errors.New("read failed") + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupSkipConfigError, + ConfigError: "parse failed", + }, nil } - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { - t.Fatalf("runner should not be called when env cannot be read") + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + t.Fatalf("runner should not be called when config bootstrap failed") return nil } @@ -104,33 +169,53 @@ func TestRunTelegramSetupWizard_ConfigLoadAndReadFailSkipsUI(t *testing.T) { } } -func TestRunTelegramSetupWizard_FallbackPersonalMode_Continue(t *testing.T) { +func TestRunTelegramSetupWizard_PersonalModeSkipsUI(t *testing.T) { stubTelegramSetupDeps(t) - identityFile := filepath.Join(t.TempDir(), ".server_identity") - if err := os.WriteFile(identityFile, []byte("id"), 0o600); err != nil { - t.Fatalf("write identity file: %v", err) + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupSkipPersonalMode, + ConfigLoaded: true, + TelegramEnabled: true, + TelegramMode: "personal", + ServerAPIHost: "https://bot.tis24.it:1443", + }, nil + } + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + t.Fatalf("runner should not be called in personal mode") + return nil } - telegramSetupLoadConfig = func(path string) (*config.Config, error) { - return nil, errors.New(strings.Repeat("x", 250)) + result, err := RunTelegramSetupWizard(context.Background(), t.TempDir(), "/fake/backup.env", "sig") + if err != nil { + t.Fatalf("RunTelegramSetupWizard error: %v", err) } - telegramSetupReadFile = func(path string) ([]byte, error) { - return []byte("TELEGRAM_ENABLED=true\nBOT_TELEGRAM_TYPE=Personal\n"), nil + if result.Shown { + t.Fatalf("expected wizard to not be shown") } - telegramSetupIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { - return &identity.Info{ServerID: " 123 ", IdentityFile: " " + identityFile + " "}, nil + if result.TelegramMode != "personal" { + t.Fatalf("TelegramMode=%q, want personal", result.TelegramMode) } - telegramSetupStat = os.Stat - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { - form := focus.(*tview.Form) - if form.GetButtonIndex("Check") != -1 { - t.Fatalf("expected no Check button in personal mode") - } - if form.GetButtonIndex("Continue") == -1 { - t.Fatalf("expected Continue button in personal mode") - } - pressFormButton(t, form, "Continue") + if !result.TelegramEnabled { + t.Fatalf("expected TelegramEnabled=true") + } +} + +func TestRunTelegramSetupWizard_IdentityUnavailableSkipsUI(t *testing.T) { + stubTelegramSetupDeps(t) + + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{ + Eligibility: orchestrator.TelegramSetupSkipIdentityUnavailable, + ConfigLoaded: true, + TelegramEnabled: true, + TelegramMode: "centralized", + ServerAPIHost: "https://api.example.test", + IdentityDetectError: "detect failed", + }, nil + } + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + t.Fatalf("runner should not be called when server ID is unavailable") return nil } @@ -138,56 +223,70 @@ func TestRunTelegramSetupWizard_FallbackPersonalMode_Continue(t *testing.T) { if err != nil { t.Fatalf("RunTelegramSetupWizard error: %v", err) } - if !result.Shown { - t.Fatalf("expected wizard to be shown") + if result.Shown { + t.Fatalf("expected wizard to not be shown") } - if result.ConfigLoaded { - t.Fatalf("expected ConfigLoaded=false for fallback mode") + if result.IdentityDetectError == "" { + t.Fatalf("expected IdentityDetectError to be set") } - if result.ConfigError == "" { - t.Fatalf("expected ConfigError to be set") + if result.ServerID != "" { + t.Fatalf("ServerID=%q, want empty", result.ServerID) } - if !result.TelegramEnabled { - t.Fatalf("expected TelegramEnabled=true") +} + +func TestRunTelegramSetupWizard_PropagatesBootstrapError(t *testing.T) { + stubTelegramSetupDeps(t) + + expectedErr := errors.New("bootstrap failed") + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return orchestrator.TelegramSetupBootstrap{}, expectedErr } - if result.TelegramMode != "personal" { - t.Fatalf("TelegramMode=%q, want personal", result.TelegramMode) + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + t.Fatalf("runner should not be called when bootstrap returns an error") + return nil } - if result.ServerAPIHost != "https://bot.tis24.it:1443" { - t.Fatalf("ServerAPIHost=%q, want default", result.ServerAPIHost) + + result, err := RunTelegramSetupWizard(context.Background(), t.TempDir(), "/fake/backup.env", "sig") + if !errors.Is(err, expectedErr) { + t.Fatalf("expected error %v, got %v", expectedErr, err) } - if result.ServerID != "123" { - t.Fatalf("ServerID=%q, want 123", result.ServerID) + if result != (TelegramSetupResult{}) { + t.Fatalf("expected empty result on bootstrap error, got %#v", result) } - if !result.IdentityPersisted { - t.Fatalf("expected IdentityPersisted=true") +} + +func TestRunTelegramSetupWizard_PassesContextToRunner(t *testing.T) { + stubTelegramSetupDeps(t) + + ctx := t.Context() + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return eligibleTelegramSetupBootstrap(), nil } - if result.Verified { - t.Fatalf("expected Verified=false") + telegramSetupWizardRunner = func(gotCtx context.Context, app *tui.App, root, focus tview.Primitive) error { + if gotCtx != ctx { + t.Fatalf("got context %p, want %p", gotCtx, ctx) + } + form := focus.(*tview.Form) + pressFormButton(t, form, "Skip") + return nil } - if result.SkippedVerification { - t.Fatalf("expected SkippedVerification=false") + + result, err := RunTelegramSetupWizard(ctx, t.TempDir(), "/fake/backup.env", "sig") + if err != nil { + t.Fatalf("RunTelegramSetupWizard error: %v", err) } - if result.CheckAttempts != 0 { - t.Fatalf("CheckAttempts=%d, want 0", result.CheckAttempts) + if !result.SkippedVerification { + t.Fatalf("expected SkippedVerification=true") } } func TestRunTelegramSetupWizard_CentralizedSuccess_RequiresCheckBeforeContinue(t *testing.T) { stubTelegramSetupDeps(t) - telegramSetupLoadConfig = func(path string) (*config.Config, error) { - return &config.Config{ - TelegramEnabled: true, - TelegramBotType: " ", - TelegramServerAPIHost: " https://api.example.test ", - }, nil - } - telegramSetupIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { - return &identity.Info{ServerID: " 987654321 ", IdentityFile: " /missing "}, nil - } - telegramSetupStat = func(path string) (os.FileInfo, error) { - return nil, os.ErrNotExist + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + state := eligibleTelegramSetupBootstrap() + state.ServerID = "987654321" + return state, nil } telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { if serverAPIHost != "https://api.example.test" { @@ -198,7 +297,7 @@ func TestRunTelegramSetupWizard_CentralizedSuccess_RequiresCheckBeforeContinue(t } return notify.TelegramRegistrationStatus{Code: 200, Message: "ok"} } - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { form := focus.(*tview.Form) if form.GetButtonIndex("Continue") != -1 { t.Fatalf("expected no Continue button before verification") @@ -255,31 +354,102 @@ func TestRunTelegramSetupWizard_CentralizedSuccess_RequiresCheckBeforeContinue(t } } +func TestRunTelegramSetupWizard_ShowsPersistedIdentityState(t *testing.T) { + stubTelegramSetupDeps(t) + + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + state := eligibleTelegramSetupBootstrap() + state.IdentityPersisted = true + return state, nil + } + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + serverIDView, _, form := extractTelegramSetupViews(t, root) + text := serverIDView.GetText(true) + if !strings.Contains(text, "persisted") { + t.Fatalf("expected persisted identity state, got %q", text) + } + if strings.Contains(text, "not persisted") { + t.Fatalf("did not expect non-persisted label, got %q", text) + } + + pressFormButton(t, form, "Skip") + return nil + } + + result, err := RunTelegramSetupWizard(context.Background(), t.TempDir(), "/fake/backup.env", "sig") + if err != nil { + t.Fatalf("RunTelegramSetupWizard error: %v", err) + } + if !result.IdentityPersisted { + t.Fatalf("expected IdentityPersisted=true") + } + if !result.SkippedVerification { + t.Fatalf("expected SkippedVerification=true") + } +} + +func TestRunTelegramSetupWizard_EscapesBracketedServerIdentityValues(t *testing.T) { + stubTelegramSetupDeps(t) + + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + state := eligibleTelegramSetupBootstrap() + state.ServerID = "srv[42]" + state.IdentityFile = "/tmp/identity[prod].key" + return state, nil + } + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + serverIDView, _, form := extractTelegramSetupViews(t, root) + + rawText := serverIDView.GetText(false) + if !strings.Contains(rawText, tview.Escape("srv[42]")) { + t.Fatalf("expected escaped server ID in raw text, got %q", rawText) + } + if !strings.Contains(rawText, tview.Escape("/tmp/identity[prod].key")) { + t.Fatalf("expected escaped identity file in raw text, got %q", rawText) + } + + plainText := serverIDView.GetText(true) + if !strings.Contains(plainText, "srv[42]") { + t.Fatalf("expected literal server ID in plain text, got %q", plainText) + } + if !strings.Contains(plainText, "/tmp/identity[prod].key") { + t.Fatalf("expected literal identity file in plain text, got %q", plainText) + } + + pressFormButton(t, form, "Skip") + return nil + } + + result, err := RunTelegramSetupWizard(context.Background(), t.TempDir(), "/fake/backup.env", "sig") + if err != nil { + t.Fatalf("RunTelegramSetupWizard error: %v", err) + } + if !result.SkippedVerification { + t.Fatalf("expected SkippedVerification=true") + } +} + func TestRunTelegramSetupWizard_CentralizedFailure_CanRetryAndSkip(t *testing.T) { stubTelegramSetupDeps(t) var calls int - telegramSetupLoadConfig = func(path string) (*config.Config, error) { - return &config.Config{ - TelegramEnabled: true, - TelegramBotType: "centralized", - TelegramServerAPIHost: "https://api.example.test", - }, nil - } - telegramSetupIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { - return &identity.Info{ServerID: "111222333"}, nil + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + state := eligibleTelegramSetupBootstrap() + state.ServerID = "111222333" + return state, nil } telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { calls++ - if calls == 1 { + switch calls { + case 1: return notify.TelegramRegistrationStatus{Code: 403, Error: errors.New("not registered")} - } - if calls == 2 { + case 2: return notify.TelegramRegistrationStatus{Code: 422, Message: "invalid"} + default: + return notify.TelegramRegistrationStatus{Code: 500, Message: "oops"} } - return notify.TelegramRegistrationStatus{Code: 500, Message: "oops"} } - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { form := focus.(*tview.Form) pressFormButton(t, form, "Check") pressFormButton(t, form, "Check") @@ -315,36 +485,36 @@ func TestRunTelegramSetupWizard_CentralizedFailure_CanRetryAndSkip(t *testing.T) } } -func TestRunTelegramSetupWizard_CentralizedMissingServerID_ExitsOnEscWithoutSkipping(t *testing.T) { +func TestRunTelegramSetupWizard_TruncatesLongFailureMessage(t *testing.T) { stubTelegramSetupDeps(t) - telegramSetupLoadConfig = func(path string) (*config.Config, error) { - return &config.Config{ - TelegramEnabled: true, - TelegramBotType: "centralized", - TelegramServerAPIHost: "https://api.example.test", - }, nil + longMessage := strings.Repeat("x", 320) + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return eligibleTelegramSetupBootstrap(), nil } - telegramSetupIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { - return nil, errors.New("detect failed") + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + return notify.TelegramRegistrationStatus{ + Code: 500, + Message: " " + longMessage + " ", + } } - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { - form := focus.(*tview.Form) - if form.GetButtonIndex("Check") != -1 { - t.Fatalf("expected no Check button without Server ID") + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + _, statusView, form := extractTelegramSetupViews(t, root) + + pressFormButton(t, form, "Check") + + text := statusView.GetText(true) + if !strings.Contains(text, "...(truncated)") { + t.Fatalf("expected truncated status, got %q", text) } - if form.GetButtonIndex("Skip") != -1 { - t.Fatalf("expected no Skip button without Server ID") + if !strings.Contains(text, "Skip verification and complete pairing later.") { + t.Fatalf("expected retry/skip hint, got %q", text) } - if form.GetButtonIndex("Continue") == -1 { - t.Fatalf("expected Continue button without Server ID") + if strings.Contains(text, longMessage) { + t.Fatalf("expected long message to be truncated, got %q", text) } - capture := app.GetInputCapture() - if capture == nil { - t.Fatalf("expected input capture to be set") - } - capture(tcell.NewEventKey(tcell.KeyEscape, 0, tcell.ModNone)) + pressFormButton(t, form, "Skip") return nil } @@ -352,39 +522,49 @@ func TestRunTelegramSetupWizard_CentralizedMissingServerID_ExitsOnEscWithoutSkip if err != nil { t.Fatalf("RunTelegramSetupWizard error: %v", err) } - if result.SkippedVerification { - t.Fatalf("expected SkippedVerification=false") - } if result.Verified { t.Fatalf("expected Verified=false") } - if result.CheckAttempts != 0 { - t.Fatalf("CheckAttempts=%d, want 0", result.CheckAttempts) + if result.CheckAttempts != 1 { + t.Fatalf("CheckAttempts=%d, want 1", result.CheckAttempts) } - if result.ServerID != "" { - t.Fatalf("ServerID=%q, want empty", result.ServerID) + if result.LastStatusCode != 500 { + t.Fatalf("LastStatusCode=%d, want 500", result.LastStatusCode) } - if result.IdentityDetectError == "" { - t.Fatalf("expected IdentityDetectError to be set") + if result.LastStatusMessage != " "+longMessage+" " { + t.Fatalf("LastStatusMessage=%q, want original message", result.LastStatusMessage) } } -func TestRunTelegramSetupWizard_CentralizedMissingServerID_CanContinueButton(t *testing.T) { +func TestRunTelegramSetupWizard_EscapesBracketedStatusMessage(t *testing.T) { stubTelegramSetupDeps(t) - telegramSetupLoadConfig = func(path string) (*config.Config, error) { - return &config.Config{ - TelegramEnabled: true, - TelegramBotType: "centralized", - TelegramServerAPIHost: "https://api.example.test", - }, nil + bracketedMessage := "bad [status] from body" + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return eligibleTelegramSetupBootstrap(), nil } - telegramSetupIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { - return &identity.Info{ServerID: ""}, nil + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + return notify.TelegramRegistrationStatus{ + Code: 500, + Message: bracketedMessage, + } } - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { - form := focus.(*tview.Form) - pressFormButton(t, form, "Continue") + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + _, statusView, form := extractTelegramSetupViews(t, root) + + pressFormButton(t, form, "Check") + + rawText := statusView.GetText(false) + if !strings.Contains(rawText, tview.Escape(bracketedMessage)) { + t.Fatalf("expected escaped status text in raw view, got %q", rawText) + } + + plainText := statusView.GetText(true) + if !strings.Contains(plainText, bracketedMessage) { + t.Fatalf("expected literal bracketed message in plain text, got %q", plainText) + } + + pressFormButton(t, form, "Skip") return nil } @@ -392,31 +572,18 @@ func TestRunTelegramSetupWizard_CentralizedMissingServerID_CanContinueButton(t * if err != nil { t.Fatalf("RunTelegramSetupWizard error: %v", err) } - if result.SkippedVerification { - t.Fatalf("expected SkippedVerification=false") - } - if result.Verified { - t.Fatalf("expected Verified=false") - } - if result.CheckAttempts != 0 { - t.Fatalf("CheckAttempts=%d, want 0", result.CheckAttempts) + if result.LastStatusMessage != bracketedMessage { + t.Fatalf("LastStatusMessage=%q, want %q", result.LastStatusMessage, bracketedMessage) } } func TestRunTelegramSetupWizard_CentralizedEscSkipsWhenNotVerified(t *testing.T) { stubTelegramSetupDeps(t) - telegramSetupLoadConfig = func(path string) (*config.Config, error) { - return &config.Config{ - TelegramEnabled: true, - TelegramBotType: "centralized", - TelegramServerAPIHost: "https://api.example.test", - }, nil + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return eligibleTelegramSetupBootstrap(), nil } - telegramSetupIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { - return &identity.Info{ServerID: "123456"}, nil - } - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { capture := app.GetInputCapture() if capture == nil { t.Fatalf("expected input capture to be set") @@ -446,20 +613,55 @@ func TestRunTelegramSetupWizard_CentralizedEscSkipsWhenNotVerified(t *testing.T) } } -func TestRunTelegramSetupWizard_PropagatesRunnerError(t *testing.T) { +func TestRunTelegramSetupWizard_CentralizedEscAfterVerificationDoesNotSkip(t *testing.T) { stubTelegramSetupDeps(t) - telegramSetupLoadConfig = func(path string) (*config.Config, error) { - return &config.Config{ - TelegramEnabled: true, - TelegramBotType: "centralized", - TelegramServerAPIHost: "https://api.example.test", - }, nil + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return eligibleTelegramSetupBootstrap(), nil + } + telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { + return notify.TelegramRegistrationStatus{Code: 200, Message: "ok"} + } + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { + _, _, form := extractTelegramSetupViews(t, root) + + pressFormButton(t, form, "Check") + if form.GetButtonIndex("Continue") == -1 { + t.Fatalf("expected Continue button after verification") + } + + capture := app.GetInputCapture() + if capture == nil { + t.Fatalf("expected input capture to be set") + } + if got := capture(tcell.NewEventKey(tcell.KeyEscape, 0, tcell.ModNone)); got != nil { + t.Fatalf("expected ESC to be consumed, got %#v", got) + } + return nil + } + + result, err := RunTelegramSetupWizard(context.Background(), t.TempDir(), "/fake/backup.env", "sig") + if err != nil { + t.Fatalf("RunTelegramSetupWizard error: %v", err) } - telegramSetupIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { - return &identity.Info{ServerID: "123456"}, nil + if !result.Verified { + t.Fatalf("expected Verified=true") + } + if result.SkippedVerification { + t.Fatalf("expected SkippedVerification=false after ESC on verified flow") + } + if result.CheckAttempts != 1 { + t.Fatalf("CheckAttempts=%d, want 1", result.CheckAttempts) + } +} + +func TestRunTelegramSetupWizard_PropagatesRunnerError(t *testing.T) { + stubTelegramSetupDeps(t) + + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + return eligibleTelegramSetupBootstrap(), nil } - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { return errors.New("runner failed") } @@ -480,22 +682,16 @@ func TestRunTelegramSetupWizard_CheckIgnoredWhileChecking_AndUpdateSuppressedAft telegramSetupGo = func(fn func()) { pending = fn } telegramSetupQueueUpdateDraw = func(app *tui.App, f func()) { f() } - - telegramSetupLoadConfig = func(path string) (*config.Config, error) { - return &config.Config{ - TelegramEnabled: true, - TelegramBotType: "centralized", - TelegramServerAPIHost: "https://api.example.test", - }, nil - } - telegramSetupIdentityDetect = func(baseDir string, logger *logging.Logger) (*identity.Info, error) { - return &identity.Info{ServerID: "999888777"}, nil + telegramSetupBuildBootstrap = func(configPath, baseDir string) (orchestrator.TelegramSetupBootstrap, error) { + state := eligibleTelegramSetupBootstrap() + state.ServerID = "999888777" + return state, nil } telegramSetupCheckRegistration = func(ctx context.Context, serverAPIHost, serverID string, logger *logging.Logger) notify.TelegramRegistrationStatus { checkCalls++ return notify.TelegramRegistrationStatus{Code: 200, Message: "ok"} } - telegramSetupWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + telegramSetupWizardRunner = func(ctx context.Context, app *tui.App, root, focus tview.Primitive) error { form := focus.(*tview.Form) pressFormButton(t, form, "Check") @@ -503,10 +699,10 @@ func TestRunTelegramSetupWizard_CheckIgnoredWhileChecking_AndUpdateSuppressedAft t.Fatalf("expected pending check goroutine") } - pressFormButton(t, form, "Check") // should be ignored while checking=true - pressFormButton(t, form, "Skip") // closes the wizard + pressFormButton(t, form, "Check") + pressFormButton(t, form, "Skip") - pending() // simulate late completion after closing + pending() return nil } @@ -527,3 +723,49 @@ func TestRunTelegramSetupWizard_CheckIgnoredWhileChecking_AndUpdateSuppressedAft t.Fatalf("checkCalls=%d, want 1", checkCalls) } } + +func TestTelegramSetupDefaultWrappers(t *testing.T) { + done := make(chan struct{}) + telegramSetupGo(func() { + close(done) + }) + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timed out waiting for telegramSetupGo") + } + + app := tui.NewApp() + app.SetScreen(tcell.NewSimulationScreen("UTF-8")) + root := tview.NewBox() + + updateQueued := make(chan struct{}) + updateDone := make(chan struct{}) + go func() { + close(updateQueued) + telegramSetupQueueUpdateDraw(app, func() { + close(updateDone) + app.Stop() + }) + }() + <-updateQueued + + go func() { + select { + case <-updateDone: + return + case <-time.After(100 * time.Millisecond): + app.Stop() + } + }() + + if err := telegramSetupWizardRunner(context.Background(), app, root, root); err != nil { + t.Fatalf("telegramSetupWizardRunner error: %v", err) + } + + select { + case <-updateDone: + case <-time.After(time.Second): + t.Fatal("timed out waiting for telegramSetupQueueUpdateDraw") + } +}