diff --git a/.github/instructions/codacy.instructions.md b/.github/instructions/codacy.instructions.md index 9e27a080..a5391f1f 100644 --- a/.github/instructions/codacy.instructions.md +++ b/.github/instructions/codacy.instructions.md @@ -7,9 +7,9 @@ Configuration for AI behavior when interacting with Codacy's MCP Server ## CRITICAL: After ANY successful `edit_file` or `reapply` operation - YOU MUST IMMEDIATELY run the `codacy_cli_analyze` tool from Codacy's MCP Server for each file that was edited, with: - - `rootPath`: set to the workspace path - - `file`: set to the path of the edited file - - `tool`: leave empty or unset +- `rootPath`: set to the workspace path +- `file`: set to the path of the edited file +- `tool`: leave empty or unset - If any issues are found in the new edits, propose and apply fixes for them. - > NOTE: Failure to follow this rule is considered a critical error. @@ -24,8 +24,8 @@ Configuration for AI behavior when interacting with Codacy's MCP Server ## When there are no Codacy MCP Server tools available, or the MCP Server is not reachable - Suggest the user the following troubleshooting steps: - - Try to reset the MCP on the extension - - If the user is using VSCode, suggest them to review their Copilot > MCP settings in Github, under their organization or personal account. Refer them to Settings > Copilot > Enable MCP servers in Copilot. Suggested URL (https://github.com/settings/copilot/features) or https://github.com/organizations/{organization-name}/settings/copilot/features (This can only be done by their organization admins / owners) +- Try to reset the MCP on the extension +- If the user is using VSCode, suggest them to review their Copilot > MCP settings in Github, under their organization or personal account. Refer them to Settings > Copilot > Enable MCP servers in Copilot. Suggested URL (https://github.com/settings/copilot/features) or https://github.com/organizations/{organization-name}/settings/copilot/features (This can only be done by their organization admins / owners) - If none of the above steps work, suggest the user to contact Codacy support ## Trying to call a tool that needs a rootPath as a parameter @@ -33,24 +33,24 @@ Configuration for AI behavior when interacting with Codacy's MCP Server ## CRITICAL: Dependencies and Security Checks - IMMEDIATELY after ANY of these actions: - - Running npm/yarn/pnpm install - - Adding dependencies to package.json - - Adding requirements to requirements.txt - - Adding dependencies to pom.xml - - Adding dependencies to build.gradle - - Any other package manager operations +- Running npm/yarn/pnpm install +- Adding dependencies to package.json +- Adding requirements to requirements.txt +- Adding dependencies to pom.xml +- Adding dependencies to build.gradle +- Any other package manager operations - You MUST run the `codacy_cli_analyze` tool with: - - `rootPath`: set to the workspace path - - `tool`: set to "trivy" - - `file`: leave empty or unset +- `rootPath`: set to the workspace path +- `tool`: set to "trivy" +- `file`: leave empty or unset - If any vulnerabilities are found because of the newly added packages: - - Stop all other operations - - Propose and apply fixes for the security issues - - Only continue with the original task after security issues are resolved +- Stop all other operations +- Propose and apply fixes for the security issues +- Only continue with the original task after security issues are resolved - EXAMPLE: - - After: npm install react-markdown - - Do: Run codacy_cli_analyze with trivy - - Before: Continuing with any other tasks +- After: npm install react-markdown +- Do: Run codacy_cli_analyze with trivy +- Before: Continuing with any other tasks ## General - Repeat the relevant steps for each modified file. diff --git a/.github/workflows/autotag.yml b/.github/workflows/autotag.yml index 29407b1c..654aee19 100644 --- a/.github/workflows/autotag.yml +++ b/.github/workflows/autotag.yml @@ -15,7 +15,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd with: fetch-depth: 0 # necessario per leggere commit + tag diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index ccac49ca..f9abf810 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -11,10 +11,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd - name: Setup Go - uses: actions/setup-go@v6 + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c with: go-version-file: 'go.mod' @@ -34,7 +34,7 @@ jobs: go test $(go list ./... | grep -v -E '/cmd/|/pbs$|/bech32$|^github.com/tis24dev/proxsave$') -coverprofile=coverage.out - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v6 + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6 with: token: ${{ secrets.CODECOV_TOKEN }} files: coverage.out diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml index 2301c6f5..3a2f2f4b 100644 --- a/.github/workflows/dependency-review.yml +++ b/.github/workflows/dependency-review.yml @@ -18,10 +18,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd - name: Dependency Review - uses: actions/dependency-review-action@v4 + uses: actions/dependency-review-action@2031cfc080254a8a887f58cffee85186f0e49e48 with: # Blocca solo severity critical (zero-touch per gli altri) fail-on-severity: critical diff --git a/.github/workflows/race.yml b/.github/workflows/race.yml new file mode 100644 index 00000000..4ad0e9bb --- /dev/null +++ b/.github/workflows/race.yml @@ -0,0 +1,57 @@ +name: Race Detector + +run-name: Race detector - ${{ github.ref_name }} + +"on": + push: + branches: + - main + - dev + pull_request: {} + workflow_dispatch: + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash + +jobs: + race: + name: Go race detector + runs-on: ubuntu-latest + timeout-minutes: 30 + + env: + CGO_ENABLED: "1" + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd + with: + persist-credentials: false + + - name: Setup Go + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c + with: + go-version-file: go.mod + cache: true + cache-dependency-path: | + go.sum + **/go.sum + + - name: Show Go environment + run: | + go version + go env GOTOOLCHAIN CGO_ENABLED GOOS GOARCH + + - name: Download dependencies + run: go mod download + + - name: Run race detector + run: go test -race -count=1 ./... diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d4c15779..72ade9ea 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -24,7 +24,7 @@ jobs: # CHECKOUT (fetch-depth 0 per changelog e GoReleaser) ######################################## - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd with: fetch-depth: 0 @@ -45,7 +45,7 @@ jobs: # SETUP GO ######################################## - name: Set up Go - uses: actions/setup-go@v6 + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c with: go-version-file: 'go.mod' @@ -62,7 +62,7 @@ jobs: # INSTALL SYFT (per SBOM CycloneDX via GoReleaser) ######################################## - name: Install Syft (for SBOM generation) - uses: anchore/sbom-action/download-syft@v0 + uses: anchore/sbom-action/download-syft@e22c389904149dbc22b58101806040fa8d37a610 # v0 with: syft-version: v1.19.0 @@ -70,7 +70,7 @@ jobs: # GORELEASER ######################################## - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v7 + uses: goreleaser/goreleaser-action@1a80836c5c9d9e5755a25cb59ec6f45a3b5f41a8 # v7 with: version: latest workdir: ${{ github.workspace }} @@ -82,6 +82,6 @@ jobs: # ATTESTAZIONE PROVENIENZA BUILD ######################################## - name: Attest Build Provenance - uses: actions/attest-build-provenance@v4 + uses: actions/attest-build-provenance@a2bbfa25375fe432b6a289bc6b6cd05ecd0c4c32 with: subject-path: build/proxsave_* diff --git a/.github/workflows/security-ultimate.yml b/.github/workflows/security-ultimate.yml index 45023a66..18fe733c 100644 --- a/.github/workflows/security-ultimate.yml +++ b/.github/workflows/security-ultimate.yml @@ -21,13 +21,13 @@ jobs: # CHECKOUT ######################################## - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd ######################################## # GO 1.25 — MAIN TOOLCHAIN ######################################## - name: Set up Go (from go.mod) - uses: actions/setup-go@v6 + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c with: go-version-file: 'go.mod' @@ -56,7 +56,7 @@ jobs: # GOSEC — RUN USING GO 1.21 (NO DOCKER) ######################################## - name: Set up Go 1.21 for GoSec - uses: actions/setup-go@v6 + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c with: go-version: "1.21" @@ -88,7 +88,7 @@ jobs: # UPLOAD SARIF ######################################## - name: Upload GoSec SARIF - uses: github/codeql-action/upload-sarif@v4 + uses: github/codeql-action/upload-sarif@68bde559dea0fdcac2102bfdf6230c5f70eb485e with: sarif_file: gosec.sarif @@ -96,7 +96,7 @@ jobs: # RESTORE GO 1.25 FOR CODEQL ######################################## - name: Restore Go 1.25 for CodeQL - uses: actions/setup-go@v6 + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c with: go-version-file: 'go.mod' @@ -104,7 +104,7 @@ jobs: # CODEQL ######################################## - name: Initialize CodeQL - uses: github/codeql-action/init@v4 + uses: github/codeql-action/init@68bde559dea0fdcac2102bfdf6230c5f70eb485e with: languages: go @@ -114,4 +114,4 @@ jobs: go build ./... - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v4 + uses: github/codeql-action/analyze@68bde559dea0fdcac2102bfdf6230c5f70eb485e diff --git a/.github/workflows/sync-dev.yml b/.github/workflows/sync-dev.yml index 223a39cc..6c33326f 100644 --- a/.github/workflows/sync-dev.yml +++ b/.github/workflows/sync-dev.yml @@ -15,7 +15,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd with: fetch-depth: 0 diff --git a/cmd/proxsave/main.go b/cmd/proxsave/main.go index a5c3f2cd..fdec4d60 100644 --- a/cmd/proxsave/main.go +++ b/cmd/proxsave/main.go @@ -10,7 +10,7 @@ import ( const ( defaultLegacyEnvPath = "/opt/proxsave/env/backup.env" legacyEnvFallbackPath = "/opt/proxmox-backup/env/backup.env" - goRuntimeMinVersion = "1.25.5" + goRuntimeMinVersion = "1.25.10" networkPreflightTimeout = 2 * time.Second bytesPerMegabyte int64 = 1024 * 1024 defaultDirPerm = 0o755 diff --git a/cmd/proxsave/main_defers.go b/cmd/proxsave/main_defers.go index e138610a..2df746b8 100644 --- a/cmd/proxsave/main_defers.go +++ b/cmd/proxsave/main_defers.go @@ -85,7 +85,11 @@ func closeRunProfiling(rt *appRuntime) { logging.Warning("Failed to create heap profile file: %v", err) return } - defer f.Close() + defer func() { + if err := f.Close(); err != nil { + logging.Warning("Failed to close heap profile file: %v", err) + } + }() if err := pprof.WriteHeapProfile(f); err != nil { logging.Warning("Failed to write heap profile: %v", err) } diff --git a/cmd/proxsave/main_runtime.go b/cmd/proxsave/main_runtime.go index f1a174d8..96dc9b8b 100644 --- a/cmd/proxsave/main_runtime.go +++ b/cmd/proxsave/main_runtime.go @@ -265,10 +265,10 @@ func buildHeapProfilePath(rt *appRuntime) string { // checkGoRuntimeVersion ensures the running binary was built with at least the specified Go version (semver: major.minor.patch). func checkGoRuntimeVersion(minimum string) error { - rt := runtime.Version() // e.g., "go1.25.4" + rt := runtime.Version() // e.g., "go1.25.10" // Normalize versions to x.y.z parse := func(v string) (int, int, int) { - // Accept forms: go1.25.4, go1.25, 1.25.4, 1.25 + // Accept forms: go1.25.10, go1.25, 1.25.10, 1.25 v = strings.TrimPrefix(v, "go") parts := strings.Split(v, ".") toInt := func(s string) int { n, _ := strconv.Atoi(s); return n } diff --git a/cmd/proxsave/runtime_helpers.go b/cmd/proxsave/runtime_helpers.go index 7a778d66..2de51d52 100644 --- a/cmd/proxsave/runtime_helpers.go +++ b/cmd/proxsave/runtime_helpers.go @@ -84,10 +84,7 @@ func detectExecInfo() ExecInfo { originalDir := dir baseDir := "" - for { - if dir == "" || dir == "." || dir == string(filepath.Separator) { - break - } + for dir != "" && dir != "." { if info, err := os.Stat(filepath.Join(dir, "env")); err == nil && info.IsDir() { baseDir = dir break @@ -1085,7 +1082,7 @@ func executableHash() string { if err != nil { return "" } - defer f.Close() + defer func() { _ = f.Close() }() h := sha256.New() if _, err := io.Copy(h, f); err != nil { return "" diff --git a/cmd/proxsave/upgrade.go b/cmd/proxsave/upgrade.go index b1bab7f6..dbfcefb3 100644 --- a/cmd/proxsave/upgrade.go +++ b/cmd/proxsave/upgrade.go @@ -227,7 +227,11 @@ func downloadAndInstallLatest(ctx context.Context, execPath string, bootstrap *l if err != nil { return "", fmt.Errorf("cannot create temp dir: %w", err) } - defer os.RemoveAll(tmpDir) + defer func() { + if removeErr := os.RemoveAll(tmpDir); removeErr != nil { + bootstrap.Debug("Failed to remove temporary upgrade directory %s: %v", tmpDir, removeErr) + } + }() logging.DebugStepBootstrap(bootstrap, "upgrade download/install", "temp dir=%s", tmpDir) archivePath := filepath.Join(tmpDir, filename) @@ -288,7 +292,7 @@ func fetchLatestRelease(ctx context.Context) (string, string, error) { if err != nil { return "", "", fmt.Errorf("failed to fetch latest release: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) @@ -381,7 +385,7 @@ func downloadFile(ctx context.Context, url, dest string, bootstrap *logging.Boot if err != nil { return fmt.Errorf("download failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() logging.DebugStepBootstrap(bootstrap, "upgrade download", "status=%s", resp.Status) if resp.StatusCode != http.StatusOK { @@ -393,7 +397,7 @@ func downloadFile(ctx context.Context, url, dest string, bootstrap *logging.Boot if err != nil { return fmt.Errorf("cannot create file %s: %w", dest, err) } - defer out.Close() + defer closeIntoErr(&err, out, "close downloaded file") written, err := io.Copy(out, resp.Body) if err != nil { @@ -436,7 +440,7 @@ func verifyChecksum(archivePath, checksumPath, filename string, bootstrap *loggi if err != nil { return fmt.Errorf("cannot open archive for checksum: %w", err) } - defer f.Close() + defer closeIntoErr(&err, f, "close archive for checksum") hasher := sha256.New() if _, err := io.Copy(hasher, f); err != nil { @@ -459,13 +463,13 @@ func extractBinaryFromTar(archivePath, targetName, destPath string, bootstrap *l if err != nil { return fmt.Errorf("cannot open archive: %w", err) } - defer f.Close() + defer closeIntoErr(&err, f, "close release archive") gzr, err := gzip.NewReader(f) if err != nil { return fmt.Errorf("cannot create gzip reader: %w", err) } - defer gzr.Close() + defer closeIntoErr(&err, gzr, "close release gzip reader") tr := tar.NewReader(gzr) for { @@ -489,7 +493,7 @@ func extractBinaryFromTar(archivePath, targetName, destPath string, bootstrap *l return fmt.Errorf("cannot create extracted binary: %w", err) } if _, err := io.Copy(tmpFile, tr); err != nil { - tmpFile.Close() + _ = tmpFile.Close() return fmt.Errorf("cannot write extracted binary: %w", err) } if err := tmpFile.Close(); err != nil { @@ -513,7 +517,7 @@ func installBinary(srcPath, destPath string, bootstrap *logging.BootstrapLogger) if err != nil { return fmt.Errorf("cannot open extracted binary: %w", err) } - defer src.Close() + defer closeIntoErr(&err, src, "close extracted binary") dst, err := os.OpenFile(tmpDest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) if err != nil { @@ -521,7 +525,7 @@ func installBinary(srcPath, destPath string, bootstrap *logging.BootstrapLogger) } if _, err := io.Copy(dst, src); err != nil { - dst.Close() + _ = dst.Close() return fmt.Errorf("cannot copy binary to temp target: %w", err) } if err := dst.Close(); err != nil { @@ -534,6 +538,15 @@ func installBinary(srcPath, destPath string, bootstrap *logging.BootstrapLogger) return nil } +func closeIntoErr(errp *error, closer io.Closer, operation string) { + if errp == nil || closer == nil { + return + } + if closeErr := closer.Close(); closeErr != nil && *errp == nil { + *errp = fmt.Errorf("%s: %w", operation, closeErr) + } +} + func printUpgradeFooter(upgradeErr error, version, configPath, baseDir, telegramCode, permStatus, permMessage string, cfgUpgradeResult *config.UpgradeResult, cfgUpgradeErr error) { colorReset := "\033[0m" diff --git a/docs/INSTALL.md b/docs/INSTALL.md index e09b6b79..86a0ff10 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -154,8 +154,8 @@ For more details, see [CLI Reference - Binary Upgrade](CLI_REFERENCE.md#binary-u ```bash # Install Go (if building from source) -wget https://go.dev/dl/go1.25.4.linux-amd64.tar.gz -tar -C /usr/local -xzf go1.25.4.linux-amd64.tar.gz +wget https://go.dev/dl/go1.25.10.linux-amd64.tar.gz +tar -C /usr/local -xzf go1.25.10.linux-amd64.tar.gz export PATH=$PATH:/usr/local/go/bin # Install rclone (for cloud storage) @@ -168,7 +168,7 @@ apt update && apt install -y git apt update && apt install -y make # Verify installations -go version # Should show go1.25+ +go version # Should show go1.25.10+ rclone version # Should show rclone v1.50+ git --version # Should show git 2.47.3+ make --version # Should show make 4.4.1+ diff --git a/docs/RESTORE_GUIDE.md b/docs/RESTORE_GUIDE.md index 05bc7172..b514c52e 100644 --- a/docs/RESTORE_GUIDE.md +++ b/docs/RESTORE_GUIDE.md @@ -92,6 +92,9 @@ Examples: - `dual` backup on `pbs` host: restore `PBS + Common` - `pve` backup on `dual` host: restore `PVE + Common` +When compatibility is partial, ProxSave automatically filters selectable +restore categories to the roles supported by the current host. + `unknown` hosts can still use export-oriented or common-only workflows, but ProxSave warns because role-specific compatibility cannot be verified. @@ -2499,9 +2502,10 @@ systemctl restart proxmox-backup proxmox-backup-proxy **Q: Can I restore PVE backup to PBS system (or vice versa)?** -A: Direct cross-role restore is still not recommended. PVE and PBS have -different role-specific configurations. However, ProxSave now evaluates -compatibility by **role overlap**: +A: Pure cross-role restore (no role overlap) is not recommended; however +ProxSave supports restores when roles overlap. PVE and PBS have different +role-specific configurations. ProxSave now evaluates compatibility by +**role overlap**: - `pve` ↔ `pbs`: only common categories are sensible - `dual` → `pve`: PVE + Common can be restored diff --git a/embed.go b/embed.go index b493a4c1..22dce604 100644 --- a/embed.go +++ b/embed.go @@ -1,3 +1,4 @@ +// Package proxmoxbackup embeds installable project documentation. package proxmoxbackup import ( @@ -34,7 +35,9 @@ var installableDocs = func() []DocAsset { }() // InstallableDocs returns the list of documentation files embedded in the -// binary that should be written to the installation root. +// binary that should be written to the installation root. The returned DocAsset +// slice is copied, but DocAsset.Data shares the embedded installableDocs backing +// data; callers must treat Data as read-only or copy it before mutation. func InstallableDocs() []DocAsset { out := make([]DocAsset, len(installableDocs)) copy(out, installableDocs) diff --git a/go.mod b/go.mod index 8bc70ffe..3016aaff 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/tis24dev/proxsave -go 1.25.0 - -toolchain go1.25.9 +go 1.25.10 require ( filippo.io/age v1.3.1 diff --git a/internal/backup/archiver.go b/internal/backup/archiver.go index b63d7826..5a99c3ef 100644 --- a/internal/backup/archiver.go +++ b/internal/backup/archiver.go @@ -15,6 +15,7 @@ import ( "time" "filippo.io/age" + "github.com/tis24dev/proxsave/internal/closeerr" "github.com/tis24dev/proxsave/internal/logging" "github.com/tis24dev/proxsave/internal/safeexec" "github.com/tis24dev/proxsave/internal/types" @@ -22,6 +23,8 @@ import ( var lookPath = exec.LookPath +var closeIntoErr = closeerr.CloseIntoErr + // ArchiverDeps groups external dependencies used by Archiver. type ArchiverDeps struct { LookPath func(string) (string, error) @@ -443,7 +446,7 @@ func (a *Archiver) createGzipArchive(ctx context.Context, sourceDir, outputPath if err != nil { return fmt.Errorf("failed to create output file: %w", err) } - defer outFile.Close() + defer closeIntoErr(&err, outFile, "close output archive") writer, finalizeEncryption, err := a.wrapEncryptionWriter(outFile) if err != nil { @@ -464,7 +467,7 @@ func (a *Archiver) createGzipArchive(ctx context.Context, sourceDir, outputPath if err != nil { return fmt.Errorf("failed to create gzip writer: %w", err) } - defer gzWriter.Close() + defer closeIntoErr(&err, gzWriter, "close gzip writer") // Stream tar content into gzip writer if err := a.writeTar(ctx, sourceDir, gzWriter); err != nil { @@ -493,7 +496,7 @@ func (a *Archiver) createTarArchive(ctx context.Context, sourceDir, outputPath s if err != nil { return fmt.Errorf("failed to create output file: %w", err) } - defer outFile.Close() + defer closeIntoErr(&err, outFile, "close output archive") writer, finalizeEncryption, err := a.wrapEncryptionWriter(outFile) if err != nil { @@ -576,7 +579,7 @@ func (a *Archiver) createXZArchive(ctx context.Context, sourceDir, outputPath st if err != nil { return fmt.Errorf("failed to create output file: %w", err) } - defer outFile.Close() + defer closeIntoErr(&err, outFile, "close output archive") pr, pw := io.Pipe() cmd.Stdin = pr @@ -600,15 +603,15 @@ func (a *Archiver) createXZArchive(ctx context.Context, sourceDir, outputPath st defer close(errChan) err := a.writeTar(ctx, sourceDir, pw) if err != nil { - pw.CloseWithError(err) + _ = pw.CloseWithError(err) } else { - pw.Close() + err = pw.Close() } errChan <- err }() if err := cmd.Start(); err != nil { - pw.Close() + _ = pw.Close() if startErr := <-errChan; startErr != nil { return startErr } @@ -649,7 +652,7 @@ func (a *Archiver) createZstdArchive(ctx context.Context, sourceDir, outputPath if err != nil { return fmt.Errorf("failed to create output file: %w", err) } - defer outFile.Close() + defer closeIntoErr(&err, outFile, "close output archive") pr, pw := io.Pipe() cmd.Stdin = pr @@ -673,15 +676,15 @@ func (a *Archiver) createZstdArchive(ctx context.Context, sourceDir, outputPath defer close(errChan) err := a.writeTar(ctx, sourceDir, pw) if err != nil { - pw.CloseWithError(err) + _ = pw.CloseWithError(err) } else { - pw.Close() + err = pw.Close() } errChan <- err }() if err := cmd.Start(); err != nil { - pw.Close() + _ = pw.Close() if startErr := <-errChan; startErr != nil { return startErr } @@ -730,7 +733,7 @@ func (a *Archiver) pipeTarThroughCommand(ctx context.Context, sourceDir, outputP if err != nil { return fmt.Errorf("failed to create output file: %w", err) } - defer outFile.Close() + defer closeIntoErr(&err, outFile, "close output archive") pr, pw := io.Pipe() cmd.Stdin = pr @@ -756,16 +759,15 @@ func (a *Archiver) pipeTarThroughCommand(ctx context.Context, sourceDir, outputP go func() { defer close(errChan) if err := a.writeTar(ctx, sourceDir, pw); err != nil { - pw.CloseWithError(err) + _ = pw.CloseWithError(err) errChan <- err return } - pw.Close() - errChan <- nil + errChan <- pw.Close() }() if err := cmd.Start(); err != nil { - pw.Close() + _ = pw.Close() if startErr := <-errChan; startErr != nil { return startErr } @@ -889,12 +891,16 @@ func (a *Archiver) addToTar(ctx context.Context, tarWriter *tar.Writer, sourceDi a.logger.Warning("Failed to open file %s: %v", path, err) return nil } - defer file.Close() if _, err := io.Copy(tarWriter, file); err != nil { + _ = file.Close() a.logger.Warning("Failed to write file %s to archive: %v", path, err) return nil } + if err := file.Close(); err != nil { + a.logger.Warning("Failed to close file %s after archiving: %v", path, err) + return nil + } a.logger.Debug("Added file to archive: %s", archivePath) } else if linkInfo.Mode()&os.ModeSymlink != 0 { diff --git a/internal/backup/archiver_test.go b/internal/backup/archiver_test.go index 6a29842f..fb09de31 100644 --- a/internal/backup/archiver_test.go +++ b/internal/backup/archiver_test.go @@ -83,9 +83,15 @@ func TestCreateTarArchive(t *testing.T) { // Create test files testDir := filepath.Join(tempDir, "source") - os.MkdirAll(filepath.Join(testDir, "subdir"), 0755) - os.WriteFile(filepath.Join(testDir, "file1.txt"), []byte("content1"), 0644) - os.WriteFile(filepath.Join(testDir, "subdir", "file2.txt"), []byte("content2"), 0644) + if err := os.MkdirAll(filepath.Join(testDir, "subdir"), 0755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } + if err := os.WriteFile(filepath.Join(testDir, "file1.txt"), []byte("content1"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + if err := os.WriteFile(filepath.Join(testDir, "subdir", "file2.txt"), []byte("content2"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } // Create archive outputPath := filepath.Join(tempDir, "test.tar") @@ -138,7 +144,7 @@ func TestCreateTarArchiveRespectsExcludePatterns(t *testing.T) { if err != nil { t.Fatalf("open archive: %v", err) } - defer f.Close() + defer func() { _ = f.Close() }() found := map[string]bool{} tr := tar.NewReader(f) @@ -178,8 +184,12 @@ func TestCreateGzipArchive(t *testing.T) { // Create test files testDir := filepath.Join(tempDir, "source") - os.MkdirAll(testDir, 0755) - os.WriteFile(filepath.Join(testDir, "file.txt"), []byte("test content"), 0644) + if err := os.MkdirAll(testDir, 0755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } + if err := os.WriteFile(filepath.Join(testDir, "file.txt"), []byte("test content"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } // Create archive outputPath := filepath.Join(tempDir, "test.tar.gz") @@ -199,7 +209,7 @@ func TestCreateGzipArchive(t *testing.T) { if err != nil { t.Fatalf("Failed to open archive: %v", err) } - defer f.Close() + defer func() { _ = f.Close() }() _, err = gzip.NewReader(f) if err != nil { @@ -364,13 +374,19 @@ func TestVerifyArchive(t *testing.T) { // Create a test archive testDir := filepath.Join(tempDir, "source") - os.MkdirAll(testDir, 0755) - os.WriteFile(filepath.Join(testDir, "file.txt"), []byte("test"), 0644) + if err := os.MkdirAll(testDir, 0755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } + if err := os.WriteFile(filepath.Join(testDir, "file.txt"), []byte("test"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } outputPath := filepath.Join(tempDir, "test.tar") ctx := context.Background() - archiver.CreateArchive(ctx, testDir, outputPath) + if err := archiver.CreateArchive(ctx, testDir, outputPath); err != nil { + t.Fatalf("CreateArchive failed: %v", err) + } // Verify it if err := archiver.VerifyArchive(ctx, outputPath); err != nil { @@ -405,14 +421,20 @@ func TestGetArchiveSize(t *testing.T) { // Create a test archive testDir := filepath.Join(tempDir, "source") - os.MkdirAll(testDir, 0755) + if err := os.MkdirAll(testDir, 0755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } content := []byte("test content with some length") - os.WriteFile(filepath.Join(testDir, "file.txt"), content, 0644) + if err := os.WriteFile(filepath.Join(testDir, "file.txt"), content, 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } outputPath := filepath.Join(tempDir, "test.tar") ctx := context.Background() - archiver.CreateArchive(ctx, testDir, outputPath) + if err := archiver.CreateArchive(ctx, testDir, outputPath); err != nil { + t.Fatalf("CreateArchive failed: %v", err) + } // Get size size, err := archiver.GetArchiveSize(outputPath) @@ -437,8 +459,12 @@ func TestDryRunMode(t *testing.T) { // Create test files testDir := filepath.Join(tempDir, "source") - os.MkdirAll(testDir, 0755) - os.WriteFile(filepath.Join(testDir, "file.txt"), []byte("test"), 0644) + if err := os.MkdirAll(testDir, 0755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } + if err := os.WriteFile(filepath.Join(testDir, "file.txt"), []byte("test"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } // Try to create archive in dry-run mode outputPath := filepath.Join(tempDir, "test.tar") @@ -539,9 +565,13 @@ func TestContextCancellation(t *testing.T) { // Create a large test directory to ensure cancellation can happen testDir := filepath.Join(tempDir, "source") - os.MkdirAll(testDir, 0755) + if err := os.MkdirAll(testDir, 0755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } for i := 0; i < 100; i++ { - os.WriteFile(filepath.Join(testDir, fmt.Sprintf("file%d.txt", i)), []byte("test content"), 0644) + if err := os.WriteFile(filepath.Join(testDir, fmt.Sprintf("file%d.txt", i)), []byte("test content"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } } // Create a context that we'll cancel immediately @@ -564,7 +594,7 @@ func verifyTarContent(tarPath string, expectedFiles []string) error { if err != nil { return err } - defer f.Close() + defer func() { _ = f.Close() }() tr := tar.NewReader(f) found := make(map[string]bool) @@ -677,12 +707,12 @@ func TestEncryptedArchiveRejectsWrongIdentity(t *testing.T) { } } -func decryptArchiveForTest(src, dst string, identity age.Identity) error { +func decryptArchiveForTest(src, dst string, identity age.Identity) (err error) { in, err := os.Open(src) if err != nil { return err } - defer in.Close() + defer closeIntoErr(&err, in, "close encrypted test archive") reader, err := age.Decrypt(in, identity) if err != nil { @@ -693,7 +723,7 @@ func decryptArchiveForTest(src, dst string, identity age.Identity) error { if err != nil { return err } - defer out.Close() + defer closeIntoErr(&err, out, "close decrypted test archive") if _, err := io.Copy(out, reader); err != nil { return err diff --git a/internal/backup/archiver_verification_test.go b/internal/backup/archiver_verification_test.go index c1536995..223822b6 100644 --- a/internal/backup/archiver_verification_test.go +++ b/internal/backup/archiver_verification_test.go @@ -246,14 +246,16 @@ func TestVerifyGzipArchive_CorruptedTar(t *testing.T) { if err != nil { t.Fatal(err) } - defer file.Close() + defer func() { _ = file.Close() }() gzipWriter := gzip.NewWriter(file) _, err = gzipWriter.Write([]byte("corrupted tar content")) if err != nil { t.Fatal(err) } - gzipWriter.Close() + if err := gzipWriter.Close(); err != nil { + t.Fatal(err) + } logger := logging.New(types.LogLevelInfo, false) archiver := &Archiver{logger: logger} @@ -352,7 +354,7 @@ func TestVerifyGzipArchive_ValidTarContent(t *testing.T) { if err != nil { t.Fatal(err) } - defer file.Close() + defer func() { _ = file.Close() }() gzipWriter := gzip.NewWriter(file) tarWriter := tar.NewWriter(gzipWriter) @@ -370,8 +372,12 @@ func TestVerifyGzipArchive_ValidTarContent(t *testing.T) { t.Fatal(err) } - tarWriter.Close() - gzipWriter.Close() + if err := tarWriter.Close(); err != nil { + t.Fatal(err) + } + if err := gzipWriter.Close(); err != nil { + t.Fatal(err) + } logger := logging.New(types.LogLevelInfo, false) archiver := &Archiver{logger: logger} diff --git a/internal/backup/checksum.go b/internal/backup/checksum.go index 9b3ac7f2..c318dd8e 100644 --- a/internal/backup/checksum.go +++ b/internal/backup/checksum.go @@ -62,14 +62,14 @@ func ParseChecksumData(data []byte) (string, error) { } // GenerateChecksum calculates SHA256 checksum of a file -func GenerateChecksum(ctx context.Context, logger *logging.Logger, filePath string) (string, error) { +func GenerateChecksum(ctx context.Context, logger *logging.Logger, filePath string) (checksum string, err error) { logger.Debug("Generating SHA256 checksum for: %s", filePath) file, err := os.Open(filePath) if err != nil { return "", fmt.Errorf("failed to open file: %w", err) } - defer file.Close() + defer closeIntoErr(&err, file, "close checksum source file") hash := sha256.New() @@ -98,7 +98,7 @@ func GenerateChecksum(ctx context.Context, logger *logging.Logger, filePath stri } } - checksum := hex.EncodeToString(hash.Sum(nil)) + checksum = hex.EncodeToString(hash.Sum(nil)) logger.Debug("Generated checksum: %s", checksum) return checksum, nil } diff --git a/internal/backup/collector.go b/internal/backup/collector.go index b87366a3..aa3e6ddd 100644 --- a/internal/backup/collector.go +++ b/internal/backup/collector.go @@ -854,7 +854,7 @@ func (c *Collector) removeExistingSymlinkDestination(dest string) error { return nil } -func (c *Collector) copyRegularFile(src, dest, description string, info os.FileInfo) error { +func (c *Collector) copyRegularFile(src, dest, description string, info os.FileInfo) (err error) { if err := c.prepareCopyDestination(src, dest); err != nil { c.incFilesFailed() return err @@ -865,7 +865,7 @@ func (c *Collector) copyRegularFile(src, dest, description string, info os.FileI c.incFilesFailed() return fmt.Errorf("failed to open %s: %w", src, err) } - defer srcFile.Close() + defer closeIntoErr(&err, srcFile, "close source file") written, err := copyRegularFileContents(srcFile, src, dest) if err != nil { @@ -990,6 +990,7 @@ type commandRunOptions struct { critical bool logCollection bool handleSystemctlStatus bool + debugNonCritical bool } type commandRunResult struct { @@ -1047,6 +1048,25 @@ func (c *Collector) runAndClassifyCommand(ctx context.Context, spec CommandSpec, out, err := c.depRunCommand(runCtx, spec.Name, spec.Args...) result.output = out if err != nil { + if isContextCancellationError(runCtx, err) { + if isNonCriticalPveshDeadline(ctx, runCtx, spec, opts.critical) { + result.classification = commandRunNonCriticalFailure + result.outputSummary = summarizeCommandOutputText(string(out)) + timeoutSeconds := 0 + if c.config != nil { + timeoutSeconds = c.config.PveshTimeoutSeconds + } + if opts.debugNonCritical { + c.logger.Debug("Skipping %s: command `%s` timed out after %d seconds. Non-critical; backup continues. Output: %s", + opts.description, cmdString, timeoutSeconds, result.outputSummary) + } else { + c.logger.Warning("Skipping %s: command `%s` timed out after %d seconds. Non-critical; backup continues. Output: %s", + opts.description, cmdString, timeoutSeconds, result.outputSummary) + } + return result, nil + } + return result, err + } result.outputSummary = summarizeCommandOutputText(string(out)) if opts.critical { c.incFilesFailed() @@ -1123,6 +1143,13 @@ func (c *Collector) runAndClassifyCommand(ctx context.Context, spec CommandSpec, err, result.outputSummary, ) + } else if opts.debugNonCritical { + c.logger.Debug("Skipping %s: command `%s` failed (%v). Non-critical; backup continues. Output: %s", + opts.description, + cmdString, + err, + result.outputSummary, + ) } else { c.logger.Warning("Skipping %s: command `%s` failed (%v). Non-critical; backup continues. Ensure the required CLI is available and has proper permissions. Output: %s", opts.description, @@ -1138,6 +1165,16 @@ func (c *Collector) runAndClassifyCommand(ctx context.Context, spec CommandSpec, return result, nil } +func isNonCriticalPveshDeadline(parentCtx, runCtx context.Context, spec CommandSpec, critical bool) bool { + if parentCtx == nil || runCtx == nil { + return false + } + return spec.Name == "pvesh" && + !critical && + parentCtx.Err() == nil && + errors.Is(runCtx.Err(), context.DeadlineExceeded) +} + func (c *Collector) safeCmdOutput(ctx context.Context, spec CommandSpec, output, description string, critical bool) error { result, err := c.runAndClassifyCommand(ctx, spec, commandRunOptions{ output: output, @@ -1161,6 +1198,29 @@ func (c *Collector) safeCmdOutput(ctx context.Context, spec CommandSpec, output, return nil } +func (c *Collector) safeCmdOutputBestEffort(ctx context.Context, spec CommandSpec, output, description string) error { + result, err := c.runAndClassifyCommand(ctx, spec, commandRunOptions{ + output: output, + description: description, + caller: "safeCmdOutputBestEffort", + logCollection: true, + debugNonCritical: true, + }) + if err != nil { + return err + } + if result.classification != commandRunSucceeded { + return nil + } + + if err := c.writeReportFile(output, result.output); err != nil { + return err + } + + c.logger.Debug("Successfully collected %s via command: %s", description, spec.String()) + return nil +} + // safeCmdOutputWithPBSAuth executes a command with PBS authentication environment variables // This enables automatic authentication for proxmox-backup-client commands func (c *Collector) safeCmdOutputWithPBSAuth(ctx context.Context, spec CommandSpec, output, description string, critical bool) error { diff --git a/internal/backup/collector_bricks_pve.go b/internal/backup/collector_bricks_pve.go index 4218cd0c..00482356 100644 --- a/internal/backup/collector_bricks_pve.go +++ b/internal/backup/collector_bricks_pve.go @@ -124,8 +124,7 @@ func newPVERuntimeBricks() []collectionBrick { if err != nil { return err } - state.collector.collectPVEACLRuntime(ctx, commandsDir) - return nil + return state.collector.collectPVEACLRuntime(ctx, commandsDir) }, }, { @@ -136,8 +135,7 @@ func newPVERuntimeBricks() []collectionBrick { if err != nil { return err } - state.collector.collectPVEClusterRuntime(ctx, commandsDir, state.pve.clustered) - return nil + return state.collector.collectPVEClusterRuntime(ctx, commandsDir, state.pve.clustered) }, }, { diff --git a/internal/backup/collector_pbs.go b/internal/backup/collector_pbs.go index 44cd9553..5e7bfc41 100644 --- a/internal/backup/collector_pbs.go +++ b/internal/backup/collector_pbs.go @@ -284,7 +284,7 @@ func (c *Collector) collectPBSCoreRuntime(ctx context.Context, commandsDir strin func (c *Collector) collectPBSNodeRuntime(ctx context.Context, commandsDir string) error { if c.config.BackupPBSNodeConfig { - c.safeCmdOutput(ctx, + return c.safeCmdOutput(ctx, commandSpec("proxmox-backup-manager", "node", "show", "--output-format=json"), filepath.Join(commandsDir, "node_config.json"), "Node configuration", @@ -295,7 +295,7 @@ func (c *Collector) collectPBSNodeRuntime(ctx context.Context, commandsDir strin func (c *Collector) collectPBSNetworkRuntime(ctx context.Context, commandsDir string) error { if c.config.BackupPBSNetworkConfig { - c.safeCmdOutput(ctx, + return c.safeCmdOutput(ctx, commandSpec("proxmox-backup-manager", "network", "list", "--output-format=json"), filepath.Join(commandsDir, "network_list.json"), "Network configuration", @@ -327,11 +327,13 @@ func (c *Collector) collectPBSDatastoreStatusRuntime(ctx context.Context, comman continue } dsKey := ds.pathKey() - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("proxmox-backup-manager", "datastore", "show", cliName, "--output-format=json"), filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", dsKey)), fmt.Sprintf("Datastore %s status", ds.Name), - false) + false); err != nil { + return err + } } return nil } @@ -577,45 +579,41 @@ func (c *Collector) collectPBSTapeDrivesRuntime(ctx context.Context, commandsDir if !enabled { return nil } - c.safeCmdOutput(ctx, + return c.safeCmdOutput(ctx, commandSpec("proxmox-tape", "drive", "list", "--output-format=json"), filepath.Join(commandsDir, "tape_drives.json"), "Tape drives", false) - return nil } func (c *Collector) collectPBSTapeChangersRuntime(ctx context.Context, commandsDir string, enabled bool) error { if !enabled { return nil } - c.safeCmdOutput(ctx, + return c.safeCmdOutput(ctx, commandSpec("proxmox-tape", "changer", "list", "--output-format=json"), filepath.Join(commandsDir, "tape_changers.json"), "Tape changers", false) - return nil } func (c *Collector) collectPBSTapePoolsRuntime(ctx context.Context, commandsDir string, enabled bool) error { if !enabled { return nil } - c.safeCmdOutput(ctx, + return c.safeCmdOutput(ctx, commandSpec("proxmox-tape", "pool", "list", "--output-format=json"), filepath.Join(commandsDir, "tape_pools.json"), "Tape pools", false) - return nil } func (c *Collector) collectPBSDisksRuntime(ctx context.Context, commandsDir string) error { - c.safeCmdOutput(ctx, + return c.safeCmdOutput(ctx, commandSpec("proxmox-backup-manager", "disk", "list", "--output-format=json"), filepath.Join(commandsDir, "disk_list.json"), "Disk list", false) - return nil } func (c *Collector) collectPBSCertInfoRuntime(ctx context.Context, commandsDir string) error { @@ -630,25 +628,23 @@ func (c *Collector) collectPBSTrafficControlRuntime(ctx context.Context, command if !c.config.BackupPBSTrafficControl { return nil } - c.safeCmdOutput(ctx, + return c.safeCmdOutput(ctx, commandSpec("proxmox-backup-manager", "traffic-control", "list", "--output-format=json"), filepath.Join(commandsDir, "traffic_control.json"), "Traffic control rules", false) - return nil } func (c *Collector) collectPBSRecentTasksRuntime(ctx context.Context, commandsDir string) error { - c.safeCmdOutput(ctx, + return c.safeCmdOutput(ctx, commandSpec("proxmox-backup-manager", "task", "list", "--limit", "50", "--output-format=json"), filepath.Join(commandsDir, "recent_tasks.json"), "Recent tasks", false) - return nil } func (c *Collector) collectPBSS3EndpointsRuntime(ctx context.Context, commandsDir string) ([]string, error) { - if !(c.config.BackupDatastoreConfigs && c.config.BackupPBSS3Endpoints) { + if !c.config.BackupDatastoreConfigs || !c.config.BackupPBSS3Endpoints { return nil, nil } raw, err := c.captureCommandOutput(ctx, @@ -667,7 +663,7 @@ func (c *Collector) collectPBSS3EndpointsRuntime(ctx context.Context, commandsDi } func (c *Collector) collectPBSS3EndpointBucketsRuntime(ctx context.Context, commandsDir string, endpointIDs []string) error { - if !(c.config.BackupDatastoreConfigs && c.config.BackupPBSS3Endpoints) { + if !c.config.BackupDatastoreConfigs || !c.config.BackupPBSS3Endpoints { return nil } for _, id := range uniqueSortedStrings(endpointIDs) { diff --git a/internal/backup/collector_pbs_datastore.go b/internal/backup/collector_pbs_datastore.go index da5715e1..858c0281 100644 --- a/internal/backup/collector_pbs_datastore.go +++ b/internal/backup/collector_pbs_datastore.go @@ -343,11 +343,13 @@ func (c *Collector) collectPBSDatastoreCLIConfigs(ctx context.Context, state *pb for _, ds := range state.datastores { dsKey := ds.pathKey() if cliName := ds.cliName(); cliName != "" && !ds.isOverride() { - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("proxmox-backup-manager", "datastore", "show", cliName, "--output-format=json"), filepath.Join(state.datastoreDir, fmt.Sprintf("%s_config.json", dsKey)), fmt.Sprintf("Datastore %s configuration", ds.Name), - false) + false); err != nil { + return err + } continue } c.logger.Debug("Skipping datastore CLI config for %s (path=%s): no PBS datastore identity", ds.Name, ds.Path) @@ -548,10 +550,6 @@ func (c *Collector) runPBSPXARStep(ctx context.Context, state *pbsPxarState, fn dsWorkers = 1 } - parentCtx := ctx - ctx, cancel := context.WithCancel(parentCtx) - defer cancel() - var ( wg sync.WaitGroup sem = make(chan struct{}, dsWorkers) @@ -578,7 +576,6 @@ func (c *Collector) runPBSPXARStep(ctx context.Context, state *pbsPxarState, fn errMu.Lock() if firstErr == nil { firstErr = err - cancel() } errMu.Unlock() } @@ -590,7 +587,7 @@ func (c *Collector) runPBSPXARStep(ctx context.Context, state *pbsPxarState, fn if firstErr != nil { return firstErr } - if err := parentCtx.Err(); err != nil { + if err := ctx.Err(); err != nil { return err } return nil @@ -665,15 +662,15 @@ func (c *Collector) collectPBSPXARMetadataForDatastore(ctx context.Context, ds p func (c *Collector) writePxarSubdirReport(ctx context.Context, target string, ds pbsDatastore, ioTimeout time.Duration) error { c.logger.Debug("Writing PXAR subdirectory report for datastore %s", ds.Name) var builder strings.Builder - builder.WriteString(fmt.Sprintf("# Datastore subdirectories in %s generated on %s\n", ds.Path, time.Now().Format(time.RFC1123))) - builder.WriteString(fmt.Sprintf("# Datastore: %s\n", ds.Name)) + fmt.Fprintf(&builder, "# Datastore subdirectories in %s generated on %s\n", ds.Path, time.Now().Format(time.RFC1123)) + fmt.Fprintf(&builder, "# Datastore: %s\n", ds.Name) entries, err := safefs.ReadDir(ctx, ds.Path, ioTimeout) if err != nil { if errors.Is(err, safefs.ErrTimeout) { return err } - builder.WriteString(fmt.Sprintf("# Unable to read datastore path: %v\n", err)) + fmt.Fprintf(&builder, "# Unable to read datastore path: %v\n", err) return c.writeReportFile(target, []byte(builder.String())) } @@ -702,8 +699,8 @@ func (c *Collector) writePxarListReport(ctx context.Context, target string, ds p basePath := filepath.Join(ds.Path, subDir) var builder strings.Builder - builder.WriteString(fmt.Sprintf("# List of .pxar files in %s generated on %s\n", basePath, time.Now().Format(time.RFC1123))) - builder.WriteString(fmt.Sprintf("# Datastore: %s, Subdirectory: %s\n", ds.Name, subDir)) + fmt.Fprintf(&builder, "# List of .pxar files in %s generated on %s\n", basePath, time.Now().Format(time.RFC1123)) + fmt.Fprintf(&builder, "# Datastore: %s, Subdirectory: %s\n", ds.Name, subDir) builder.WriteString("# Format: permissions size date name\n") entries, err := safefs.ReadDir(ctx, basePath, ioTimeout) @@ -711,7 +708,7 @@ func (c *Collector) writePxarListReport(ctx context.Context, target string, ds p if errors.Is(err, safefs.ErrTimeout) { return err } - builder.WriteString(fmt.Sprintf("# Unable to read directory: %v\n", err)) + fmt.Fprintf(&builder, "# Unable to read directory: %v\n", err) if writeErr := c.writeReportFile(target, []byte(builder.String())); writeErr != nil { return writeErr } @@ -756,11 +753,11 @@ func (c *Collector) writePxarListReport(ctx context.Context, target string, ds p builder.WriteString("# No .pxar files found\n") } else { for _, file := range files { - builder.WriteString(fmt.Sprintf("%s %d %s %s\n", + fmt.Fprintf(&builder, "%s %d %s %s\n", file.mode.String(), file.size, file.time.Format("2006-01-02 15:04:05"), - file.name)) + file.name) } } diff --git a/internal/backup/collector_pve.go b/internal/backup/collector_pve.go index 18313254..6353ea44 100644 --- a/internal/backup/collector_pve.go +++ b/internal/backup/collector_pve.go @@ -463,17 +463,21 @@ func (c *Collector) collectPVECoreRuntime(ctx context.Context, commandsDir strin return fmt.Errorf("failed to get PVE version (critical): %w", err) } - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("pvenode", "config", "get"), filepath.Join(commandsDir, "node_config.txt"), "Node configuration", - false) + false); err != nil { + return err + } - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("pvesh", "get", "/version", "--output-format=json"), filepath.Join(commandsDir, "api_version.json"), "API version", - false) + false); err != nil { + return err + } if nodeData, err := c.captureCommandOutput(ctx, commandSpec("pvesh", "get", "/nodes", "--output-format=json"), @@ -499,68 +503,90 @@ func (c *Collector) collectPVECoreRuntime(ctx context.Context, commandsDir strin return nil } -func (c *Collector) collectPVEACLRuntime(ctx context.Context, commandsDir string) { +func (c *Collector) collectPVEACLRuntime(ctx context.Context, commandsDir string) error { if !c.config.BackupPVEACL { - return + return nil } - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("pveum", "user", "list", "--output-format=json"), filepath.Join(commandsDir, "pve_users.json"), "PVE users", - false) - c.safeCmdOutput(ctx, + false); err != nil { + return err + } + if err := c.safeCmdOutput(ctx, commandSpec("pveum", "group", "list", "--output-format=json"), filepath.Join(commandsDir, "pve_groups.json"), "PVE groups", - false) - c.safeCmdOutput(ctx, + false); err != nil { + return err + } + if err := c.safeCmdOutput(ctx, commandSpec("pveum", "role", "list", "--output-format=json"), filepath.Join(commandsDir, "pve_roles.json"), "PVE roles", - false) - c.safeCmdOutput(ctx, + false); err != nil { + return err + } + if err := c.safeCmdOutput(ctx, commandSpec("pveum", "pool", "list", "--output-format=json"), filepath.Join(commandsDir, "pools.json"), "PVE resource pools", - false) + false); err != nil { + return err + } + return nil } -func (c *Collector) collectPVEClusterRuntime(ctx context.Context, commandsDir string, clustered bool) { +func (c *Collector) collectPVEClusterRuntime(ctx context.Context, commandsDir string, clustered bool) error { if clustered && c.config.BackupClusterConfig { - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("pvecm", "status"), filepath.Join(commandsDir, "cluster_status.txt"), "Cluster status", - false) - c.safeCmdOutput(ctx, + false); err != nil { + return err + } + if err := c.safeCmdOutput(ctx, commandSpec("pvecm", "nodes"), filepath.Join(commandsDir, "cluster_nodes.txt"), "Cluster nodes", - false) - c.safeCmdOutput(ctx, + false); err != nil { + return err + } + if err := c.safeCmdOutput(ctx, commandSpec("pvesh", "get", "/cluster/ha/status", "--output-format=json"), filepath.Join(commandsDir, "ha_status.json"), "HA status", - false) - c.safeCmdOutput(ctx, + false); err != nil { + return err + } + if err := c.safeCmdOutput(ctx, commandSpec("pvesh", "get", "/cluster/mapping/pci", "--output-format=json"), filepath.Join(commandsDir, "mapping_pci.json"), "PCI resource mappings", - false) - c.safeCmdOutput(ctx, + false); err != nil { + return err + } + if err := c.safeCmdOutput(ctx, commandSpec("pvesh", "get", "/cluster/mapping/usb", "--output-format=json"), filepath.Join(commandsDir, "mapping_usb.json"), "USB resource mappings", - false) - c.safeCmdOutput(ctx, + false); err != nil { + return err + } + if err := c.safeCmdOutput(ctx, commandSpec("pvesh", "get", "/cluster/mapping/dir", "--output-format=json"), filepath.Join(commandsDir, "mapping_dir.json"), "Directory resource mappings", - false) + false); err != nil { + return err + } } else if clustered && !c.config.BackupClusterConfig { c.logger.Debug("Skipping cluster runtime commands: BACKUP_CLUSTER_CONFIG=false (clustered=%v)", clustered) } + return nil } func (c *Collector) collectPVEStorageRuntime(ctx context.Context, commandsDir string, info *pveRuntimeInfo) error { @@ -570,11 +596,13 @@ func (c *Collector) collectPVEStorageRuntime(ctx context.Context, commandsDir st nodeName = hostname } - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("pvesh", "get", fmt.Sprintf("/nodes/%s/disks/list", nodeName), "--output-format=json"), filepath.Join(commandsDir, "disks_list.json"), "Disks list", - false) + false); err != nil { + return err + } storageJSONPath := filepath.Join(commandsDir, "storage_status.json") if storageData, err := c.captureCommandOutput(ctx, @@ -595,11 +623,13 @@ func (c *Collector) collectPVEStorageRuntime(ctx context.Context, commandsDir st } } - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("pvesm", "status"), filepath.Join(commandsDir, "pvesm_status.txt"), "Storage manager status", - false) + false); err != nil { + return err + } return nil } @@ -729,17 +759,21 @@ func (c *Collector) collectPVEGuestInventory(ctx context.Context) error { nodeName = hostname } - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("pvesh", "get", fmt.Sprintf("/nodes/%s/qemu", nodeName), "--output-format=json"), filepath.Join(commandsDir, "qemu_vms.json"), "QEMU VMs list", - false) + false); err != nil { + return err + } - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("pvesh", "get", fmt.Sprintf("/nodes/%s/lxc", nodeName), "--output-format=json"), filepath.Join(commandsDir, "lxc_containers.json"), "LXC containers list", - false) + false); err != nil { + return err + } return nil } @@ -1298,13 +1332,13 @@ func (c *Collector) writePVEStorageSummary(ctx context.Context, storages []pveSt summary.WriteString("\n# Format: TYPE|NAME|PATH|CONTENT\n\n") for _, storage := range storages { - summary.WriteString(fmt.Sprintf("%s|%s|%s|%s\n", + fmt.Fprintf(&summary, "%s|%s|%s|%s\n", storage.Type, storage.Name, storage.Path, - storage.Content)) + storage.Content) } - summary.WriteString(fmt.Sprintf("\n# Total datastores processed: %d\n", len(storages))) + fmt.Fprintf(&summary, "\n# Total datastores processed: %d\n", len(storages)) return c.writeReportFile(filepath.Join(c.pveDatastoresBaseDir(), "detected_datastores.txt"), []byte(summary.String())) } @@ -1519,7 +1553,7 @@ func newPatternWriter(storageName, storagePath, analysisDir, pattern string, dry time.Now().Format(time.RFC3339), ) if _, err := writer.WriteString(header); err != nil { - file.Close() + _ = file.Close() return nil, err } return &patternWriter{ @@ -1611,7 +1645,7 @@ func (c *Collector) copyBackupSample(ctx context.Context, src, destDir, descript return c.safeCopyFile(ctx, src, dest, description) } -func (c *Collector) writePatternSummary(storage pveStorageEntry, analysisDir string, writers []*patternWriter, totalFiles, totalSize int64) error { +func (c *Collector) writePatternSummary(storage pveStorageEntry, analysisDir string, writers []*patternWriter, totalFiles, totalSize int64) (err error) { // Skip file creation in dry-run mode if c.dryRun { c.logger.Debug("[DRY RUN] Would write backup summary for datastore: %s", storage.Name) @@ -1623,35 +1657,65 @@ func (c *Collector) writePatternSummary(storage pveStorageEntry, analysisDir str if err != nil { return err } - defer file.Close() + defer closeIntoErr(&err, file, "close PVE backup summary") writer := bufio.NewWriter(file) - fmt.Fprintf(writer, "# PVE Backup Files Summary for datastore: %s\n", storage.Name) - fmt.Fprintf(writer, "# Path: %s\n", storage.Path) - fmt.Fprintf(writer, "# Generated on: %s\n\n", time.Now().Format(time.RFC3339)) + writeSummaryf := func(format string, args ...any) error { + _, err := fmt.Fprintf(writer, format, args...) + return err + } + if err := writeSummaryf("# PVE Backup Files Summary for datastore: %s\n", storage.Name); err != nil { + return err + } + if err := writeSummaryf("# Path: %s\n", storage.Path); err != nil { + return err + } + if err := writeSummaryf("# Generated on: %s\n\n", time.Now().Format(time.RFC3339)); err != nil { + return err + } for _, w := range writers { - fmt.Fprintf(writer, "## Files matching pattern: %s\n", w.pattern) + if err := writeSummaryf("## Files matching pattern: %s\n", w.pattern); err != nil { + return err + } if w.count == 0 { - fmt.Fprintln(writer, " No files found") - fmt.Fprintln(writer) + if _, err := fmt.Fprintln(writer, " No files found"); err != nil { + return err + } + if _, err := fmt.Fprintln(writer); err != nil { + return err + } continue } - fmt.Fprintf(writer, " Files: %d\n", w.count) + if err := writeSummaryf(" Files: %d\n", w.count); err != nil { + return err + } if w.errorCount > 0 { - fmt.Fprintf(writer, " Successfully analyzed: %d\n", w.count-w.errorCount) - fmt.Fprintf(writer, " Files with errors: %d\n", w.errorCount) + if err := writeSummaryf(" Successfully analyzed: %d\n", w.count-w.errorCount); err != nil { + return err + } + if err := writeSummaryf(" Files with errors: %d\n", w.errorCount); err != nil { + return err + } + } + if err := writeSummaryf(" Total size: %s\n\n", FormatBytes(w.totalSize)); err != nil { + return err } - fmt.Fprintf(writer, " Total size: %s\n\n", FormatBytes(w.totalSize)) } - fmt.Fprintln(writer, "## Overall Summary") - fmt.Fprintf(writer, "Total backup files: %d\n", totalFiles) - fmt.Fprintf(writer, "Total backup size: %s\n", FormatBytes(totalSize)) + if _, err := fmt.Fprintln(writer, "## Overall Summary"); err != nil { + return err + } + if err := writeSummaryf("Total backup files: %d\n", totalFiles); err != nil { + return err + } + if err := writeSummaryf("Total backup size: %s\n", FormatBytes(totalSize)); err != nil { + return err + } if err := writer.Flush(); err != nil { return err } - return file.Close() + return nil } func (c *Collector) collectPVECephConfigSnapshot(ctx context.Context) error { @@ -1712,11 +1776,13 @@ func (c *Collector) collectPVECephRuntime(ctx context.Context) error { } for _, command := range commands { - c.captureCommandOutput(ctx, + if _, err := c.captureCommandOutput(ctx, command.cmd, filepath.Join(cephDir, command.file), command.desc, - false) + false); err != nil { + return err + } } return nil @@ -1971,7 +2037,7 @@ func (c *Collector) parseStorageConfigEntries() []pveStorageEntry { if err != nil { return nil } - defer file.Close() + defer func() { _ = file.Close() }() scanner := bufio.NewScanner(file) var ( diff --git a/internal/backup/collector_pve_additional_test.go b/internal/backup/collector_pve_additional_test.go index 8c4e915c..70d8cfc8 100644 --- a/internal/backup/collector_pve_additional_test.go +++ b/internal/backup/collector_pve_additional_test.go @@ -21,7 +21,7 @@ func TestPatternWriterWrite_DryRunCountsOnly(t *testing.T) { if err != nil { t.Fatalf("CreateTemp: %v", err) } - defer f.Close() + defer func() { _ = f.Close() }() if _, err := f.WriteString("payload"); err != nil { t.Fatalf("WriteString: %v", err) diff --git a/internal/backup/collector_pve_util_test.go b/internal/backup/collector_pve_util_test.go index b6451873..a3c6305a 100644 --- a/internal/backup/collector_pve_util_test.go +++ b/internal/backup/collector_pve_util_test.go @@ -719,6 +719,9 @@ func TestIsClusteredPVE(t *testing.T) { cfg.PVEConfigPath = pveDir cfg.CorosyncConfigPath = "" collector := NewCollector(logger, cfg, tmpDir, "pve", false) + collector.deps.LookPath = func(string) (string, error) { + return "", os.ErrNotExist + } clustered, err := collector.isClusteredPVE(context.Background()) if err != nil { diff --git a/internal/backup/collector_system.go b/internal/backup/collector_system.go index 13c4ab93..78158f3b 100644 --- a/internal/backup/collector_system.go +++ b/internal/backup/collector_system.go @@ -57,6 +57,79 @@ func (c *Collector) detectZFSUsage() (bool, string) { return true, strings.Join(indicators, ",") } +func (c *Collector) collectBestEffortProbe(ctx context.Context, spec CommandSpec, output, description string, available func() (bool, string)) error { + if err := ctx.Err(); err != nil { + return err + } + if _, err := c.depLookPath(spec.Name); err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + c.logger.Debug("Skipping %s: command %s not available: %v", description, spec.Name, err) + return nil + } + if available != nil { + ok, reason := available() + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if !ok { + if reason == "" { + reason = "required capability not detected" + } + c.logger.Debug("Skipping %s: %s", description, reason) + return nil + } + } + if err := c.safeCmdOutputBestEffort(ctx, spec, output, description); err != nil { + if isContextCancellationError(ctx, err) { + return err + } + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + c.logger.Debug("Skipping %s: %v", description, err) + return nil + } + if err := ctx.Err(); err != nil { + return err + } + return nil +} + +func (c *Collector) systemctlProbeAvailable() (bool, string) { + if _, err := c.depStat(c.systemPath("/run/systemd/system")); err == nil { + return true, "" + } + if data, err := os.ReadFile(c.systemPath("/proc/1/comm")); err == nil && strings.TrimSpace(string(data)) == "systemd" { + return true, "" + } + if ctxInfo := c.depDetectUnprivilegedContainer(); ctxInfo.Detected { + return false, "systemd runtime not detected in container: " + ctxInfo.Details + } + return false, "systemd runtime not detected" +} + +func (c *Collector) dmidecodeProbeAvailable() (bool, string) { + if os.Geteuid() != 0 { + return false, fmt.Sprintf("dmidecode requires root privileges (euid=%d)", os.Geteuid()) + } + if _, err := c.depStat(c.systemPath("/sys/firmware/dmi/tables")); err == nil { + return true, "" + } + if _, err := c.depStat(c.systemPath("/dev/mem")); err == nil { + return true, "" + } + return false, "DMI tables not accessible" +} + +func (c *Collector) sensorsProbeAvailable() (bool, string) { + if _, err := c.depStat(c.systemPath("/sys/class/hwmon")); err == nil { + return true, "" + } + return false, "hardware sensor sysfs not detected" +} + // CollectSystemInfo collects common system information (both PVE and PBS) func (c *Collector) CollectSystemInfo(ctx context.Context) error { c.logger.Info("Collecting system information") @@ -526,11 +599,13 @@ func (c *Collector) collectSystemCoreRuntime(ctx context.Context, commandsDir st return fmt.Errorf("failed to get kernel version (critical): %w", err) } - c.safeCmdOutput(ctx, + if err := c.collectBestEffortProbe(ctx, commandSpec("hostname", "-f"), filepath.Join(commandsDir, "hostname.txt"), "Hostname", - false) + nil); err != nil { + return err + } return nil } @@ -608,16 +683,20 @@ func (c *Collector) collectSystemNetworkLinksRuntime(ctx context.Context, comman } func (c *Collector) collectSystemNetworkNeighborsRuntime(ctx context.Context, commandsDir string) error { - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("ip", "neigh", "show"), filepath.Join(commandsDir, "ip_neigh.txt"), "Neighbor table", - false) - c.safeCmdOutput(ctx, + false); err != nil { + return err + } + if err := c.safeCmdOutput(ctx, commandSpec("ip", "-6", "neigh", "show"), filepath.Join(commandsDir, "ip6_neigh.txt"), "Neighbor table (IPv6)", - false) + false); err != nil { + return err + } return nil } @@ -692,11 +771,13 @@ func (c *Collector) collectSystemStorageMountsRuntime(ctx context.Context, comma return err } - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("mount"), filepath.Join(commandsDir, "mount.txt"), "Mounted filesystems", - false) + false); err != nil { + return err + } return nil } @@ -752,11 +833,13 @@ func (c *Collector) collectSystemComputeBusInventoryRuntime(ctx context.Context, return err } - c.safeCmdOutput(ctx, + if err := c.collectBestEffortProbe(ctx, commandSpec("lsusb"), filepath.Join(commandsDir, "lsusb.txt"), "USB devices", - false) + nil); err != nil { + return err + } return nil } @@ -766,17 +849,19 @@ func (c *Collector) collectSystemServicesRuntime(ctx context.Context, commandsDi return nil } - if err := c.collectCommandMulti(ctx, + if err := c.collectBestEffortProbe(ctx, commandSpec("systemctl", "list-units", "--type=service", "--all"), filepath.Join(commandsDir, "systemctl_services.txt"), "Systemd services", - false); err != nil { + c.systemctlProbeAvailable); err != nil { return err } - c.safeCmdOutput(ctx, commandSpec("systemctl", "list-unit-files", "--type=service"), + if err := c.collectBestEffortProbe(ctx, commandSpec("systemctl", "list-unit-files", "--type=service"), filepath.Join(commandsDir, "systemctl_service_files.txt"), - "Systemd service files", false) + "Systemd service files", c.systemctlProbeAvailable); err != nil { + return err + } return nil } @@ -875,10 +960,13 @@ func (c *Collector) collectSystemFirewallUFWRuntime(ctx context.Context, command commandSpec("ufw", "status", "verbose"), filepath.Join(commandsDir, "ufw_status.txt"), "UFW status") - c.collectCommandOptional(ctx, + if err := c.collectBestEffortProbe(ctx, commandSpec("systemctl", "status", "--no-pager", "ufw"), filepath.Join(commandsDir, "systemctl_ufw.txt"), - "systemctl ufw") + "systemctl ufw", + c.systemctlProbeAvailable); err != nil { + return err + } return nil } @@ -896,10 +984,13 @@ func (c *Collector) collectSystemFirewallFirewalldRuntime(ctx context.Context, c commandSpec("firewall-cmd", "--list-all"), filepath.Join(commandsDir, "firewalld_list_all.txt"), "firewalld rules") - c.collectCommandOptional(ctx, + if err := c.collectBestEffortProbe(ctx, commandSpec("systemctl", "status", "--no-pager", "firewalld"), filepath.Join(commandsDir, "systemctl_firewalld.txt"), - "systemctl firewalld") + "systemctl firewalld", + c.systemctlProbeAvailable); err != nil { + return err + } return nil } @@ -909,11 +1000,13 @@ func (c *Collector) collectSystemKernelModulesRuntime(ctx context.Context, comma return nil } - c.safeCmdOutput(ctx, + if err := c.collectBestEffortProbe(ctx, commandSpec("lsmod"), filepath.Join(commandsDir, "lsmod.txt"), "Loaded kernel modules", - false) + nil); err != nil { + return err + } return nil } @@ -922,12 +1015,11 @@ func (c *Collector) collectSystemSysctlRuntime(ctx context.Context, commandsDir return nil } - c.safeCmdOutput(ctx, + return c.safeCmdOutput(ctx, commandSpec("sysctl", "-a"), filepath.Join(commandsDir, "sysctl.txt"), "Sysctl values", false) - return nil } func (c *Collector) collectSystemZFSRuntime(ctx context.Context, commandsDir string) error { @@ -980,25 +1072,31 @@ func (c *Collector) collectSystemLVMRuntime(ctx context.Context, commandsDir str return err } if _, err := c.depLookPath("pvs"); err == nil { - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("pvs"), filepath.Join(commandsDir, "lvm_pvs.txt"), "LVM physical volumes", - false) + false); err != nil { + return err + } } if _, err := c.depLookPath("vgs"); err == nil { - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("vgs"), filepath.Join(commandsDir, "lvm_vgs.txt"), "LVM volume groups", - false) + false); err != nil { + return err + } } if _, err := c.depLookPath("lvs"); err == nil { - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("lvs"), filepath.Join(commandsDir, "lvm_lvs.txt"), "LVM logical volumes", - false) + false); err != nil { + return err + } } return nil } @@ -1020,8 +1118,8 @@ func (c *Collector) buildNetworkReport(ctx context.Context, commandsDir string) now := time.Now().Format(time.RFC3339) hostname, _ := os.Hostname() b.WriteString("Proxsave Network Report\n") - b.WriteString(fmt.Sprintf("Timestamp: %s\n", now)) - b.WriteString(fmt.Sprintf("Hostname: %s\n", hostname)) + fmt.Fprintf(&b, "Timestamp: %s\n", now) + fmt.Fprintf(&b, "Hostname: %s\n", hostname) b.WriteString("\n") appendFile := func(title, path string) { @@ -1032,7 +1130,7 @@ func (c *Collector) buildNetworkReport(ctx context.Context, commandsDir string) if err != nil || len(data) == 0 { return } - b.WriteString(fmt.Sprintf("## %s (%s)\n", title, path)) + fmt.Fprintf(&b, "## %s (%s)\n", title, path) b.Write(data) if !strings.HasSuffix(string(data), "\n") { b.WriteString("\n") @@ -1160,18 +1258,22 @@ func (c *Collector) collectKernelInfo(ctx context.Context) error { c.logger.Debug("Collecting kernel information into %s", commandsDir) // Kernel command line - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("cat", c.systemPath("/proc/cmdline")), filepath.Join(commandsDir, "kernel_cmdline.txt"), "Kernel command line", - false) + false); err != nil { + return err + } // Kernel version details - c.safeCmdOutput(ctx, + if err := c.safeCmdOutput(ctx, commandSpec("cat", c.systemPath("/proc/version")), filepath.Join(commandsDir, "kernel_version.txt"), "Kernel version details", - false) + false); err != nil { + return err + } c.logger.Debug("Kernel information snapshot completed") return nil @@ -1182,30 +1284,31 @@ func (c *Collector) collectHardwareInfo(ctx context.Context) error { commandsDir := c.proxsaveCommandsDir("system") c.logger.Debug("Collecting hardware inventory into %s", commandsDir) - // DMI decode (requires root) - c.safeCmdOutput(ctx, + if err := c.collectBestEffortProbe(ctx, commandSpec("dmidecode"), filepath.Join(commandsDir, "dmidecode.txt"), "Hardware DMI information", - false) + c.dmidecodeProbeAvailable); err != nil { + return err + } - // Hardware sensors (if available) - if _, err := c.depStat(c.systemPath("/usr/bin/sensors")); err == nil { - c.safeCmdOutput(ctx, - commandSpec("sensors"), - filepath.Join(commandsDir, "sensors.txt"), - "Hardware sensors", - false) + if err := c.collectBestEffortProbe(ctx, + commandSpec("sensors"), + filepath.Join(commandsDir, "sensors.txt"), + "Hardware sensors", + c.sensorsProbeAvailable); err != nil { + return err } // SMART status for disks (if available) if _, err := c.depStat(c.systemPath("/usr/sbin/smartctl")); err == nil { - // Get list of disks - c.safeCmdOutput(ctx, + if err := c.collectBestEffortProbe(ctx, commandSpec("smartctl", "--scan"), filepath.Join(commandsDir, "smartctl_scan.txt"), "SMART scan", - false) + nil); err != nil { + return err + } } c.logger.Debug("Hardware information snapshot completed") diff --git a/internal/backup/collector_system_test.go b/internal/backup/collector_system_test.go index 2983a478..3da01a2f 100644 --- a/internal/backup/collector_system_test.go +++ b/internal/backup/collector_system_test.go @@ -1,6 +1,7 @@ package backup import ( + "bytes" "context" "errors" "os" @@ -28,6 +29,170 @@ func TestEnsureSystemPathAddsDefaults(t *testing.T) { } } +func TestCollectSystemKernelModulesRuntimeBestEffort(t *testing.T) { + var log bytes.Buffer + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(&log) + + tempDir := t.TempDir() + config := GetDefaultCollectorConfig() + calls := 0 + collector := NewCollectorWithDeps(logger, config, tempDir, types.ProxmoxUnknown, false, CollectorDeps{ + LookPath: func(name string) (string, error) { + if name == "lsmod" { + return "/usr/sbin/lsmod", nil + } + return "", os.ErrNotExist + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name != "lsmod" { + t.Fatalf("unexpected command %s", name) + } + calls++ + return []byte("lsmod failed"), errors.New("lsmod failed") + }, + DetectUnprivilegedContainer: func() (bool, string) { return false, "" }, + }) + + commandsDir := filepath.Join(tempDir, "commands") + if err := collector.collectSystemKernelModulesRuntime(context.Background(), commandsDir); err != nil { + t.Fatalf("collectSystemKernelModulesRuntime returned error: %v", err) + } + if calls != 1 { + t.Fatalf("lsmod calls=%d; want 1", calls) + } + if logger.WarningCount() != 0 { + t.Fatalf("expected lsmod failure to stay below warning level, warnings=%d log=%s", logger.WarningCount(), log.String()) + } + if _, err := os.Stat(filepath.Join(commandsDir, "lsmod.txt")); !os.IsNotExist(err) { + t.Fatalf("expected no lsmod output file on failure, stat err: %v", err) + } +} + +func TestCollectSystemKernelModulesRuntimePropagatesCommandCancellation(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(&bytes.Buffer{}) + + tempDir := t.TempDir() + config := GetDefaultCollectorConfig() + collector := NewCollectorWithDeps(logger, config, tempDir, types.ProxmoxUnknown, false, CollectorDeps{ + LookPath: func(name string) (string, error) { + if name == "lsmod" { + return "/usr/sbin/lsmod", nil + } + return "", os.ErrNotExist + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name != "lsmod" { + t.Fatalf("unexpected command %s", name) + } + return nil, context.Canceled + }, + DetectUnprivilegedContainer: func() (bool, string) { return false, "" }, + }) + + err := collector.collectSystemKernelModulesRuntime(context.Background(), filepath.Join(tempDir, "commands")) + if !errors.Is(err, context.Canceled) { + t.Fatalf("collectSystemKernelModulesRuntime error=%v; want %v", err, context.Canceled) + } +} + +func TestCollectSystemKernelModulesRuntimePropagatesCommandDeadline(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(&bytes.Buffer{}) + + tempDir := t.TempDir() + config := GetDefaultCollectorConfig() + collector := NewCollectorWithDeps(logger, config, tempDir, types.ProxmoxUnknown, false, CollectorDeps{ + LookPath: func(name string) (string, error) { + if name == "lsmod" { + return "/usr/sbin/lsmod", nil + } + return "", os.ErrNotExist + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name != "lsmod" { + t.Fatalf("unexpected command %s", name) + } + return nil, context.DeadlineExceeded + }, + DetectUnprivilegedContainer: func() (bool, string) { return false, "" }, + }) + + err := collector.collectSystemKernelModulesRuntime(context.Background(), filepath.Join(tempDir, "commands")) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("collectSystemKernelModulesRuntime error=%v; want %v", err, context.DeadlineExceeded) + } +} + +func TestCollectBestEffortProbePropagatesCanceledContext(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(&bytes.Buffer{}) + + collector := NewCollector(logger, GetDefaultCollectorConfig(), t.TempDir(), types.ProxmoxUnknown, false) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := collector.collectBestEffortProbe(ctx, commandSpec("lsusb"), filepath.Join(t.TempDir(), "lsusb.txt"), "USB devices", nil) + if !errors.Is(err, context.Canceled) { + t.Fatalf("collectBestEffortProbe error=%v; want %v", err, context.Canceled) + } +} + +func TestCollectHardwareInfoSmartctlScanBestEffort(t *testing.T) { + var log bytes.Buffer + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(&log) + + tempDir := t.TempDir() + smartctlMarker := filepath.Join(tempDir, "smartctl") + if err := os.WriteFile(smartctlMarker, []byte("#!/bin/sh\n"), 0o755); err != nil { + t.Fatalf("write smartctl marker: %v", err) + } + smartctlInfo, err := os.Stat(smartctlMarker) + if err != nil { + t.Fatalf("stat smartctl marker: %v", err) + } + config := GetDefaultCollectorConfig() + calls := 0 + collector := NewCollectorWithDeps(logger, config, tempDir, types.ProxmoxUnknown, false, CollectorDeps{ + LookPath: func(name string) (string, error) { + if name == "smartctl" { + return "/usr/sbin/smartctl", nil + } + return "", os.ErrNotExist + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name != "smartctl" || len(args) != 1 || args[0] != "--scan" { + t.Fatalf("unexpected command %s %v", name, args) + } + calls++ + return []byte("smartctl failed"), errors.New("smartctl failed") + }, + Stat: func(path string) (os.FileInfo, error) { + if strings.HasSuffix(path, "/usr/sbin/smartctl") { + return smartctlInfo, nil + } + return nil, os.ErrNotExist + }, + DetectUnprivilegedContainer: func() (bool, string) { return false, "" }, + }) + + if err := collector.collectHardwareInfo(context.Background()); err != nil { + t.Fatalf("collectHardwareInfo returned error: %v", err) + } + if calls != 1 { + t.Fatalf("smartctl calls=%d; want 1", calls) + } + if logger.WarningCount() != 0 { + t.Fatalf("expected smartctl failure to stay below warning level, warnings=%d log=%s", logger.WarningCount(), log.String()) + } + output := filepath.Join(collector.proxsaveCommandsDir("system"), "smartctl_scan.txt") + if _, err := os.Stat(output); !os.IsNotExist(err) { + t.Fatalf("expected no smartctl output file on failure, stat err: %v", err) + } +} + func TestEnsureSystemPathDeduplicates(t *testing.T) { t.Setenv("PATH", "/usr/bin:/usr/bin:/usr/sbin:/usr/sbin") diff --git a/internal/backup/collector_test.go b/internal/backup/collector_test.go index c4c92379..7086e14f 100644 --- a/internal/backup/collector_test.go +++ b/internal/backup/collector_test.go @@ -187,9 +187,15 @@ func TestCollectorSafeCopyDir(t *testing.T) { // Create test source directory with files srcDir := filepath.Join(tempDir, "source") - os.MkdirAll(filepath.Join(srcDir, "subdir"), 0755) - os.WriteFile(filepath.Join(srcDir, "file1.txt"), []byte("content1"), 0644) - os.WriteFile(filepath.Join(srcDir, "subdir", "file2.txt"), []byte("content2"), 0644) + if err := os.MkdirAll(filepath.Join(srcDir, "subdir"), 0755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } + if err := os.WriteFile(filepath.Join(srcDir, "file1.txt"), []byte("content1"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + if err := os.WriteFile(filepath.Join(srcDir, "subdir", "file2.txt"), []byte("content2"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } if err := os.Chmod(srcDir, 0700); err != nil { t.Fatalf("Failed to chmod source dir: %v", err) } @@ -378,7 +384,9 @@ func TestCollectorDryRun(t *testing.T) { // Create a test file and try to copy it srcFile := filepath.Join(tempDir, "source.txt") - os.WriteFile(srcFile, []byte("test"), 0644) + if err := os.WriteFile(srcFile, []byte("test"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } destFile := filepath.Join(tempDir, "dryrun", "dest.txt") ctx := context.Background() @@ -486,7 +494,9 @@ func TestGetStats(t *testing.T) { // Perform an operation testDir := filepath.Join(tempDir, "test") - collector.ensureDir(testDir) + if err := collector.ensureDir(testDir); err != nil { + t.Fatalf("ensureDir failed: %v", err) + } // Check stats updated stats = collector.GetStats() @@ -1418,6 +1428,59 @@ func TestSafeCmdOutputHonorsContextCancellation(t *testing.T) { } } +func TestSafeCmdOutputSwallowsNonCriticalPveshDeadline(t *testing.T) { + logger := logging.New(types.LogLevelWarning, false) + cfg := GetDefaultCollectorConfig() + cfg.PveshTimeoutSeconds = 1 + tmp := t.TempDir() + deps := CollectorDeps{ + LookPath: func(string) (string, error) { return "/usr/bin/pvesh", nil }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name != "pvesh" { + t.Fatalf("unexpected command %s", name) + } + <-ctx.Done() + return []byte("timeout"), ctx.Err() + }, + } + c := NewCollectorWithDeps(logger, cfg, tmp, types.ProxmoxUnknown, false, deps) + + output := filepath.Join(tmp, "pvesh.txt") + err := c.safeCmdOutput(context.Background(), commandSpec("pvesh", "get", "/nodes"), output, "pvesh nodes", false) + if err != nil { + t.Fatalf("expected non-critical pvesh timeout to be skipped, got %v", err) + } + if _, err := os.Stat(output); !os.IsNotExist(err) { + t.Fatalf("expected no output file on timeout, stat err=%v", err) + } + if logger.WarningCount() != 1 { + t.Fatalf("expected one warning for skipped pvesh timeout, got %d", logger.WarningCount()) + } +} + +func TestSafeCmdOutputPropagatesParentCancellationDuringPvesh(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + cfg := GetDefaultCollectorConfig() + cfg.PveshTimeoutSeconds = 15 + tmp := t.TempDir() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + deps := CollectorDeps{ + LookPath: func(string) (string, error) { return "/usr/bin/pvesh", nil }, + RunCommand: func(runCtx context.Context, name string, args ...string) ([]byte, error) { + cancel() + <-runCtx.Done() + return nil, runCtx.Err() + }, + } + c := NewCollectorWithDeps(logger, cfg, tmp, types.ProxmoxUnknown, false, deps) + + err := c.safeCmdOutput(ctx, commandSpec("pvesh", "get", "/nodes"), filepath.Join(tmp, "pvesh.txt"), "pvesh nodes", false) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected parent context cancellation, got %v", err) + } +} + func TestSafeCmdOutputReturnsErrorOnEmptyCommand(t *testing.T) { logger := logging.New(types.LogLevelError, false) cfg := GetDefaultCollectorConfig() diff --git a/internal/backup/optimizations.go b/internal/backup/optimizations.go index f31943f8..90e5d0ea 100644 --- a/internal/backup/optimizations.go +++ b/internal/backup/optimizations.go @@ -150,12 +150,12 @@ func shouldSkipDedupPath(rel string) bool { } } -func hashFile(path string) (string, error) { +func hashFile(path string) (sum string, err error) { f, err := os.Open(path) if err != nil { return "", err } - defer f.Close() + defer closeIntoErr(&err, f, "close file for hash") hasher := sha256.New() if _, err := io.Copy(hasher, f); err != nil { @@ -243,7 +243,7 @@ func chunkLargeFiles(ctx context.Context, logger *logging.Logger, root string, c return nil } -func splitFile(path, destBase string, chunkSize int64) error { +func splitFile(path, destBase string, chunkSize int64) (err error) { if err := os.MkdirAll(filepath.Dir(destBase), defaultChunkDirPerm); err != nil { return err } @@ -252,7 +252,7 @@ func splitFile(path, destBase string, chunkSize int64) error { if err != nil { return err } - defer in.Close() + defer closeIntoErr(&err, in, "close source file") buf := make([]byte, chunkBufferSize) index := 0 @@ -270,22 +270,32 @@ func splitFile(path, destBase string, chunkSize int64) error { return nil } -func writeChunk(src *os.File, chunkPath string, buf []byte, limit int64) (bool, error) { - out, err := os.OpenFile(chunkPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, defaultChunkFilePerm) - if err != nil { - return false, err +func writeChunk(src *os.File, chunkPath string, buf []byte, limit int64) (done bool, err error) { + if limit <= 0 { + return true, nil } - defer out.Close() - + var out *os.File + defer func() { + if out != nil { + closeIntoErr(&err, out, "close chunk file") + } + }() var written int64 for written < limit { remaining := limit - written - if remaining < int64(len(buf)) { - buf = buf[:remaining] + readBuf := buf + if remaining < int64(len(readBuf)) { + readBuf = readBuf[:remaining] } - n, err := src.Read(buf) + n, err := src.Read(readBuf) if n > 0 { - if _, wErr := out.Write(buf[:n]); wErr != nil { + if out == nil { + out, err = os.OpenFile(chunkPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, defaultChunkFilePerm) + if err != nil { + return false, err + } + } + if _, wErr := out.Write(readBuf[:n]); wErr != nil { return false, wErr } written += int64(n) diff --git a/internal/backup/optimizations_bench_test.go b/internal/backup/optimizations_bench_test.go index 8cd04131..a26f74b7 100644 --- a/internal/backup/optimizations_bench_test.go +++ b/internal/backup/optimizations_bench_test.go @@ -83,7 +83,7 @@ func BenchmarkPrefilterFiles(b *testing.B) { } } -func writeFileOfSize(path string, size int64) error { +func writeFileOfSize(path string, size int64) (err error) { if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return err } @@ -91,7 +91,7 @@ func writeFileOfSize(path string, size int64) error { if err != nil { return err } - defer f.Close() + defer closeIntoErr(&err, f, "close benchmark file") chunk := bytes.Repeat([]byte("x"), 32*1024) var written int64 @@ -134,14 +134,18 @@ func copyDir(src, dst string) error { } dstFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640) if err != nil { - srcFile.Close() + _ = srcFile.Close() return err } - _, err = io.Copy(dstFile, srcFile) - srcFile.Close() - if err != nil { + _, copyErr := io.Copy(dstFile, srcFile) + closeSrcErr := srcFile.Close() + if copyErr != nil { _ = dstFile.Close() - return err + return copyErr + } + if closeSrcErr != nil { + _ = dstFile.Close() + return closeSrcErr } return dstFile.Close() }) diff --git a/internal/checks/checks.go b/internal/checks/checks.go index 1a0a59eb..f9caaa86 100644 --- a/internal/checks/checks.go +++ b/internal/checks/checks.go @@ -375,11 +375,14 @@ func (c *Checker) CheckLockFile() CheckResult { result.Message = result.Error.Error() return result } - defer f.Close() hostname, _ := os.Hostname() lockContent := fmt.Sprintf("pid=%d\nhost=%s\ntime=%s\n", os.Getpid(), hostname, time.Now().Format(time.RFC3339)) if _, err := f.WriteString(lockContent); err != nil { + if closeErr := f.Close(); closeErr != nil { + c.logger.Warning("Failed to close lock file %s: %v", lockPath, closeErr) + } + c.removePartialLockFile(lockPath) result.Error = fmt.Errorf("failed to write lock file: %w", err) result.Message = result.Error.Error() return result @@ -387,6 +390,13 @@ func (c *Checker) CheckLockFile() CheckResult { if err := syncFile(f); err != nil { c.logger.Warning("Failed to sync lock file %s: %v", lockPath, err) } + if err := f.Close(); err != nil { + c.logger.Warning("Failed to close lock file %s: %v", lockPath, err) + c.removePartialLockFile(lockPath) + result.Error = fmt.Errorf("failed to close lock file: %w", err) + result.Message = result.Error.Error() + return result + } } else { c.logger.Info("[DRY RUN] Would create lock file: %s", lockPath) } @@ -397,6 +407,12 @@ func (c *Checker) CheckLockFile() CheckResult { return result } +func (c *Checker) removePartialLockFile(lockPath string) { + if err := osRemove(lockPath); err != nil && !os.IsNotExist(err) { + c.logger.Warning("Failed to remove partial lock file %s: %v", lockPath, err) + } +} + // CheckPermissions verifies write permissions on required directories func (c *Checker) CheckPermissions() CheckResult { result := CheckResult{ @@ -425,8 +441,11 @@ func (c *Checker) CheckPermissions() CheckResult { for attempt := 1; attempt <= maxAttempts; attempt++ { f, err := createTestFile(testFile) if err == nil { - f.Close() - lastErr = nil + if closeErr := f.Close(); closeErr != nil { + lastErr = closeErr + } else { + lastErr = nil + } break } @@ -565,7 +584,7 @@ func (c *Checker) CheckTempDirectory() CheckResult { if err != nil { if !os.IsNotExist(err) { result.Code = "STAT_FAILED" - result.Error = fmt.Errorf("Temp directory check failed - path: %s: %w", tempRoot, err) + result.Error = fmt.Errorf("temp directory check failed - path: %s: %w", tempRoot, err) result.Message = result.Error.Error() return result } @@ -574,7 +593,7 @@ func (c *Checker) CheckTempDirectory() CheckResult { c.logger.Debug("Temp directory not found, creating: %s", tempRoot) if err := osMkdirAll(tempRoot, 0o755); err != nil { result.Code = "CREATE_FAILED" - result.Error = fmt.Errorf("Temp directory creation failed - path: %s: %w", tempRoot, err) + result.Error = fmt.Errorf("temp directory creation failed - path: %s: %w", tempRoot, err) result.Message = result.Error.Error() return result } @@ -583,7 +602,7 @@ func (c *Checker) CheckTempDirectory() CheckResult { info, err = osStat(tempRoot) if err != nil { result.Code = "VERIFY_FAILED" - result.Error = fmt.Errorf("Temp directory verification failed - path: %s: %w", tempRoot, err) + result.Error = fmt.Errorf("temp directory verification failed - path: %s: %w", tempRoot, err) result.Message = result.Error.Error() return result } @@ -593,7 +612,7 @@ func (c *Checker) CheckTempDirectory() CheckResult { if !info.IsDir() { result.Code = "NOT_DIRECTORY" - result.Error = fmt.Errorf("Temp path is not a directory - path: %s", tempRoot) + result.Error = fmt.Errorf("temp path is not a directory - path: %s", tempRoot) result.Message = result.Error.Error() return result } @@ -603,22 +622,24 @@ func (c *Checker) CheckTempDirectory() CheckResult { testFile := filepath.Join(tempRoot, ".proxsave-permission-test") if err := osWriteFile(testFile, []byte("test"), 0o600); err != nil { result.Code = "NOT_WRITABLE" - result.Error = fmt.Errorf("Temp directory not writable - path: %s: %w", tempRoot, err) + result.Error = fmt.Errorf("temp directory not writable - path: %s: %w", tempRoot, err) result.Message = result.Error.Error() return result } - defer osRemove(testFile) + defer func() { _ = osRemove(testFile) }() // Test symlink support c.logger.Debug("Testing symlink support: %s", tempRoot) testSymlink := filepath.Join(tempRoot, ".proxsave-symlink-test") if err := osSymlink(testFile, testSymlink); err != nil { result.Code = "NO_SYMLINK_SUPPORT" - result.Error = fmt.Errorf("Temp directory does not support symlinks - path: %s: %w", tempRoot, err) + result.Error = fmt.Errorf("temp directory does not support symlinks - path: %s: %w", tempRoot, err) result.Message = result.Error.Error() return result } - osRemove(testSymlink) + if err := osRemove(testSymlink); err != nil && !os.IsNotExist(err) { + c.logger.Warning("Failed to remove temp symlink test %s: %v", testSymlink, err) + } result.Passed = true result.Message = fmt.Sprintf("%s writable with symlink support", tempRoot) diff --git a/internal/checks/checks_test.go b/internal/checks/checks_test.go index 1dbc6ffd..f9283e64 100644 --- a/internal/checks/checks_test.go +++ b/internal/checks/checks_test.go @@ -109,7 +109,9 @@ func TestCheckLockFile(t *testing.T) { } // Clean up - checker.ReleaseLock() + if err := checker.ReleaseLock(); err != nil { + t.Fatalf("ReleaseLock failed: %v", err) + } } func TestCheckLockFileStaleLock(t *testing.T) { @@ -146,7 +148,9 @@ func TestCheckLockFileStaleLock(t *testing.T) { } // Clean up - checker.ReleaseLock() + if err := checker.ReleaseLock(); err != nil { + t.Fatalf("ReleaseLock failed: %v", err) + } } func TestCheckLockFile_RemovesLockWhenProcessIsGone(t *testing.T) { @@ -624,7 +628,9 @@ func TestRunAllChecks(t *testing.T) { } // Clean up - checker.ReleaseLock() + if err := checker.ReleaseLock(); err != nil { + t.Fatalf("ReleaseLock failed: %v", err) + } } func TestRunAllChecksSkipPermissionCheck(t *testing.T) { @@ -797,7 +803,9 @@ func TestCheckDiskSpaceForEstimate(t *testing.T) { func TestCheckTempDirectory_Success(t *testing.T) { // Ensure /tmp/proxsave exists for the test tempRoot := filepath.Join("/tmp", "proxsave") - os.MkdirAll(tempRoot, 0o755) + if err := os.MkdirAll(tempRoot, 0o755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } config := GetDefaultCheckerConfig(t.TempDir(), t.TempDir(), t.TempDir()) logger := logging.New(types.LogLevelDebug, false) @@ -859,7 +867,9 @@ func TestCheckTempDirectory_NotWritable(t *testing.T) { func TestCheckTempDirectory_SymlinkSupport(t *testing.T) { // Verify that the temp directory check includes symlink validation tempRoot := filepath.Join("/tmp", "proxsave") - os.MkdirAll(tempRoot, 0o755) + if err := os.MkdirAll(tempRoot, 0o755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } config := GetDefaultCheckerConfig(t.TempDir(), t.TempDir(), t.TempDir()) logger := logging.New(types.LogLevelDebug, false) @@ -883,7 +893,9 @@ func TestCheckTempDirectory_SymlinkSupport(t *testing.T) { func TestRunAllChecks_IncludesTempDirectory(t *testing.T) { // Ensure /tmp/proxsave exists - os.MkdirAll(filepath.Join("/tmp", "proxsave"), 0o755) + if err := os.MkdirAll(filepath.Join("/tmp", "proxsave"), 0o755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } backupPath := t.TempDir() logPath := t.TempDir() @@ -1114,10 +1126,6 @@ func TestCheckLockFile_RemoveStaleLockFails(t *testing.T) { } func TestCheckLockFile_WriteFails(t *testing.T) { - if _, err := os.Stat("/dev/full"); err != nil { - t.Skipf("/dev/full not available: %v", err) - } - logger := logging.New(types.LogLevelInfo, false) logger.SetOutput(io.Discard) @@ -1131,7 +1139,7 @@ func TestCheckLockFile_WriteFails(t *testing.T) { origOpen := osOpenFile t.Cleanup(func() { osOpenFile = origOpen }) osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { - return os.OpenFile("/dev/full", os.O_WRONLY, 0) + return os.OpenFile(name, os.O_CREATE|os.O_EXCL|os.O_RDONLY, perm) } checker := NewChecker(logger, config) @@ -1142,6 +1150,9 @@ func TestCheckLockFile_WriteFails(t *testing.T) { if result.Error == nil || !strings.Contains(result.Error.Error(), "failed to write lock file") { t.Fatalf("expected write lock file error, got: %v", result.Error) } + if _, err := os.Stat(lockPath); !os.IsNotExist(err) { + t.Fatalf("expected partial lock file to be removed, stat err: %v", err) + } } func TestCheckLockFile_SyncWarningDoesNotFail(t *testing.T) { @@ -1166,6 +1177,36 @@ func TestCheckLockFile_SyncWarningDoesNotFail(t *testing.T) { } } +func TestCheckLockFile_CloseFailsRemovesPartialLock(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + logger.SetOutput(io.Discard) + + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, ".backup.lock") + + config := GetDefaultCheckerConfig(tmpDir, tmpDir, tmpDir) + config.LockFilePath = lockPath + config.MaxLockAge = time.Hour + + origSync := syncFile + t.Cleanup(func() { syncFile = origSync }) + syncFile = func(f *os.File) error { + return f.Close() + } + + checker := NewChecker(logger, config) + result := checker.CheckLockFile() + if result.Passed { + t.Fatalf("expected CheckLockFile to fail, got passed") + } + if result.Error == nil || !strings.Contains(result.Error.Error(), "failed to close lock file") { + t.Fatalf("expected close lock file error, got: %v", result.Error) + } + if _, err := os.Stat(lockPath); !os.IsNotExist(err) { + t.Fatalf("expected partial lock file to be removed, stat err: %v", err) + } +} + func TestCheckLockFile_DefaultLockPath_DryRun(t *testing.T) { logger := logging.New(types.LogLevelInfo, false) logger.SetOutput(io.Discard) diff --git a/internal/cli/args.go b/internal/cli/args.go index 6919c964..3892ba58 100644 --- a/internal/cli/args.go +++ b/internal/cli/args.go @@ -210,26 +210,29 @@ func ShowVersion() { } func printHelp(w io.Writer, argv0 string) { - fmt.Fprintf(w, "Usage: %s [options]\n\n", argv0) - fmt.Fprintln(w, "ProxSave") - fmt.Fprintln(w, "") - fmt.Fprintln(w, "Options:") + _, _ = fmt.Fprintf(w, "Usage: %s [options]\n\n", argv0) + _, _ = fmt.Fprintln(w, "ProxSave") + _, _ = fmt.Fprintln(w, "") + _, _ = fmt.Fprintln(w, "Options:") + previousOutput := flag.CommandLine.Output() + flag.CommandLine.SetOutput(w) + defer flag.CommandLine.SetOutput(previousOutput) flag.PrintDefaults() - fmt.Fprintln(w, "") - fmt.Fprintln(w, "Examples:") - fmt.Fprintf(w, " %s -c /path/to/config.env\n", argv0) - fmt.Fprintf(w, " %s --dry-run --log-level debug\n", argv0) - fmt.Fprintf(w, " %s --version\n", argv0) + _, _ = fmt.Fprintln(w, "") + _, _ = fmt.Fprintln(w, "Examples:") + _, _ = fmt.Fprintf(w, " %s -c /path/to/config.env\n", argv0) + _, _ = fmt.Fprintf(w, " %s --dry-run --log-level debug\n", argv0) + _, _ = fmt.Fprintf(w, " %s --version\n", argv0) } func printVersion(w io.Writer) { - fmt.Fprintln(w, "ProxSave") + _, _ = fmt.Fprintln(w, "ProxSave") v := version.String() if strings.TrimSpace(v) == "" { v = "0.0.0-dev" } - fmt.Fprintf(w, "Version: %s\n", v) + _, _ = fmt.Fprintf(w, "Version: %s\n", v) build := "development" commit := strings.TrimSpace(version.Commit) @@ -242,9 +245,9 @@ func printVersion(w io.Writer) { case date != "": build = date } - fmt.Fprintf(w, "Build: %s\n", build) + _, _ = fmt.Fprintf(w, "Build: %s\n", build) - fmt.Fprintln(w, "Author: tis24dev") + _, _ = fmt.Fprintln(w, "Author: tis24dev") } type stringFlag struct { diff --git a/internal/closeerr/closeerr.go b/internal/closeerr/closeerr.go new file mode 100644 index 00000000..60920bfa --- /dev/null +++ b/internal/closeerr/closeerr.go @@ -0,0 +1,19 @@ +package closeerr + +import ( + "errors" + "fmt" + "io" + "os" +) + +// CloseIntoErr closes closer and stores the close failure in errp only when no +// earlier error is present. +func CloseIntoErr(errp *error, closer io.Closer, operation string) { + if errp == nil || closer == nil { + return + } + if closeErr := closer.Close(); closeErr != nil && !errors.Is(closeErr, os.ErrClosed) && *errp == nil { + *errp = fmt.Errorf("%s: %w", operation, closeErr) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 07532c9d..7e014cfe 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1419,7 +1419,7 @@ func isLocalPBSHost(host string) bool { hostShort = hostShort[:dot] } return hostShort == currentShort && - (strings.Index(host, ".") < 0 || strings.Index(currentHost, ".") < 0) + (!strings.Contains(host, ".") || !strings.Contains(currentHost, ".")) } func normalizePBSHost(host string) string { @@ -1498,14 +1498,18 @@ func (c *Config) BuildWebhookConfig() *WebhookConfig { } } -func parseEnvFile(path string) (map[string]string, error) { +func parseEnvFile(path string) (raw map[string]string, err error) { file, err := os.Open(path) if err != nil { return nil, fmt.Errorf("cannot open config file: %w", err) } - defer file.Close() + defer func() { + if closeErr := file.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("close config file: %w", closeErr) + } + }() - raw := make(map[string]string) + raw = make(map[string]string) scanner := bufio.NewScanner(file) for scanner.Scan() { diff --git a/internal/config/upgrade.go b/internal/config/upgrade.go index d319a4af..0a402f47 100644 --- a/internal/config/upgrade.go +++ b/internal/config/upgrade.go @@ -385,14 +385,12 @@ func computeConfigUpgrade(configPath string) (*UpgradeResult, string, []byte, er ops := make([]insertOp, 0, len(missingEntries)) unanchored := make([]templateEntry, 0) for _, entry := range missingEntries { - insertIndex := appendIndex - if prev, ok := findPrevAnchor(entry.index); ok { - insertIndex = prev - } else { + prev, ok := findPrevAnchor(entry.index) + if !ok { unanchored = append(unanchored, entry) continue } - insertIndex = normalizeInsertIndex(insertIndex) + insertIndex := normalizeInsertIndex(prev) ops = append(ops, insertOp{ index: insertIndex, lines: entry.lines, diff --git a/internal/environment/detect.go b/internal/environment/detect.go index a6fd841d..73c66029 100644 --- a/internal/environment/detect.go +++ b/internal/environment/detect.go @@ -143,12 +143,6 @@ func detectEnvironmentInfo() (*EnvironmentInfo, error) { return info, fmt.Errorf("unable to detect Proxmox environment") } -// detectProxmox is retained as a compatibility wrapper for legacy call sites and tests. -func detectProxmox() (types.ProxmoxType, string, error) { - info, err := detectEnvironmentInfo() - return info.Type, info.Version, err -} - func resolveType(hasPVE, hasPBS bool) types.ProxmoxType { switch { case hasPVE && hasPBS: @@ -421,58 +415,58 @@ func writeDetectionDebug() string { path := filepath.Join(debugDir, fmt.Sprintf("proxmox_detection_debug_%d.log", now.Unix())) var builder strings.Builder - builder.WriteString(fmt.Sprintf("=== Proxmox Detection Failure Debug - %s ===\n", now.Format("2006-01-02 15:04:05"))) - builder.WriteString(fmt.Sprintf("Current PATH: %s\n", os.Getenv("PATH"))) + fmt.Fprintf(&builder, "=== Proxmox Detection Failure Debug - %s ===\n", now.Format("2006-01-02 15:04:05")) + fmt.Fprintf(&builder, "Current PATH: %s\n", os.Getenv("PATH")) if u, err := userCurrentFunc(); err == nil { - builder.WriteString(fmt.Sprintf("Current USER: %s\n", u.Username)) + fmt.Fprintf(&builder, "Current USER: %s\n", u.Username) } else { builder.WriteString("Current USER: unknown\n") } if cwd, err := getwdFunc(); err == nil { - builder.WriteString(fmt.Sprintf("Current PWD: %s\n", cwd)) + fmt.Fprintf(&builder, "Current PWD: %s\n", cwd) } - builder.WriteString(fmt.Sprintf("Shell: %s\n\n", os.Getenv("SHELL"))) + fmt.Fprintf(&builder, "Shell: %s\n\n", os.Getenv("SHELL")) builder.WriteString("=== Command availability check ===\n") - builder.WriteString(fmt.Sprintf("command -v pveversion: %s\n", lookPathOrNotFound("pveversion"))) - builder.WriteString(fmt.Sprintf("command -v proxmox-backup-manager: %s\n", lookPathOrNotFound("proxmox-backup-manager"))) + fmt.Fprintf(&builder, "command -v pveversion: %s\n", lookPathOrNotFound("pveversion")) + fmt.Fprintf(&builder, "command -v proxmox-backup-manager: %s\n", lookPathOrNotFound("proxmox-backup-manager")) builder.WriteString("\n") builder.WriteString("=== File existence check ===\n") - builder.WriteString(fmt.Sprintf("%s exists: %s\n", "/usr/bin/pveversion", boolToYes(fileExists("/usr/bin/pveversion")))) - builder.WriteString(fmt.Sprintf("%s executable: %s\n", "/usr/bin/pveversion", boolToYes(isExecutable("/usr/bin/pveversion")))) - builder.WriteString(fmt.Sprintf("%s exists: %s\n", "/usr/sbin/pveversion", boolToYes(fileExists("/usr/sbin/pveversion")))) - builder.WriteString(fmt.Sprintf("%s executable: %s\n", "/usr/sbin/pveversion", boolToYes(isExecutable("/usr/sbin/pveversion")))) - builder.WriteString(fmt.Sprintf("%s exists: %s\n", "/usr/bin/proxmox-backup-manager", boolToYes(fileExists("/usr/bin/proxmox-backup-manager")))) - builder.WriteString(fmt.Sprintf("%s executable: %s\n", "/usr/bin/proxmox-backup-manager", boolToYes(isExecutable("/usr/bin/proxmox-backup-manager")))) + fmt.Fprintf(&builder, "%s exists: %s\n", "/usr/bin/pveversion", boolToYes(fileExists("/usr/bin/pveversion"))) + fmt.Fprintf(&builder, "%s executable: %s\n", "/usr/bin/pveversion", boolToYes(isExecutable("/usr/bin/pveversion"))) + fmt.Fprintf(&builder, "%s exists: %s\n", "/usr/sbin/pveversion", boolToYes(fileExists("/usr/sbin/pveversion"))) + fmt.Fprintf(&builder, "%s executable: %s\n", "/usr/sbin/pveversion", boolToYes(isExecutable("/usr/sbin/pveversion"))) + fmt.Fprintf(&builder, "%s exists: %s\n", "/usr/bin/proxmox-backup-manager", boolToYes(fileExists("/usr/bin/proxmox-backup-manager"))) + fmt.Fprintf(&builder, "%s executable: %s\n", "/usr/bin/proxmox-backup-manager", boolToYes(isExecutable("/usr/bin/proxmox-backup-manager"))) builder.WriteString("\n") builder.WriteString("=== Directory existence check ===\n") for _, dir := range append(pveDirCandidates, pbsDirCandidates...) { - builder.WriteString(fmt.Sprintf("%s exists: %s\n", dir, boolToYes(dirExists(dir)))) + fmt.Fprintf(&builder, "%s exists: %s\n", dir, boolToYes(dirExists(dir))) } builder.WriteString("\n") builder.WriteString("=== Version file check ===\n") - builder.WriteString(fmt.Sprintf("%s exists: %s\n", pveLegacyFile, boolToYes(fileExists(pveLegacyFile)))) + fmt.Fprintf(&builder, "%s exists: %s\n", pveLegacyFile, boolToYes(fileExists(pveLegacyFile))) if content := readAndTrim(pveLegacyFile); content != "" { - builder.WriteString(fmt.Sprintf("%s content: %s\n", pveLegacyFile, content)) + fmt.Fprintf(&builder, "%s content: %s\n", pveLegacyFile, content) } - builder.WriteString(fmt.Sprintf("%s exists: %s\n", pveVersionFile, boolToYes(fileExists(pveVersionFile)))) + fmt.Fprintf(&builder, "%s exists: %s\n", pveVersionFile, boolToYes(fileExists(pveVersionFile))) if content := readAndTrim(pveVersionFile); content != "" { - builder.WriteString(fmt.Sprintf("%s content: %s\n", pveVersionFile, content)) + fmt.Fprintf(&builder, "%s content: %s\n", pveVersionFile, content) } - builder.WriteString(fmt.Sprintf("%s exists: %s\n", pbsVersionFile, boolToYes(fileExists(pbsVersionFile)))) + fmt.Fprintf(&builder, "%s exists: %s\n", pbsVersionFile, boolToYes(fileExists(pbsVersionFile))) if content := readAndTrim(pbsVersionFile); content != "" { - builder.WriteString(fmt.Sprintf("%s content: %s\n", pbsVersionFile, content)) + fmt.Fprintf(&builder, "%s content: %s\n", pbsVersionFile, content) } builder.WriteString("\n") builder.WriteString("=== APT source files check ===\n") for _, source := range append(pveSourceFiles, pbsSourceFiles...) { - builder.WriteString(fmt.Sprintf("%s exists: %s\n", source, boolToYes(fileExists(source)))) + fmt.Fprintf(&builder, "%s exists: %s\n", source, boolToYes(fileExists(source))) } builder.WriteString("\n") diff --git a/internal/environment/unprivileged.go b/internal/environment/unprivileged.go index ed27be80..fb840f70 100644 --- a/internal/environment/unprivileged.go +++ b/internal/environment/unprivileged.go @@ -335,23 +335,6 @@ func formatIDMapDetails(label string, info IDMapOutsideZeroInfo) string { } } -func formatFileValueDetails(label string, info FileValueInfo) string { - label = strings.TrimSpace(label) - if label == "" { - label = "value" - } - switch { - case info.OK && strings.TrimSpace(info.Value) != "": - return fmt.Sprintf("%s=%s", label, strings.TrimSpace(info.Value)) - case info.OK: - return fmt.Sprintf("%s=empty", label) - case info.ReadError != "": - return fmt.Sprintf("%s=unavailable(err=%s)", label, info.ReadError) - default: - return fmt.Sprintf("%s=unavailable", label) - } -} - func formatSimpleDetails(label, value, emptyValue string) string { label = strings.TrimSpace(label) if label == "" { diff --git a/internal/identity/identity.go b/internal/identity/identity.go index 755dc24b..22a9d546 100644 --- a/internal/identity/identity.go +++ b/internal/identity/identity.go @@ -474,7 +474,7 @@ func buildSystemData(macs []string, logger *logging.Logger) string { } if builder.Len() == 0 { - builder.WriteString(fmt.Sprintf("fallback-%d-%d", time.Now().Unix(), os.Getpid())) + fmt.Fprintf(&builder, "fallback-%d-%d", time.Now().Unix(), os.Getpid()) logDebug(logger, "Identity: buildSystemData: WARNING: used fallback seed (unexpected)") } @@ -494,10 +494,10 @@ func encodeProtectedServerIDWithMACs(serverID string, macs []string, primaryMAC var builder strings.Builder builder.WriteString("# ProxSave Backup System Configuration\n") - builder.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().Format(time.RFC3339))) + fmt.Fprintf(&builder, "# Generated: %s\n", time.Now().Format(time.RFC3339)) builder.WriteString("# DO NOT MODIFY THIS FILE MANUALLY\n") builder.WriteString("# Format: proxsave-identity-v2\n") - builder.WriteString(fmt.Sprintf("SYSTEM_CONFIG_DATA=\"%s\"\n", encoded)) + fmt.Fprintf(&builder, "SYSTEM_CONFIG_DATA=\"%s\"\n", encoded) builder.WriteString("# End of configuration\n") content := builder.String() diff --git a/internal/identity/identity_test.go b/internal/identity/identity_test.go index e0639abd..25bad342 100644 --- a/internal/identity/identity_test.go +++ b/internal/identity/identity_test.go @@ -1138,7 +1138,8 @@ func TestReadAddrAssignType(t *testing.T) { } func TestIsBridgeInterfaceByName(t *testing.T) { - // On non-Linux or without sysfs, falls back to name-based detection + // On non-Linux, detection falls back to interface names. On Linux, + // sysfs decides the result and these synthetic names may not exist. tests := []struct { name string want bool @@ -1155,16 +1156,21 @@ func TestIsBridgeInterfaceByName(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // This will use name-based fallback if sysfs not available got := isBridgeInterface(tt.name) - // On Linux with sysfs, result may differ, so we just check it doesn't panic - _ = got + if runtime.GOOS == "linux" { + t.Logf("sysfs bridge detection for %q returned %v", tt.name, got) + return + } + if got != tt.want { + t.Fatalf("isBridgeInterface(%q)=%v; want %v", tt.name, got, tt.want) + } }) } } func TestIsWirelessInterfaceByName(t *testing.T) { - // On non-Linux or without sysfs, falls back to name-based detection + // On non-Linux, detection falls back to interface names. On Linux, + // sysfs decides the result and these synthetic names may not exist. tests := []struct { name string want bool @@ -1179,9 +1185,12 @@ func TestIsWirelessInterfaceByName(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := isWirelessInterface(tt.name) - // Check name-based fallback behavior - if strings.HasPrefix(strings.ToLower(tt.name), "wl") && !got { - // May or may not work depending on sysfs + if runtime.GOOS == "linux" { + t.Logf("sysfs wireless detection for %q returned %v", tt.name, got) + return + } + if got != tt.want { + t.Fatalf("isWirelessInterface(%q)=%v; want %v", tt.name, got, tt.want) } }) } diff --git a/internal/input/input_test.go b/internal/input/input_test.go index 535f0482..f82af39d 100644 --- a/internal/input/input_test.go +++ b/internal/input/input_test.go @@ -218,7 +218,7 @@ func TestReadLineWithContext_ReturnsLine(t *testing.T) { func TestReadLineWithContext_NilContextWorks(t *testing.T) { reader := bufio.NewReader(strings.NewReader("hello\n")) - got, err := ReadLineWithContext(nil, reader) + got, err := ReadLineWithContext(nil, reader) //nolint:staticcheck // Verifies the documented nil context fallback. if err != nil { t.Fatalf("ReadLineWithContext error: %v", err) } @@ -229,8 +229,8 @@ func TestReadLineWithContext_NilContextWorks(t *testing.T) { func TestReadLineWithContext_CancelledReturnsAborted(t *testing.T) { pr, pw := io.Pipe() - defer pr.Close() - defer pw.Close() + defer func() { _ = pr.Close() }() + defer func() { _ = pw.Close() }() reader := bufio.NewReader(pr) ctx, cancel := context.WithCancel(context.Background()) @@ -258,8 +258,8 @@ func TestReadLineWithContext_CancelledReturnsAborted(t *testing.T) { func TestReadLineWithContext_DeadlineReturnsDeadlineExceeded(t *testing.T) { pr, pw := io.Pipe() - defer pr.Close() - defer pw.Close() + defer func() { _ = pr.Close() }() + defer func() { _ = pw.Close() }() reader := bufio.NewReader(pr) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) @@ -364,7 +364,7 @@ func TestReadPasswordWithContext_NilContextWorks(t *testing.T) { readPassword := func(fd int) ([]byte, error) { return []byte("secret"), nil } - got, err := ReadPasswordWithContext(nil, readPassword, 0) + got, err := ReadPasswordWithContext(nil, readPassword, 0) //nolint:staticcheck // Verifies the documented nil context fallback. if err != nil { t.Fatalf("ReadPasswordWithContext error: %v", err) } diff --git a/internal/logging/logger.go b/internal/logging/logger.go index ebb80512..2e80b719 100644 --- a/internal/logging/logger.go +++ b/internal/logging/logger.go @@ -72,7 +72,10 @@ func (l *Logger) OpenLogFile(logPath string) error { defer l.mu.Unlock() // If a log file is already open, close it first. if l.logFile != nil { - l.logFile.Close() + if err := l.logFile.Close(); err != nil { + return fmt.Errorf("failed to close existing log file: %w", err) + } + l.logFile = nil } // Create the log file (O_CREATE|O_WRONLY|O_APPEND). @@ -197,11 +200,11 @@ func (l *Logger) logWithLabel(level types.LogLevel, label string, colorOverride } // Write to stdout with colors. - fmt.Fprint(l.output, outputStdout) + _, _ = fmt.Fprint(l.output, outputStdout) // If a log file is open, write there too (without colors). if l.logFile != nil { - fmt.Fprint(l.logFile, outputFile) + _, _ = fmt.Fprint(l.logFile, outputFile) } } @@ -333,7 +336,7 @@ func (l *Logger) AppendRaw(message string) { types.LogLevelInfo.String(), message, ) - fmt.Fprint(l.logFile, output) + _, _ = fmt.Fprint(l.logFile, output) } // Package-level default logger diff --git a/internal/logging/session_test.go b/internal/logging/session_test.go index 233c9264..03fa6c7a 100644 --- a/internal/logging/session_test.go +++ b/internal/logging/session_test.go @@ -42,7 +42,7 @@ func TestDetectHostname(t *testing.T) { for _, r := range host { isLower := r >= 'a' && r <= 'z' isDigit := r >= '0' && r <= '9' - if !(isLower || isDigit || r == '-') { + if !isLower && !isDigit && r != '-' { t.Fatalf("unexpected rune %q in hostname %q", r, host) } } diff --git a/internal/metrics/prometheus.go b/internal/metrics/prometheus.go index d06a8eb2..c1e55240 100644 --- a/internal/metrics/prometheus.go +++ b/internal/metrics/prometheus.go @@ -48,7 +48,7 @@ func NewPrometheusExporter(textfileDir string, logger *logging.Logger) *Promethe } // Export writes the given metrics snapshot to proxmox_backup.prom in textfileDir. -func (pe *PrometheusExporter) Export(m *BackupMetrics) error { +func (pe *PrometheusExporter) Export(m *BackupMetrics) (err error) { if pe == nil || m == nil { return nil } @@ -68,13 +68,48 @@ func (pe *PrometheusExporter) Export(m *BackupMetrics) error { if err != nil { return fmt.Errorf("create metrics file %s: %w", tmpPath, err) } - defer f.Close() + defer func() { + if f == nil { + return + } + if closeErr := f.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("close metrics file %s: %w", tmpPath, closeErr) + } + }() + + var writeErr error + wrap := func(err error) error { + if err == nil { + return nil + } + if writeErr == nil { + writeErr = fmt.Errorf("write metrics file %s: %w", tmpPath, err) + } + return writeErr + } + writef := func(format string, a ...any) error { + if writeErr != nil { + return writeErr + } + _, err := fmt.Fprintf(f, format, a...) + return wrap(err) + } // Helper to write a single metric with HELP/TYPE - writeMetric := func(name, mtype, help, value string) { - fmt.Fprintf(f, "# HELP %s %s\n", name, help) - fmt.Fprintf(f, "# TYPE %s %s\n", name, mtype) - fmt.Fprintln(f, value) + writeMetric := func(name, mtype, help, value string) error { + if writeErr != nil { + return writeErr + } + if err := writef("# HELP %s %s\n", name, help); err != nil { + return err + } + if err := writef("# TYPE %s %s\n", name, mtype); err != nil { + return err + } + if err := writef("%s\n", value); err != nil { + return err + } + return nil } // Timestamps @@ -93,105 +128,146 @@ func (pe *PrometheusExporter) Export(m *BackupMetrics) error { } // Core metrics - writeMetric( + if err := writeMetric( "proxmox_backup_start_time_seconds", "gauge", "Unix timestamp of backup start", fmt.Sprintf("proxmox_backup_start_time_seconds %.0f", startTs), - ) + ); err != nil { + return err + } - writeMetric( + if err := writeMetric( "proxmox_backup_end_time_seconds", "gauge", "Unix timestamp of backup end", fmt.Sprintf("proxmox_backup_end_time_seconds %.0f", endTs), - ) + ); err != nil { + return err + } - writeMetric( + if err := writeMetric( "proxmox_backup_duration_seconds", "gauge", "Duration of last backup in seconds", fmt.Sprintf("proxmox_backup_duration_seconds %.2f", m.Duration.Seconds()), - ) + ); err != nil { + return err + } - writeMetric( + if err := writeMetric( "proxmox_backup_exit_code", "gauge", "Exit code of last backup", fmt.Sprintf("proxmox_backup_exit_code %d", m.ExitCode), - ) + ); err != nil { + return err + } - writeMetric( + if err := writeMetric( "proxmox_backup_status", "gauge", "Status of last backup (0=success,1=warning,2=error)", fmt.Sprintf("proxmox_backup_status %d", status), - ) + ); err != nil { + return err + } - writeMetric( + if err := writeMetric( "proxmox_backup_errors_total", "gauge", "Total number of errors in last backup", fmt.Sprintf("proxmox_backup_errors_total %d", m.ErrorCount), - ) + ); err != nil { + return err + } - writeMetric( + if err := writeMetric( "proxmox_backup_warnings_total", "gauge", "Total number of warnings in last backup", fmt.Sprintf("proxmox_backup_warnings_total %d", m.WarningCount), - ) + ); err != nil { + return err + } - writeMetric( + if err := writeMetric( "proxmox_backup_bytes_collected", "gauge", "Total number of bytes collected during last backup", fmt.Sprintf("proxmox_backup_bytes_collected %d", m.BytesCollected), - ) + ); err != nil { + return err + } - writeMetric( + if err := writeMetric( "proxmox_backup_archive_size_bytes", "gauge", "Size of last backup archive in bytes", fmt.Sprintf("proxmox_backup_archive_size_bytes %d", m.ArchiveSize), - ) + ); err != nil { + return err + } - writeMetric( + if err := writeMetric( "proxmox_backup_files_collected_total", "gauge", "Total files successfully collected during last backup", fmt.Sprintf("proxmox_backup_files_collected_total %d", m.FilesCollected), - ) + ); err != nil { + return err + } - writeMetric( + if err := writeMetric( "proxmox_backup_files_failed_total", "gauge", "Total files that failed to collect during last backup", fmt.Sprintf("proxmox_backup_files_failed_total %d", m.FilesFailed), - ) + ); err != nil { + return err + } // Per-location backup counts - fmt.Fprintf(f, "# HELP proxmox_backup_backups_total Number of backups per location\n") - fmt.Fprintf(f, "# TYPE proxmox_backup_backups_total gauge\n") - fmt.Fprintf(f, "proxmox_backup_backups_total{location=\"local\"} %d\n", m.LocalBackups) - fmt.Fprintf(f, "proxmox_backup_backups_total{location=\"secondary\"} %d\n", m.SecBackups) - fmt.Fprintf(f, "proxmox_backup_backups_total{location=\"cloud\"} %d\n", m.CloudBackups) + if err := writef("# HELP proxmox_backup_backups_total Number of backups per location\n"); err != nil { + return err + } + if err := writef("# TYPE proxmox_backup_backups_total gauge\n"); err != nil { + return err + } + if err := writef("proxmox_backup_backups_total{location=\"local\"} %d\n", m.LocalBackups); err != nil { + return err + } + if err := writef("proxmox_backup_backups_total{location=\"secondary\"} %d\n", m.SecBackups); err != nil { + return err + } + if err := writef("proxmox_backup_backups_total{location=\"cloud\"} %d\n", m.CloudBackups); err != nil { + return err + } // Static info metric with labels - fmt.Fprintf(f, "# HELP proxmox_backup_info Static information about this backup instance\n") - fmt.Fprintf(f, "# TYPE proxmox_backup_info gauge\n") - fmt.Fprintf( - f, + if err := writef("# HELP proxmox_backup_info Static information about this backup instance\n"); err != nil { + return err + } + if err := writef("# TYPE proxmox_backup_info gauge\n"); err != nil { + return err + } + if err := writef( "proxmox_backup_info{hostname=%q,proxmox_type=%q,proxmox_version=%q,script_version=%q} 1\n", m.Hostname, m.ProxmoxType, m.ProxmoxVersion, m.ScriptVersion, - ) + ); err != nil { + return err + } if err := f.Sync(); err != nil { return fmt.Errorf("sync metrics file %s: %w", tmpPath, err) } + if err := f.Close(); err != nil { + return fmt.Errorf("close metrics file %s: %w", tmpPath, err) + } + f = nil if err := os.Rename(tmpPath, finalPath); err != nil { return fmt.Errorf("rename metrics file to %s: %w", finalPath, err) diff --git a/internal/notify/email.go b/internal/notify/email.go index da44380f..e04644da 100644 --- a/internal/notify/email.go +++ b/internal/notify/email.go @@ -1151,9 +1151,9 @@ func (e *EmailNotifier) buildEmailMessage(recipient, subject, htmlBody, textBody if toHeader == "" { toHeader = "root" } - email.WriteString(fmt.Sprintf("To: %s\n", toHeader)) - email.WriteString(fmt.Sprintf("From: %s\n", e.config.From)) - email.WriteString(fmt.Sprintf("Subject: =?UTF-8?B?%s?=\n", encodedSubject)) + fmt.Fprintf(&email, "To: %s\n", toHeader) + fmt.Fprintf(&email, "From: %s\n", e.config.From) + fmt.Fprintf(&email, "Subject: =?UTF-8?B?%s?=\n", encodedSubject) email.WriteString("MIME-Version: 1.0\n") // Decide whether to attach log file @@ -1170,16 +1170,16 @@ func (e *EmailNotifier) buildEmailMessage(recipient, subject, htmlBody, textBody mixedBoundary := "mixed_boundary_42" altBoundary := "alt_boundary_42" - email.WriteString(fmt.Sprintf("Content-Type: multipart/mixed; boundary=\"%s\"\n", mixedBoundary)) + fmt.Fprintf(&email, "Content-Type: multipart/mixed; boundary=\"%s\"\n", mixedBoundary) email.WriteString("\n") // First part: multipart/alternative with text and HTML bodies - email.WriteString(fmt.Sprintf("--%s\n", mixedBoundary)) - email.WriteString(fmt.Sprintf("Content-Type: multipart/alternative; boundary=\"%s\"\n", altBoundary)) + fmt.Fprintf(&email, "--%s\n", mixedBoundary) + fmt.Fprintf(&email, "Content-Type: multipart/alternative; boundary=\"%s\"\n", altBoundary) email.WriteString("\n") // Plain text part - email.WriteString(fmt.Sprintf("--%s\n", altBoundary)) + fmt.Fprintf(&email, "--%s\n", altBoundary) email.WriteString("Content-Type: text/plain; charset=UTF-8\n") email.WriteString("Content-Transfer-Encoding: quoted-printable\n") email.WriteString("\n") @@ -1187,14 +1187,14 @@ func (e *EmailNotifier) buildEmailMessage(recipient, subject, htmlBody, textBody email.WriteString("\n\n") // HTML part - email.WriteString(fmt.Sprintf("--%s\n", altBoundary)) + fmt.Fprintf(&email, "--%s\n", altBoundary) email.WriteString("Content-Type: text/html; charset=UTF-8\n") email.WriteString("Content-Transfer-Encoding: quoted-printable\n") email.WriteString("\n") email.WriteString(encodeQuotedPrintableBody(htmlBody)) email.WriteString("\n\n") - email.WriteString(fmt.Sprintf("--%s--\n", altBoundary)) + fmt.Fprintf(&email, "--%s--\n", altBoundary) email.WriteString("\n") // Second part: log file attachment (Base64 encoded) @@ -1203,9 +1203,9 @@ func (e *EmailNotifier) buildEmailMessage(recipient, subject, htmlBody, textBody filename = "backup.log" } - email.WriteString(fmt.Sprintf("--%s\n", mixedBoundary)) - email.WriteString(fmt.Sprintf("Content-Type: text/plain; charset=UTF-8; name=\"%s\"\n", filename)) - email.WriteString(fmt.Sprintf("Content-Disposition: attachment; filename=\"%s\"\n", filename)) + fmt.Fprintf(&email, "--%s\n", mixedBoundary) + fmt.Fprintf(&email, "Content-Type: text/plain; charset=UTF-8; name=\"%s\"\n", filename) + fmt.Fprintf(&email, "Content-Disposition: attachment; filename=\"%s\"\n", filename) email.WriteString("Content-Transfer-Encoding: base64\n") email.WriteString("\n") @@ -1220,18 +1220,18 @@ func (e *EmailNotifier) buildEmailMessage(recipient, subject, htmlBody, textBody email.WriteString("\n") } email.WriteString("\n") - email.WriteString(fmt.Sprintf("--%s--\n", mixedBoundary)) + fmt.Fprintf(&email, "--%s--\n", mixedBoundary) } } if !attachLog { // Fallback / default: simple multipart/alternative (no attachment) altBoundary := "boundary42" - email.WriteString(fmt.Sprintf("Content-Type: multipart/alternative; boundary=\"%s\"\n", altBoundary)) + fmt.Fprintf(&email, "Content-Type: multipart/alternative; boundary=\"%s\"\n", altBoundary) email.WriteString("\n") // Plain text part - email.WriteString(fmt.Sprintf("--%s\n", altBoundary)) + fmt.Fprintf(&email, "--%s\n", altBoundary) email.WriteString("Content-Type: text/plain; charset=UTF-8\n") email.WriteString("Content-Transfer-Encoding: quoted-printable\n") email.WriteString("\n") @@ -1239,14 +1239,14 @@ func (e *EmailNotifier) buildEmailMessage(recipient, subject, htmlBody, textBody email.WriteString("\n\n") // HTML part - email.WriteString(fmt.Sprintf("--%s\n", altBoundary)) + fmt.Fprintf(&email, "--%s\n", altBoundary) email.WriteString("Content-Type: text/html; charset=UTF-8\n") email.WriteString("Content-Transfer-Encoding: quoted-printable\n") email.WriteString("\n") email.WriteString(encodeQuotedPrintableBody(htmlBody)) email.WriteString("\n\n") - email.WriteString(fmt.Sprintf("--%s--\n", altBoundary)) + fmt.Fprintf(&email, "--%s--\n", altBoundary) } e.logger.Debug("Email message built (%d bytes)", email.Len()) diff --git a/internal/notify/email_delivery_methods_test.go b/internal/notify/email_delivery_methods_test.go index 10a73fea..d9807683 100644 --- a/internal/notify/email_delivery_methods_test.go +++ b/internal/notify/email_delivery_methods_test.go @@ -42,7 +42,7 @@ func TestEmailNotifier_RelayNoFallback_ReturnsError(t *testing.T) { // Force relay failure. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(`{"error":"temporary"}`)) + _, _ = w.Write([]byte(`{"error":"temporary"}`)) })) defer server.Close() @@ -136,7 +136,7 @@ func TestEmailNotifier_RelayFallback_UsesPMFOnly(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(`{"error":"temporary"}`)) + _, _ = w.Write([]byte(`{"error":"temporary"}`)) })) defer server.Close() diff --git a/internal/notify/email_relay.go b/internal/notify/email_relay.go index dd17a19d..26cd84c3 100644 --- a/internal/notify/email_relay.go +++ b/internal/notify/email_relay.go @@ -131,7 +131,7 @@ func sendViaCloudRelay( // Read response body body, err := io.ReadAll(resp.Body) - resp.Body.Close() + closeErr := resp.Body.Close() if err != nil { if ctxErr := ctx.Err(); ctxErr != nil { return ctxErr @@ -139,6 +139,10 @@ func sendViaCloudRelay( lastErr = fmt.Errorf("failed to read response: %w", err) continue } + if closeErr != nil { + lastErr = fmt.Errorf("failed to close response body: %w", closeErr) + continue + } // Log raw response body for all status codes (aids future diagnosis) logger.Debug("Cloud relay: HTTP %d response (%d bytes): %s", resp.StatusCode, len(body), string(body)) diff --git a/internal/notify/email_relay_test.go b/internal/notify/email_relay_test.go index 94a9b148..99467e4f 100644 --- a/internal/notify/email_relay_test.go +++ b/internal/notify/email_relay_test.go @@ -154,7 +154,7 @@ func TestSendViaCloudRelay_StatusHandling(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ w.WriteHeader(tt.statusCode) - w.Write([]byte(tt.body)) + _, _ = w.Write([]byte(tt.body)) })) defer server.Close() @@ -195,7 +195,7 @@ func TestSendViaCloudRelay_RetryOnServerError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if atomic.AddInt32(&attempts, 1) < 3 { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(`{"error":"temporary"}`)) + _, _ = w.Write([]byte(`{"error":"temporary"}`)) return } w.WriteHeader(http.StatusOK) @@ -238,7 +238,7 @@ func TestSendViaCloudRelay_StopsRetryingWhenContextCanceled(t *testing.T) { cancel() } w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(`{"error":"temporary"}`)) + _, _ = w.Write([]byte(`{"error":"temporary"}`)) })) defer server.Close() diff --git a/internal/notify/gotify.go b/internal/notify/gotify.go index be2aa370..44acc3b6 100644 --- a/internal/notify/gotify.go +++ b/internal/notify/gotify.go @@ -143,7 +143,7 @@ func (g *GotifyNotifier) Send(ctx context.Context, data *NotificationData) (*Not result.Duration = time.Since(start) return result, nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() result.Metadata["status_code"] = resp.StatusCode respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) diff --git a/internal/notify/gotify_test.go b/internal/notify/gotify_test.go index 0f1f9c06..a019f90b 100644 --- a/internal/notify/gotify_test.go +++ b/internal/notify/gotify_test.go @@ -128,7 +128,7 @@ func TestGotifySendSuccessAndFailure(t *testing.T) { // Now force server to return 500 to trigger failure path. serverFail := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("fail")) + _, _ = w.Write([]byte("fail")) })) defer serverFail.Close() diff --git a/internal/notify/telegram.go b/internal/notify/telegram.go index 7b45f38e..f1cb14e3 100644 --- a/internal/notify/telegram.go +++ b/internal/notify/telegram.go @@ -199,7 +199,7 @@ func (t *TelegramNotifier) fetchCentralizedCredentials(ctx context.Context) (str if err != nil { return "", "", fmt.Errorf("API request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() // Read response body, err := io.ReadAll(resp.Body) @@ -245,61 +245,61 @@ func (t *TelegramNotifier) buildMessage(data *NotificationData) string { // Tool name and version header version := strings.TrimSpace(data.ScriptVersion) if version != "" { - msg.WriteString(fmt.Sprintf("ProxSave - v%s\n\n", version)) + fmt.Fprintf(&msg, "ProxSave - v%s\n\n", version) } else { msg.WriteString("ProxSave\n\n") } // Header with status and hostname statusEmoji := GetStatusEmoji(data.Status) - msg.WriteString(fmt.Sprintf("%s Backup %s - %s\n\n", + fmt.Fprintf(&msg, "%s Backup %s - %s\n\n", statusEmoji, data.ProxmoxType.String(), - data.Hostname)) + data.Hostname) // Storage status localEmoji := GetStorageEmoji(data.LocalStatus) - msg.WriteString(fmt.Sprintf("%s Local (%s backups)\n", localEmoji, data.LocalStatusSummary)) + fmt.Fprintf(&msg, "%s Local (%s backups)\n", localEmoji, data.LocalStatusSummary) if data.SecondaryEnabled { secondaryEmoji := GetStorageEmoji(data.SecondaryStatus) - msg.WriteString(fmt.Sprintf("%s Secondary (%s backups)\n", secondaryEmoji, data.SecondaryStatusSummary)) + fmt.Fprintf(&msg, "%s Secondary (%s backups)\n", secondaryEmoji, data.SecondaryStatusSummary) } else { msg.WriteString("➖ Secondary (disabled)\n") } if data.CloudEnabled { cloudEmoji := GetStorageEmoji(data.CloudStatus) - msg.WriteString(fmt.Sprintf("%s Cloud (%s backups)\n", cloudEmoji, data.CloudStatusSummary)) + fmt.Fprintf(&msg, "%s Cloud (%s backups)\n", cloudEmoji, data.CloudStatusSummary) } else { msg.WriteString("➖ Cloud (disabled)\n") } // Email status emailEmoji := GetStorageEmoji(data.EmailStatus) - msg.WriteString(fmt.Sprintf("%s Email\n\n", emailEmoji)) + fmt.Fprintf(&msg, "%s Email\n\n", emailEmoji) // File counts - msg.WriteString(fmt.Sprintf("📁 Included files: %d\n", data.FilesIncluded)) + fmt.Fprintf(&msg, "📁 Included files: %d\n", data.FilesIncluded) if data.FilesMissing > 0 { - msg.WriteString(fmt.Sprintf("⚠️ Missing files: %d\n", data.FilesMissing)) + fmt.Fprintf(&msg, "⚠️ Missing files: %d\n", data.FilesMissing) } msg.WriteString("\n") // Disk space msg.WriteString("💾 Available space:\n") - msg.WriteString(fmt.Sprintf("🔹 Local: %s\n", data.LocalFree)) + fmt.Fprintf(&msg, "🔹 Local: %s\n", data.LocalFree) if data.SecondaryEnabled && data.SecondaryFree != "" { - msg.WriteString(fmt.Sprintf("🔹 Secondary: %s\n", data.SecondaryFree)) + fmt.Fprintf(&msg, "🔹 Secondary: %s\n", data.SecondaryFree) } msg.WriteString("\n") // Backup metadata - msg.WriteString(fmt.Sprintf("📅 Backup date: %s\n", data.BackupDate.Format("2006-01-02 15:04"))) - msg.WriteString(fmt.Sprintf("⏱️ Duration: %s\n\n", FormatDuration(data.BackupDuration))) + fmt.Fprintf(&msg, "📅 Backup date: %s\n", data.BackupDate.Format("2006-01-02 15:04")) + fmt.Fprintf(&msg, "⏱️ Duration: %s\n\n", FormatDuration(data.BackupDuration)) // Exit code - msg.WriteString(fmt.Sprintf("🔢 Exit code: %d", data.ExitCode)) + fmt.Fprintf(&msg, "🔢 Exit code: %d", data.ExitCode) // Optional version update information if data.NewVersionAvailable && strings.TrimSpace(data.LatestVersion) != "" { @@ -307,9 +307,9 @@ func (t *TelegramNotifier) buildMessage(data *NotificationData) string { current := strings.TrimSpace(data.CurrentVersion) if current != "" { - msg.WriteString(fmt.Sprintf("New version: %s (current: %s)\n", data.LatestVersion, current)) + fmt.Fprintf(&msg, "New version: %s (current: %s)\n", data.LatestVersion, current) } else { - msg.WriteString(fmt.Sprintf("New version: %s\n", data.LatestVersion)) + fmt.Fprintf(&msg, "New version: %s\n", data.LatestVersion) } msg.WriteString("Run 'proxsave --upgrade'\n") } @@ -340,7 +340,7 @@ func (t *TelegramNotifier) sendToTelegram(ctx context.Context, botToken, chatID, if err != nil { return fmt.Errorf("api request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() // Check response if resp.StatusCode != 200 { diff --git a/internal/notify/telegram_registration.go b/internal/notify/telegram_registration.go index 69a4fd39..4730094b 100644 --- a/internal/notify/telegram_registration.go +++ b/internal/notify/telegram_registration.go @@ -70,7 +70,7 @@ func CheckTelegramRegistration(ctx context.Context, serverAPIHost, serverID stri logTelegramRegistrationDebug(logger, "Telegram registration: request failed: %v", err) return TelegramRegistrationStatus{Message: "Connection failed", Error: err} } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) logTelegramRegistrationDebug(logger, "Telegram registration: response status=%d bodyBytes=%d bodyPreview=%q", resp.StatusCode, len(body), truncateTelegramRegistrationBody(body, 200)) diff --git a/internal/notify/telegram_registration_test.go b/internal/notify/telegram_registration_test.go index bf38776c..9e1dfa77 100644 --- a/internal/notify/telegram_registration_test.go +++ b/internal/notify/telegram_registration_test.go @@ -39,7 +39,7 @@ func TestCheckTelegramRegistrationResponses(t *testing.T) { t.Run(tt.name, func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(tt.statusCode) - w.Write([]byte(tt.name)) + _, _ = w.Write([]byte(tt.name)) })) defer server.Close() diff --git a/internal/notify/templates.go b/internal/notify/templates.go index d8ed667e..0bdbfb14 100644 --- a/internal/notify/templates.go +++ b/internal/notify/templates.go @@ -20,53 +20,53 @@ func BuildEmailPlainText(data *NotificationData) string { var body strings.Builder statusEmoji := GetStatusEmoji(data.Status) - body.WriteString(fmt.Sprintf("%s %s BACKUP REPORT - %s\n", - statusEmoji, strings.ToUpper(data.ProxmoxType.String()), strings.ToUpper(data.Status.String()))) - body.WriteString(fmt.Sprintf("Hostname: %s\n", data.Hostname)) - body.WriteString(fmt.Sprintf("Date: %s\n\n", data.BackupDate.Format("2006-01-02 15:04:05"))) + fmt.Fprintf(&body, "%s %s BACKUP REPORT - %s\n", + statusEmoji, strings.ToUpper(data.ProxmoxType.String()), strings.ToUpper(data.Status.String())) + fmt.Fprintf(&body, "Hostname: %s\n", data.Hostname) + fmt.Fprintf(&body, "Date: %s\n\n", data.BackupDate.Format("2006-01-02 15:04:05")) body.WriteString("BACKUP STATUS:\n") - body.WriteString(fmt.Sprintf(" Local: %s backups (%s free)\n", data.LocalStatusSummary, data.LocalFree)) + fmt.Fprintf(&body, " Local: %s backups (%s free)\n", data.LocalStatusSummary, data.LocalFree) if data.SecondaryEnabled { - body.WriteString(fmt.Sprintf(" Secondary: %s backups (%s free)\n", data.SecondaryStatusSummary, data.SecondaryFree)) + fmt.Fprintf(&body, " Secondary: %s backups (%s free)\n", data.SecondaryStatusSummary, data.SecondaryFree) } if data.CloudEnabled { - body.WriteString(fmt.Sprintf(" Cloud: %s backups\n", data.CloudStatusSummary)) + fmt.Fprintf(&body, " Cloud: %s backups\n", data.CloudStatusSummary) } body.WriteString("\n") body.WriteString("BACKUP DETAILS:\n") - body.WriteString(fmt.Sprintf(" Backup File: %s\n", data.BackupFile)) - body.WriteString(fmt.Sprintf(" Size: %s\n", data.BackupSizeHR)) - body.WriteString(fmt.Sprintf(" Included Files: %d\n", data.FilesIncluded)) - body.WriteString(fmt.Sprintf(" Missing Files: %d\n", data.FilesMissing)) - body.WriteString(fmt.Sprintf(" Duration: %s\n", FormatDuration(data.BackupDuration))) - body.WriteString(fmt.Sprintf(" Compression: %s (level %d, ratio %.2f%%)\n", - data.CompressionType, data.CompressionLevel, data.CompressionRatio)) + fmt.Fprintf(&body, " Backup File: %s\n", data.BackupFile) + fmt.Fprintf(&body, " Size: %s\n", data.BackupSizeHR) + fmt.Fprintf(&body, " Included Files: %d\n", data.FilesIncluded) + fmt.Fprintf(&body, " Missing Files: %d\n", data.FilesMissing) + fmt.Fprintf(&body, " Duration: %s\n", FormatDuration(data.BackupDuration)) + fmt.Fprintf(&body, " Compression: %s (level %d, ratio %.2f%%)\n", + data.CompressionType, data.CompressionLevel, data.CompressionRatio) body.WriteString("\n") body.WriteString("ISSUES:\n") - body.WriteString(fmt.Sprintf(" Errors: %d\n", data.ErrorCount)) - body.WriteString(fmt.Sprintf(" Warnings: %d\n", data.WarningCount)) - body.WriteString(fmt.Sprintf(" Total Issues: %d\n", data.TotalIssues)) + fmt.Fprintf(&body, " Errors: %d\n", data.ErrorCount) + fmt.Fprintf(&body, " Warnings: %d\n", data.WarningCount) + fmt.Fprintf(&body, " Total Issues: %d\n", data.TotalIssues) if data.LogFilePath != "" { - body.WriteString(fmt.Sprintf(" Log: %s\n", data.LogFilePath)) + fmt.Fprintf(&body, " Log: %s\n", data.LogFilePath) } body.WriteString("\n") if len(data.LogCategories) > 0 { body.WriteString("ISSUE DETAILS:\n") for _, cat := range data.LogCategories { - body.WriteString(fmt.Sprintf(" - [%s] %s (count: %d)\n", cat.Type, cat.Label, cat.Count)) + fmt.Fprintf(&body, " - [%s] %s (count: %d)\n", cat.Type, cat.Label, cat.Count) if cat.Example != "" { - body.WriteString(fmt.Sprintf(" Example: %s\n", cat.Example)) + fmt.Fprintf(&body, " Example: %s\n", cat.Example) } } body.WriteString("\n") } - body.WriteString(fmt.Sprintf("Exit Code: %d\n", data.ExitCode)) - body.WriteString(fmt.Sprintf("Script Version: %s\n", data.ScriptVersion)) + fmt.Fprintf(&body, "Exit Code: %d\n", data.ExitCode) + fmt.Fprintf(&body, "Script Version: %s\n", data.ScriptVersion) return body.String() } @@ -101,7 +101,7 @@ func BuildEmailHTML(data *NotificationData) string { html.WriteString("\n") html.WriteString("\n\n") html.WriteString(" \n") - html.WriteString(fmt.Sprintf(" %s Backup Report\n", proxmoxType)) + fmt.Fprintf(&html, " %s Backup Report\n", proxmoxType) html.WriteString(" \n") @@ -111,22 +111,22 @@ func BuildEmailHTML(data *NotificationData) string { html.WriteString("
\n") // Header - html.WriteString(fmt.Sprintf("
\n", statusColor)) - html.WriteString(fmt.Sprintf("

%s Backup Report - %s

\n", proxmoxType, statusText)) - html.WriteString(fmt.Sprintf("

%s - %s

\n", data.Hostname, data.BackupDate.Format("2006-01-02 15:04:05"))) + fmt.Fprintf(&html, "
\n", statusColor) + fmt.Fprintf(&html, "

%s Backup Report - %s

\n", proxmoxType, statusText) + fmt.Fprintf(&html, "

%s - %s

\n", data.Hostname, data.BackupDate.Format("2006-01-02 15:04:05")) html.WriteString("
\n") // Content html.WriteString("
\n") // Backup Status Section - html.WriteString(fmt.Sprintf("
\n", backupPathsColor)) + fmt.Fprintf(&html, "
\n", backupPathsColor) // Local Storage html.WriteString("
\n") html.WriteString("

Local Storage

\n") html.WriteString("
\n") - html.WriteString(fmt.Sprintf(" %s %s backups\n", GetStorageEmoji(data.LocalStatus), data.LocalStatusSummary)) + fmt.Fprintf(&html, " %s %s backups\n", GetStorageEmoji(data.LocalStatus), data.LocalStatusSummary) html.WriteString("
\n") if data.LocalFree != "" && data.LocalFree != "N/A" { barColor := "normal" @@ -136,11 +136,11 @@ func BuildEmailHTML(data *NotificationData) string { barColor = "warning" } html.WriteString("
\n") - html.WriteString(fmt.Sprintf(" %s\n", data.LocalUsed)) + fmt.Fprintf(&html, " %s\n", data.LocalUsed) html.WriteString("
\n") - html.WriteString(fmt.Sprintf("
\n", barColor, data.LocalUsagePercent)) + fmt.Fprintf(&html, "
\n", barColor, data.LocalUsagePercent) html.WriteString("
\n") - html.WriteString(fmt.Sprintf(" %s free (%s used)\n", data.LocalFree, data.LocalPercent)) + fmt.Fprintf(&html, " %s free (%s used)\n", data.LocalFree, data.LocalPercent) html.WriteString("
\n") } html.WriteString("
\n") @@ -150,7 +150,7 @@ func BuildEmailHTML(data *NotificationData) string { html.WriteString("
\n") html.WriteString("

Secondary Storage

\n") html.WriteString("
\n") - html.WriteString(fmt.Sprintf(" %s %s backups\n", GetStorageEmoji(data.SecondaryStatus), data.SecondaryStatusSummary)) + fmt.Fprintf(&html, " %s %s backups\n", GetStorageEmoji(data.SecondaryStatus), data.SecondaryStatusSummary) html.WriteString("
\n") if data.SecondaryEnabled && data.SecondaryFree != "" && data.SecondaryFree != "N/A" { barColor := "normal" @@ -160,11 +160,11 @@ func BuildEmailHTML(data *NotificationData) string { barColor = "warning" } html.WriteString("
\n") - html.WriteString(fmt.Sprintf(" %s\n", data.SecondaryUsed)) + fmt.Fprintf(&html, " %s\n", data.SecondaryUsed) html.WriteString("
\n") - html.WriteString(fmt.Sprintf("
\n", barColor, data.SecondaryUsagePercent)) + fmt.Fprintf(&html, "
\n", barColor, data.SecondaryUsagePercent) html.WriteString("
\n") - html.WriteString(fmt.Sprintf(" %s free (%s used)\n", data.SecondaryFree, data.SecondaryPercent)) + fmt.Fprintf(&html, " %s free (%s used)\n", data.SecondaryFree, data.SecondaryPercent) html.WriteString("
\n") } html.WriteString("
\n") @@ -174,7 +174,7 @@ func BuildEmailHTML(data *NotificationData) string { html.WriteString("
\n") html.WriteString("

Cloud Storage

\n") html.WriteString("
\n") - html.WriteString(fmt.Sprintf(" %s %s backups\n", GetStorageEmoji(data.CloudStatus), data.CloudStatusSummary)) + fmt.Fprintf(&html, " %s %s backups\n", GetStorageEmoji(data.CloudStatus), data.CloudStatusSummary) html.WriteString("
\n") html.WriteString("
\n") @@ -210,10 +210,10 @@ func BuildEmailHTML(data *NotificationData) string { html.WriteString(" \n") html.WriteString("
\n") html.WriteString("

Error and Warning Summary

\n") - html.WriteString(fmt.Sprintf("
\n", errorSummaryColor)) - html.WriteString(fmt.Sprintf("

Total Issues: %d

\n", data.TotalIssues)) - html.WriteString(fmt.Sprintf("

Errors: %d

\n", data.ErrorCount)) - html.WriteString(fmt.Sprintf("

Warnings: %d

\n", data.WarningCount)) + fmt.Fprintf(&html, "
\n", errorSummaryColor) + fmt.Fprintf(&html, "

Total Issues: %d

\n", data.TotalIssues) + fmt.Fprintf(&html, "

Errors: %d

\n", data.ErrorCount) + fmt.Fprintf(&html, "

Warnings: %d

\n", data.WarningCount) html.WriteString("
\n") if len(data.LogCategories) > 0 { @@ -225,9 +225,9 @@ func BuildEmailHTML(data *NotificationData) string { html.WriteString(" \n") for _, cat := range data.LogCategories { html.WriteString(" \n") - html.WriteString(fmt.Sprintf(" %s\n", escapeHTML(cat.Label))) - html.WriteString(fmt.Sprintf(" %s\n", escapeHTML(cat.Type))) - html.WriteString(fmt.Sprintf(" %d\n", cat.Count)) + fmt.Fprintf(&html, " %s\n", escapeHTML(cat.Label)) + fmt.Fprintf(&html, " %s\n", escapeHTML(cat.Type)) + fmt.Fprintf(&html, " %d\n", cat.Count) html.WriteString(" \n") } html.WriteString(" \n") @@ -235,7 +235,7 @@ func BuildEmailHTML(data *NotificationData) string { // Show log file path after the table if data.LogFilePath != "" { - html.WriteString(fmt.Sprintf("

Full log available at: %s

\n", escapeHTML(data.LogFilePath))) + fmt.Fprintf(&html, "

Full log available at: %s

\n", escapeHTML(data.LogFilePath)) } html.WriteString("
\n") @@ -246,10 +246,10 @@ func BuildEmailHTML(data *NotificationData) string { html.WriteString("

System Recommendations

\n") html.WriteString("
\n") if data.LocalUsagePercent > 85 { - html.WriteString(fmt.Sprintf("

⚠️ Local storage is %.1f%% full. Consider cleaning old backups or expanding storage capacity.

\n", data.LocalUsagePercent)) + fmt.Fprintf(&html, "

⚠️ Local storage is %.1f%% full. Consider cleaning old backups or expanding storage capacity.

\n", data.LocalUsagePercent) } if data.SecondaryEnabled && data.SecondaryUsagePercent > 85 { - html.WriteString(fmt.Sprintf("

⚠️ Secondary storage is %.1f%% full. Consider cleaning old backups or expanding storage capacity.

\n", data.SecondaryUsagePercent)) + fmt.Fprintf(&html, "

⚠️ Secondary storage is %.1f%% full. Consider cleaning old backups or expanding storage capacity.

\n", data.SecondaryUsagePercent) } html.WriteString("
\n") html.WriteString("
\n") @@ -259,7 +259,7 @@ func BuildEmailHTML(data *NotificationData) string { html.WriteString("
\n") html.WriteString("
\n") html.WriteString("

This is an automated message from the Proxmox Backup Script.

\n") - html.WriteString(fmt.Sprintf("

Generated on %s by backup script v%s

\n", data.BackupDate.Format("2006-01-02 15:04:05"), data.ScriptVersion)) + fmt.Fprintf(&html, "

Generated on %s by backup script v%s

\n", data.BackupDate.Format("2006-01-02 15:04:05"), data.ScriptVersion) html.WriteString("
\n") html.WriteString("
\n") diff --git a/internal/notify/webhook.go b/internal/notify/webhook.go index 9a020bc3..2470e3e8 100644 --- a/internal/notify/webhook.go +++ b/internal/notify/webhook.go @@ -387,7 +387,7 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We // Read response body w.logger.Debug("Reading response body...") body, err := io.ReadAll(resp.Body) - resp.Body.Close() + closeErr := resp.Body.Close() if err != nil { if ctxErr := ctx.Err(); ctxErr != nil { @@ -397,6 +397,11 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We w.logger.Warning("Failed to read response body: %v", err) continue } + if closeErr != nil { + lastErr = fmt.Errorf("failed to close response body: %w", closeErr) + w.logger.Warning("Failed to close response body: %v", closeErr) + continue + } w.logger.Debug("Received HTTP %d in %dms", resp.StatusCode, requestDuration.Milliseconds()) if len(body) > 0 { diff --git a/internal/notify/webhook_test.go b/internal/notify/webhook_test.go index da918043..bfa26079 100644 --- a/internal/notify/webhook_test.go +++ b/internal/notify/webhook_test.go @@ -236,7 +236,7 @@ func TestWebhookNotifier_Send_Success(t *testing.T) { // Respond with success w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"status":"ok"}`)) + _, _ = w.Write([]byte(`{"status":"ok"}`)) })) defer server.Close() @@ -336,11 +336,11 @@ func TestWebhookNotifier_Send_Retry(t *testing.T) { attempts++ if attempts < 3 { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(`{"error":"temporary failure"}`)) + _, _ = w.Write([]byte(`{"error":"temporary failure"}`)) return } w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"status":"ok"}`)) + _, _ = w.Write([]byte(`{"status":"ok"}`)) })) defer server.Close() diff --git a/internal/orchestrator/additional_helpers_test.go b/internal/orchestrator/additional_helpers_test.go index 39286f81..905ce07a 100644 --- a/internal/orchestrator/additional_helpers_test.go +++ b/internal/orchestrator/additional_helpers_test.go @@ -1492,8 +1492,11 @@ func TestRunGoBackupConfigValidationError(t *testing.T) { orch := New(logger, false) tempDir := t.TempDir() orch.SetBackupConfig(tempDir, tempDir, types.CompressionType("invalid"), 1, 0, "standard", nil) + setSmallBackupTestConfig(t, orch, tempDir) - stats, err := orch.RunGoBackup(context.Background(), &environment.EnvironmentInfo{Type: types.ProxmoxUnknown, Version: "unknown"}, "host-invalid") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + stats, err := orch.RunGoBackup(ctx, &environment.EnvironmentInfo{Type: types.ProxmoxUnknown, Version: "unknown"}, "host-invalid") if err == nil { t.Fatalf("expected error for invalid compression type") } diff --git a/internal/orchestrator/backup_run_phases.go b/internal/orchestrator/backup_run_phases.go index 0e681583..e21e88af 100644 --- a/internal/orchestrator/backup_run_phases.go +++ b/internal/orchestrator/backup_run_phases.go @@ -126,7 +126,7 @@ func (o *Orchestrator) prepareBackupWorkspace(run *backupRunContext, workspace * o.logger.Debug("Creating temporary directory for collection output") workspace.tempRoot = filepath.Join("/tmp", "proxsave") if err := workspace.fs.MkdirAll(workspace.tempRoot, 0o755); err != nil { - return fmt.Errorf("Temp directory creation failed - path: %s: %w", workspace.tempRoot, err) + return fmt.Errorf("temp directory creation failed - path: %s: %w", workspace.tempRoot, err) } tempDir, err := workspace.fs.MkdirTemp(workspace.tempRoot, fmt.Sprintf("proxsave-%s-%s-", run.hostname, run.timestamp)) diff --git a/internal/orchestrator/backup_safety.go b/internal/orchestrator/backup_safety.go index 3455dc47..992b6996 100644 --- a/internal/orchestrator/backup_safety.go +++ b/internal/orchestrator/backup_safety.go @@ -60,13 +60,13 @@ func createSafetyBackup(logger *logging.Logger, selectedCategories []Category, d if err != nil { return nil, fmt.Errorf("create backup archive: %w", err) } - defer file.Close() + defer closeIntoErr(&err, file, "close backup archive") gzWriter := gzip.NewWriter(file) - defer gzWriter.Close() + defer closeIntoErr(&err, gzWriter, "close gzip writer") tarWriter := tar.NewWriter(gzWriter) - defer tarWriter.Close() + defer closeIntoErr(&err, tarWriter, "close tar writer") result = &SafetyBackupResult{ BackupPath: backupArchive, @@ -225,12 +225,12 @@ func CreatePVEAccessControlRollbackBackup(logger *logging.Logger, selectedCatego } // backupFile adds a single file to the tar archive -func backupFile(tw *tar.Writer, sourcePath, archivePath string, result *SafetyBackupResult, logger *logging.Logger) error { +func backupFile(tw *tar.Writer, sourcePath, archivePath string, result *SafetyBackupResult, logger *logging.Logger) (err error) { file, err := safetyFS.Open(sourcePath) if err != nil { return err } - defer file.Close() + defer closeIntoErr(&err, file, "close source file") info, err := file.Stat() if err != nil { @@ -333,13 +333,13 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str if err != nil { return fmt.Errorf("open backup: %w", err) } - defer file.Close() + defer closeIntoErr(&err, file, "close backup archive") gzReader, err := gzip.NewReader(file) if err != nil { return fmt.Errorf("create gzip reader: %w", err) } - defer gzReader.Close() + defer closeIntoErr(&err, gzReader, "close gzip reader") tarReader := tar.NewReader(gzReader) filesRestored := 0 @@ -400,7 +400,10 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str } // Remove existing file/symlink before creating new one - safetyFS.Remove(target) + if err := safetyFS.Remove(target); err != nil && !os.IsNotExist(err) { + logger.Warning("Cannot remove existing path %s before symlink restore: %v", target, err) + continue + } // Create the symlink if err := safetyFS.Symlink(linkTarget, target); err != nil { @@ -412,12 +415,16 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str actualTarget, err := safetyFS.Readlink(target) if err != nil { logger.Warning("Cannot read created symlink %s: %v", target, err) - safetyFS.Remove(target) // Clean up the symlink + if removeErr := safetyFS.Remove(target); removeErr != nil && !os.IsNotExist(removeErr) { + logger.Warning("Cannot remove unreadable symlink %s: %v", target, removeErr) + } continue } if _, err := resolvePathRelativeToBaseWithinRootFS(safetyFS, absDestRoot, filepath.Dir(target), actualTarget); err != nil { - safetyFS.Remove(target) + if removeErr := safetyFS.Remove(target); removeErr != nil && !os.IsNotExist(removeErr) { + logger.Warning("Cannot remove unsafe symlink %s: %v", target, removeErr) + } if isPathSecurityError(err) { logger.Warning("Removing symlink %s -> %s: target escapes root after creation: %v", target, actualTarget, err) @@ -438,11 +445,16 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str } if _, err := io.Copy(outFile, tarReader); err != nil { - outFile.Close() + if closeErr := outFile.Close(); closeErr != nil { + logger.Warning("Cannot close partially restored file %s: %v", target, closeErr) + } logger.Warning("Cannot write file %s: %v", target, err) continue } - outFile.Close() + if err := outFile.Close(); err != nil { + logger.Warning("Cannot close restored file %s: %v", target, err) + continue + } filesRestored++ logger.Debug("Restored: %s", header.Name) diff --git a/internal/orchestrator/backup_safety_glob_test.go b/internal/orchestrator/backup_safety_glob_test.go index dfdcc485..44480365 100644 --- a/internal/orchestrator/backup_safety_glob_test.go +++ b/internal/orchestrator/backup_safety_glob_test.go @@ -45,13 +45,13 @@ func TestCreateSafetyBackup_ExpandsGlobPaths(t *testing.T) { if err != nil { t.Fatalf("open backup: %v", err) } - defer f.Close() + defer func() { _ = f.Close() }() gzReader, err := gzip.NewReader(f) if err != nil { t.Fatalf("gzip reader: %v", err) } - defer gzReader.Close() + defer func() { _ = gzReader.Close() }() tr := tar.NewReader(gzReader) seen := map[string]bool{} diff --git a/internal/orchestrator/backup_safety_test.go b/internal/orchestrator/backup_safety_test.go index 83044440..256e6bcf 100644 --- a/internal/orchestrator/backup_safety_test.go +++ b/internal/orchestrator/backup_safety_test.go @@ -225,7 +225,7 @@ func TestBackupFileAndDirectory(t *testing.T) { if err != nil { t.Fatalf("gzip reader error: %v", err) } - defer reader.Close() + defer func() { _ = reader.Close() }() tr := tar.NewReader(reader) var files []string @@ -663,13 +663,13 @@ func TestCreateSafetyBackupArchivesSelectedPaths(t *testing.T) { if err != nil { t.Fatalf("open archive: %v", err) } - defer archiveFile.Close() + defer func() { _ = archiveFile.Close() }() gzr, err := gzip.NewReader(archiveFile) if err != nil { t.Fatalf("gzip reader: %v", err) } - defer gzr.Close() + defer func() { _ = gzr.Close() }() tr := tar.NewReader(gzr) var entries []string @@ -1416,7 +1416,7 @@ func TestRestoreSafetyBackup_FileCreationError(t *testing.T) { if err := os.Chmod(subDir, 0o444); err != nil { t.Fatalf("chmod: %v", err) } - t.Cleanup(func() { os.Chmod(subDir, 0o755) }) + t.Cleanup(func() { _ = os.Chmod(subDir, 0o755) }) err := RestoreSafetyBackup(logger, backupPath, restoreDir) // Should not fail, just log warning @@ -1823,8 +1823,12 @@ func TestBackupDirectory_WalkError(t *testing.T) { t.Fatal("expected error for non-existent directory") } - tw.Close() - gzw.Close() + if err := tw.Close(); err != nil { + t.Fatalf("tar writer close failed: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip writer close failed: %v", err) + } } // ===================================== @@ -1846,8 +1850,12 @@ func TestBackupFile_OpenError(t *testing.T) { t.Fatal("expected error for non-existent file") } - tw.Close() - gzw.Close() + if err := tw.Close(); err != nil { + t.Fatalf("tar writer close failed: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip writer close failed: %v", err) + } } func TestBackupFile_LargeFile(t *testing.T) { diff --git a/internal/orchestrator/backup_sources.go b/internal/orchestrator/backup_sources.go index 05225a8b..2e6fb0e6 100644 --- a/internal/orchestrator/backup_sources.go +++ b/internal/orchestrator/backup_sources.go @@ -525,18 +525,19 @@ func readBoundedChecksumLine(reader io.Reader) ([]byte, bool, error) { return nil, false, err } -func parseLocalChecksumFile(checksumPath string) (string, error) { +func parseLocalChecksumFile(checksumPath string) (checksum string, err error) { file, err := restoreFS.Open(checksumPath) if err != nil { return "", fmt.Errorf("read checksum file %s: %w", checksumPath, err) } - defer file.Close() + defer closeIntoErr(&err, file, "close checksum file") - data, _, err := readBoundedChecksumLine(file) + var data []byte + data, _, err = readBoundedChecksumLine(file) if err != nil { return "", fmt.Errorf("read checksum file %s: %w", checksumPath, err) } - checksum, err := backup.ParseChecksumData(data) + checksum, err = backup.ParseChecksumData(data) if err != nil { return "", fmt.Errorf("parse checksum file %s: %w", checksumPath, err) } diff --git a/internal/orchestrator/bundle_test.go b/internal/orchestrator/bundle_test.go index be0ba164..e8c1da25 100644 --- a/internal/orchestrator/bundle_test.go +++ b/internal/orchestrator/bundle_test.go @@ -154,7 +154,7 @@ func TestCreateBundle_CreatesValidTarArchive(t *testing.T) { if err != nil { t.Fatalf("open bundle: %v", err) } - defer bundleFile.Close() + defer func() { _ = bundleFile.Close() }() tr := tar.NewReader(bundleFile) foundFiles := make(map[string]bool) diff --git a/internal/orchestrator/close_error.go b/internal/orchestrator/close_error.go new file mode 100644 index 00000000..8167e7dc --- /dev/null +++ b/internal/orchestrator/close_error.go @@ -0,0 +1,5 @@ +package orchestrator + +import "github.com/tis24dev/proxsave/internal/closeerr" + +var closeIntoErr = closeerr.CloseIntoErr diff --git a/internal/orchestrator/decompress_reader_test.go b/internal/orchestrator/decompress_reader_test.go index ddf86fef..814f7df6 100644 --- a/internal/orchestrator/decompress_reader_test.go +++ b/internal/orchestrator/decompress_reader_test.go @@ -16,8 +16,8 @@ func TestCreateDecompressionReaderUnsupported(t *testing.T) { if err != nil { t.Fatalf("CreateTemp: %v", err) } - defer os.Remove(f.Name()) - defer f.Close() + defer func() { _ = os.Remove(f.Name()) }() + defer func() { _ = f.Close() }() if _, err := createDecompressionReader(context.Background(), f, f.Name()); err == nil { t.Fatalf("expected error for unsupported extension") @@ -29,8 +29,8 @@ func TestCreateDecompressionReaderTar(t *testing.T) { if err != nil { t.Fatalf("CreateTemp: %v", err) } - defer os.Remove(f.Name()) - defer f.Close() + defer func() { _ = os.Remove(f.Name()) }() + defer func() { _ = f.Close() }() reader, err := createDecompressionReader(context.Background(), f, f.Name()) if err != nil { @@ -117,14 +117,14 @@ func TestCreateDecompressionReaderUsesStreamingRunnerForCompressedFormats(t *tes if err != nil { t.Fatalf("CreateTemp: %v", err) } - defer os.Remove(f.Name()) - defer f.Close() + defer func() { _ = os.Remove(f.Name()) }() + defer func() { _ = f.Close() }() reader, err := createDecompressionReader(context.Background(), f, f.Name()) if err != nil { t.Fatalf("createDecompressionReader(%s) error: %v", tt.ext, err) } - defer reader.Close() + defer func() { _ = reader.Close() }() out, err := io.ReadAll(reader) if err != nil { diff --git a/internal/orchestrator/decrypt.go b/internal/orchestrator/decrypt.go index f55cde81..3028f941 100644 --- a/internal/orchestrator/decrypt.go +++ b/internal/orchestrator/decrypt.go @@ -136,12 +136,12 @@ func promptPathSelection(ctx context.Context, reader *bufio.Reader, options []de } } -func inspectBundleManifest(bundlePath string) (*backup.Manifest, error) { +func inspectBundleManifest(bundlePath string) (manifest *backup.Manifest, err error) { file, err := restoreFS.Open(bundlePath) if err != nil { return nil, fmt.Errorf("open bundle: %w", err) } - defer file.Close() + defer closeIntoErr(&err, file, "close bundle") tr := tar.NewReader(file) for { @@ -194,7 +194,7 @@ func inspectRcloneBundleManifest(ctx context.Context, remotePath string, logger if err != nil { return nil, fmt.Errorf("open rclone stream: %w", err) } - defer stdout.Close() + defer func() { _ = stdout.Close() }() var stderr bytes.Buffer cmd.Stderr = &stderr @@ -413,11 +413,15 @@ func downloadRcloneBackup(ctx context.Context, remotePath string, logger *loggin return "", nil, fmt.Errorf("failed to create temp file: %w", err) } tmpPath = tmpFile.Name() - tmpFile.Close() - cleanup = func() { logger.Debug("Removing temporary rclone download: %s", tmpPath) - os.Remove(tmpPath) + if err := os.Remove(tmpPath); err != nil && !os.IsNotExist(err) { + logger.Debug("Failed to remove temporary rclone download %s: %v", tmpPath, err) + } + } + if err := tmpFile.Close(); err != nil { + cleanup() + return "", nil, fmt.Errorf("close temp file: %w", err) } logger.Info("Downloading backup from cloud storage: %s", remotePath) @@ -496,7 +500,7 @@ func extractBundleToWorkdirWithLogger(bundlePath, workDir string, logger *loggin if err != nil { return stagedFiles{}, fmt.Errorf("open bundle: %w", err) } - defer file.Close() + defer closeIntoErr(&err, file, "close bundle") tr := tar.NewReader(file) extracted := 0 @@ -531,10 +535,14 @@ func extractBundleToWorkdirWithLogger(bundlePath, workDir string, logger *loggin return stagedFiles{}, fmt.Errorf("extract %s: %w", hdr.Name, err) } if _, err := io.Copy(out, tr); err != nil { - out.Close() + if closeErr := out.Close(); closeErr != nil { + return stagedFiles{}, fmt.Errorf("write %s: %w (close: %v)", hdr.Name, err, closeErr) + } return stagedFiles{}, fmt.Errorf("write %s: %w", hdr.Name, err) } - out.Close() + if err := out.Close(); err != nil { + return stagedFiles{}, fmt.Errorf("close extracted %s: %w", hdr.Name, err) + } extracted++ switch { @@ -592,8 +600,6 @@ func rcloneCopyTo(ctx context.Context, remotePath, localPath string, showProgres } func copyRawArtifactsToWorkdirWithLogger(ctx context.Context, cand *backupCandidate, workDir string, logger *logging.Logger) (staged stagedFiles, err error) { - done := logging.DebugStart(logger, "stage raw artifacts", "archive=%s workdir=%s rclone=%v", cand.RawArchivePath, workDir, cand.IsRclone) - defer func() { done(err) }() if ctx == nil { ctx = context.Background() } @@ -601,6 +607,9 @@ func copyRawArtifactsToWorkdirWithLogger(ctx context.Context, cand *backupCandid return stagedFiles{}, fmt.Errorf("candidate is nil") } + done := logging.DebugStart(logger, "stage raw artifacts", "archive=%s workdir=%s rclone=%v", cand.RawArchivePath, workDir, cand.IsRclone) + defer func() { done(err) }() + archiveBase := filepath.Base(cand.RawArchivePath) metaBase := filepath.Base(cand.RawMetadataPath) sumBase := "" @@ -682,18 +691,18 @@ func parseIdentityInput(input string) ([]age.Identity, error) { return deriveDeterministicIdentitiesFromPassphrase(input) } -func decryptWithIdentity(src, dst string, identities ...age.Identity) error { +func decryptWithIdentity(src, dst string, identities ...age.Identity) (err error) { in, err := restoreFS.Open(src) if err != nil { return fmt.Errorf("open encrypted archive: %w", err) } - defer in.Close() + defer closeIntoErr(&err, in, "close encrypted archive") out, err := restoreFS.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o640) if err != nil { return fmt.Errorf("create decrypted archive: %w", err) } - defer out.Close() + defer closeIntoErr(&err, out, "close decrypted archive") reader, err := age.Decrypt(in, identities...) if err != nil { diff --git a/internal/orchestrator/decrypt_test.go b/internal/orchestrator/decrypt_test.go index a952e3b0..1cb08bc9 100644 --- a/internal/orchestrator/decrypt_test.go +++ b/internal/orchestrator/decrypt_test.go @@ -574,7 +574,7 @@ func createTestBundleAt(t *testing.T, bundlePath string, entries []bundleEntry) if err != nil { t.Fatalf("create bundle: %v", err) } - defer f.Close() + defer func() { _ = f.Close() }() tw := tar.NewWriter(f) for _, entry := range entries { @@ -1496,14 +1496,26 @@ func TestDecryptWithIdentity_CreateOutputError(t *testing.T) { // Create encrypted file encPath := filepath.Join(dir, "file.age") - f, _ := os.Create(encPath) - w, _ := age.Encrypt(f, id.Recipient()) - w.Write([]byte("data")) - w.Close() - f.Close() + f, err := os.Create(encPath) + if err != nil { + t.Fatalf("create encrypted file: %v", err) + } + w, err := age.Encrypt(f, id.Recipient()) + if err != nil { + t.Fatalf("age encrypt: %v", err) + } + if _, err := w.Write([]byte("data")); err != nil { + t.Fatalf("write encrypted file: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("close age writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close encrypted file: %v", err) + } // Try to write to nonexistent directory - err := decryptWithIdentity(encPath, "/nonexistent/dir/out", id) + err = decryptWithIdentity(encPath, "/nonexistent/dir/out", id) if err == nil { t.Fatal("expected error for nonexistent output directory") } @@ -1524,14 +1536,26 @@ func TestDecryptWithIdentity_WrongIdentity(t *testing.T) { // Create encrypted file with correct identity encPath := filepath.Join(dir, "file.age") outPath := filepath.Join(dir, "file.out") - f, _ := os.Create(encPath) - w, _ := age.Encrypt(f, correctID.Recipient()) - w.Write([]byte("data")) - w.Close() - f.Close() + f, err := os.Create(encPath) + if err != nil { + t.Fatalf("create encrypted file: %v", err) + } + w, err := age.Encrypt(f, correctID.Recipient()) + if err != nil { + t.Fatalf("age encrypt: %v", err) + } + if _, err := w.Write([]byte("data")); err != nil { + t.Fatalf("write encrypted file: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("close age writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close encrypted file: %v", err) + } // Try to decrypt with wrong identity - err := decryptWithIdentity(encPath, outPath, wrongID) + err = decryptWithIdentity(encPath, outPath, wrongID) if err == nil { t.Fatal("expected error for wrong identity") } @@ -1577,11 +1601,23 @@ func TestDecryptArchiveWithPrompts_EmptyInputRetries(t *testing.T) { // Create encrypted file encPath := filepath.Join(dir, "file.age") outPath := filepath.Join(dir, "file.out") - f, _ := os.Create(encPath) - w, _ := age.Encrypt(f, id.Recipient()) - w.Write([]byte("data")) - w.Close() - f.Close() + f, err := os.Create(encPath) + if err != nil { + t.Fatalf("create encrypted file: %v", err) + } + w, err := age.Encrypt(f, id.Recipient()) + if err != nil { + t.Fatalf("age encrypt: %v", err) + } + if _, err := w.Write([]byte("data")); err != nil { + t.Fatalf("write encrypted file: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("close age writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close encrypted file: %v", err) + } // First return empty, then correct key inputs := [][]byte{ @@ -1602,7 +1638,7 @@ func TestDecryptArchiveWithPrompts_EmptyInputRetries(t *testing.T) { logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) - err := decryptArchiveWithPrompts(context.Background(), nil, encPath, outPath, logger) + err = decryptArchiveWithPrompts(context.Background(), nil, encPath, outPath, logger) if err != nil { t.Fatalf("decryptArchiveWithPrompts error: %v", err) } @@ -1799,16 +1835,27 @@ func TestSelectDecryptCandidate_RequireEncryptedFiltersPlain(t *testing.T) { // Create dir with encrypted backup encDir := t.TempDir() archive := filepath.Join(encDir, "enc.tar.xz.age.bundle.tar") - f, _ := os.Create(archive) + f, err := os.Create(archive) + if err != nil { + t.Fatalf("create archive: %v", err) + } tw := tar.NewWriter(f) manifestData, _ := json.Marshal(&backup.Manifest{ ArchivePath: filepath.Join(encDir, "enc.tar.xz.age"), EncryptionMode: "age", }) - tw.WriteHeader(&tar.Header{Name: "enc.metadata", Size: int64(len(manifestData)), Mode: 0o600}) - tw.Write(manifestData) - tw.Close() - f.Close() + if err := tw.WriteHeader(&tar.Header{Name: "enc.metadata", Size: int64(len(manifestData)), Mode: 0o600}); err != nil { + t.Fatalf("write manifest header: %v", err) + } + if _, err := tw.Write(manifestData); err != nil { + t.Fatalf("write manifest: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close archive: %v", err) + } cfg := &config.Config{ BackupPath: plainDir, @@ -2173,16 +2220,32 @@ func TestExtractBundleToWorkdir_WithFakeFS(t *testing.T) { } tw := tar.NewWriter(f) content := []byte("archive content") - tw.WriteHeader(&tar.Header{Name: "archive.tar.xz", Size: int64(len(content)), Mode: 0o600}) - tw.Write(content) + if err := tw.WriteHeader(&tar.Header{Name: "archive.tar.xz", Size: int64(len(content)), Mode: 0o600}); err != nil { + t.Fatalf("write archive header: %v", err) + } + if _, err := tw.Write(content); err != nil { + t.Fatalf("write archive content: %v", err) + } meta := []byte("{}") - tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(meta)), Mode: 0o600}) - tw.Write(meta) + if err := tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(meta)), Mode: 0o600}); err != nil { + t.Fatalf("write metadata header: %v", err) + } + if _, err := tw.Write(meta); err != nil { + t.Fatalf("write metadata: %v", err) + } checksum := []byte("abcd1234") - tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o600}) - tw.Write(checksum) - tw.Close() - f.Close() + if err := tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o600}); err != nil { + t.Fatalf("write checksum header: %v", err) + } + if _, err := tw.Write(checksum); err != nil { + t.Fatalf("write checksum: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close bundle file: %v", err) + } // We need to add the bundle to FakeFS - extractBundleToWorkdir uses restoreFS.Open // which translates the path, but the file exists in the real FS, not the fake one. @@ -2285,7 +2348,9 @@ func TestInspectRcloneBundleManifest_TarReadErrorInLoop(t *testing.T) { t.Fatalf("write data: %v", err) } // Don't close properly to leave truncated tar - f.Close() + if err := f.Close(); err != nil { + t.Fatalf("close truncated bundle: %v", err) + } // Create fake rclone that cats the truncated bundle scriptPath := filepath.Join(tmpDir, "rclone") @@ -2327,8 +2392,12 @@ func TestInspectRcloneBundleManifest_UnmarshalError(t *testing.T) { if _, err := tw.Write(invalidJSON); err != nil { t.Fatalf("write data: %v", err) } - tw.Close() - f.Close() + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close bundle file: %v", err) + } // Create fake rclone that cats the bundle scriptPath := filepath.Join(tmpDir, "rclone") @@ -2378,8 +2447,12 @@ func TestInspectRcloneBundleManifest_ValidManifest(t *testing.T) { if _, err := tw.Write(manifestData); err != nil { t.Fatalf("write data: %v", err) } - tw.Close() - f.Close() + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close bundle file: %v", err) + } // Create fake rclone that cats the bundle scriptPath := filepath.Join(tmpDir, "rclone") @@ -2570,6 +2643,16 @@ func TestCopyRawArtifactsToWorkdir_ContextWorks(t *testing.T) { } } +func TestCopyRawArtifactsToWorkdirWithLogger_NilCandidate(t *testing.T) { + _, err := copyRawArtifactsToWorkdirWithLogger(context.Background(), nil, t.TempDir(), nil) + if err == nil { + t.Fatal("expected error for nil candidate") + } + if !strings.Contains(err.Error(), "candidate is nil") { + t.Fatalf("expected 'candidate is nil' error, got: %v", err) + } +} + func TestCopyRawArtifactsToWorkdir_InvalidRclonePaths(t *testing.T) { origFS := restoreFS restoreFS = osFS{} @@ -2637,11 +2720,23 @@ func TestDecryptArchiveWithPrompts_InvalidIdentityThenValid(t *testing.T) { // Create encrypted file encPath := filepath.Join(dir, "file.age") outPath := filepath.Join(dir, "file.out") - f, _ := os.Create(encPath) - w, _ := age.Encrypt(f, id.Recipient()) - w.Write([]byte("secret data")) - w.Close() - f.Close() + f, err := os.Create(encPath) + if err != nil { + t.Fatalf("create encrypted file: %v", err) + } + w, err := age.Encrypt(f, id.Recipient()) + if err != nil { + t.Fatalf("age encrypt: %v", err) + } + if _, err := w.Write([]byte("secret data")); err != nil { + t.Fatalf("write encrypted file: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("close age writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close encrypted file: %v", err) + } // First return invalid key format, then correct key inputs := [][]byte{ @@ -2661,7 +2756,7 @@ func TestDecryptArchiveWithPrompts_InvalidIdentityThenValid(t *testing.T) { logger := logging.New(types.LogLevelError, false) logger.SetOutput(io.Discard) - err := decryptArchiveWithPrompts(context.Background(), nil, encPath, outPath, logger) + err = decryptArchiveWithPrompts(context.Background(), nil, encPath, outPath, logger) if err != nil { t.Fatalf("decryptArchiveWithPrompts error: %v", err) } @@ -2809,11 +2904,23 @@ func TestPreparePlainBundle_AgeDecryptionWithRclone(t *testing.T) { // Create an encrypted archive id, _ := age.GenerateX25519Identity() archivePath := filepath.Join(tmpDir, "backup.tar.xz.age") - f, _ := os.Create(archivePath) - w, _ := age.Encrypt(f, id.Recipient()) - w.Write([]byte("encrypted content")) - w.Close() - f.Close() + f, err := os.Create(archivePath) + if err != nil { + t.Fatalf("create encrypted archive: %v", err) + } + w, err := age.Encrypt(f, id.Recipient()) + if err != nil { + t.Fatalf("age encrypt: %v", err) + } + if _, err := w.Write([]byte("encrypted content")); err != nil { + t.Fatalf("write encrypted archive: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("close age writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close encrypted archive: %v", err) + } // Create bundle tar containing the encrypted archive bundlePath := filepath.Join(tmpDir, "backup.bundle.tar") @@ -2822,8 +2929,12 @@ func TestPreparePlainBundle_AgeDecryptionWithRclone(t *testing.T) { // Add archive archiveContent, _ := os.ReadFile(archivePath) - tw.WriteHeader(&tar.Header{Name: "backup.tar.xz.age", Size: int64(len(archiveContent)), Mode: 0o600}) - tw.Write(archiveContent) + if err := tw.WriteHeader(&tar.Header{Name: "backup.tar.xz.age", Size: int64(len(archiveContent)), Mode: 0o600}); err != nil { + t.Fatalf("write archive header: %v", err) + } + if _, err := tw.Write(archiveContent); err != nil { + t.Fatalf("write archive content: %v", err) + } // Add metadata manifest := &backup.Manifest{ @@ -2831,16 +2942,28 @@ func TestPreparePlainBundle_AgeDecryptionWithRclone(t *testing.T) { EncryptionMode: "age", } manifestData, _ := json.Marshal(manifest) - tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(manifestData)), Mode: 0o600}) - tw.Write(manifestData) + if err := tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(manifestData)), Mode: 0o600}); err != nil { + t.Fatalf("write metadata header: %v", err) + } + if _, err := tw.Write(manifestData); err != nil { + t.Fatalf("write metadata: %v", err) + } // Add checksum checksumData := checksumLineForBytes("backup.tar.xz.age", archiveContent) - tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksumData)), Mode: 0o600}) - tw.Write(checksumData) + if err := tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksumData)), Mode: 0o600}); err != nil { + t.Fatalf("write checksum header: %v", err) + } + if _, err := tw.Write(checksumData); err != nil { + t.Fatalf("write checksum: %v", err) + } - tw.Close() - bf.Close() + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := bf.Close(); err != nil { + t.Fatalf("close bundle file: %v", err) + } // Create fake rclone scriptPath := filepath.Join(binDir, "rclone") @@ -3086,31 +3209,52 @@ func TestExtractBundleToWorkdir_SkipsDirectories(t *testing.T) { // Create bundle with directory entries dir := t.TempDir() bundlePath := filepath.Join(dir, "bundle.tar") - f, _ := os.Create(bundlePath) + f, err := os.Create(bundlePath) + if err != nil { + t.Fatalf("create bundle: %v", err) + } tw := tar.NewWriter(f) // Add directory entry (should be skipped) - tw.WriteHeader(&tar.Header{ + if err := tw.WriteHeader(&tar.Header{ Name: "subdir/", Mode: 0o755, Typeflag: tar.TypeDir, - }) + }); err != nil { + t.Fatalf("write directory header: %v", err) + } // Add files archiveData := []byte("archive content") - tw.WriteHeader(&tar.Header{Name: "subdir/archive.tar.xz", Size: int64(len(archiveData)), Mode: 0o600}) - tw.Write(archiveData) + if err := tw.WriteHeader(&tar.Header{Name: "subdir/archive.tar.xz", Size: int64(len(archiveData)), Mode: 0o600}); err != nil { + t.Fatalf("write archive header: %v", err) + } + if _, err := tw.Write(archiveData); err != nil { + t.Fatalf("write archive content: %v", err) + } metaData := []byte("{}") - tw.WriteHeader(&tar.Header{Name: "subdir/backup.metadata", Size: int64(len(metaData)), Mode: 0o600}) - tw.Write(metaData) + if err := tw.WriteHeader(&tar.Header{Name: "subdir/backup.metadata", Size: int64(len(metaData)), Mode: 0o600}); err != nil { + t.Fatalf("write metadata header: %v", err) + } + if _, err := tw.Write(metaData); err != nil { + t.Fatalf("write metadata: %v", err) + } sumData := []byte("checksum") - tw.WriteHeader(&tar.Header{Name: "subdir/backup.sha256", Size: int64(len(sumData)), Mode: 0o600}) - tw.Write(sumData) + if err := tw.WriteHeader(&tar.Header{Name: "subdir/backup.sha256", Size: int64(len(sumData)), Mode: 0o600}); err != nil { + t.Fatalf("write checksum header: %v", err) + } + if _, err := tw.Write(sumData); err != nil { + t.Fatalf("write checksum: %v", err) + } - tw.Close() - f.Close() + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close bundle file: %v", err) + } staged, err := extractBundleToWorkdirWithLogger(bundlePath, workDir, nil) if err != nil { @@ -3135,27 +3279,46 @@ func TestPreparePlainBundle_SourceBundleAdditional(t *testing.T) { // Create a valid bundle tar with plain archive bundlePath := filepath.Join(dir, "backup.bundle.tar") - f, _ := os.Create(bundlePath) + f, err := os.Create(bundlePath) + if err != nil { + t.Fatalf("create bundle: %v", err) + } tw := tar.NewWriter(f) archiveData := []byte("archive content") - tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o600}) - tw.Write(archiveData) + if err := tw.WriteHeader(&tar.Header{Name: "backup.tar.xz", Size: int64(len(archiveData)), Mode: 0o600}); err != nil { + t.Fatalf("write archive header: %v", err) + } + if _, err := tw.Write(archiveData); err != nil { + t.Fatalf("write archive content: %v", err) + } manifest := &backup.Manifest{ ArchivePath: "/backup.tar.xz", EncryptionMode: "none", } manifestData, _ := json.Marshal(manifest) - tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(manifestData)), Mode: 0o600}) - tw.Write(manifestData) + if err := tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(manifestData)), Mode: 0o600}); err != nil { + t.Fatalf("write metadata header: %v", err) + } + if _, err := tw.Write(manifestData); err != nil { + t.Fatalf("write metadata: %v", err) + } checksumData := checksumLineForBytes("backup.tar.xz", archiveData) - tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksumData)), Mode: 0o600}) - tw.Write(checksumData) + if err := tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksumData)), Mode: 0o600}); err != nil { + t.Fatalf("write checksum header: %v", err) + } + if _, err := tw.Write(checksumData); err != nil { + t.Fatalf("write checksum: %v", err) + } - tw.Close() - f.Close() + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close bundle file: %v", err) + } cand := &backupCandidate{ Manifest: manifest, @@ -3249,14 +3412,26 @@ func TestDecryptWithIdentity_WrongKey(t *testing.T) { encPath := filepath.Join(dir, "file.age") outPath := filepath.Join(dir, "file.out") - f, _ := os.Create(encPath) - w, _ := age.Encrypt(f, correctID.Recipient()) - w.Write([]byte("secret data")) - w.Close() - f.Close() + f, err := os.Create(encPath) + if err != nil { + t.Fatalf("create encrypted file: %v", err) + } + w, err := age.Encrypt(f, correctID.Recipient()) + if err != nil { + t.Fatalf("age encrypt: %v", err) + } + if _, err := w.Write([]byte("secret data")); err != nil { + t.Fatalf("write encrypted file: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("close age writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close encrypted file: %v", err) + } // Try to decrypt with wrong key - err := decryptWithIdentity(encPath, outPath, wrongID) + err = decryptWithIdentity(encPath, outPath, wrongID) if err == nil { t.Fatal("expected error when decrypting with wrong key") } @@ -3505,8 +3680,12 @@ func TestExtractBundleToWorkdir_OpenFileErrorOnExtract(t *testing.T) { if _, err := tw.Write(checksum); err != nil { t.Fatalf("write checksum: %v", err) } - tw.Close() - bundleFile.Close() + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := bundleFile.Close(); err != nil { + t.Fatalf("close bundle file: %v", err) + } workDir := filepath.Join(tmp, "work") if err := os.MkdirAll(workDir, 0o755); err != nil { @@ -3550,12 +3729,23 @@ func TestInspectRcloneBundleManifest_ManifestFoundWithWaitErr(t *testing.T) { // Create a tar file with manifest tarPath := filepath.Join(tmp, "bundle.tar") - tarFile, _ := os.Create(tarPath) + tarFile, err := os.Create(tarPath) + if err != nil { + t.Fatalf("create tar: %v", err) + } tw := tar.NewWriter(tarFile) - tw.WriteHeader(&tar.Header{Name: "test.manifest.json", Size: int64(len(manifestJSON)), Mode: 0o640}) - tw.Write(manifestJSON) - tw.Close() - tarFile.Close() + if err := tw.WriteHeader(&tar.Header{Name: "test.manifest.json", Size: int64(len(manifestJSON)), Mode: 0o640}); err != nil { + t.Fatalf("write manifest header: %v", err) + } + if _, err := tw.Write(manifestJSON); err != nil { + t.Fatalf("write manifest: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := tarFile.Close(); err != nil { + t.Fatalf("close tar file: %v", err) + } // Script that outputs the tar and then exits with error script := fmt.Sprintf(`#!/bin/bash @@ -3963,13 +4153,22 @@ func TestInspectRcloneBundleManifest_ReadManifestError(t *testing.T) { // Create a tar file with a metadata entry that has invalid JSON tarPath := filepath.Join(tmp, "bundle.tar") - tarFile, _ := os.Create(tarPath) + tarFile, err := os.Create(tarPath) + if err != nil { + t.Fatalf("create tar: %v", err) + } tw := tar.NewWriter(tarFile) // Write header with size larger than actual data to cause read error - tw.WriteHeader(&tar.Header{Name: "test.metadata", Size: 1000, Mode: 0o640}) - tw.Write([]byte("partial")) - tw.Close() - tarFile.Close() + if err := tw.WriteHeader(&tar.Header{Name: "test.metadata", Size: 1000, Mode: 0o640}); err != nil { + t.Fatalf("write metadata header: %v", err) + } + if _, err := tw.Write([]byte("partial")); err != nil { + t.Fatalf("write partial metadata: %v", err) + } + _ = tw.Close() + if err := tarFile.Close(); err != nil { + t.Fatalf("close tar file: %v", err) + } script := fmt.Sprintf(`#!/bin/bash cat "%s" @@ -3983,7 +4182,7 @@ cat "%s" ctx := context.Background() logger := logging.New(types.LogLevelError, false) - _, err := inspectRcloneBundleManifest(ctx, "remote:bundle.tar", logger) + _, err = inspectRcloneBundleManifest(ctx, "remote:bundle.tar", logger) if err == nil { t.Fatalf("expected error, got nil") } @@ -4001,10 +4200,17 @@ func TestInspectRcloneBundleManifest_ManifestNilWithWaitErr(t *testing.T) { // Create an empty tar file tarPath := filepath.Join(tmp, "empty.tar") - tarFile, _ := os.Create(tarPath) + tarFile, err := os.Create(tarPath) + if err != nil { + t.Fatalf("create tar: %v", err) + } tw := tar.NewWriter(tarFile) - tw.Close() - tarFile.Close() + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := tarFile.Close(); err != nil { + t.Fatalf("close tar file: %v", err) + } script := fmt.Sprintf(`#!/bin/bash cat "%s" @@ -4019,7 +4225,7 @@ exit 1 ctx := context.Background() logger := logging.New(types.LogLevelError, false) - _, err := inspectRcloneBundleManifest(ctx, "remote:bundle.tar", logger) + _, err = inspectRcloneBundleManifest(ctx, "remote:bundle.tar", logger) if err == nil { t.Fatalf("expected error, got nil") } @@ -4040,13 +4246,23 @@ func TestInspectRcloneBundleManifest_SkipsDirectories(t *testing.T) { tw := tar.NewWriter(tarFile) // Add a directory entry - tw.WriteHeader(&tar.Header{Name: "subdir/", Typeflag: tar.TypeDir, Mode: 0o755}) + if err := tw.WriteHeader(&tar.Header{Name: "subdir/", Typeflag: tar.TypeDir, Mode: 0o755}); err != nil { + t.Fatalf("write directory header: %v", err) + } // Add manifest - tw.WriteHeader(&tar.Header{Name: "subdir/test.metadata", Size: int64(len(manifestJSON)), Mode: 0o640}) - tw.Write(manifestJSON) - tw.Close() - tarFile.Close() + if err := tw.WriteHeader(&tar.Header{Name: "subdir/test.metadata", Size: int64(len(manifestJSON)), Mode: 0o640}); err != nil { + t.Fatalf("write manifest header: %v", err) + } + if _, err := tw.Write(manifestJSON); err != nil { + t.Fatalf("write manifest: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := tarFile.Close(); err != nil { + t.Fatalf("close tar file: %v", err) + } rcloneScript := filepath.Join(tmp, "rclone") script := fmt.Sprintf(`#!/bin/bash @@ -4116,15 +4332,26 @@ func TestExtractBundleToWorkdir_RelPathError(t *testing.T) { // Create a tar with an entry that would cause filepath.Rel to fail // This is hard to trigger naturally, but we can test the escape check bundlePath := filepath.Join(tmp, "bundle.tar") - bundleFile, _ := os.Create(bundlePath) + bundleFile, err := os.Create(bundlePath) + if err != nil { + t.Fatalf("create bundle: %v", err) + } tw := tar.NewWriter(bundleFile) // Add file with path traversal attempt archiveData := []byte("archive content") - tw.WriteHeader(&tar.Header{Name: "../../../etc/passwd", Size: int64(len(archiveData)), Mode: 0o640}) - tw.Write(archiveData) - tw.Close() - bundleFile.Close() + if err := tw.WriteHeader(&tar.Header{Name: "../../../etc/passwd", Size: int64(len(archiveData)), Mode: 0o640}); err != nil { + t.Fatalf("write traversal header: %v", err) + } + if _, err := tw.Write(archiveData); err != nil { + t.Fatalf("write traversal content: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := bundleFile.Close(); err != nil { + t.Fatalf("close bundle file: %v", err) + } workDir := filepath.Join(tmp, "work") if err := os.MkdirAll(workDir, 0o755); err != nil { @@ -4136,7 +4363,7 @@ func TestExtractBundleToWorkdir_RelPathError(t *testing.T) { defer func() { restoreFS = orig }() logger := logging.New(types.LogLevelError, false) - _, err := extractBundleToWorkdirWithLogger(bundlePath, workDir, logger) + _, err = extractBundleToWorkdirWithLogger(bundlePath, workDir, logger) if err == nil { t.Fatalf("expected error for path traversal, got nil") } @@ -4261,7 +4488,7 @@ func (f *fakeStatThenRemoveFS) Stat(path string) (os.FileInfo, error) { } // After stat succeeds, remove the file so GenerateChecksum can't open it if strings.Contains(path, "proxmox-decrypt") && strings.HasSuffix(path, ".tar.xz") { - os.Remove(path) + _ = os.Remove(path) } return info, nil } @@ -4318,7 +4545,9 @@ func TestPreparePlainBundle_MkdirAllErrorAfterRcloneDownload(t *testing.T) { // Create fake rclone that downloads a valid bundle fakeRclone := filepath.Join(tmp, "rclone") bundleDir := filepath.Join(tmp, "bundles") - os.MkdirAll(bundleDir, 0o755) + if err := os.MkdirAll(bundleDir, 0o755); err != nil { + t.Fatalf("mkdir bundle dir: %v", err) + } // Create the bundle that will be "downloaded" sourceBundlePath := filepath.Join(bundleDir, "backup.bundle.tar") @@ -4334,7 +4563,9 @@ if [[ "$1" == "copyto" ]]; then fi exit 0 `, sourceBundlePath) - os.WriteFile(fakeRclone, []byte(script), 0o755) + if err := os.WriteFile(fakeRclone, []byte(script), 0o755); err != nil { + t.Fatalf("write fake rclone: %v", err) + } prependPathEnv(t, tmp) diff --git a/internal/orchestrator/decrypt_tui_e2e_helpers_test.go b/internal/orchestrator/decrypt_tui_e2e_helpers_test.go index 9dd6dc32..556cb635 100644 --- a/internal/orchestrator/decrypt_tui_e2e_helpers_test.go +++ b/internal/orchestrator/decrypt_tui_e2e_helpers_test.go @@ -74,8 +74,8 @@ func (s *notifyingSimulationScreen) snapshotState() timedSimScreenSnapshot { } func (s *notifyingSimulationScreen) captureLocked() { - cells, width, height := s.SimulationScreen.GetContents() - cursorX, cursorY, cursorVisible := s.SimulationScreen.GetCursor() + cells, width, height := s.GetContents() + cursorX, cursorY, cursorVisible := s.GetCursor() s.snapshot = timedSimScreenSnapshot{ cells: cloneSimCells(cells), width: width, @@ -609,7 +609,7 @@ func readTarEntries(t *testing.T, tarPath string) map[string][]byte { if err != nil { t.Fatalf("open tar %s: %v", tarPath, err) } - defer file.Close() + defer func() { _ = file.Close() }() tr := tar.NewReader(file) entries := make(map[string][]byte) diff --git a/internal/orchestrator/deps.go b/internal/orchestrator/deps.go index 9a5099a4..f18fe2e2 100644 --- a/internal/orchestrator/deps.go +++ b/internal/orchestrator/deps.go @@ -155,7 +155,7 @@ func (osCommandRunner) RunStream(ctx context.Context, name string, stdin io.Read return nil, err } if err := cmd.Start(); err != nil { - stdout.Close() + _ = stdout.Close() return nil, err } return &waitReadCloser{ReadCloser: stdout, wait: cmd.Wait}, nil diff --git a/internal/orchestrator/deps_additional_test.go b/internal/orchestrator/deps_additional_test.go index 7ff5b57b..7b26981a 100644 --- a/internal/orchestrator/deps_additional_test.go +++ b/internal/orchestrator/deps_additional_test.go @@ -116,7 +116,7 @@ func TestConsolePrompterWrappers(t *testing.T) { _, _ = w.WriteString("1\n") _ = w.Close() os.Stdin = r - defer r.Close() + defer func() { _ = r.Close() }() mode, err := (consolePrompter{}).SelectRestoreMode(context.Background(), logger, SystemTypePVE) if err != nil { @@ -141,7 +141,7 @@ func TestConsolePrompterWrappers(t *testing.T) { _, _ = w.WriteString("a\nc\n") _ = w.Close() os.Stdin = r - defer r.Close() + defer func() { _ = r.Close() }() cats, err := (consolePrompter{}).SelectCategories(context.Background(), logger, available, SystemTypePVE) if err != nil { @@ -160,7 +160,7 @@ func TestConsolePrompterWrappers(t *testing.T) { _, _ = w.WriteString("RESTORE\n") _ = w.Close() os.Stdin = r - defer r.Close() + defer func() { _ = r.Close() }() ok, err := (consolePrompter{}).ConfirmRestore(context.Background(), logger) if err != nil { diff --git a/internal/orchestrator/directory_recreation.go b/internal/orchestrator/directory_recreation.go index 870f3a25..060b6e6e 100644 --- a/internal/orchestrator/directory_recreation.go +++ b/internal/orchestrator/directory_recreation.go @@ -1,16 +1,9 @@ +// Package orchestrator coordinates backup, restore, decrypt, and notification workflows. package orchestrator import ( - "bufio" "errors" "fmt" - "io" - "os" - "os/user" - "path/filepath" - "strconv" - "strings" - "syscall" "github.com/tis24dev/proxsave/internal/logging" ) @@ -23,75 +16,20 @@ var ( // RecreateStorageDirectories parses storage.cfg and recreates storage directories (PVE) func RecreateStorageDirectories(logger *logging.Logger) error { - // Check if file exists - if _, err := os.Stat(storageCfgPath); err != nil { - if os.IsNotExist(err) { - logger.Debug("No storage.cfg found, skipping storage directory recreation") - return nil - } - return fmt.Errorf("stat storage.cfg: %w", err) - } - - logger.Info("Parsing storage.cfg to recreate storage directories...") - - file, err := os.Open(storageCfgPath) + entries, err := loadPVEStorageEntries(storageCfgPath, logger) if err != nil { - return fmt.Errorf("open storage.cfg: %w", err) + return err } - defer file.Close() - - scanner := bufio.NewScanner(file) - var currentStorage string - var currentPath string - var currentType string directoriesCreated := 0 - - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - - // Skip comments and empty lines - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - // Check for storage definition start (e.g., "dir: local") - if strings.Contains(line, ":") && !strings.Contains(line, "=") { - parts := strings.Fields(line) - if len(parts) >= 2 { - currentType = strings.TrimSuffix(parts[0], ":") - currentStorage = strings.TrimSuffix(parts[1], ":") - currentPath = "" - } + for _, entry := range entries { + if err := createPVEStorageStructure(entry.Path, entry.Type, logger); err != nil { + logger.Warning("Failed to create storage structure for %s: %v", entry.Name, err) continue } - // Parse path directive - if strings.HasPrefix(line, "path ") { - parts := strings.Fields(line) - if len(parts) >= 2 { - currentPath = parts[1] - } - } - - // When we have both storage name and path, create the directory structure - if currentStorage != "" && currentPath != "" && currentType != "" { - if err := createPVEStorageStructure(currentPath, currentType, logger); err != nil { - logger.Warning("Failed to create storage structure for %s: %v", currentStorage, err) - } else { - directoriesCreated++ - logger.Debug("Created storage structure: %s (%s) at %s", currentStorage, currentType, currentPath) - } - - // Reset for next storage - currentStorage = "" - currentPath = "" - currentType = "" - } - } - - if err := scanner.Err(); err != nil { - return fmt.Errorf("read storage.cfg: %w", err) + directoriesCreated++ + logger.Debug("Created storage structure: %s (%s) at %s", entry.Name, entry.Type, entry.Path) } if directoriesCreated > 0 { @@ -101,119 +39,24 @@ func RecreateStorageDirectories(logger *logging.Logger) error { return nil } -// createPVEStorageStructure creates the directory structure for a PVE storage -func createPVEStorageStructure(basePath, storageType string, logger *logging.Logger) error { - // Create base directory - if err := os.MkdirAll(basePath, 0750); err != nil { - return fmt.Errorf("create base directory: %w", err) - } - - // Create subdirectories based on storage type - switch storageType { - case "dir": - // Standard directory storage needs these subdirectories - subdirs := []string{"dump", "images", "template", "snippets", "private"} - for _, subdir := range subdirs { - path := filepath.Join(basePath, subdir) - if err := os.MkdirAll(path, 0750); err != nil { - logger.Warning("Failed to create %s: %v", path, err) - } - } - - case "nfs", "cifs": - // Network storage - subdirs := []string{"dump", "images", "template"} - for _, subdir := range subdirs { - path := filepath.Join(basePath, subdir) - if err := os.MkdirAll(path, 0750); err != nil { - logger.Warning("Failed to create %s: %v", path, err) - } - } - - default: - // For other storage types, just ensure base path exists - logger.Debug("Storage type %s does not require subdirectories", storageType) - } - - // Set ownership to root:root (already the case when running as root) - // PVE typically uses root:root for storage directories - - return nil -} - // RecreateDatastoreDirectories parses datastore.cfg and recreates datastore directories (PBS) func RecreateDatastoreDirectories(logger *logging.Logger) error { - // Check if file exists - if _, err := os.Stat(datastoreCfgPath); err != nil { - if os.IsNotExist(err) { - logger.Debug("No datastore.cfg found, skipping datastore directory recreation") - return nil - } - return fmt.Errorf("stat datastore.cfg: %w", err) - } - - if err := normalizePBSDatastoreCfg(datastoreCfgPath, logger); err != nil { - logger.Warning("PBS datastore.cfg normalization failed: %v", err) - } - - logger.Info("Parsing datastore.cfg to recreate datastore directories...") - - file, err := os.Open(datastoreCfgPath) + entries, err := loadPBSDatastoreEntries(datastoreCfgPath, logger) if err != nil { - return fmt.Errorf("open datastore.cfg: %w", err) + return err } - defer file.Close() - - scanner := bufio.NewScanner(file) - var currentDatastore string - var currentPath string directoriesCreated := 0 - - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - - // Skip comments and empty lines - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - // Check for datastore definition start (e.g., "datastore: backup") - if strings.HasPrefix(line, "datastore:") { - parts := strings.Fields(line) - if len(parts) >= 2 { - currentDatastore = strings.TrimSuffix(parts[1], ":") - currentPath = "" - } + for _, entry := range entries { + created, err := createPBSDatastoreStructure(entry.Path, entry.Name, logger) + if err != nil { + logger.Warning("Failed to create datastore structure for %s: %v", entry.Name, err) continue } - - // Parse path directive - if strings.HasPrefix(line, "path ") { - parts := strings.Fields(line) - if len(parts) >= 2 { - currentPath = parts[1] - } + if created { + directoriesCreated++ + logger.Debug("Created datastore structure: %s at %s", entry.Name, entry.Path) } - - // When we have both datastore name and path, create the directory - if currentDatastore != "" && currentPath != "" { - created, err := createPBSDatastoreStructure(currentPath, currentDatastore, logger) - if err != nil { - logger.Warning("Failed to create datastore structure for %s: %v", currentDatastore, err) - } else if created { - directoriesCreated++ - logger.Debug("Created datastore structure: %s at %s", currentDatastore, currentPath) - } - - // Reset for next datastore - currentDatastore = "" - currentPath = "" - } - } - - if err := scanner.Err(); err != nil { - return fmt.Errorf("read datastore.cfg: %w", err) } if directoriesCreated > 0 { @@ -223,622 +66,6 @@ func RecreateDatastoreDirectories(logger *logging.Logger) error { return nil } -// createPBSDatastoreStructure creates the directory structure for a PBS datastore. -// It returns true when ProxSave made filesystem changes for this datastore path. -func createPBSDatastoreStructure(basePath, datastoreName string, logger *logging.Logger) (bool, error) { - done := logging.DebugStart(logger, "pbs datastore directory recreation", "datastore=%s path=%s", datastoreName, basePath) - var err error - defer func() { done(err) }() - - changed := false - - // ZFS SAFETY: if ZFS is detected and this path looks like a ZFS mountpoint, avoid creating the datastore directory - // when it does not exist yet. On ZFS systems the directory is typically created by mounting/importing the pool; - // creating it ourselves can "shadow" the intended mountpoint and leads to confusing restore outcomes. - if isLikelyZFSMountPoint(basePath, logger) { - if _, statErr := os.Stat(basePath); statErr != nil { - if os.IsNotExist(statErr) { - logger.Warning("PBS datastore preflight: %s looks like a ZFS mountpoint and does not exist yet; skipping directory creation to avoid shadowing a not-yet-imported pool", basePath) - err = nil - return false, nil - } - logger.Warning("PBS datastore preflight: unable to stat potential ZFS mountpoint %s: %v; skipping any datastore filesystem changes", basePath, statErr) - err = nil - return false, nil - } - } - - dataUnknown := false - hasData, dataErr := pbsDatastoreHasData(basePath) - if dataErr != nil { - dataUnknown = true - logger.Warning("PBS datastore preflight: unable to determine whether %s contains datastore data: %v", basePath, dataErr) - } - - onRootFS, existingPath, devErr := isPathOnRootFilesystem(basePath) - if devErr != nil { - logger.Warning("PBS datastore preflight: unable to determine filesystem device for %s: %v", basePath, devErr) - } - logging.DebugStep( - logger, - "pbs datastore preflight", - "path=%s existing=%s on_rootfs=%t has_data=%t data_unknown=%t", - basePath, - existingPath, - onRootFS, - hasData, - dataUnknown, - ) - - // IMPORTANT SAFETY GUARD: - // If the datastore path looks like a mountpoint location (e.g. under /mnt) but resolves to the root filesystem - // and contains no datastore data, we assume the disk/pool is not mounted and refuse to write. This prevents - // accidentally creating datastore scaffolding on "/" during restore. - if onRootFS && (isSuspiciousDatastoreMountLocation(basePath) || isLikelyZFSMountPoint(basePath, logger)) && (dataUnknown || !hasData) { - logger.Warning("PBS datastore preflight: %s resolves to the root filesystem (mount missing?) — skipping datastore directory initialization to avoid writing to the wrong disk", basePath) - logger.Info("Mount/import the datastore disk/pool first, then restart PBS services.") - if _, zfsErr := os.Stat(zpoolCachePath); zfsErr == nil { - logger.Info("ZFS detected: if this datastore was on ZFS, you may need to import the pool first (e.g. `zpool import` then `zpool import `).") - } - err = nil - return false, nil - } - - // If we cannot reliably inspect the datastore path, we refuse to mutate it to avoid risking real datastore data. - if dataUnknown { - logger.Warning("PBS datastore preflight: datastore path inspection failed — skipping any datastore filesystem changes to avoid risking existing data") - err = nil - return false, nil - } - - // If the datastore already contains chunk/index data, avoid any modifications to prevent touching real backup data. - // We only validate and report issues. - if hasData { - if warn := validatePBSDatastoreReadOnly(basePath); warn != "" { - logger.Warning("PBS datastore preflight: %s", warn) - } - logger.Info("PBS datastore preflight: datastore %s appears to contain data; skipping directory/permission changes to avoid risking datastore contents", datastoreName) - err = nil - return false, nil - } - - // If the datastore root contains any entries outside of the expected PBS scaffolding, do not touch it. - // This keeps ProxSave conservative: only initialize truly empty/uninitialized datastore directories. - unexpected, unexpectedErr := pbsDatastoreHasUnexpectedEntries(basePath) - if unexpectedErr != nil { - logger.Warning("PBS datastore preflight: unable to inspect %s contents: %v; skipping any datastore filesystem changes to avoid risking unrelated data", basePath, unexpectedErr) - err = nil - return false, nil - } - if unexpected { - logger.Warning("PBS datastore preflight: %s is not empty (unexpected entries present); skipping any datastore filesystem changes to avoid risking unrelated data", basePath) - err = nil - return false, nil - } - - dirsToFix, err := computeMissingDirs(basePath) - if err != nil { - return false, fmt.Errorf("compute missing dirs: %w", err) - } - - // Create base directory - if err := os.MkdirAll(basePath, 0750); err != nil { - return false, fmt.Errorf("create base directory: %w", err) - } - if len(dirsToFix) > 0 { - changed = true - } - - // PBS datastores need these subdirectories - subdirs := []string{".chunks", ".index"} - for _, subdir := range subdirs { - path := filepath.Join(basePath, subdir) - if _, err := os.Stat(path); err != nil { - if os.IsNotExist(err) { - changed = true - dirsToFix = append(dirsToFix, path) - } - } - if err := os.MkdirAll(path, 0750); err != nil { - logger.Warning("Failed to create %s: %v", path, err) - } - } - - // Set ownership to backup:backup when possible for directory components created by ProxSave. - // This avoids a common failure mode where parent directories created by MkdirAll remain root-only - // and prevent PBS (backup user) from accessing the datastore path. - if len(dirsToFix) > 0 { - logger.Debug("PBS datastore permissions: applying ownership to %d created path(s) (datastore=%s path=%s)", len(dirsToFix), datastoreName, basePath) - } - for _, dir := range dirsToFix { - if err := setDatastoreOwnership(dir, logger); err != nil { - logger.Warning("Could not set datastore ownership for %s: %v", dir, err) - } - } - - // Always attempt to fix the datastore root itself (even if it pre-existed), since PBS requires - // backup:backup ownership and accessible permissions to function. - if err := setDatastoreOwnership(basePath, logger); err != nil { - logger.Warning("Could not set datastore ownership for %s: %v", basePath, err) - } - - lockChanged, lockErr := ensurePBSDatastoreLockFile(basePath, logger) - if lockErr != nil { - logger.Warning("PBS datastore lock file: %v", lockErr) - } - changed = changed || lockChanged - - return changed, nil -} - -func validatePBSDatastoreReadOnly(datastorePath string) string { - if datastorePath == "" { - return "datastore path is empty" - } - - info, err := os.Stat(datastorePath) - if err != nil { - return fmt.Sprintf("datastore path %s cannot be stat'd: %v", datastorePath, err) - } - if !info.IsDir() { - return fmt.Sprintf("datastore path %s is not a directory (type=%s)", datastorePath, info.Mode()) - } - - chunksPath := filepath.Join(datastorePath, ".chunks") - chunksInfo, err := os.Stat(chunksPath) - if err != nil { - return fmt.Sprintf("datastore %s missing .chunks directory: %v", datastorePath, err) - } - if !chunksInfo.IsDir() { - return fmt.Sprintf("datastore %s .chunks is not a directory (type=%s)", datastorePath, chunksInfo.Mode()) - } - - indexPath := filepath.Join(datastorePath, ".index") - indexInfo, err := os.Stat(indexPath) - if err != nil { - return fmt.Sprintf("datastore %s missing .index directory: %v", datastorePath, err) - } - if !indexInfo.IsDir() { - return fmt.Sprintf("datastore %s .index is not a directory (type=%s)", datastorePath, indexInfo.Mode()) - } - - lockPath := filepath.Join(datastorePath, ".lock") - lockInfo, err := os.Stat(lockPath) - if err != nil { - return fmt.Sprintf("datastore %s missing .lock file: %v", datastorePath, err) - } - if !lockInfo.Mode().IsRegular() { - return fmt.Sprintf("datastore %s .lock is not a regular file (type=%s)", datastorePath, lockInfo.Mode()) - } - - return "" -} - -func ensurePBSDatastoreLockFile(datastorePath string, logger *logging.Logger) (bool, error) { - lockPath := filepath.Join(datastorePath, ".lock") - - info, err := os.Lstat(lockPath) - if err != nil { - if !os.IsNotExist(err) { - return false, fmt.Errorf("stat %s: %w", lockPath, err) - } - - logger.Debug("PBS datastore lock: creating %s", lockPath) - file, err := os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o640) - if err != nil { - return false, fmt.Errorf("create %s: %w", lockPath, err) - } - _ = file.Close() - - if err := setDatastoreOwnership(lockPath, logger); err != nil { - return true, fmt.Errorf("chown %s: %w", lockPath, err) - } - return true, nil - } - - if info.Mode()&os.ModeSymlink != 0 { - return false, fmt.Errorf("%s is a symlink; refusing to manage lock file", lockPath) - } - - if info.IsDir() { - changed := false - entries, err := os.ReadDir(lockPath) - if err != nil { - return false, fmt.Errorf("lock path %s is a directory and cannot be read: %w", lockPath, err) - } - - if len(entries) == 0 { - logger.Warning("PBS datastore lock: %s is a directory (invalid); removing and recreating as file", lockPath) - if err := os.Remove(lockPath); err != nil { - return false, fmt.Errorf("remove invalid lock dir %s: %w", lockPath, err) - } - changed = true - } else { - backupPath := fmt.Sprintf("%s.proxsave-dir.%s", lockPath, nowRestore().Format("20060102-150405")) - logger.Warning("PBS datastore lock: %s is a non-empty directory (invalid); renaming to %s and creating lock file", lockPath, backupPath) - if err := os.Rename(lockPath, backupPath); err != nil { - return false, fmt.Errorf("rename invalid lock dir %s -> %s: %w", lockPath, backupPath, err) - } - changed = true - } - - file, err := os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o640) - if err != nil { - return changed, fmt.Errorf("create %s: %w", lockPath, err) - } - _ = file.Close() - changed = true - - if err := setDatastoreOwnership(lockPath, logger); err != nil { - return changed, fmt.Errorf("chown %s: %w", lockPath, err) - } - - return changed, nil - } - - if err := setDatastoreOwnership(lockPath, logger); err != nil { - return false, fmt.Errorf("chown %s: %w", lockPath, err) - } - - return false, nil -} - -func normalizePBSDatastoreCfg(path string, logger *logging.Logger) error { - raw, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("read datastore.cfg: %w", err) - } - - normalized, fixed := normalizePBSDatastoreCfgContent(string(raw)) - if fixed == 0 { - logger.Debug("PBS datastore.cfg: formatting looks OK (no normalization needed)") - return nil - } - - if err := os.MkdirAll("/tmp/proxsave", 0o755); err != nil { - return fmt.Errorf("ensure /tmp/proxsave exists: %w", err) - } - - backupPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("datastore.cfg.pre-normalize.%s", nowRestore().Format("20060102-150405"))) - if err := os.WriteFile(backupPath, raw, 0o600); err != nil { - return fmt.Errorf("write backup copy: %w", err) - } - - mode := os.FileMode(0o644) - if info, err := os.Stat(path); err == nil { - mode = info.Mode().Perm() - } - - tmpPath := fmt.Sprintf("%s.proxsave.tmp", path) - if err := os.WriteFile(tmpPath, []byte(normalized), mode); err != nil { - return fmt.Errorf("write normalized datastore.cfg: %w", err) - } - if err := os.Rename(tmpPath, path); err != nil { - _ = os.Remove(tmpPath) - return fmt.Errorf("replace datastore.cfg: %w", err) - } - - logger.Warning("PBS datastore.cfg: fixed %d malformed line(s) (properties must be indented); backup saved to %s", fixed, backupPath) - return nil -} - -func normalizePBSDatastoreCfgContent(content string) (string, int) { - lines := strings.Split(content, "\n") - if len(lines) == 0 { - return content, 0 - } - - inDatastoreBlock := false - fixed := 0 - for i, line := range lines { - trimmed := strings.TrimSpace(line) - if trimmed == "" || strings.HasPrefix(trimmed, "#") { - continue - } - - if strings.HasPrefix(trimmed, "datastore:") { - inDatastoreBlock = true - continue - } - - if !inDatastoreBlock { - continue - } - - if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") { - continue - } - - lines[i] = " " + line - fixed++ - } - - return strings.Join(lines, "\n"), fixed -} - -func computeMissingDirs(target string) ([]string, error) { - path := filepath.Clean(target) - if path == "" || path == "." || path == "/" { - return nil, nil - } - - var missing []string - for { - if path == "" || path == "." || path == "/" { - break - } - _, err := os.Stat(path) - if err == nil { - break - } - if !os.IsNotExist(err) { - return nil, err - } - missing = append(missing, path) - parent := filepath.Dir(path) - if parent == path { - break - } - path = parent - } - - // Reverse so parents come first (top-down), making logs more readable. - for i, j := 0, len(missing)-1; i < j; i, j = i+1, j-1 { - missing[i], missing[j] = missing[j], missing[i] - } - return missing, nil -} - -func pbsDatastoreHasData(datastorePath string) (bool, error) { - if strings.TrimSpace(datastorePath) == "" { - return false, fmt.Errorf("path is empty") - } - info, err := os.Stat(datastorePath) - if err != nil { - if os.IsNotExist(err) || errors.Is(err, syscall.ENOTDIR) { - return false, nil - } - return false, err - } - if !info.IsDir() { - return false, nil - } - - for _, subdir := range []string{".chunks", ".index"} { - has, err := dirHasAnyEntry(filepath.Join(datastorePath, subdir)) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - continue - } - return false, err - } - if has { - return true, nil - } - } - - return false, nil -} - -func pbsDatastoreHasUnexpectedEntries(datastorePath string) (bool, error) { - if strings.TrimSpace(datastorePath) == "" { - return false, nil - } - - info, err := os.Stat(datastorePath) - if err != nil { - if os.IsNotExist(err) || errors.Is(err, syscall.ENOTDIR) { - return false, nil - } - return false, err - } - if !info.IsDir() { - return false, nil - } - - allowed := map[string]struct{}{ - ".chunks": {}, - ".index": {}, - ".lock": {}, - } - - f, err := os.Open(datastorePath) - if err != nil { - return false, err - } - defer f.Close() - - for { - names, err := f.Readdirnames(64) - if err == nil { - for _, name := range names { - if _, ok := allowed[name]; ok { - continue - } - return true, nil - } - continue - } - - if errors.Is(err, io.EOF) { - return false, nil - } - return false, err - } -} - -func dirHasAnyEntry(path string) (bool, error) { - f, err := os.Open(path) - if err != nil { - return false, err - } - defer f.Close() - - _, err = f.Readdirnames(1) - if err == nil { - return true, nil - } - if errors.Is(err, io.EOF) { - return false, nil - } - return false, err -} - -func isConfirmableDatastoreMountRoot(path string) bool { - path = filepath.Clean(path) - switch { - case strings.HasPrefix(path, "/mnt/"): - return true - case strings.HasPrefix(path, "/media/"): - return true - case strings.HasPrefix(path, "/run/media/"): - return true - default: - return false - } -} - -func isSuspiciousDatastoreMountLocation(path string) bool { - // Conservative: only treat typical mount roots as "must be mounted". - // This prevents accidental writes to "/" when a disk/pool wasn't mounted yet. - return isConfirmableDatastoreMountRoot(path) -} - -func isPathOnRootFilesystem(path string) (bool, string, error) { - rootDev, err := deviceID("/") - if err != nil { - return false, "/", err - } - - existing, err := nearestExistingPath(path) - if err != nil { - return false, "", err - } - targetDev, err := deviceID(existing) - if err != nil { - return false, existing, err - } - return rootDev == targetDev, existing, nil -} - -func nearestExistingPath(target string) (string, error) { - path := filepath.Clean(target) - if path == "" || path == "." { - return "", fmt.Errorf("invalid path") - } - - for { - if _, err := os.Stat(path); err == nil { - return path, nil - } else if !os.IsNotExist(err) { - return "", err - } - - parent := filepath.Dir(path) - if parent == path { - return path, nil - } - path = parent - } -} - -func deviceID(path string) (uint64, error) { - info, err := os.Stat(path) - if err != nil { - return 0, err - } - stat, ok := info.Sys().(*syscall.Stat_t) - if !ok || stat == nil { - return 0, fmt.Errorf("unsupported stat type for %s", path) - } - return uint64(stat.Dev), nil -} - -// isLikelyZFSMountPoint checks if a path is likely a ZFS mount point -func isLikelyZFSMountPoint(path string, logger *logging.Logger) bool { - // Check if /etc/zfs/zpool.cache exists (indicates ZFS is used on this system) - if _, err := os.Stat(zpoolCachePath); err != nil { - // No ZFS on this system - return false - } - - // Common ZFS mount point patterns - // PBS datastores on ZFS are typically under /mnt/ or use "backup" in the name - pathLower := strings.ToLower(path) - if strings.HasPrefix(pathLower, "/mnt/") || - strings.Contains(pathLower, "backup") || - strings.Contains(pathLower, "datastore") { - logger.Debug("Path %s matches ZFS mount point pattern", path) - return true - } - - return false -} - -// setDatastoreOwnership sets ownership to backup:backup for PBS datastores -func setDatastoreOwnership(path string, logger *logging.Logger) error { - if os.Geteuid() != 0 { - // Ownership/permission adjustments are best-effort and should not block - // directory recreation when running without privileges (common in CI/tests). - logger.Debug("PBS datastore ownership: running as non-root (euid=%d); skipping chown/chmod for %s", os.Geteuid(), path) - return nil - } - - backupUser, err := user.Lookup("backup") - if err != nil { - // On non-PBS systems the user may not exist; treat as non-fatal. - logger.Debug("PBS datastore ownership: user 'backup' not found; skipping chown for %s", path) - return nil - } - uid, err := strconv.Atoi(backupUser.Uid) - if err != nil { - return fmt.Errorf("parse backup uid: %w", err) - } - gid, err := strconv.Atoi(backupUser.Gid) - if err != nil { - return fmt.Errorf("parse backup gid: %w", err) - } - - logger.Debug("PBS datastore ownership: chown %s to backup:backup (uid=%d gid=%d)", path, uid, gid) - if err := os.Chown(path, uid, gid); err != nil { - if isIgnorableOwnershipError(err) { - logger.Warning("PBS datastore ownership: unable to chown %s to backup:backup (uid=%d gid=%d): %v (continuing)", path, uid, gid, err) - return nil - } - return fmt.Errorf("chown %s: %w", path, err) - } - - info, err := os.Stat(path) - if err != nil { - // Ownership was already applied; ignore stat errors for further chmod adjustments. - return nil - } - if info.IsDir() { - current := info.Mode().Perm() - required := os.FileMode(0o750) - desired := current | required - if desired != current { - logger.Debug("PBS datastore permissions: chmod %s from %o to %o", path, current, desired) - if err := os.Chmod(path, desired); err != nil { - if isIgnorableOwnershipError(err) { - logger.Warning("PBS datastore permissions: unable to chmod %s from %o to %o: %v (continuing)", path, current, desired, err) - return nil - } - return fmt.Errorf("chmod %s: %w", path, err) - } - } - } - - return nil -} - -func isIgnorableOwnershipError(err error) bool { - // Common "can't chown/chmod here" situations: - // - EPERM/EACCES: not permitted (non-root, user namespace restrictions, etc.) - // - EROFS: read-only filesystem - return errors.Is(err, syscall.EPERM) || errors.Is(err, syscall.EACCES) || errors.Is(err, syscall.EROFS) -} - // RecreateDirectoriesFromConfig recreates storage/datastore directories based on system type func RecreateDirectoriesFromConfig(systemType SystemType, logger *logging.Logger) error { logger.Info("Recreating directory structures from configuration...") diff --git a/internal/orchestrator/directory_recreation_config.go b/internal/orchestrator/directory_recreation_config.go new file mode 100644 index 00000000..f1a8aba3 --- /dev/null +++ b/internal/orchestrator/directory_recreation_config.go @@ -0,0 +1,164 @@ +// Package orchestrator coordinates backup, restore, decrypt, and notification workflows. +package orchestrator + +import ( + "bufio" + "fmt" + "io" + "os" + "strings" + + "github.com/tis24dev/proxsave/internal/logging" +) + +type pveStorageEntry struct { + Name string + Type string + Path string +} + +type pbsDatastoreEntry struct { + Name string + Path string +} + +func loadPVEStorageEntries(path string, logger *logging.Logger) (entries []pveStorageEntry, err error) { + if exists, err := configFileExists(path, "storage.cfg", "storage directory recreation", logger); err != nil || !exists { + return nil, err + } + + logger.Info("Parsing storage.cfg to recreate storage directories...") + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open storage.cfg: %w", err) + } + defer closeIntoErr(&err, file, "close storage.cfg") + + entries, err = parsePVEStorageEntries(file) + if err != nil { + return nil, fmt.Errorf("read storage.cfg: %w", err) + } + return entries, nil +} + +func loadPBSDatastoreEntries(path string, logger *logging.Logger) (entries []pbsDatastoreEntry, err error) { + if exists, err := configFileExists(path, "datastore.cfg", "datastore directory recreation", logger); err != nil || !exists { + return nil, err + } + + if err := normalizePBSDatastoreCfg(path, logger); err != nil { + logger.Warning("PBS datastore.cfg normalization failed: %v", err) + } + + logger.Info("Parsing datastore.cfg to recreate datastore directories...") + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open datastore.cfg: %w", err) + } + defer closeIntoErr(&err, file, "close datastore.cfg") + + entries, err = parsePBSDatastoreEntries(file) + if err != nil { + return nil, fmt.Errorf("read datastore.cfg: %w", err) + } + return entries, nil +} + +func configFileExists(path, label, skipReason string, logger *logging.Logger) (bool, error) { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + logger.Debug("No %s found, skipping %s", label, skipReason) + return false, nil + } + return false, fmt.Errorf("stat %s: %w", label, err) + } + return true, nil +} + +func parsePVEStorageEntries(reader io.Reader) ([]pveStorageEntry, error) { + scanner := bufio.NewScanner(reader) + var entries []pveStorageEntry + var current pveStorageEntry + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if isIgnoredConfigLine(line) { + continue + } + if entry, ok := parsePVEStorageHeader(line); ok { + current = entry + continue + } + if path, ok := parseConfigPath(line); ok && current.Name != "" { + current.Path = path + entries = append(entries, current) + current = pveStorageEntry{} + } + } + + return entries, scanner.Err() +} + +func parsePBSDatastoreEntries(reader io.Reader) ([]pbsDatastoreEntry, error) { + scanner := bufio.NewScanner(reader) + var entries []pbsDatastoreEntry + var current pbsDatastoreEntry + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if isIgnoredConfigLine(line) { + continue + } + if entry, ok := parsePBSDatastoreHeader(line); ok { + current = entry + continue + } + if path, ok := parseConfigPath(line); ok && current.Name != "" { + current.Path = path + entries = append(entries, current) + current = pbsDatastoreEntry{} + } + } + + return entries, scanner.Err() +} + +func isIgnoredConfigLine(line string) bool { + return line == "" || strings.HasPrefix(line, "#") +} + +func parsePVEStorageHeader(line string) (pveStorageEntry, bool) { + if !strings.Contains(line, ":") || strings.Contains(line, "=") { + return pveStorageEntry{}, false + } + parts := strings.Fields(line) + if len(parts) < 2 { + return pveStorageEntry{}, false + } + return pveStorageEntry{ + Type: strings.TrimSuffix(parts[0], ":"), + Name: strings.TrimSuffix(parts[1], ":"), + }, true +} + +func parsePBSDatastoreHeader(line string) (pbsDatastoreEntry, bool) { + if !strings.HasPrefix(line, "datastore:") { + return pbsDatastoreEntry{}, false + } + parts := strings.Fields(line) + if len(parts) < 2 { + return pbsDatastoreEntry{}, false + } + return pbsDatastoreEntry{Name: strings.TrimSuffix(parts[1], ":")}, true +} + +func parseConfigPath(line string) (string, bool) { + if !strings.HasPrefix(line, "path ") { + return "", false + } + parts := strings.Fields(line) + if len(parts) < 2 { + return "", false + } + return parts[1], true +} diff --git a/internal/orchestrator/directory_recreation_ownership.go b/internal/orchestrator/directory_recreation_ownership.go new file mode 100644 index 00000000..2a338479 --- /dev/null +++ b/internal/orchestrator/directory_recreation_ownership.go @@ -0,0 +1,100 @@ +// Package orchestrator coordinates backup, restore, decrypt, and notification workflows. +package orchestrator + +import ( + "fmt" + "os" + "os/user" + "strconv" + + "github.com/tis24dev/proxsave/internal/logging" +) + +// setDatastoreOwnership sets ownership to backup:backup for PBS datastores +func setDatastoreOwnership(path string, logger *logging.Logger) error { + if os.Geteuid() != 0 { + logger.Debug("PBS datastore ownership: running as non-root (euid=%d); skipping chown/chmod for %s", os.Geteuid(), path) + return nil + } + + uid, gid, found, err := lookupBackupOwnership(path, logger) + if err != nil || !found { + return err + } + if err := chownDatastorePath(path, uid, gid, logger); err != nil { + return err + } + return ensureDatastoreDirectoryMode(path, logger) +} + +func lookupBackupOwnership(path string, logger *logging.Logger) (int, int, bool, error) { + backupUser, err := user.Lookup("backup") + if err != nil { + logger.Debug("PBS datastore ownership: user 'backup' not found; skipping chown for %s", path) + return 0, 0, false, nil + } + + uid, err := parseBackupUserID("uid", backupUser.Uid) + if err != nil { + return 0, 0, false, err + } + gid, err := parseBackupUserID("gid", backupUser.Gid) + if err != nil { + return 0, 0, false, err + } + return uid, gid, true, nil +} + +func parseBackupUserID(label, value string) (int, error) { + id, err := strconv.Atoi(value) + if err != nil { + return 0, fmt.Errorf("parse backup %s: %w", label, err) + } + return id, nil +} + +func chownDatastorePath(path string, uid, gid int, logger *logging.Logger) error { + logger.Debug("PBS datastore ownership: chown %s to backup:backup (uid=%d gid=%d)", path, uid, gid) + if err := os.Chown(path, uid, gid); err != nil { + return handleDatastoreOwnershipError("ownership", path, uid, gid, err, logger) + } + return nil +} + +func handleDatastoreOwnershipError(action, path string, uid, gid int, err error, logger *logging.Logger) error { + if isIgnorableOwnershipError(err) { + logger.Warning("PBS datastore %s: unable to chown %s to backup:backup (uid=%d gid=%d): %v (continuing)", action, path, uid, gid, err) + return nil + } + return fmt.Errorf("chown %s: %w", path, err) +} + +func ensureDatastoreDirectoryMode(path string, logger *logging.Logger) error { + info, err := os.Stat(path) + if err != nil { + return err + } + if !info.IsDir() { + return nil + } + + current := info.Mode().Perm() + desired := current | os.FileMode(0o750) + if desired == current { + return nil + } + + logger.Debug("PBS datastore permissions: chmod %s from %o to %o", path, current, desired) + if err := os.Chmod(path, desired); err != nil { + return handleDatastoreModeError(path, current, desired, err, logger) + } + return nil +} + +func handleDatastoreModeError(path string, current, desired os.FileMode, err error, logger *logging.Logger) error { + if isIgnorableOwnershipError(err) { + logger.Warning("PBS datastore permissions: unable to chmod %s from %o to %o: %v (continuing)", path, current, desired, err) + return nil + } + return fmt.Errorf("chmod %s: %w", path, err) +} diff --git a/internal/orchestrator/directory_recreation_paths.go b/internal/orchestrator/directory_recreation_paths.go new file mode 100644 index 00000000..4c377f51 --- /dev/null +++ b/internal/orchestrator/directory_recreation_paths.go @@ -0,0 +1,158 @@ +// Package orchestrator coordinates backup, restore, decrypt, and notification workflows. +package orchestrator + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "syscall" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func computeMissingDirs(target string) ([]string, error) { + path := filepath.Clean(target) + if isTerminalFilesystemPath(path) { + return nil, nil + } + + missing, err := collectMissingDirs(path) + if err != nil { + return nil, err + } + reverseStrings(missing) + return missing, nil +} + +func collectMissingDirs(path string) ([]string, error) { + var missing []string + for !isTerminalFilesystemPath(path) { + exists, err := pathExistsForMissingDirs(path) + if err != nil || exists { + return missing, err + } + missing = append(missing, path) + + parent := filepath.Dir(path) + if parent == path { + break + } + path = parent + } + return missing, nil +} + +func pathExistsForMissingDirs(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func reverseStrings(values []string) { + for i, j := 0, len(values)-1; i < j; i, j = i+1, j-1 { + values[i], values[j] = values[j], values[i] + } +} + +func isTerminalFilesystemPath(path string) bool { + return path == "" || path == "." || path == "/" +} + +func isConfirmableDatastoreMountRoot(path string) bool { + path = filepath.Clean(path) + switch { + case strings.HasPrefix(path, "/mnt/"): + return true + case strings.HasPrefix(path, "/media/"): + return true + case strings.HasPrefix(path, "/run/media/"): + return true + default: + return false + } +} + +func isSuspiciousDatastoreMountLocation(path string) bool { + return isConfirmableDatastoreMountRoot(path) +} + +func isPathOnRootFilesystem(path string) (bool, string, error) { + rootDev, err := deviceID("/") + if err != nil { + return false, "/", err + } + + existing, err := nearestExistingPath(path) + if err != nil { + return false, "", err + } + targetDev, err := deviceID(existing) + if err != nil { + return false, existing, err + } + return rootDev == targetDev, existing, nil +} + +func nearestExistingPath(target string) (string, error) { + path := filepath.Clean(target) + if path == "" || path == "." { + return "", fmt.Errorf("invalid path") + } + + for { + if _, err := os.Stat(path); err == nil { + return path, nil + } else if !os.IsNotExist(err) { + return "", err + } + + parent := filepath.Dir(path) + if parent == path { + return path, nil + } + path = parent + } +} + +func deviceID(path string) (uint64, error) { + info, err := os.Stat(path) + if err != nil { + return 0, err + } + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok || stat == nil { + return 0, fmt.Errorf("unsupported stat type for %s", path) + } + return uint64(stat.Dev), nil +} + +// isLikelyZFSMountPoint checks if a path is likely a ZFS mount point +func isLikelyZFSMountPoint(path string, logger *logging.Logger) bool { + if _, err := os.Stat(zpoolCachePath); err != nil { + return false + } + + pathLower := strings.ToLower(path) + if isCommonZFSMountPath(pathLower) { + logger.Debug("Path %s matches ZFS mount point pattern", path) + return true + } + return false +} + +func isCommonZFSMountPath(pathLower string) bool { + return strings.HasPrefix(pathLower, "/mnt/") || + strings.Contains(pathLower, "backup") || + strings.Contains(pathLower, "datastore") +} + +func isIgnorableOwnershipError(err error) bool { + return errors.Is(err, syscall.EPERM) || errors.Is(err, syscall.EACCES) || errors.Is(err, syscall.EROFS) +} diff --git a/internal/orchestrator/directory_recreation_pbs.go b/internal/orchestrator/directory_recreation_pbs.go new file mode 100644 index 00000000..cb6b4875 --- /dev/null +++ b/internal/orchestrator/directory_recreation_pbs.go @@ -0,0 +1,230 @@ +// Package orchestrator coordinates backup, restore, decrypt, and notification workflows. +package orchestrator + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/tis24dev/proxsave/internal/logging" +) + +var pbsDatastoreSubdirs = []string{".chunks", ".index"} + +type pbsDatastorePreflight struct { + basePath string + datastoreName string + existingPath string + zfsLikely bool + onRootFS bool + hasData bool + dataUnknown bool + suspiciousMount bool +} + +// createPBSDatastoreStructure creates the directory structure for a PBS datastore. +// It returns true when ProxSave made filesystem changes for this datastore path. +func createPBSDatastoreStructure(basePath, datastoreName string, logger *logging.Logger) (bool, error) { + done := logging.DebugStart(logger, "pbs datastore directory recreation", "datastore=%s path=%s", datastoreName, basePath) + var err error + defer func() { done(err) }() + + zfsLikely := isLikelyZFSMountPoint(basePath, logger) + if shouldSkipMissingZFSMountPoint(basePath, zfsLikely, logger) { + return false, nil + } + + preflight := inspectPBSDatastore(basePath, datastoreName, zfsLikely, logger) + if shouldSkipUnsafePBSDatastore(preflight, logger) { + return false, nil + } + + changed, err := initializePBSDatastore(basePath, datastoreName, logger) + if err != nil { + return false, err + } + return changed, nil +} + +func shouldSkipMissingZFSMountPoint(basePath string, zfsLikely bool, logger *logging.Logger) bool { + if !zfsLikely { + return false + } + _, statErr := os.Stat(basePath) + if statErr == nil { + return false + } + if os.IsNotExist(statErr) { + logger.Warning("PBS datastore preflight: %s looks like a ZFS mountpoint and does not exist yet; skipping directory creation to avoid shadowing a not-yet-imported pool", basePath) + return true + } + logger.Warning("PBS datastore preflight: unable to stat potential ZFS mountpoint %s: %v; skipping any datastore filesystem changes", basePath, statErr) + return true +} + +func inspectPBSDatastore(basePath, datastoreName string, zfsLikely bool, logger *logging.Logger) pbsDatastorePreflight { + preflight := pbsDatastorePreflight{ + basePath: basePath, + datastoreName: datastoreName, + zfsLikely: zfsLikely, + suspiciousMount: isSuspiciousDatastoreMountLocation(basePath) || zfsLikely, + } + + preflight.hasData, preflight.dataUnknown = inspectPBSDatastoreData(basePath, logger) + preflight.onRootFS, preflight.existingPath = inspectPBSDatastoreDevice(basePath, logger) + logPBSDatastorePreflight(preflight, logger) + return preflight +} + +func inspectPBSDatastoreData(basePath string, logger *logging.Logger) (bool, bool) { + hasData, err := pbsDatastoreHasData(basePath) + if err == nil { + return hasData, false + } + logger.Warning("PBS datastore preflight: unable to determine whether %s contains datastore data: %v", basePath, err) + return false, true +} + +func inspectPBSDatastoreDevice(basePath string, logger *logging.Logger) (bool, string) { + onRootFS, existingPath, err := isPathOnRootFilesystem(basePath) + if err == nil { + return onRootFS, existingPath + } + logger.Warning("PBS datastore preflight: unable to determine filesystem device for %s: %v", basePath, err) + return false, existingPath +} + +func logPBSDatastorePreflight(preflight pbsDatastorePreflight, logger *logging.Logger) { + logging.DebugStep( + logger, + "pbs datastore preflight", + "path=%s existing=%s on_rootfs=%t has_data=%t data_unknown=%t", + preflight.basePath, + preflight.existingPath, + preflight.onRootFS, + preflight.hasData, + preflight.dataUnknown, + ) +} + +func shouldSkipUnsafePBSDatastore(preflight pbsDatastorePreflight, logger *logging.Logger) bool { + if shouldSkipRootFilesystemDatastore(preflight, logger) { + return true + } + if shouldSkipUnknownDatastoreData(preflight, logger) { + return true + } + if shouldSkipExistingDatastoreData(preflight, logger) { + return true + } + return shouldSkipUnexpectedDatastoreEntries(preflight.basePath, logger) +} + +func shouldSkipRootFilesystemDatastore(preflight pbsDatastorePreflight, logger *logging.Logger) bool { + if !preflight.onRootFS || !preflight.suspiciousMount || (!preflight.dataUnknown && preflight.hasData) { + return false + } + + logger.Warning("PBS datastore preflight: %s resolves to the root filesystem (mount missing?) — skipping datastore directory initialization to avoid writing to the wrong disk", preflight.basePath) + logger.Info("Mount/import the datastore disk/pool first, then restart PBS services.") + if _, err := os.Stat(zpoolCachePath); err == nil { + logger.Info("ZFS detected: if this datastore was on ZFS, you may need to import the pool first (e.g. `zpool import` then `zpool import `).") + } + return true +} + +func shouldSkipUnknownDatastoreData(preflight pbsDatastorePreflight, logger *logging.Logger) bool { + if !preflight.dataUnknown { + return false + } + logger.Warning("PBS datastore preflight: datastore path inspection failed — skipping any datastore filesystem changes to avoid risking existing data") + return true +} + +func shouldSkipExistingDatastoreData(preflight pbsDatastorePreflight, logger *logging.Logger) bool { + if !preflight.hasData { + return false + } + if warn := validatePBSDatastoreReadOnly(preflight.basePath); warn != "" { + logger.Warning("PBS datastore preflight: %s", warn) + } + logger.Info("PBS datastore preflight: datastore %s appears to contain data; skipping directory/permission changes to avoid risking datastore contents", preflight.datastoreName) + return true +} + +func shouldSkipUnexpectedDatastoreEntries(basePath string, logger *logging.Logger) bool { + unexpected, err := pbsDatastoreHasUnexpectedEntries(basePath) + if err != nil { + logger.Warning("PBS datastore preflight: unable to inspect %s contents: %v; skipping any datastore filesystem changes to avoid risking unrelated data", basePath, err) + return true + } + if unexpected { + logger.Warning("PBS datastore preflight: %s is not empty (unexpected entries present); skipping any datastore filesystem changes to avoid risking unrelated data", basePath) + return true + } + return false +} + +func initializePBSDatastore(basePath, datastoreName string, logger *logging.Logger) (bool, error) { + dirsToFix, err := computeMissingDirs(basePath) + if err != nil { + return false, fmt.Errorf("compute missing dirs: %w", err) + } + + if err := os.MkdirAll(basePath, 0750); err != nil { + return false, fmt.Errorf("create base directory: %w", err) + } + changed := len(dirsToFix) > 0 + + subdirChanged, dirsToFix := ensurePBSDatastoreSubdirs(basePath, dirsToFix, logger) + applyPBSDatastoreOwnership(basePath, datastoreName, dirsToFix, logger) + + lockChanged, lockErr := ensurePBSDatastoreLockFile(basePath, logger) + if lockErr != nil { + logger.Warning("PBS datastore lock file: %v", lockErr) + } + + return changed || subdirChanged || lockChanged, nil +} + +func ensurePBSDatastoreSubdirs(basePath string, dirsToFix []string, logger *logging.Logger) (bool, []string) { + changed := false + for _, subdir := range pbsDatastoreSubdirs { + path := filepath.Join(basePath, subdir) + if isMissingPath(path) { + changed = true + dirsToFix = append(dirsToFix, path) + } + if err := os.MkdirAll(path, 0750); err != nil { + logger.Warning("Failed to create %s: %v", path, err) + } + } + return changed, dirsToFix +} + +func applyPBSDatastoreOwnership(basePath, datastoreName string, dirsToFix []string, logger *logging.Logger) { + if len(dirsToFix) > 0 { + logger.Debug("PBS datastore permissions: applying ownership to %d created path(s) (datastore=%s path=%s)", len(dirsToFix), datastoreName, basePath) + } + baseProcessed := false + cleanBasePath := filepath.Clean(basePath) + for _, dir := range dirsToFix { + if err := setDatastoreOwnership(dir, logger); err != nil { + logger.Warning("Could not set datastore ownership for %s: %v", dir, err) + } + if filepath.Clean(dir) == cleanBasePath { + baseProcessed = true + } + } + if baseProcessed { + return + } + if err := setDatastoreOwnership(basePath, logger); err != nil { + logger.Warning("Could not set datastore ownership for %s: %v", basePath, err) + } +} + +func isMissingPath(path string) bool { + _, err := os.Stat(path) + return os.IsNotExist(err) +} diff --git a/internal/orchestrator/directory_recreation_pbs_config.go b/internal/orchestrator/directory_recreation_pbs_config.go new file mode 100644 index 00000000..72d59e6f --- /dev/null +++ b/internal/orchestrator/directory_recreation_pbs_config.go @@ -0,0 +1,151 @@ +// Package orchestrator coordinates backup, restore, decrypt, and notification workflows. +package orchestrator + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func normalizePBSDatastoreCfg(path string, logger *logging.Logger) error { + raw, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("read datastore.cfg: %w", err) + } + + normalized, fixed := normalizePBSDatastoreCfgContent(string(raw)) + if fixed == 0 { + logger.Debug("PBS datastore.cfg: formatting looks OK (no normalization needed)") + return nil + } + + backupPath, err := writePBSDatastoreCfgBackup(raw) + if err != nil { + return fmt.Errorf("write backup copy: %w", err) + } + + mode := datastoreCfgMode(path) + if err := writePBSDatastoreCfgAtomically(path, []byte(normalized), mode); err != nil { + return fmt.Errorf("write normalized datastore.cfg: %w", err) + } + + logger.Warning("PBS datastore.cfg: fixed %d malformed line(s) (properties must be indented); backup saved to %s", fixed, backupPath) + return nil +} + +func writePBSDatastoreCfgBackup(raw []byte) (backupPath string, err error) { + backupDir, err := os.MkdirTemp("/tmp", "proxsave-") + if err != nil { + return "", err + } + removeBackupDir := true + defer func() { + if err != nil && removeBackupDir { + _ = os.RemoveAll(backupDir) + } + }() + + prefix := fmt.Sprintf("datastore.cfg.pre-normalize.%s-", nowRestore().Format("20060102-150405")) + backupFile, err := os.CreateTemp(backupDir, prefix) + if err != nil { + return "", err + } + backupPath = backupFile.Name() + defer func() { + if err != nil { + _ = backupFile.Close() + _ = os.Remove(backupPath) + } + }() + + if err = backupFile.Chmod(0o600); err != nil { + return "", err + } + if _, err = backupFile.Write(raw); err != nil { + return "", err + } + if err = backupFile.Close(); err != nil { + return "", err + } + removeBackupDir = false + return backupPath, nil +} + +func writePBSDatastoreCfgAtomically(path string, data []byte, mode os.FileMode) (err error) { + tmpFile, err := os.CreateTemp(filepath.Dir(path), "datastore.cfg.proxsave-*") + if err != nil { + return err + } + tmpPath := tmpFile.Name() + defer func() { + if err != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + } + }() + + if err = tmpFile.Chmod(mode); err != nil { + return err + } + if _, err = tmpFile.Write(data); err != nil { + return err + } + if err = tmpFile.Close(); err != nil { + return err + } + if err = os.Rename(tmpPath, path); err != nil { + return fmt.Errorf("replace datastore.cfg: %w", err) + } + return nil +} + +func datastoreCfgMode(path string) os.FileMode { + if info, err := os.Stat(path); err == nil { + return info.Mode().Perm() + } + return os.FileMode(0o644) +} + +// normalizePBSDatastoreCfgContent expects PBS datastore.cfg content, where the +// only supported top-level sections are datastore blocks. Once a datastore block +// is seen, subsequent non-comment lines are treated as datastore properties. +func normalizePBSDatastoreCfgContent(content string) (string, int) { + lines := strings.Split(content, "\n") + inDatastoreBlock := false + fixed := 0 + + for i, line := range lines { + startsBlock, needsIndent := classifyPBSDatastoreCfgLine(line, inDatastoreBlock) + if startsBlock { + inDatastoreBlock = true + continue + } + if needsIndent { + lines[i] = " " + line + fixed++ + } + } + + return strings.Join(lines, "\n"), fixed +} + +func classifyPBSDatastoreCfgLine(line string, inDatastoreBlock bool) (bool, bool) { + trimmed := strings.TrimSpace(line) + if isIgnoredConfigLine(trimmed) { + return false, false + } + if strings.HasPrefix(trimmed, "datastore:") { + return true, false + } + if !inDatastoreBlock || hasConfigIndent(line) { + return false, false + } + return false, true +} + +func hasConfigIndent(line string) bool { + return strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") +} diff --git a/internal/orchestrator/directory_recreation_pbs_inspect.go b/internal/orchestrator/directory_recreation_pbs_inspect.go new file mode 100644 index 00000000..58ec84a1 --- /dev/null +++ b/internal/orchestrator/directory_recreation_pbs_inspect.go @@ -0,0 +1,167 @@ +// Package orchestrator coordinates backup, restore, decrypt, and notification workflows. +package orchestrator + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "syscall" +) + +var allowedPBSDatastoreScaffoldEntries = map[string]struct{}{ + ".chunks": {}, + ".index": {}, + ".lock": {}, +} + +func validatePBSDatastoreReadOnly(datastorePath string) string { + if datastorePath == "" { + return "datastore path is empty" + } + if warn := validatePBSDatastoreRoot(datastorePath); warn != "" { + return warn + } + if warn := validatePBSDatastoreSubdir(datastorePath, ".chunks"); warn != "" { + return warn + } + if warn := validatePBSDatastoreSubdir(datastorePath, ".index"); warn != "" { + return warn + } + return validatePBSDatastoreLock(datastorePath) +} + +func validatePBSDatastoreRoot(datastorePath string) string { + info, err := os.Stat(datastorePath) + if err != nil { + return fmt.Sprintf("datastore path %s cannot be stat'd: %v", datastorePath, err) + } + if !info.IsDir() { + return fmt.Sprintf("datastore path %s is not a directory (type=%s)", datastorePath, info.Mode()) + } + return "" +} + +func validatePBSDatastoreSubdir(datastorePath, name string) string { + info, err := os.Stat(filepath.Join(datastorePath, name)) + if err != nil { + return fmt.Sprintf("datastore %s missing %s directory: %v", datastorePath, name, err) + } + if !info.IsDir() { + return fmt.Sprintf("datastore %s %s is not a directory (type=%s)", datastorePath, name, info.Mode()) + } + return "" +} + +func validatePBSDatastoreLock(datastorePath string) string { + info, err := os.Stat(filepath.Join(datastorePath, ".lock")) + if err != nil { + return fmt.Sprintf("datastore %s missing .lock file: %v", datastorePath, err) + } + if !info.Mode().IsRegular() { + return fmt.Sprintf("datastore %s .lock is not a regular file (type=%s)", datastorePath, info.Mode()) + } + return "" +} + +func pbsDatastoreHasData(datastorePath string) (bool, error) { + if strings.TrimSpace(datastorePath) == "" { + return false, fmt.Errorf("path is empty") + } + exists, err := existingDirectoryOrNoData(datastorePath) + if err != nil || !exists { + return false, err + } + return anyPBSDatastoreDataDirHasEntries(datastorePath) +} + +func anyPBSDatastoreDataDirHasEntries(datastorePath string) (bool, error) { + for _, subdir := range pbsDatastoreSubdirs { + has, err := dirHasAnyEntry(filepath.Join(datastorePath, subdir)) + if errors.Is(err, os.ErrNotExist) { + continue + } + if err != nil || has { + return has, err + } + } + return false, nil +} + +func pbsDatastoreHasUnexpectedEntries(datastorePath string) (bool, error) { + if strings.TrimSpace(datastorePath) == "" { + return false, nil + } + exists, err := existingDirectoryOrNoData(datastorePath) + if err != nil || !exists { + return false, err + } + return datastoreContainsUnexpectedEntries(datastorePath) +} + +func existingDirectoryOrNoData(path string) (bool, error) { + info, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) || errors.Is(err, syscall.ENOTDIR) { + return false, nil + } + return false, err + } + return info.IsDir(), nil +} + +func datastoreContainsUnexpectedEntries(datastorePath string) (unexpected bool, err error) { + f, err := os.Open(datastorePath) + if err != nil { + return false, err + } + defer closeIntoErr(&err, f, "close datastore directory") + return readerContainsUnexpectedEntries(f) +} + +func readerContainsUnexpectedEntries(f *os.File) (bool, error) { + for { + names, err := f.Readdirnames(64) + if err != nil { + return handleDatastoreReaddirError(err) + } + if hasUnexpectedDatastoreName(names) { + return true, nil + } + } +} + +func handleDatastoreReaddirError(err error) (bool, error) { + if errors.Is(err, io.EOF) { + return false, nil + } + return false, err +} + +func hasUnexpectedDatastoreName(names []string) bool { + for _, name := range names { + if _, ok := allowedPBSDatastoreScaffoldEntries[name]; !ok { + return true + } + } + return false +} + +func dirHasAnyEntry(path string) (hasEntry bool, err error) { + f, err := os.Open(path) + if err != nil { + return false, err + } + defer closeIntoErr(&err, f, "close directory") + + _, err = f.Readdirnames(1) + if err == nil { + return true, nil + } + if errors.Is(err, io.EOF) { + return false, nil + } + return false, err +} diff --git a/internal/orchestrator/directory_recreation_pbs_lock.go b/internal/orchestrator/directory_recreation_pbs_lock.go new file mode 100644 index 00000000..5d2bcc9d --- /dev/null +++ b/internal/orchestrator/directory_recreation_pbs_lock.go @@ -0,0 +1,95 @@ +// Package orchestrator coordinates backup, restore, decrypt, and notification workflows. +package orchestrator + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func ensurePBSDatastoreLockFile(datastorePath string, logger *logging.Logger) (bool, error) { + lockPath := datastoreLockPath(datastorePath) + info, err := os.Lstat(lockPath) + if err != nil { + return ensureMissingPBSDatastoreLock(lockPath, err, logger) + } + if info.Mode()&os.ModeSymlink != 0 { + return false, fmt.Errorf("%s is a symlink; refusing to manage lock file", lockPath) + } + if info.IsDir() { + return replacePBSDatastoreLockDirectory(lockPath, logger) + } + return chownExistingPBSDatastoreLock(lockPath, logger) +} + +func datastoreLockPath(datastorePath string) string { + return filepath.Join(datastorePath, ".lock") +} + +func ensureMissingPBSDatastoreLock(lockPath string, statErr error, logger *logging.Logger) (bool, error) { + if !os.IsNotExist(statErr) { + return false, fmt.Errorf("stat %s: %w", lockPath, statErr) + } + + logger.Debug("PBS datastore lock: creating %s", lockPath) + if err := createPBSDatastoreLockFile(lockPath); err != nil { + return false, err + } + if err := setDatastoreOwnership(lockPath, logger); err != nil { + return true, fmt.Errorf("chown %s: %w", lockPath, err) + } + return true, nil +} + +func replacePBSDatastoreLockDirectory(lockPath string, logger *logging.Logger) (bool, error) { + changed, err := removeOrRenamePBSDatastoreLockDirectory(lockPath, logger) + if err != nil { + return false, err + } + if err := createPBSDatastoreLockFile(lockPath); err != nil { + return changed, err + } + if err := setDatastoreOwnership(lockPath, logger); err != nil { + return true, fmt.Errorf("chown %s: %w", lockPath, err) + } + return true, nil +} + +func removeOrRenamePBSDatastoreLockDirectory(lockPath string, logger *logging.Logger) (bool, error) { + entries, err := os.ReadDir(lockPath) + if err != nil { + return false, fmt.Errorf("lock path %s is a directory and cannot be read: %w", lockPath, err) + } + if len(entries) == 0 { + logger.Warning("PBS datastore lock: %s is a directory (invalid); removing and recreating as file", lockPath) + if err := os.Remove(lockPath); err != nil { + return false, err + } + return true, nil + } + + backupPath := fmt.Sprintf("%s.proxsave-dir.%s", lockPath, nowRestore().Format("20060102-150405")) + logger.Warning("PBS datastore lock: %s is a non-empty directory (invalid); renaming to %s and creating lock file", lockPath, backupPath) + if err := os.Rename(lockPath, backupPath); err != nil { + return false, fmt.Errorf("rename invalid lock dir %s -> %s: %w", lockPath, backupPath, err) + } + return true, nil +} + +func createPBSDatastoreLockFile(lockPath string) error { + file, err := os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o640) + if err != nil { + return fmt.Errorf("create %s: %w", lockPath, err) + } + _ = file.Close() + return nil +} + +func chownExistingPBSDatastoreLock(lockPath string, logger *logging.Logger) (bool, error) { + if err := setDatastoreOwnership(lockPath, logger); err != nil { + return false, fmt.Errorf("chown %s: %w", lockPath, err) + } + return false, nil +} diff --git a/internal/orchestrator/directory_recreation_pve.go b/internal/orchestrator/directory_recreation_pve.go new file mode 100644 index 00000000..48a187de --- /dev/null +++ b/internal/orchestrator/directory_recreation_pve.go @@ -0,0 +1,41 @@ +// Package orchestrator coordinates backup, restore, decrypt, and notification workflows. +package orchestrator + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/tis24dev/proxsave/internal/logging" +) + +var pveStorageSubdirs = map[string][]string{ + "dir": {"dump", "images", "template", "snippets", "private"}, + "nfs": {"dump", "images", "template"}, + "cifs": {"dump", "images", "template"}, +} + +// createPVEStorageStructure creates the directory structure for a PVE storage +func createPVEStorageStructure(basePath, storageType string, logger *logging.Logger) error { + if err := os.MkdirAll(basePath, 0750); err != nil { + return fmt.Errorf("create base directory: %w", err) + } + + subdirs, ok := pveStorageSubdirs[storageType] + if !ok { + logger.Debug("Storage type %s does not require subdirectories", storageType) + return nil + } + + createStorageSubdirs(basePath, subdirs, logger) + return nil +} + +func createStorageSubdirs(basePath string, subdirs []string, logger *logging.Logger) { + for _, subdir := range subdirs { + path := filepath.Join(basePath, subdir) + if err := os.MkdirAll(path, 0750); err != nil { + logger.Warning("Failed to create %s: %v", path, err) + } + } +} diff --git a/internal/orchestrator/directory_recreation_test.go b/internal/orchestrator/directory_recreation_test.go index 6a4da559..0ea70867 100644 --- a/internal/orchestrator/directory_recreation_test.go +++ b/internal/orchestrator/directory_recreation_test.go @@ -188,79 +188,85 @@ func TestNormalizePBSDatastoreCfgContentNoChangesWhenValid(t *testing.T) { } func TestRecreateDirectoriesFromConfigRoutes(t *testing.T) { + t.Run("PVE", testRecreateDirectoriesFromConfigPVE) + t.Run("PBS", testRecreateDirectoriesFromConfigPBS) + t.Run("Dual", testRecreateDirectoriesFromConfigDual) + t.Run("Unknown", testRecreateDirectoriesFromConfigUnknown) +} + +func testRecreateDirectoriesFromConfigPVE(t *testing.T) { logger := newTestLogger() + baseDir := filepath.Join(t.TempDir(), "local") + cfg := fmt.Sprintf("dir: local\n path %s\n", baseDir) + cfgPath, restore := overridePath(t, &storageCfgPath, "storage.cfg") + t.Cleanup(restore) + writeFile(t, cfgPath, cfg) - t.Run("PVE", func(t *testing.T) { - baseDir := filepath.Join(t.TempDir(), "local") - cfg := fmt.Sprintf("dir: local\n path %s\n", baseDir) - cfgPath, restore := overridePath(t, &storageCfgPath, "storage.cfg") - t.Cleanup(restore) - writeFile(t, cfgPath, cfg) + if err := RecreateDirectoriesFromConfig(SystemTypePVE, logger); err != nil { + t.Fatalf("RecreateDirectoriesFromConfig PVE: %v", err) + } + if _, err := os.Stat(filepath.Join(baseDir, "images")); err != nil { + t.Fatalf("expected PVE directories to be created: %v", err) + } +} - if err := RecreateDirectoriesFromConfig(SystemTypePVE, logger); err != nil { - t.Fatalf("RecreateDirectoriesFromConfig PVE: %v", err) - } - if _, err := os.Stat(filepath.Join(baseDir, "images")); err != nil { - t.Fatalf("expected PVE directories to be created: %v", err) - } - }) - - t.Run("PBS", func(t *testing.T) { - baseDir := filepath.Join(t.TempDir(), "data") - cfg := fmt.Sprintf("datastore: main\n path %s\n", baseDir) - cfgPath, restore := overridePath(t, &datastoreCfgPath, "datastore.cfg") - t.Cleanup(restore) - writeFile(t, cfgPath, cfg) - - cachePath, cacheRestore := overridePath(t, &zpoolCachePath, "zpool.cache") - t.Cleanup(cacheRestore) - if err := os.RemoveAll(cachePath); err != nil && !os.IsNotExist(err) { - t.Fatalf("cleanup cache path: %v", err) - } +func testRecreateDirectoriesFromConfigPBS(t *testing.T) { + logger := newTestLogger() + baseDir := filepath.Join(t.TempDir(), "data") + cfg := fmt.Sprintf("datastore: main\n path %s\n", baseDir) + cfgPath, restore := overridePath(t, &datastoreCfgPath, "datastore.cfg") + t.Cleanup(restore) + writeFile(t, cfgPath, cfg) + removeZpoolCacheForTest(t) - if err := RecreateDirectoriesFromConfig(SystemTypePBS, logger); err != nil { - t.Fatalf("RecreateDirectoriesFromConfig PBS: %v", err) - } - if _, err := os.Stat(filepath.Join(baseDir, ".chunks")); err != nil { - t.Fatalf("expected PBS directories to be created: %v", err) - } - }) - - t.Run("Dual", func(t *testing.T) { - pveBaseDir := filepath.Join(t.TempDir(), "local") - pveCfg := fmt.Sprintf("dir: local\n path %s\n", pveBaseDir) - pveCfgPath, restorePVE := overridePath(t, &storageCfgPath, "storage.cfg") - t.Cleanup(restorePVE) - writeFile(t, pveCfgPath, pveCfg) - - pbsBaseDir := filepath.Join(t.TempDir(), "data") - pbsCfg := fmt.Sprintf("datastore: main\n path %s\n", pbsBaseDir) - pbsCfgPath, restorePBS := overridePath(t, &datastoreCfgPath, "datastore.cfg") - t.Cleanup(restorePBS) - writeFile(t, pbsCfgPath, pbsCfg) - - cachePath, cacheRestore := overridePath(t, &zpoolCachePath, "zpool.cache") - t.Cleanup(cacheRestore) - if err := os.RemoveAll(cachePath); err != nil && !os.IsNotExist(err) { - t.Fatalf("cleanup cache path: %v", err) - } + if err := RecreateDirectoriesFromConfig(SystemTypePBS, logger); err != nil { + t.Fatalf("RecreateDirectoriesFromConfig PBS: %v", err) + } + if _, err := os.Stat(filepath.Join(baseDir, ".chunks")); err != nil { + t.Fatalf("expected PBS directories to be created: %v", err) + } +} - if err := RecreateDirectoriesFromConfig(SystemTypeDual, logger); err != nil { - t.Fatalf("RecreateDirectoriesFromConfig Dual: %v", err) - } - if _, err := os.Stat(filepath.Join(pveBaseDir, "images")); err != nil { - t.Fatalf("expected PVE directories to be created for dual system: %v", err) - } - if _, err := os.Stat(filepath.Join(pbsBaseDir, ".chunks")); err != nil { - t.Fatalf("expected PBS directories to be created for dual system: %v", err) - } - }) +func testRecreateDirectoriesFromConfigDual(t *testing.T) { + logger := newTestLogger() + pveBaseDir := filepath.Join(t.TempDir(), "local") + pveCfg := fmt.Sprintf("dir: local\n path %s\n", pveBaseDir) + pveCfgPath, restorePVE := overridePath(t, &storageCfgPath, "storage.cfg") + t.Cleanup(restorePVE) + writeFile(t, pveCfgPath, pveCfg) - t.Run("Unknown", func(t *testing.T) { - if err := RecreateDirectoriesFromConfig(SystemTypeUnknown, logger); err != nil { - t.Fatalf("RecreateDirectoriesFromConfig unknown: %v", err) - } - }) + pbsBaseDir := filepath.Join(t.TempDir(), "data") + pbsCfg := fmt.Sprintf("datastore: main\n path %s\n", pbsBaseDir) + pbsCfgPath, restorePBS := overridePath(t, &datastoreCfgPath, "datastore.cfg") + t.Cleanup(restorePBS) + writeFile(t, pbsCfgPath, pbsCfg) + removeZpoolCacheForTest(t) + + if err := RecreateDirectoriesFromConfig(SystemTypeDual, logger); err != nil { + t.Fatalf("RecreateDirectoriesFromConfig Dual: %v", err) + } + if _, err := os.Stat(filepath.Join(pveBaseDir, "images")); err != nil { + t.Fatalf("expected PVE directories to be created for dual system: %v", err) + } + if _, err := os.Stat(filepath.Join(pbsBaseDir, ".chunks")); err != nil { + t.Fatalf("expected PBS directories to be created for dual system: %v", err) + } +} + +func testRecreateDirectoriesFromConfigUnknown(t *testing.T) { + logger := newTestLogger() + if err := RecreateDirectoriesFromConfig(SystemTypeUnknown, logger); err != nil { + t.Fatalf("RecreateDirectoriesFromConfig unknown: %v", err) + } +} + +func removeZpoolCacheForTest(t *testing.T) { + t.Helper() + cachePath, cacheRestore := overridePath(t, &zpoolCachePath, "zpool.cache") + t.Cleanup(cacheRestore) + if err := os.RemoveAll(cachePath); err != nil && !os.IsNotExist(err) { + t.Fatalf("cleanup cache path: %v", err) + } } // Test: RecreateStorageDirectories quando il file non esiste @@ -523,7 +529,7 @@ func TestRecreateDirectoriesFromConfigPVEStatError(t *testing.T) { if err := os.MkdirAll(cfgDir, 0o000); err != nil { t.Skipf("cannot create restricted directory: %v", err) } - defer os.Chmod(cfgDir, 0o755) + defer func() { _ = os.Chmod(cfgDir, 0o755) }() cfgPath := filepath.Join(cfgDir, "storage.cfg") prev := storageCfgPath @@ -549,7 +555,7 @@ func TestRecreateDirectoriesFromConfigPBSStatError(t *testing.T) { if err := os.MkdirAll(cfgDir, 0o000); err != nil { t.Skipf("cannot create restricted directory: %v", err) } - defer os.Chmod(cfgDir, 0o755) + defer func() { _ = os.Chmod(cfgDir, 0o755) }() cfgPath := filepath.Join(cfgDir, "datastore.cfg") prev := datastoreCfgPath diff --git a/internal/orchestrator/encryption.go b/internal/orchestrator/encryption.go index 7a642e88..f5171f19 100644 --- a/internal/orchestrator/encryption.go +++ b/internal/orchestrator/encryption.go @@ -297,14 +297,13 @@ func parseRecipientString(value string) (age.Recipient, error) { } } -func readRecipientFile(path string) ([]string, error) { +func readRecipientFile(path string) (recipients []string, err error) { f, err := os.Open(path) if err != nil { return nil, err } - defer f.Close() + defer closeIntoErr(&err, f, "close recipient file") - var recipients []string scanner := bufio.NewScanner(f) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) @@ -438,7 +437,7 @@ func writeFileAtomicWithDeps(fs FS, tp TimeProvider, path string, data []byte, p return syncDirectoryWithDeps(fs, filepath.Dir(path)) } -func copyRecipientFileWithDeps(fs FS, src, dest string, perm os.FileMode) error { +func copyRecipientFileWithDeps(fs FS, src, dest string, perm os.FileMode) (err error) { if fs == nil { fs = osFS{} } @@ -447,7 +446,7 @@ func copyRecipientFileWithDeps(fs FS, src, dest string, perm os.FileMode) error if err != nil { return err } - defer in.Close() + defer closeIntoErr(&err, in, "close recipient source file") out, err := fs.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_EXCL, perm) if err != nil { diff --git a/internal/orchestrator/encryption_exported_test.go b/internal/orchestrator/encryption_exported_test.go index 980634d9..6ad2b956 100644 --- a/internal/orchestrator/encryption_exported_test.go +++ b/internal/orchestrator/encryption_exported_test.go @@ -175,16 +175,22 @@ func TestPrepareAgeRecipients_NoRecipientsNonInteractiveErrors(t *testing.T) { if err != nil { t.Fatalf("pipe stdin: %v", err) } + cleanupClose := func(name string, closeFn func() error) { + t.Cleanup(func() { + if err := closeFn(); err != nil { + t.Errorf("close %s: %v", name, err) + } + }) + } + cleanupClose("stdin reader", inR.Close) + cleanupClose("stdin writer", inW.Close) + outR, outW, err := os.Pipe() if err != nil { - inR.Close() - inW.Close() t.Fatalf("pipe stdout: %v", err) } - defer inR.Close() - defer inW.Close() - defer outR.Close() - defer outW.Close() + cleanupClose("stdout reader", outR.Close) + cleanupClose("stdout writer", outW.Close) os.Stdin = inR os.Stdout = outW diff --git a/internal/orchestrator/fs_atomic.go b/internal/orchestrator/fs_atomic.go index 879f0ce3..c4e6f856 100644 --- a/internal/orchestrator/fs_atomic.go +++ b/internal/orchestrator/fs_atomic.go @@ -105,10 +105,7 @@ func ensureDirExistsWithInheritedMeta(dir string) error { var toCreate []string cur := dir - for { - if cur == existing || cur == "" || cur == "." { - break - } + for cur != existing && cur != "" && cur != "." { toCreate = append([]string{cur}, toCreate...) parent := filepath.Dir(cur) if parent == cur { diff --git a/internal/orchestrator/log_parser.go b/internal/orchestrator/log_parser.go index 42220878..5d49aecb 100644 --- a/internal/orchestrator/log_parser.go +++ b/internal/orchestrator/log_parser.go @@ -19,7 +19,7 @@ func ParseLogCounts(logPath string, categoryLimit int) (categories []notify.LogC if err != nil { return nil, 0, 0 } - defer file.Close() + defer func() { _ = file.Close() }() scanner := bufio.NewScanner(file) buf := make([]byte, 0, 64*1024) diff --git a/internal/orchestrator/mount_guard.go b/internal/orchestrator/mount_guard.go index 1cf699f5..836ab1e8 100644 --- a/internal/orchestrator/mount_guard.go +++ b/internal/orchestrator/mount_guard.go @@ -1,3 +1,4 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. package orchestrator import ( @@ -11,8 +12,6 @@ import ( "strings" "syscall" "time" - - "github.com/tis24dev/proxsave/internal/logging" ) const mountGuardBaseDir = "/var/lib/proxsave/guards" @@ -30,242 +29,69 @@ var ( mountGuardParsePBSDatastoreCfg = parsePBSDatastoreCfgBlocks ) -func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logger, plan *RestorePlan, stageRoot, destRoot string, dryRun bool) error { - if plan == nil || !plan.SystemType.SupportsPBS() || !plan.HasCategoryID("datastore_pbs") { - return nil - } - if strings.TrimSpace(stageRoot) == "" { - return nil - } - if filepath.Clean(strings.TrimSpace(destRoot)) != string(os.PathSeparator) { - if logger != nil { - logger.Debug("Skipping PBS mount guards: restore destination is not system root (dest=%s)", destRoot) - } - return nil - } - - if dryRun { - if logger != nil { - logger.Info("Dry run enabled: skipping PBS mount guards") - } - return nil - } - if !isRealRestoreFS(restoreFS) { - if logger != nil { - logger.Debug("Skipping PBS mount guards: non-system filesystem in use") - } - return nil - } - if mountGuardGeteuid() != 0 { - if logger != nil { - logger.Warning("Skipping PBS mount guards: requires root privileges") - } - return nil - } - - stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") - data, err := restoreFS.ReadFile(stagePath) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil - } - return fmt.Errorf("read staged datastore.cfg: %w", err) - } - if strings.TrimSpace(string(data)) == "" { - return nil - } - - normalized, _ := normalizePBSDatastoreCfgContent(string(data)) - blocks, err := mountGuardParsePBSDatastoreCfg(normalized) +func guardMountPoint(ctx context.Context, guardTarget string) error { + target, err := normalizeGuardMountRequest(ctx, guardTarget) if err != nil { return err } - if len(blocks) == 0 { + if err := ensureGuardTargetUnmounted(target); err != nil { + return fmt.Errorf("check mount status: %w", err) + } else if isAlreadyMounted(target) { return nil } - var fstabMounts map[string]struct{} - var mountpointCandidates []string - currentFstab := filepath.Join(destRoot, "etc", "fstab") - if mounts, err := mountGuardFstabMountpointsSet(currentFstab); err != nil { - if logger != nil { - logger.Warning("PBS mount guard: unable to parse current fstab %s: %v (continuing without fstab cross-check)", currentFstab, err) - } - } else { - fstabMounts = mounts - for mp := range mounts { - if mp == "" || mp == "." || mp == string(os.PathSeparator) { - continue - } - if !isConfirmableDatastoreMountRoot(mp) { - continue - } - mountpointCandidates = append(mountpointCandidates, mp) - } - sortByLengthDesc(mountpointCandidates) - } - - protected := make(map[string]struct{}) - for _, block := range blocks { - dsPath := filepath.Clean(strings.TrimSpace(block.Path)) - if dsPath == "" || dsPath == "." || dsPath == string(os.PathSeparator) { - continue - } - - guardTarget := "" - if len(mountpointCandidates) > 0 { - guardTarget = firstFstabMountpointMatch(dsPath, mountpointCandidates) - } - if guardTarget == "" { - guardTarget = pbsMountGuardRootForDatastorePath(dsPath) - } - guardTarget = filepath.Clean(strings.TrimSpace(guardTarget)) - if guardTarget == "" || guardTarget == "." || guardTarget == string(os.PathSeparator) { - continue - } - if _, seen := protected[guardTarget]; seen { - continue - } - - // If we can parse /etc/fstab, only guard mountpoints that exist there. - // This avoids making local (rootfs) datastores immutable by mistake. - if fstabMounts != nil { - if _, ok := fstabMounts[guardTarget]; !ok { - continue - } - } - - if err := mountGuardMkdirAll(guardTarget, 0o755); err != nil { - if logger != nil { - logger.Warning("PBS mount guard: unable to create mountpoint directory %s: %v", guardTarget, err) - } - continue - } - - onRootFS, _, devErr := mountGuardIsPathOnRootFilesystem(guardTarget) - if devErr != nil { - if logger != nil { - logger.Warning("PBS mount guard: unable to determine filesystem device for %s: %v", guardTarget, devErr) - } - continue - } - if !onRootFS { - continue - } - - mounted, mountErr := isMounted(guardTarget) - if mountErr != nil && logger != nil { - logger.Warning("PBS mount guard: unable to check mount status for %s: %v (continuing)", guardTarget, mountErr) - } - if mountErr == nil && mounted { - if logger != nil { - logger.Debug("PBS mount guard: mountpoint %s already mounted, skipping guard", guardTarget) - } - continue - } - - // Best-effort attempt to mount now (the entry may have just been restored to /etc/fstab). - // If the storage is online, this avoids applying guards on mountpoints that would mount cleanly. - mountCtx, cancel := context.WithTimeout(ctx, mountGuardMountAttemptTimeout) - out, attemptErr := restoreCmd.Run(mountCtx, "mount", guardTarget) - cancel() - if attemptErr == nil { - onRootFSNow, _, devErrNow := mountGuardIsPathOnRootFilesystem(guardTarget) - if devErrNow == nil && !onRootFSNow { - if logger != nil { - logger.Info("PBS mount guard: mountpoint %s is now mounted (mount attempt succeeded)", guardTarget) - } - continue - } - if mountedNow, mountErrNow := isMounted(guardTarget); mountErrNow == nil && mountedNow { - if logger != nil { - logger.Info("PBS mount guard: mountpoint %s is now mounted (mount attempt succeeded)", guardTarget) - } - continue - } - } else { - if logger != nil { - if errors.Is(mountCtx.Err(), context.DeadlineExceeded) { - logger.Warning("PBS mount guard: mount attempt timed out for %s after %s", guardTarget, mountGuardMountAttemptTimeout) - } else { - trimmed := strings.TrimSpace(string(out)) - if trimmed != "" { - logger.Debug("PBS mount guard: mount attempt failed for %s: %v (output=%s)", guardTarget, attemptErr, trimmed) - } else { - logger.Debug("PBS mount guard: mount attempt failed for %s: %v", guardTarget, attemptErr) - } - } - } - } - - if logger != nil { - logger.Info("PBS mount guard: mountpoint %s offline, applying guard bind mount", guardTarget) - } - - if err := guardMountPoint(ctx, guardTarget); err != nil { - if logger != nil { - logger.Warning("PBS mount guard: failed to bind-mount guard on %s: %v; falling back to chattr +i", guardTarget, err) - } - if _, fallbackErr := restoreCmd.Run(ctx, "chattr", "+i", guardTarget); fallbackErr != nil { - if logger != nil { - logger.Warning("PBS mount guard: failed to set immutable attribute on %s: %v", guardTarget, fallbackErr) - } - continue - } - protected[guardTarget] = struct{}{} - if logger != nil { - logger.Warning("PBS mount guard: %s resolves to root filesystem (mount missing?) — marked immutable (chattr +i) to prevent writes until storage is available", guardTarget) - } - continue - } - - protected[guardTarget] = struct{}{} - if logger != nil { - if entries, err := mountGuardReadDir(guardTarget); err == nil && len(entries) > 0 { - logger.Warning("PBS mount guard: guard mount point %s is not empty (entries=%d)", guardTarget, len(entries)) - } - logger.Warning("PBS mount guard: %s resolves to root filesystem (mount missing?) — bind-mounted a read-only guard to prevent writes until storage is available", guardTarget) - } + guardDir := guardDirForTarget(target) + if err := ensureGuardDirectories(guardDir, target); err != nil { + return err } - - return nil + return bindReadOnlyGuard(guardDir, target) } -func guardMountPoint(ctx context.Context, guardTarget string) error { +func normalizeGuardMountRequest(ctx context.Context, guardTarget string) (string, error) { if ctx == nil { ctx = context.Background() } if err := ctx.Err(); err != nil { - return err + return "", err } - target := filepath.Clean(strings.TrimSpace(guardTarget)) - if target == "" || target == "." || target == string(os.PathSeparator) { - return fmt.Errorf("invalid guard target: %q", guardTarget) + if !isValidGuardTarget(target) { + return "", fmt.Errorf("invalid guard target: %q", guardTarget) } + return target, nil +} +func ensureGuardTargetUnmounted(target string) error { mounted, err := isMounted(target) if err != nil { - return fmt.Errorf("check mount status: %w", err) + return err } if mounted { return nil } + return nil +} - guardDir := guardDirForTarget(target) +func isAlreadyMounted(target string) bool { + mounted, err := isMounted(target) + return err == nil && mounted +} + +func ensureGuardDirectories(guardDir, target string) error { if err := mountGuardMkdirAll(guardDir, 0o755); err != nil { return fmt.Errorf("mkdir guard dir: %w", err) } if err := mountGuardMkdirAll(target, 0o755); err != nil { return fmt.Errorf("mkdir target: %w", err) } + return nil +} - // Bind mount guard directory over the mountpoint to avoid writes to the underlying rootfs path. +func bindReadOnlyGuard(guardDir, target string) error { if err := mountGuardSysMount(guardDir, target, "", syscall.MS_BIND, ""); err != nil { return fmt.Errorf("bind mount guard: %w", err) } - // Make the bind mount read-only to ensure PBS cannot write backup data to the guard directory. remountFlags := uintptr(syscall.MS_BIND | syscall.MS_REMOUNT | syscall.MS_RDONLY | syscall.MS_NODEV | syscall.MS_NOSUID | syscall.MS_NOEXEC) if err := mountGuardSysMount("", target, "", remountFlags, ""); err != nil { _ = mountGuardSysUnmount(target, 0) @@ -355,8 +181,6 @@ func isMountedFromProcMounts(path string) (bool, error) { } func unescapeProcPath(s string) string { - // /proc/self/mountinfo uses octal escapes: \040, \011, \012, \134. - // Keep it minimal: decode any \XYZ sequence where XYZ are octal digits and the value fits into a byte (0-255). if !strings.Contains(s, "\\") { return s } @@ -364,31 +188,39 @@ func unescapeProcPath(s string) string { var b strings.Builder b.Grow(len(s)) for i := 0; i < len(s); { - if s[i] != '\\' || i+3 >= len(s) { + if !hasProcOctalEscapeAt(s, i) { _ = b.WriteByte(s[i]) i++ continue } - oct := s[i+1 : i+4] - if oct[0] < '0' || oct[0] > '7' || oct[1] < '0' || oct[1] > '7' || oct[2] < '0' || oct[2] > '7' { - _ = b.WriteByte(s[i]) - i++ - continue - } - - val := (int(oct[0]-'0') << 6) | (int(oct[1]-'0') << 3) | int(oct[2]-'0') - if val > 255 { - _ = b.WriteByte(s[i]) - i++ - continue - } - _ = b.WriteByte(byte(val)) + _ = b.WriteByte(procOctalEscapeValue(s[i+1 : i+4])) i += 4 } return b.String() } +func hasProcOctalEscapeAt(s string, i int) bool { + return i+3 < len(s) && + s[i] == '\\' && + isOctalDigit(s[i+1]) && + isOctalDigit(s[i+2]) && + isOctalDigit(s[i+3]) && + procOctalEscapeInt(s[i+1:i+4]) <= 255 +} + +func isOctalDigit(b byte) bool { + return b >= '0' && b <= '7' +} + +func procOctalEscapeValue(oct string) byte { + return byte(procOctalEscapeInt(oct)) +} + +func procOctalEscapeInt(oct string) int { + return (int(oct[0]-'0') << 6) | (int(oct[1]-'0') << 3) | int(oct[2]-'0') +} + func fstabMountpointsSet(path string) (map[string]struct{}, error) { entries, _, err := parseFstab(path) if err != nil { @@ -464,17 +296,27 @@ func sortByLengthDesc(items []string) { func firstFstabMountpointMatch(datastorePath string, mountpoints []string) string { ds := filepath.Clean(strings.TrimSpace(datastorePath)) - if ds == "" || ds == "." || ds == string(os.PathSeparator) { + if !isValidGuardTarget(ds) { return "" } for _, mp := range mountpoints { - if mp == "" || mp == "." || mp == string(os.PathSeparator) { - continue - } - if ds == mp || strings.HasPrefix(ds, mp+string(os.PathSeparator)) { + if mountpointContainsDatastore(mp, ds) { return mp } } return "" } + +func mountpointContainsDatastore(mountpoint, datastorePath string) bool { + mp := filepath.Clean(strings.TrimSpace(mountpoint)) + if !isValidGuardTarget(mp) { + return false + } + return datastorePath == mp || strings.HasPrefix(datastorePath, mp+string(os.PathSeparator)) +} + +func isValidGuardTarget(path string) bool { + path = filepath.Clean(strings.TrimSpace(path)) + return path != "" && path != "." && path != string(os.PathSeparator) +} diff --git a/internal/orchestrator/mount_guard_apply.go b/internal/orchestrator/mount_guard_apply.go new file mode 100644 index 00000000..f1bdf308 --- /dev/null +++ b/internal/orchestrator/mount_guard_apply.go @@ -0,0 +1,269 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/tis24dev/proxsave/internal/logging" +) + +type pbsMountGuardApply struct { + ctx context.Context + logger *logging.Logger + plan *RestorePlan + stageRoot string + destRoot string + dryRun bool + fstabMounts map[string]struct{} + mountpointCandidates []string + protected map[string]struct{} +} + +func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logger, plan *RestorePlan, stageRoot, destRoot string, dryRun bool) error { + apply := &pbsMountGuardApply{ + ctx: ctx, + logger: logger, + plan: plan, + stageRoot: stageRoot, + destRoot: destRoot, + dryRun: dryRun, + protected: make(map[string]struct{}), + } + return apply.run() +} + +func (a *pbsMountGuardApply) run() error { + if !a.shouldRun() { + return nil + } + + blocks, err := a.stagedDatastoreBlocks() + if err != nil || len(blocks) == 0 { + return err + } + + a.loadFstabMountpoints() + for _, block := range blocks { + a.applyDatastoreBlock(block) + } + return nil +} + +func (a *pbsMountGuardApply) shouldRun() bool { + if a.plan == nil || !a.plan.SystemType.SupportsPBS() || !a.plan.HasCategoryID("datastore_pbs") { + return false + } + if strings.TrimSpace(a.stageRoot) == "" { + return false + } + if filepath.Clean(strings.TrimSpace(a.destRoot)) != string(os.PathSeparator) { + a.debug("Skipping PBS mount guards: restore destination is not system root (dest=%s)", a.destRoot) + return false + } + return a.runtimeAllowsMountGuards() +} + +func (a *pbsMountGuardApply) runtimeAllowsMountGuards() bool { + if a.dryRun { + a.info("Dry run enabled: skipping PBS mount guards") + return false + } + if !isRealRestoreFS(restoreFS) { + a.debug("Skipping PBS mount guards: non-system filesystem in use") + return false + } + if mountGuardGeteuid() != 0 { + a.warning("Skipping PBS mount guards: requires root privileges") + return false + } + return true +} + +func (a *pbsMountGuardApply) stagedDatastoreBlocks() ([]pbsDatastoreBlock, error) { + stagePath := filepath.Join(a.stageRoot, "etc/proxmox-backup/datastore.cfg") + data, err := restoreFS.ReadFile(stagePath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("read staged datastore.cfg: %w", err) + } + if strings.TrimSpace(string(data)) == "" { + return nil, nil + } + + normalized, _ := normalizePBSDatastoreCfgContent(string(data)) + return mountGuardParsePBSDatastoreCfg(normalized) +} + +func (a *pbsMountGuardApply) loadFstabMountpoints() { + currentFstab := filepath.Join(a.destRoot, "etc", "fstab") + mounts, err := mountGuardFstabMountpointsSet(currentFstab) + if err != nil { + a.warning("PBS mount guard: unable to parse current fstab %s: %v (continuing without fstab cross-check)", currentFstab, err) + return + } + + a.fstabMounts = mounts + for mp := range mounts { + if isValidGuardTarget(mp) && isConfirmableDatastoreMountRoot(mp) { + a.mountpointCandidates = append(a.mountpointCandidates, mp) + } + } + sortByLengthDesc(a.mountpointCandidates) +} + +func (a *pbsMountGuardApply) applyDatastoreBlock(block pbsDatastoreBlock) { + dsPath := filepath.Clean(strings.TrimSpace(block.Path)) + if !isValidGuardTarget(dsPath) { + return + } + + guardTarget := a.guardTargetForDatastore(dsPath) + if !a.shouldProtectTarget(guardTarget) { + return + } + if !a.prepareOfflineGuardTarget(guardTarget) { + return + } + if a.mountAttemptSucceeded(guardTarget) { + return + } + a.protectOfflineTarget(guardTarget) +} + +func (a *pbsMountGuardApply) guardTargetForDatastore(dsPath string) string { + guardTarget := "" + if len(a.mountpointCandidates) > 0 { + guardTarget = firstFstabMountpointMatch(dsPath, a.mountpointCandidates) + } + if guardTarget == "" { + guardTarget = pbsMountGuardRootForDatastorePath(dsPath) + } + return filepath.Clean(strings.TrimSpace(guardTarget)) +} + +func (a *pbsMountGuardApply) shouldProtectTarget(guardTarget string) bool { + if !isValidGuardTarget(guardTarget) { + return false + } + if _, seen := a.protected[guardTarget]; seen { + return false + } + if a.fstabMounts == nil { + return true + } + _, ok := a.fstabMounts[guardTarget] + return ok +} + +func (a *pbsMountGuardApply) prepareOfflineGuardTarget(guardTarget string) bool { + if err := mountGuardMkdirAll(guardTarget, 0o755); err != nil { + a.warning("PBS mount guard: unable to create mountpoint directory %s: %v", guardTarget, err) + return false + } + + onRootFS, _, err := mountGuardIsPathOnRootFilesystem(guardTarget) + if err != nil { + a.warning("PBS mount guard: unable to determine filesystem device for %s: %v", guardTarget, err) + return false + } + if !onRootFS { + return false + } + + mounted, err := isMounted(guardTarget) + if err != nil { + a.warning("PBS mount guard: mount status probe for %s is inconclusive: %v (skipping guard)", guardTarget, err) + return false + } + if mounted { + a.debug("PBS mount guard: mountpoint %s already mounted, skipping guard", guardTarget) + return false + } + return true +} + +func (a *pbsMountGuardApply) mountAttemptSucceeded(guardTarget string) bool { + mountCtx, cancel := context.WithTimeout(a.ctx, mountGuardMountAttemptTimeout) + out, err := restoreCmd.Run(mountCtx, "mount", guardTarget) + cancel() + if err != nil { + a.logMountAttemptFailure(mountCtx, guardTarget, out, err) + return false + } + if a.targetMovedOffRootFS(guardTarget) || a.targetIsMounted(guardTarget) { + a.info("PBS mount guard: mountpoint %s is now mounted (mount attempt succeeded)", guardTarget) + return true + } + return false +} + +func (a *pbsMountGuardApply) targetMovedOffRootFS(guardTarget string) bool { + onRootFS, _, err := mountGuardIsPathOnRootFilesystem(guardTarget) + return err == nil && !onRootFS +} + +func (a *pbsMountGuardApply) targetIsMounted(guardTarget string) bool { + mounted, err := isMounted(guardTarget) + return err == nil && mounted +} + +func (a *pbsMountGuardApply) logMountAttemptFailure(mountCtx context.Context, guardTarget string, out []byte, err error) { + if errors.Is(mountCtx.Err(), context.DeadlineExceeded) { + a.warning("PBS mount guard: mount attempt timed out for %s after %s", guardTarget, mountGuardMountAttemptTimeout) + return + } + if trimmed := strings.TrimSpace(string(out)); trimmed != "" { + a.debug("PBS mount guard: mount attempt failed for %s: %v (output=%s)", guardTarget, err, trimmed) + return + } + a.debug("PBS mount guard: mount attempt failed for %s: %v", guardTarget, err) +} + +func (a *pbsMountGuardApply) protectOfflineTarget(guardTarget string) { + a.info("PBS mount guard: mountpoint %s offline, applying guard bind mount", guardTarget) + if err := guardMountPoint(a.ctx, guardTarget); err != nil { + a.protectOfflineTargetWithChattr(guardTarget, err) + return + } + + a.protected[guardTarget] = struct{}{} + if entries, err := mountGuardReadDir(guardTarget); err == nil && len(entries) > 0 { + a.warning("PBS mount guard: guard mount point %s is not empty (entries=%d)", guardTarget, len(entries)) + } + a.warning("PBS mount guard: %s resolves to root filesystem (mount missing?) — bind-mounted a read-only guard to prevent writes until storage is available", guardTarget) +} + +func (a *pbsMountGuardApply) protectOfflineTargetWithChattr(guardTarget string, bindErr error) { + a.warning("PBS mount guard: failed to bind-mount guard on %s: %v; falling back to chattr +i", guardTarget, bindErr) + if _, err := restoreCmd.Run(a.ctx, "chattr", "+i", guardTarget); err != nil { + a.warning("PBS mount guard: failed to set immutable attribute on %s: %v", guardTarget, err) + return + } + a.protected[guardTarget] = struct{}{} + a.warning("PBS mount guard: %s resolves to root filesystem (mount missing?) — marked immutable (chattr +i) to prevent writes until storage is available", guardTarget) +} + +func (a *pbsMountGuardApply) debug(format string, args ...interface{}) { + if a.logger != nil { + a.logger.Debug(format, args...) + } +} + +func (a *pbsMountGuardApply) info(format string, args ...interface{}) { + if a.logger != nil { + a.logger.Info(format, args...) + } +} + +func (a *pbsMountGuardApply) warning(format string, args ...interface{}) { + if a.logger != nil { + a.logger.Warning(format, args...) + } +} diff --git a/internal/orchestrator/mount_guard_more_test.go b/internal/orchestrator/mount_guard_more_test.go index 67719ac7..cf3e7470 100644 --- a/internal/orchestrator/mount_guard_more_test.go +++ b/internal/orchestrator/mount_guard_more_test.go @@ -111,7 +111,7 @@ func TestSortByLengthDesc(t *testing.T) { if len(items) != 3 { t.Fatalf("unexpected len: %d", len(items)) } - if !(len(items[0]) >= len(items[1]) && len(items[1]) >= len(items[2])) { + if len(items[0]) < len(items[1]) || len(items[1]) < len(items[2]) { t.Fatalf("expected non-increasing lengths, got %#v", items) } } @@ -269,7 +269,7 @@ func TestGuardMountPoint(t *testing.T) { return nil } - if err := guardMountPoint(nil, "/mnt/nilctx"); err != nil { + if err := guardMountPoint(nil, "/mnt/nilctx"); err != nil { //nolint:staticcheck // Verifies the documented nil context fallback. t.Fatalf("unexpected error: %v", err) } }) @@ -726,14 +726,14 @@ func TestMaybeApplyPBSDatastoreMountGuards_FullFlow(t *testing.T) { buildMountinfo := func() string { var b strings.Builder for mp := range mountedTargets { - b.WriteString(fmt.Sprintf("1 2 3:4 / %s rw - ext4 /dev/sda1 rw\n", mp)) + fmt.Fprintf(&b, "1 2 3:4 / %s rw - ext4 /dev/sda1 rw\n", mp) } return b.String() } buildProcMounts := func() string { var b strings.Builder for mp := range mountedTargets { - b.WriteString(fmt.Sprintf("/dev/sda1 %s ext4 rw 0 0\n", mp)) + fmt.Fprintf(&b, "/dev/sda1 %s ext4 rw 0 0\n", mp) } return b.String() } diff --git a/internal/orchestrator/network_apply_countdown_test.go b/internal/orchestrator/network_apply_countdown_test.go index 37a460fd..7fd60ce3 100644 --- a/internal/orchestrator/network_apply_countdown_test.go +++ b/internal/orchestrator/network_apply_countdown_test.go @@ -57,8 +57,8 @@ func TestPromptNetworkCommitWithCountdown_NonCommitInputReturnsFalse(test *testi func TestPromptNetworkCommitWithCountdown_TimeoutReturnsDeadlineExceeded(test *testing.T) { pipeReader, pipeWriter := io.Pipe() - defer pipeReader.Close() - defer pipeWriter.Close() + defer func() { _ = pipeReader.Close() }() + defer func() { _ = pipeWriter.Close() }() reader := bufio.NewReader(pipeReader) logger := logging.New(types.LogLevelInfo, false) diff --git a/internal/orchestrator/network_apply_preflight_rollback_test.go b/internal/orchestrator/network_apply_preflight_rollback_test.go index dcfbaaf0..2a53a2f1 100644 --- a/internal/orchestrator/network_apply_preflight_rollback_test.go +++ b/internal/orchestrator/network_apply_preflight_rollback_test.go @@ -11,6 +11,16 @@ import ( ) func TestApplyNetworkWithRollbackWithUI_RollsBackFilesOnPreflightFailure(t *testing.T) { + fake := setupNetworkPreflightRollbackTest(t) + err := runNetworkPreflightRollbackFailure(t) + if err == nil || !strings.Contains(err.Error(), "network preflight validation failed") { + t.Fatalf("expected preflight error, got %v", err) + } + assertNetworkPreflightRollbackCalls(t, fake.CallsList()) +} + +func setupNetworkPreflightRollbackTest(t *testing.T) *FakeCommandRunner { + t.Helper() origFS := restoreFS origCmd := restoreCmd origTime := restoreTime @@ -28,18 +38,30 @@ func TestApplyNetworkWithRollbackWithUI_RollsBackFilesOnPreflightFailure(t *test restoreTime = &FakeTime{Current: time.Date(2026, 1, 18, 13, 47, 6, 0, time.UTC)} networkDiagnosticsSequence = 0 + installNetworkPreflightRollbackTools(t) + fake := newNetworkPreflightRollbackRunner() + restoreCmd = fake + return fake +} + +func installNetworkPreflightRollbackTools(t *testing.T) { + t.Helper() pathDir := t.TempDir() - ifqueryPath := filepath.Join(pathDir, "ifquery") - if err := os.WriteFile(ifqueryPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { - t.Fatalf("write ifquery: %v", err) - } - ifupPath := filepath.Join(pathDir, "ifup") - if err := os.WriteFile(ifupPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { - t.Fatalf("write ifup: %v", err) - } + writeExecutableTestTool(t, pathDir, "ifquery") + writeExecutableTestTool(t, pathDir, "ifup") t.Setenv("PATH", pathDir+string(os.PathListSeparator)+os.Getenv("PATH")) +} - fake := &FakeCommandRunner{ +func writeExecutableTestTool(t *testing.T, pathDir, name string) { + t.Helper() + toolPath := filepath.Join(pathDir, name) + if err := os.WriteFile(toolPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("write %s: %v", name, err) + } +} + +func newNetworkPreflightRollbackRunner() *FakeCommandRunner { + return &FakeCommandRunner{ Outputs: map[string][]byte{ "ip route show default": []byte("default via 192.168.1.1 dev nic1\n"), "ifquery --check -a": []byte("ifquery check output\n"), @@ -49,31 +71,32 @@ func TestApplyNetworkWithRollbackWithUI_RollsBackFilesOnPreflightFailure(t *test "ifup -n -a": fmt.Errorf("exit 1"), }, } - restoreCmd = fake +} +func runNetworkPreflightRollbackFailure(t *testing.T) error { + t.Helper() logger := newTestLogger() rollbackBackup := "/tmp/proxsave/network_rollback_backup_20260118_134651.tar.gz" ui := &fakeRestoreWorkflowUI{confirmAction: true} - err := applyNetworkWithRollbackWithUI( + return applyNetworkWithRollbackWithUI( context.Background(), ui, logger, - rollbackBackup, - rollbackBackup, - "", - "", - defaultNetworkRollbackTimeout, - SystemTypePBS, - false, + networkRollbackUIApplyRequest{ + rollbackBackupPath: rollbackBackup, + networkRollbackPath: rollbackBackup, + timeout: defaultNetworkRollbackTimeout, + systemType: SystemTypePBS, + }, ) - if err == nil || !strings.Contains(err.Error(), "network preflight validation failed") { - t.Fatalf("expected preflight error, got %v", err) - } +} +func assertNetworkPreflightRollbackCalls(t *testing.T, calls []string) { + t.Helper() foundIfupPreflight := false foundRollbackSh := false - for _, call := range fake.CallsList() { + for _, call := range calls { if call == "ifup -n -a" { foundIfupPreflight = true } @@ -82,9 +105,9 @@ func TestApplyNetworkWithRollbackWithUI_RollsBackFilesOnPreflightFailure(t *test } } if !foundIfupPreflight { - t.Fatalf("expected ifup preflight to run; calls=%#v", fake.CallsList()) + t.Fatalf("expected ifup preflight to run; calls=%#v", calls) } if !foundRollbackSh { - t.Fatalf("expected rollback script to be invoked via sh; calls=%#v", fake.CallsList()) + t.Fatalf("expected rollback script to be invoked via sh; calls=%#v", calls) } } diff --git a/internal/orchestrator/network_apply_workflow_ui.go b/internal/orchestrator/network_apply_workflow_ui.go index d33db5a1..97898ae9 100644 --- a/internal/orchestrator/network_apply_workflow_ui.go +++ b/internal/orchestrator/network_apply_workflow_ui.go @@ -1,509 +1,16 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. package orchestrator import ( "context" "errors" "fmt" - "os" "strings" "time" "github.com/tis24dev/proxsave/internal/input" - "github.com/tis24dev/proxsave/internal/logging" ) -func maybeApplyNetworkConfigWithUI(ctx context.Context, ui RestoreWorkflowUI, logger *logging.Logger, plan *RestorePlan, safetyBackup, networkRollbackBackup *SafetyBackupResult, stageRoot, archivePath string, dryRun bool) (err error) { - if !shouldAttemptNetworkApply(plan) { - if logger != nil { - logger.Debug("Network safe apply (UI): skipped (network category not selected)") - } - return nil - } - done := logging.DebugStart(logger, "network safe apply (ui)", "dryRun=%v euid=%d stage=%s archive=%s", dryRun, os.Geteuid(), strings.TrimSpace(stageRoot), strings.TrimSpace(archivePath)) - defer func() { done(err) }() - - if ui == nil { - return fmt.Errorf("restore UI not available") - } - if !isRealRestoreFS(restoreFS) { - logger.Debug("Skipping live network apply: non-system filesystem in use") - return nil - } - if dryRun { - logger.Info("Dry run enabled: skipping live network apply") - return nil - } - if os.Geteuid() != 0 { - logger.Warning("Skipping live network apply: requires root privileges") - return nil - } - - logging.DebugStep(logger, "network safe apply (ui)", "Resolve rollback backup paths") - networkRollbackPath := "" - if networkRollbackBackup != nil { - networkRollbackPath = strings.TrimSpace(networkRollbackBackup.BackupPath) - } - fullRollbackPath := "" - if safetyBackup != nil { - fullRollbackPath = strings.TrimSpace(safetyBackup.BackupPath) - } - logging.DebugStep(logger, "network safe apply (ui)", "Rollback backup resolved: network=%q full=%q", networkRollbackPath, fullRollbackPath) - - if networkRollbackPath == "" && fullRollbackPath == "" { - logger.Warning("Skipping live network apply: rollback backup not available") - if strings.TrimSpace(stageRoot) != "" { - logger.Info("Network configuration is staged; skipping NIC repair/apply due to missing rollback backup.") - return nil - } - - repairNow, err := ui.ConfirmAction( - ctx, - "NIC name repair (recommended)", - "Attempt NIC name repair in restored network config files now (no reload)?\n\nThis will only rewrite /etc/network/interfaces and /etc/network/interfaces.d/* when safe mappings are found.", - "Repair now", - "Skip repair", - 0, - false, - ) - if err != nil { - return err - } - logging.DebugStep(logger, "network safe apply (ui)", "User choice: repairNow=%v", repairNow) - if repairNow { - if repair, err := ui.RepairNICNames(ctx, archivePath); err != nil { - return err - } else if repair != nil && strings.TrimSpace(repair.Summary()) != "" { - _ = ui.ShowMessage(ctx, "NIC repair result", repair.Summary()) - } - } - - logger.Info("Skipping live network apply (you can reboot or apply manually later).") - return nil - } - - logging.DebugStep(logger, "network safe apply (ui)", "Prompt: apply network now with rollback timer") - sourceLine := "Source: /etc/network (will be applied)" - if strings.TrimSpace(stageRoot) != "" { - sourceLine = fmt.Sprintf("Source: %s (will be copied to /etc and applied)", strings.TrimSpace(stageRoot)) - } - message := fmt.Sprintf( - "Network restore: a restored network configuration is ready to apply.\n%s\n\nThis will reload networking immediately (no reboot).\n\nWARNING: This may change the active IP and disconnect SSH/Web sessions.\n\nAfter applying, type COMMIT within %ds or ProxSave will roll back automatically.\n\nRecommendation: run this step from the local console/IPMI, not over SSH.\n\nApply network configuration now?", - sourceLine, - int(defaultNetworkRollbackTimeout.Seconds()), - ) - applyNow, err := ui.ConfirmAction(ctx, "Apply network configuration", message, "Apply now", "Skip apply", 90*time.Second, false) - if err != nil { - return err - } - logging.DebugStep(logger, "network safe apply (ui)", "User choice: applyNow=%v", applyNow) - if !applyNow { - if strings.TrimSpace(stageRoot) == "" { - repairNow, err := ui.ConfirmAction( - ctx, - "NIC name repair (recommended)", - "Attempt NIC name repair in restored network config files now (no reload)?\n\nThis will only rewrite /etc/network/interfaces and /etc/network/interfaces.d/* when safe mappings are found.", - "Repair now", - "Skip repair", - 0, - false, - ) - if err != nil { - return err - } - logging.DebugStep(logger, "network safe apply (ui)", "User choice: repairNow=%v", repairNow) - if repairNow { - if repair, err := ui.RepairNICNames(ctx, archivePath); err != nil { - return err - } else if repair != nil && strings.TrimSpace(repair.Summary()) != "" { - _ = ui.ShowMessage(ctx, "NIC repair result", repair.Summary()) - } - } - } else { - logger.Info("Network configuration is staged (not yet written to /etc); skipping NIC repair prompt.") - } - logger.Info("Skipping live network apply (you can apply later).") - return nil - } - - rollbackPath := networkRollbackPath - if rollbackPath == "" { - logging.DebugStep(logger, "network safe apply (ui)", "Prompt: network-only rollback missing; allow full rollback backup fallback") - ok, err := ui.ConfirmAction( - ctx, - "Network-only rollback not available", - "Network-only rollback backup is not available.\n\nIf you proceed, the rollback timer will use the full safety backup, which may revert other restored categories.\n\nProceed anyway?", - "Proceed with full rollback", - "Skip apply", - 0, - false, - ) - if err != nil { - return err - } - logging.DebugStep(logger, "network safe apply (ui)", "User choice: allowFullRollback=%v", ok) - if !ok { - if strings.TrimSpace(stageRoot) == "" { - repairNow, err := ui.ConfirmAction( - ctx, - "NIC name repair (recommended)", - "Attempt NIC name repair in restored network config files now (no reload)?\n\nThis will only rewrite /etc/network/interfaces and /etc/network/interfaces.d/* when safe mappings are found.", - "Repair now", - "Skip repair", - 0, - false, - ) - if err != nil { - return err - } - logging.DebugStep(logger, "network safe apply (ui)", "User choice: repairNow=%v", repairNow) - if repairNow { - if repair, err := ui.RepairNICNames(ctx, archivePath); err != nil { - return err - } else if repair != nil && strings.TrimSpace(repair.Summary()) != "" { - _ = ui.ShowMessage(ctx, "NIC repair result", repair.Summary()) - } - } - } - logger.Info("Skipping live network apply (you can reboot or apply manually later).") - return nil - } - rollbackPath = fullRollbackPath - } - logging.DebugStep(logger, "network safe apply (ui)", "Selected rollback backup: %s", rollbackPath) - - systemType := SystemTypeUnknown - suppressPVEChecks := false - if plan != nil { - systemType = plan.SystemType - // In cluster RECOVERY restores, PVE services are intentionally stopped and /etc/pve is unmounted - // until the end of the workflow. PVE UI (8006) and corosync/quorum checks are not meaningful here. - suppressPVEChecks = plan.SystemType.SupportsPVE() && plan.NeedsClusterRestore - } - return applyNetworkWithRollbackWithUI(ctx, ui, logger, rollbackPath, networkRollbackPath, stageRoot, archivePath, defaultNetworkRollbackTimeout, systemType, suppressPVEChecks) -} - -func applyNetworkWithRollbackWithUI(ctx context.Context, ui RestoreWorkflowUI, logger *logging.Logger, rollbackBackupPath, networkRollbackPath, stageRoot, archivePath string, timeout time.Duration, systemType SystemType, suppressPVEChecks bool) (err error) { - done := logging.DebugStart( - logger, - "network safe apply (ui)", - "rollbackBackup=%s networkRollback=%s timeout=%s systemType=%s stage=%s suppressPVEChecks=%v", - strings.TrimSpace(rollbackBackupPath), - strings.TrimSpace(networkRollbackPath), - timeout, - systemType, - strings.TrimSpace(stageRoot), - suppressPVEChecks, - ) - defer func() { done(err) }() - - if ui == nil { - return fmt.Errorf("restore UI not available") - } - - logging.DebugStep(logger, "network safe apply (ui)", "Create diagnostics directory") - diagnosticsDir, err := createNetworkDiagnosticsDir() - if err != nil { - logger.Warning("Network diagnostics disabled: %v", err) - diagnosticsDir = "" - } else { - logger.Info("Network diagnostics directory: %s", diagnosticsDir) - } - - logging.DebugStep(logger, "network safe apply (ui)", "Detect management interface (SSH/default route)") - iface, source := detectManagementInterface(ctx, logger) - if iface != "" { - logger.Info("Detected management interface: %s (%s)", iface, source) - } - - if diagnosticsDir != "" { - logging.DebugStep(logger, "network safe apply (ui)", "Capture network snapshot (before)") - if snap, err := writeNetworkSnapshot(ctx, logger, diagnosticsDir, "before", 3*time.Second); err != nil { - logger.Debug("Network snapshot before apply failed: %v", err) - } else { - logger.Debug("Network snapshot (before): %s", snap) - } - - logging.DebugStep(logger, "network safe apply (ui)", "Run baseline health checks (before)") - healthBefore := runNetworkHealthChecks(ctx, networkHealthOptions{ - SystemType: systemType, - Logger: logger, - CommandTimeout: 3 * time.Second, - EnableGatewayPing: false, - ForceSSHRouteCheck: false, - EnableDNSResolve: false, - }) - if path, err := writeNetworkHealthReportFileNamed(diagnosticsDir, "health_before.txt", healthBefore); err != nil { - logger.Debug("Failed to write network health (before) report: %v", err) - } else { - logger.Debug("Network health (before) report: %s", path) - } - } - - if strings.TrimSpace(stageRoot) != "" { - logging.DebugStep(logger, "network safe apply (ui)", "Apply staged network files to system paths (before NIC repair)") - applied, err := applyNetworkFilesFromStage(logger, stageRoot) - if err != nil { - return err - } - if len(applied) > 0 { - logging.DebugStep(logger, "network safe apply (ui)", "Staged network files written: %d", len(applied)) - } - } - - logging.DebugStep(logger, "network safe apply (ui)", "NIC name repair (optional)") - var nicRepair *nicRepairResult - if repair, err := ui.RepairNICNames(ctx, archivePath); err != nil { - logger.Warning("NIC repair failed: %v", err) - } else { - nicRepair = repair - if nicRepair != nil { - if nicRepair.Applied() || nicRepair.SkippedReason != "" { - logger.Info("%s", nicRepair.Summary()) - } else { - logger.Debug("%s", nicRepair.Summary()) - } - } - } - - if strings.TrimSpace(iface) != "" { - if cur, err := currentNetworkEndpoint(ctx, iface, 2*time.Second); err == nil { - if tgt, err := targetNetworkEndpointFromConfig(iface); err == nil { - logger.Info("Network plan: %s -> %s", cur.summary(), tgt.summary()) - } - } - } - - if diagnosticsDir != "" { - logging.DebugStep(logger, "network safe apply (ui)", "Write network plan (current -> target)") - if planText, err := buildNetworkPlanReport(ctx, iface, source, 2*time.Second); err != nil { - logger.Debug("Network plan build failed: %v", err) - } else if strings.TrimSpace(planText) != "" { - if path, err := writeNetworkTextReportFile(diagnosticsDir, "plan.txt", planText+"\n"); err != nil { - logger.Debug("Network plan write failed: %v", err) - } else { - logger.Debug("Network plan: %s", path) - } - } - - logging.DebugStep(logger, "network safe apply (ui)", "Run ifquery diagnostic (pre-apply)") - ifqueryPre := runNetworkIfqueryDiagnostic(ctx, 5*time.Second, logger) - if !ifqueryPre.Skipped { - if path, err := writeNetworkIfqueryDiagnosticReportFile(diagnosticsDir, "ifquery_pre_apply.txt", ifqueryPre); err != nil { - logger.Debug("Failed to write ifquery (pre-apply) report: %v", err) - } else { - logger.Debug("ifquery (pre-apply) report: %s", path) - } - } - } - - logging.DebugStep(logger, "network safe apply (ui)", "Network preflight validation (ifupdown/ifupdown2)") - preflight := runNetworkPreflightValidation(ctx, 5*time.Second, logger) - if diagnosticsDir != "" { - if path, err := writeNetworkPreflightReportFile(diagnosticsDir, preflight); err != nil { - logger.Debug("Failed to write network preflight report: %v", err) - } else { - logger.Debug("Network preflight report: %s", path) - } - } - if !preflight.Ok() { - message := preflight.Summary() - if diagnosticsDir != "" { - message += "\n\nDiagnostics saved under:\n" + diagnosticsDir - } - if out := strings.TrimSpace(preflight.Output); out != "" { - message += "\n\nOutput:\n" + out - } - - if strings.TrimSpace(stageRoot) != "" && strings.TrimSpace(networkRollbackPath) != "" { - logging.DebugStep(logger, "network safe apply (ui)", "Preflight failed in staged mode: rolling back network files automatically") - rollbackLog, rbErr := rollbackNetworkFilesNow(ctx, logger, networkRollbackPath, diagnosticsDir) - if strings.TrimSpace(rollbackLog) != "" { - logger.Info("Network rollback log: %s", rollbackLog) - } - if rbErr != nil { - logger.Error("Network apply aborted: preflight validation failed (%s) and rollback failed: %v", preflight.CommandLine(), rbErr) - return fmt.Errorf("network preflight validation failed; rollback attempt failed: %w", rbErr) - } - if diagnosticsDir != "" { - logging.DebugStep(logger, "network safe apply (ui)", "Capture network snapshot (after rollback)") - if snap, err := writeNetworkSnapshot(ctx, logger, diagnosticsDir, "after_rollback", 3*time.Second); err != nil { - logger.Debug("Network snapshot after rollback failed: %v", err) - } else { - logger.Debug("Network snapshot (after rollback): %s", snap) - } - logging.DebugStep(logger, "network safe apply (ui)", "Run ifquery diagnostic (after rollback)") - ifqueryAfterRollback := runNetworkIfqueryDiagnostic(ctx, 5*time.Second, logger) - if !ifqueryAfterRollback.Skipped { - if path, err := writeNetworkIfqueryDiagnosticReportFile(diagnosticsDir, "ifquery_after_rollback.txt", ifqueryAfterRollback); err != nil { - logger.Debug("Failed to write ifquery (after rollback) report: %v", err) - } else { - logger.Debug("ifquery (after rollback) report: %s", path) - } - } - } - logger.Warning( - "Network apply aborted: preflight validation failed (%s). Rolled back /etc/network/*, /etc/hosts, /etc/hostname, /etc/resolv.conf to the pre-restore state (rollback=%s).", - preflight.CommandLine(), - strings.TrimSpace(networkRollbackPath), - ) - _ = ui.ShowError(ctx, "Network preflight failed", "Network configuration failed preflight and was rolled back automatically.") - return fmt.Errorf("network preflight validation failed; network files rolled back") - } - - if !preflight.Skipped && preflight.ExitError != nil && strings.TrimSpace(networkRollbackPath) != "" { - message += "\n\nRollback restored network config files to the pre-restore configuration now? (recommended)" - rollbackNow, err := ui.ConfirmAction(ctx, "Network preflight failed", message, "Rollback now", "Keep restored files", 0, true) - if err != nil { - return err - } - logging.DebugStep(logger, "network safe apply (ui)", "User choice: rollbackNow=%v", rollbackNow) - if rollbackNow { - logging.DebugStep(logger, "network safe apply (ui)", "Rollback network files now (backup=%s)", strings.TrimSpace(networkRollbackPath)) - rollbackLog, rbErr := rollbackNetworkFilesNow(ctx, logger, networkRollbackPath, diagnosticsDir) - if strings.TrimSpace(rollbackLog) != "" { - logger.Info("Network rollback log: %s", rollbackLog) - } - if rbErr != nil { - logger.Warning("Network rollback failed: %v", rbErr) - return fmt.Errorf("network preflight validation failed; rollback attempt failed: %w", rbErr) - } - logger.Warning("Network files rolled back to pre-restore configuration due to preflight failure") - return fmt.Errorf("network preflight validation failed; network files rolled back") - } - } - return fmt.Errorf("network preflight validation failed; aborting live network apply") - } - - logging.DebugStep(logger, "network safe apply (ui)", "Arm rollback timer BEFORE applying changes") - handle, err := armNetworkRollback(ctx, logger, rollbackBackupPath, timeout, diagnosticsDir) - if err != nil { - return err - } - - logging.DebugStep(logger, "network safe apply (ui)", "Apply network configuration now") - if err := applyNetworkConfig(ctx, logger); err != nil { - logger.Warning("Network apply failed: %v", err) - return err - } - - if diagnosticsDir != "" { - logging.DebugStep(logger, "network safe apply (ui)", "Capture network snapshot (after)") - if snap, err := writeNetworkSnapshot(ctx, logger, diagnosticsDir, "after", 3*time.Second); err != nil { - logger.Debug("Network snapshot after apply failed: %v", err) - } else { - logger.Debug("Network snapshot (after): %s", snap) - } - - logging.DebugStep(logger, "network safe apply (ui)", "Run ifquery diagnostic (post-apply)") - ifqueryPost := runNetworkIfqueryDiagnostic(ctx, 5*time.Second, logger) - if !ifqueryPost.Skipped { - if path, err := writeNetworkIfqueryDiagnosticReportFile(diagnosticsDir, "ifquery_post_apply.txt", ifqueryPost); err != nil { - logger.Debug("Failed to write ifquery (post-apply) report: %v", err) - } else { - logger.Debug("ifquery (post-apply) report: %s", path) - } - } - } - - logging.DebugStep(logger, "network safe apply (ui)", "Run post-apply health checks") - healthOptions := networkHealthOptions{ - SystemType: systemType, - Logger: logger, - CommandTimeout: 3 * time.Second, - EnableGatewayPing: true, - ForceSSHRouteCheck: false, - EnableDNSResolve: true, - LocalPortChecks: defaultNetworkPortChecks(systemType), - } - if suppressPVEChecks { - healthOptions.SystemType = SystemTypeUnknown - healthOptions.LocalPortChecks = nil - } - health := runNetworkHealthChecks(ctx, healthOptions) - if suppressPVEChecks { - health.add("PVE service checks", networkHealthOK, "skipped (cluster database restore in progress; services will be restarted after restore completes)") - } - logNetworkHealthReport(logger, health) - if diagnosticsDir != "" { - if path, err := writeNetworkHealthReportFile(diagnosticsDir, health); err != nil { - logger.Debug("Failed to write network health report: %v", err) - } else { - logger.Debug("Network health report: %s", path) - } - } - - remaining := handle.remaining(time.Now()) - if remaining <= 0 { - logger.Warning("Rollback window already expired; leaving rollback armed") - return nil - } - - logging.DebugStep(logger, "network safe apply (ui)", "Wait for COMMIT (rollback in %ds)", int(remaining.Seconds())) - committed, commitErr := ui.PromptNetworkCommit(ctx, remaining, health, nicRepair, diagnosticsDir) - if commitErr != nil { - logger.Warning("Commit prompt error: %v", commitErr) - return buildNetworkApplyNotCommittedError(ctx, logger, iface, handle) - } - logging.DebugStep(logger, "network safe apply (ui)", "User commit result: committed=%v", committed) - if committed { - if rollbackAlreadyRunning(ctx, logger, handle) { - logger.Warning("Commit received too late: rollback already running. Network configuration NOT committed.") - return buildNetworkApplyNotCommittedError(ctx, logger, iface, handle) - } - disarmNetworkRollback(ctx, logger, handle) - logger.Info("Network configuration committed successfully.") - return nil - } - - // Not committed: keep rollback ARMED. - notCommittedErr := buildNetworkApplyNotCommittedError(ctx, logger, iface, handle) - if strings.TrimSpace(diagnosticsDir) != "" { - rollbackState := "Rollback is ARMED and will run automatically." - if notCommittedErr != nil && !notCommittedErr.RollbackArmed { - rollbackState = "Rollback has executed (or marker cleared)." - } - - observed := "unknown" - original := "unknown" - if notCommittedErr != nil { - if v := strings.TrimSpace(notCommittedErr.RestoredIP); v != "" { - observed = v - } - if v := strings.TrimSpace(notCommittedErr.OriginalIP); v != "" { - original = v - } - } - - reconnectHost := "" - if original != "" && original != "unknown" { - reconnectHost = original - if i := strings.Index(reconnectHost, ","); i >= 0 { - reconnectHost = reconnectHost[:i] - } - if i := strings.Index(reconnectHost, "/"); i >= 0 { - reconnectHost = reconnectHost[:i] - } - reconnectHost = strings.TrimSpace(reconnectHost) - } - - var b strings.Builder - b.WriteString("Network configuration not committed.\n\n") - b.WriteString(rollbackState + "\n\n") - b.WriteString(fmt.Sprintf("IP now (after apply): %s\n", observed)) - if original != "unknown" { - b.WriteString(fmt.Sprintf("Expected after rollback: %s\n", original)) - } - if reconnectHost != "" && reconnectHost != "unknown" { - b.WriteString(fmt.Sprintf("Reconnect using: %s\n", reconnectHost)) - } - b.WriteString("\nDiagnostics saved under:\n") - b.WriteString(strings.TrimSpace(diagnosticsDir)) - - _ = ui.ShowMessage(ctx, "Network rollback", b.String()) - } - return notCommittedErr -} - func (c *cliWorkflowUI) ConfirmAction(ctx context.Context, title, message, yesLabel, noLabel string, timeout time.Duration, defaultYes bool) (bool, error) { _ = yesLabel _ = noLabel diff --git a/internal/orchestrator/network_apply_workflow_ui_prompt.go b/internal/orchestrator/network_apply_workflow_ui_prompt.go new file mode 100644 index 00000000..7dc50aa6 --- /dev/null +++ b/internal/orchestrator/network_apply_workflow_ui_prompt.go @@ -0,0 +1,272 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/tis24dev/proxsave/internal/logging" +) + +type networkConfigUIApplyFlow struct { + ctx context.Context + ui RestoreWorkflowUI + logger *logging.Logger + plan *RestorePlan + safetyBackup *SafetyBackupResult + networkRollbackBackup *SafetyBackupResult + stageRoot string + archivePath string + dryRun bool + networkRollbackPath string + fullRollbackPath string +} + +type networkConfigUIApplyRequest struct { + plan *RestorePlan + safetyBackup *SafetyBackupResult + networkRollbackBackup *SafetyBackupResult + stageRoot string + archivePath string + dryRun bool +} + +func maybeApplyNetworkConfigWithUI(ctx context.Context, ui RestoreWorkflowUI, logger *logging.Logger, req networkConfigUIApplyRequest) (err error) { + if !shouldAttemptNetworkApply(req.plan) { + if logger != nil { + logger.Debug("Network safe apply (UI): skipped (network category not selected)") + } + return nil + } + done := logging.DebugStart(logger, "network safe apply (ui)", "dryRun=%v euid=%d stage=%s archive=%s", req.dryRun, os.Geteuid(), strings.TrimSpace(req.stageRoot), strings.TrimSpace(req.archivePath)) + defer func() { done(err) }() + + flow := &networkConfigUIApplyFlow{ + ctx: ctx, + ui: ui, + logger: logger, + plan: req.plan, + safetyBackup: req.safetyBackup, + networkRollbackBackup: req.networkRollbackBackup, + stageRoot: req.stageRoot, + archivePath: req.archivePath, + dryRun: req.dryRun, + } + return flow.run() +} + +func (f *networkConfigUIApplyFlow) run() error { + if err := f.validateRuntime(); err != nil { + return normalizeNetworkApplyRuntimeError(err) + } + f.resolveRollbackPaths() + if f.networkRollbackPath == "" && f.fullRollbackPath == "" { + return f.handleMissingRollbackBackup() + } + return f.confirmAndRunNetworkApply() +} + +func normalizeNetworkApplyRuntimeError(err error) error { + if errors.Is(err, errNetworkApplySkipped) { + return nil + } + return err +} + +func (f *networkConfigUIApplyFlow) confirmAndRunNetworkApply() error { + applyNow, err := f.confirmApplyNow() + if err != nil { + return err + } + if !applyNow { + return f.handleApplySkipped() + } + return f.runConfirmedNetworkApply() +} + +func (f *networkConfigUIApplyFlow) runConfirmedNetworkApply() error { + rollbackPath, err := f.selectRollbackPath() + if err != nil || rollbackPath == "" { + return err + } + systemType, suppressPVEChecks := f.networkApplyOptions() + return applyNetworkWithRollbackWithUI(f.ctx, f.ui, f.logger, networkRollbackUIApplyRequest{ + rollbackBackupPath: rollbackPath, + networkRollbackPath: f.networkRollbackPath, + stageRoot: f.stageRoot, + archivePath: f.archivePath, + timeout: defaultNetworkRollbackTimeout, + systemType: systemType, + suppressPVEChecks: suppressPVEChecks, + }) +} + +func (f *networkConfigUIApplyFlow) validateRuntime() error { + if f.ui == nil { + return fmt.Errorf("restore UI not available") + } + if !isRealRestoreFS(restoreFS) { + f.debug("Skipping live network apply: non-system filesystem in use") + return errNetworkApplySkipped + } + if f.dryRun { + f.info("Dry run enabled: skipping live network apply") + return errNetworkApplySkipped + } + if os.Geteuid() != 0 { + f.warning("Skipping live network apply: requires root privileges") + return errNetworkApplySkipped + } + return nil +} + +func (f *networkConfigUIApplyFlow) resolveRollbackPaths() { + logging.DebugStep(f.logger, "network safe apply (ui)", "Resolve rollback backup paths") + if f.networkRollbackBackup != nil { + f.networkRollbackPath = strings.TrimSpace(f.networkRollbackBackup.BackupPath) + } + if f.safetyBackup != nil { + f.fullRollbackPath = strings.TrimSpace(f.safetyBackup.BackupPath) + } + logging.DebugStep(f.logger, "network safe apply (ui)", "Rollback backup resolved: network=%q full=%q", f.networkRollbackPath, f.fullRollbackPath) +} + +func (f *networkConfigUIApplyFlow) handleMissingRollbackBackup() error { + f.warning("Skipping live network apply: rollback backup not available") + if strings.TrimSpace(f.stageRoot) != "" { + f.info("Network configuration is staged; skipping NIC repair/apply due to missing rollback backup.") + return nil + } + if err := f.promptNICRepair(); err != nil { + return err + } + f.info("Skipping live network apply (you can reboot or apply manually later).") + return nil +} + +func (f *networkConfigUIApplyFlow) confirmApplyNow() (bool, error) { + logging.DebugStep(f.logger, "network safe apply (ui)", "Prompt: apply network now with rollback timer") + sourceLine := "Source: /etc/network (will be applied)" + if strings.TrimSpace(f.stageRoot) != "" { + sourceLine = fmt.Sprintf("Source: %s (will be copied to /etc and applied)", strings.TrimSpace(f.stageRoot)) + } + message := fmt.Sprintf( + "Network restore: a restored network configuration is ready to apply.\n%s\n\nThis will reload networking immediately (no reboot).\n\nWARNING: This may change the active IP and disconnect SSH/Web sessions.\n\nAfter applying, type COMMIT within %ds or ProxSave will roll back automatically.\n\nRecommendation: run this step from the local console/IPMI, not over SSH.\n\nApply network configuration now?", + sourceLine, + int(defaultNetworkRollbackTimeout.Seconds()), + ) + applyNow, err := f.ui.ConfirmAction(f.ctx, "Apply network configuration", message, "Apply now", "Skip apply", 90*time.Second, false) + logging.DebugStep(f.logger, "network safe apply (ui)", "User choice: applyNow=%v", applyNow) + return applyNow, err +} + +func (f *networkConfigUIApplyFlow) handleApplySkipped() error { + if strings.TrimSpace(f.stageRoot) == "" { + if err := f.promptNICRepair(); err != nil { + return err + } + } else { + f.info("Network configuration is staged (not yet written to /etc); skipping NIC repair prompt.") + } + f.info("Skipping live network apply (you can apply later).") + return nil +} + +func (f *networkConfigUIApplyFlow) selectRollbackPath() (string, error) { + if f.networkRollbackPath != "" { + logging.DebugStep(f.logger, "network safe apply (ui)", "Selected rollback backup: %s", f.networkRollbackPath) + return f.networkRollbackPath, nil + } + + ok, err := f.confirmFullRollbackFallback() + if err != nil || !ok { + return f.handleFullRollbackFallbackDeclined(err) + } + logging.DebugStep(f.logger, "network safe apply (ui)", "Selected rollback backup: %s", f.fullRollbackPath) + return f.fullRollbackPath, nil +} + +func (f *networkConfigUIApplyFlow) confirmFullRollbackFallback() (bool, error) { + logging.DebugStep(f.logger, "network safe apply (ui)", "Prompt: network-only rollback missing; allow full rollback backup fallback") + ok, err := f.ui.ConfirmAction( + f.ctx, + "Network-only rollback not available", + "Network-only rollback backup is not available.\n\nIf you proceed, the rollback timer will use the full safety backup, which may revert other restored categories.\n\nProceed anyway?", + "Proceed with full rollback", + "Skip apply", + 0, + false, + ) + logging.DebugStep(f.logger, "network safe apply (ui)", "User choice: allowFullRollback=%v", ok) + return ok, err +} + +func (f *networkConfigUIApplyFlow) handleFullRollbackFallbackDeclined(err error) (string, error) { + if err != nil { + return "", err + } + if strings.TrimSpace(f.stageRoot) == "" { + if repairErr := f.promptNICRepair(); repairErr != nil { + return "", repairErr + } + } + f.info("Skipping live network apply (you can reboot or apply manually later).") + return "", nil +} + +func (f *networkConfigUIApplyFlow) promptNICRepair() error { + repairNow, err := f.ui.ConfirmAction( + f.ctx, + "NIC name repair (recommended)", + "Attempt NIC name repair in restored network config files now (no reload)?\n\nThis will only rewrite /etc/network/interfaces and /etc/network/interfaces.d/* when safe mappings are found.", + "Repair now", + "Skip repair", + 0, + false, + ) + if err != nil { + return err + } + logging.DebugStep(f.logger, "network safe apply (ui)", "User choice: repairNow=%v", repairNow) + if !repairNow { + return nil + } + + repair, err := f.ui.RepairNICNames(f.ctx, f.archivePath) + if err != nil { + return err + } + if repair != nil && strings.TrimSpace(repair.Summary()) != "" { + _ = f.ui.ShowMessage(f.ctx, "NIC repair result", repair.Summary()) + } + return nil +} + +func (f *networkConfigUIApplyFlow) networkApplyOptions() (SystemType, bool) { + if f.plan == nil { + return SystemTypeUnknown, false + } + return f.plan.SystemType, f.plan.SystemType.SupportsPVE() && f.plan.NeedsClusterRestore +} + +func (f *networkConfigUIApplyFlow) debug(format string, args ...interface{}) { + if f.logger != nil { + f.logger.Debug(format, args...) + } +} + +func (f *networkConfigUIApplyFlow) info(format string, args ...interface{}) { + if f.logger != nil { + f.logger.Info(format, args...) + } +} + +func (f *networkConfigUIApplyFlow) warning(format string, args ...interface{}) { + if f.logger != nil { + f.logger.Warning(format, args...) + } +} diff --git a/internal/orchestrator/network_apply_workflow_ui_rollback.go b/internal/orchestrator/network_apply_workflow_ui_rollback.go new file mode 100644 index 00000000..b28ddc8e --- /dev/null +++ b/internal/orchestrator/network_apply_workflow_ui_rollback.go @@ -0,0 +1,509 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/tis24dev/proxsave/internal/logging" +) + +var errNetworkApplySkipped = fmt.Errorf("network apply skipped") + +type networkRollbackUIApplyFlow struct { + ctx context.Context + ui RestoreWorkflowUI + logger *logging.Logger + rollbackBackupPath string + networkRollbackPath string + stageRoot string + archivePath string + timeout time.Duration + systemType SystemType + suppressPVEChecks bool + diagnosticsDir string + iface string + source string + nicRepair *nicRepairResult + handle *networkRollbackHandle + health networkHealthReport +} + +type networkRollbackUIApplyRequest struct { + rollbackBackupPath string + networkRollbackPath string + stageRoot string + archivePath string + timeout time.Duration + systemType SystemType + suppressPVEChecks bool +} + +func applyNetworkWithRollbackWithUI(ctx context.Context, ui RestoreWorkflowUI, logger *logging.Logger, req networkRollbackUIApplyRequest) (err error) { + done := logging.DebugStart( + logger, + "network safe apply (ui)", + "rollbackBackup=%s networkRollback=%s timeout=%s systemType=%s stage=%s suppressPVEChecks=%v", + strings.TrimSpace(req.rollbackBackupPath), + strings.TrimSpace(req.networkRollbackPath), + req.timeout, + req.systemType, + strings.TrimSpace(req.stageRoot), + req.suppressPVEChecks, + ) + defer func() { done(err) }() + + flow := &networkRollbackUIApplyFlow{ + ctx: ctx, + ui: ui, + logger: logger, + rollbackBackupPath: req.rollbackBackupPath, + networkRollbackPath: req.networkRollbackPath, + stageRoot: req.stageRoot, + archivePath: req.archivePath, + timeout: req.timeout, + systemType: req.systemType, + suppressPVEChecks: req.suppressPVEChecks, + } + return flow.run() +} + +func (f *networkRollbackUIApplyFlow) run() error { + if f.ui == nil { + return fmt.Errorf("restore UI not available") + } + f.createDiagnosticsDir() + f.detectManagementInterface() + f.captureBeforeDiagnostics() + if err := f.applyStagedNetworkFiles(); err != nil { + return err + } + f.repairNICNames() + f.logNetworkPlan() + f.writePreApplyDiagnostics() + if err := f.validatePreflight(); err != nil { + return err + } + if err := f.armRollbackAndApply(); err != nil { + return err + } + f.writePostApplyDiagnostics() + f.runPostApplyHealthChecks() + return f.waitForCommit() +} + +func (f *networkRollbackUIApplyFlow) createDiagnosticsDir() { + logging.DebugStep(f.logger, "network safe apply (ui)", "Create diagnostics directory") + dir, err := createNetworkDiagnosticsDir() + if err != nil { + f.warning("Network diagnostics disabled: %v", err) + return + } + f.diagnosticsDir = dir + f.info("Network diagnostics directory: %s", dir) +} + +func (f *networkRollbackUIApplyFlow) detectManagementInterface() { + logging.DebugStep(f.logger, "network safe apply (ui)", "Detect management interface (SSH/default route)") + f.iface, f.source = detectManagementInterface(f.ctx, f.logger) + if f.iface != "" { + f.info("Detected management interface: %s (%s)", f.iface, f.source) + } +} + +func (f *networkRollbackUIApplyFlow) captureBeforeDiagnostics() { + if f.diagnosticsDir == "" { + return + } + logging.DebugStep(f.logger, "network safe apply (ui)", "Capture network snapshot (before)") + if snap, err := writeNetworkSnapshot(f.ctx, f.logger, f.diagnosticsDir, "before", 3*time.Second); err != nil { + f.debug("Network snapshot before apply failed: %v", err) + } else { + f.debug("Network snapshot (before): %s", snap) + } + + logging.DebugStep(f.logger, "network safe apply (ui)", "Run baseline health checks (before)") + healthBefore := runNetworkHealthChecks(f.ctx, networkHealthOptions{ + SystemType: f.systemType, + Logger: f.logger, + CommandTimeout: 3 * time.Second, + EnableGatewayPing: false, + ForceSSHRouteCheck: false, + EnableDNSResolve: false, + }) + if path, err := writeNetworkHealthReportFileNamed(f.diagnosticsDir, "health_before.txt", healthBefore); err != nil { + f.debug("Failed to write network health (before) report: %v", err) + } else { + f.debug("Network health (before) report: %s", path) + } +} + +func (f *networkRollbackUIApplyFlow) applyStagedNetworkFiles() error { + if strings.TrimSpace(f.stageRoot) == "" { + return nil + } + logging.DebugStep(f.logger, "network safe apply (ui)", "Apply staged network files to system paths (before NIC repair)") + applied, err := applyNetworkFilesFromStage(f.logger, f.stageRoot) + if err != nil { + return err + } + if len(applied) > 0 { + logging.DebugStep(f.logger, "network safe apply (ui)", "Staged network files written: %d", len(applied)) + } + return nil +} + +func (f *networkRollbackUIApplyFlow) repairNICNames() { + logging.DebugStep(f.logger, "network safe apply (ui)", "NIC name repair (optional)") + repair, err := f.ui.RepairNICNames(f.ctx, f.archivePath) + if err != nil { + f.warning("NIC repair failed: %v", err) + return + } + f.nicRepair = repair + if repair == nil { + return + } + if repair.Applied() || repair.SkippedReason != "" { + f.info("%s", repair.Summary()) + return + } + f.debug("%s", repair.Summary()) +} + +func (f *networkRollbackUIApplyFlow) logNetworkPlan() { + if strings.TrimSpace(f.iface) == "" { + return + } + cur, curErr := currentNetworkEndpoint(f.ctx, f.iface, 2*time.Second) + tgt, tgtErr := targetNetworkEndpointFromConfig(f.iface) + if curErr == nil && tgtErr == nil { + f.info("Network plan: %s -> %s", cur.summary(), tgt.summary()) + } +} + +func (f *networkRollbackUIApplyFlow) writePreApplyDiagnostics() { + if f.diagnosticsDir == "" { + return + } + f.writeNetworkPlanReport() + f.writeIfqueryDiagnostic("Run ifquery diagnostic (pre-apply)", "ifquery_pre_apply.txt", "pre-apply") +} + +func (f *networkRollbackUIApplyFlow) writeNetworkPlanReport() { + logging.DebugStep(f.logger, "network safe apply (ui)", "Write network plan (current -> target)") + planText, err := buildNetworkPlanReport(f.ctx, f.iface, f.source, 2*time.Second) + if err != nil { + f.debug("Network plan build failed: %v", err) + return + } + if strings.TrimSpace(planText) == "" { + return + } + if path, err := writeNetworkTextReportFile(f.diagnosticsDir, "plan.txt", planText+"\n"); err != nil { + f.debug("Network plan write failed: %v", err) + } else { + f.debug("Network plan: %s", path) + } +} + +func (f *networkRollbackUIApplyFlow) writeIfqueryDiagnostic(step, filename, label string) { + logging.DebugStep(f.logger, "network safe apply (ui)", "%s", step) + result := runNetworkIfqueryDiagnostic(f.ctx, 5*time.Second, f.logger) + if result.Skipped { + return + } + if path, err := writeNetworkIfqueryDiagnosticReportFile(f.diagnosticsDir, filename, result); err != nil { + f.debug("Failed to write ifquery (%s) report: %v", label, err) + } else { + f.debug("ifquery (%s) report: %s", label, path) + } +} + +func (f *networkRollbackUIApplyFlow) validatePreflight() error { + logging.DebugStep(f.logger, "network safe apply (ui)", "Network preflight validation (ifupdown/ifupdown2)") + preflight := runNetworkPreflightValidation(f.ctx, 5*time.Second, f.logger) + f.writePreflightReport(preflight) + if preflight.Ok() { + return nil + } + return f.handlePreflightFailure(preflight) +} + +func (f *networkRollbackUIApplyFlow) writePreflightReport(preflight networkPreflightResult) { + if f.diagnosticsDir == "" { + return + } + if path, err := writeNetworkPreflightReportFile(f.diagnosticsDir, preflight); err != nil { + f.debug("Failed to write network preflight report: %v", err) + } else { + f.debug("Network preflight report: %s", path) + } +} + +func (f *networkRollbackUIApplyFlow) handlePreflightFailure(preflight networkPreflightResult) error { + message := f.preflightFailureMessage(preflight) + if strings.TrimSpace(f.stageRoot) != "" && strings.TrimSpace(f.networkRollbackPath) != "" { + return f.rollbackStagedPreflightFailure(preflight) + } + if f.canAskPreflightRollback(preflight) { + return f.confirmPreflightRollback(message) + } + return fmt.Errorf("network preflight validation failed; aborting live network apply") +} + +func (f *networkRollbackUIApplyFlow) preflightFailureMessage(preflight networkPreflightResult) string { + message := preflight.Summary() + if f.diagnosticsDir != "" { + message += "\n\nDiagnostics saved under:\n" + f.diagnosticsDir + } + if out := strings.TrimSpace(preflight.Output); out != "" { + message += "\n\nOutput:\n" + out + } + return message +} + +func (f *networkRollbackUIApplyFlow) rollbackStagedPreflightFailure(preflight networkPreflightResult) error { + logging.DebugStep(f.logger, "network safe apply (ui)", "Preflight failed in staged mode: rolling back network files automatically") + rollbackLog, rbErr := rollbackNetworkFilesNow(f.ctx, f.logger, f.networkRollbackPath, f.diagnosticsDir) + if strings.TrimSpace(rollbackLog) != "" { + f.info("Network rollback log: %s", rollbackLog) + } + if rbErr != nil { + f.error("Network apply aborted: preflight validation failed (%s) and rollback failed: %v", preflight.CommandLine(), rbErr) + return fmt.Errorf("network preflight validation failed; rollback attempt failed: %w", rbErr) + } + f.captureAfterRollbackDiagnostics() + f.warning( + "Network apply aborted: preflight validation failed (%s). Rolled back /etc/network/*, /etc/hosts, /etc/hostname, /etc/resolv.conf to the pre-restore state (rollback=%s).", + preflight.CommandLine(), + strings.TrimSpace(f.networkRollbackPath), + ) + _ = f.ui.ShowError(f.ctx, "Network preflight failed", "Network configuration failed preflight and was rolled back automatically.") + return fmt.Errorf("network preflight validation failed; network files rolled back") +} + +func (f *networkRollbackUIApplyFlow) captureAfterRollbackDiagnostics() { + if f.diagnosticsDir == "" { + return + } + logging.DebugStep(f.logger, "network safe apply (ui)", "Capture network snapshot (after rollback)") + if snap, err := writeNetworkSnapshot(f.ctx, f.logger, f.diagnosticsDir, "after_rollback", 3*time.Second); err != nil { + f.debug("Network snapshot after rollback failed: %v", err) + } else { + f.debug("Network snapshot (after rollback): %s", snap) + } + f.writeIfqueryDiagnostic("Run ifquery diagnostic (after rollback)", "ifquery_after_rollback.txt", "after rollback") +} + +func (f *networkRollbackUIApplyFlow) canAskPreflightRollback(preflight networkPreflightResult) bool { + return !preflight.Skipped && preflight.ExitError != nil && strings.TrimSpace(f.networkRollbackPath) != "" +} + +func (f *networkRollbackUIApplyFlow) confirmPreflightRollback(message string) error { + message += "\n\nRollback restored network config files to the pre-restore configuration now? (recommended)" + rollbackNow, err := f.ui.ConfirmAction(f.ctx, "Network preflight failed", message, "Rollback now", "Keep restored files", 0, true) + if err != nil { + return err + } + logging.DebugStep(f.logger, "network safe apply (ui)", "User choice: rollbackNow=%v", rollbackNow) + if !rollbackNow { + return fmt.Errorf("network preflight validation failed; aborting live network apply") + } + return f.rollbackPreflightFailureNow() +} + +func (f *networkRollbackUIApplyFlow) rollbackPreflightFailureNow() error { + logging.DebugStep(f.logger, "network safe apply (ui)", "Rollback network files now (backup=%s)", strings.TrimSpace(f.networkRollbackPath)) + rollbackLog, rbErr := rollbackNetworkFilesNow(f.ctx, f.logger, f.networkRollbackPath, f.diagnosticsDir) + if strings.TrimSpace(rollbackLog) != "" { + f.info("Network rollback log: %s", rollbackLog) + } + if rbErr != nil { + f.warning("Network rollback failed: %v", rbErr) + return fmt.Errorf("network preflight validation failed; rollback attempt failed: %w", rbErr) + } + f.warning("Network files rolled back to pre-restore configuration due to preflight failure") + return fmt.Errorf("network preflight validation failed; network files rolled back") +} + +func (f *networkRollbackUIApplyFlow) armRollbackAndApply() error { + logging.DebugStep(f.logger, "network safe apply (ui)", "Arm rollback timer BEFORE applying changes") + handle, err := armNetworkRollback(f.ctx, f.logger, f.rollbackBackupPath, f.timeout, f.diagnosticsDir) + if err != nil { + return err + } + f.handle = handle + + logging.DebugStep(f.logger, "network safe apply (ui)", "Apply network configuration now") + if err := applyNetworkConfig(f.ctx, f.logger); err != nil { + f.warning("Network apply failed: %v", err) + return err + } + return nil +} + +func (f *networkRollbackUIApplyFlow) writePostApplyDiagnostics() { + if f.diagnosticsDir == "" { + return + } + logging.DebugStep(f.logger, "network safe apply (ui)", "Capture network snapshot (after)") + if snap, err := writeNetworkSnapshot(f.ctx, f.logger, f.diagnosticsDir, "after", 3*time.Second); err != nil { + f.debug("Network snapshot after apply failed: %v", err) + } else { + f.debug("Network snapshot (after): %s", snap) + } + f.writeIfqueryDiagnostic("Run ifquery diagnostic (post-apply)", "ifquery_post_apply.txt", "post-apply") +} + +func (f *networkRollbackUIApplyFlow) runPostApplyHealthChecks() { + logging.DebugStep(f.logger, "network safe apply (ui)", "Run post-apply health checks") + healthOptions := networkHealthOptions{ + SystemType: f.systemType, + Logger: f.logger, + CommandTimeout: 3 * time.Second, + EnableGatewayPing: true, + ForceSSHRouteCheck: false, + EnableDNSResolve: true, + LocalPortChecks: defaultNetworkPortChecks(f.systemType), + } + if f.suppressPVEChecks { + healthOptions.SystemType = SystemTypeUnknown + healthOptions.LocalPortChecks = nil + } + f.health = runNetworkHealthChecks(f.ctx, healthOptions) + if f.suppressPVEChecks { + f.health.add("PVE service checks", networkHealthOK, "skipped (cluster database restore in progress; services will be restarted after restore completes)") + } + logNetworkHealthReport(f.logger, f.health) + if f.diagnosticsDir == "" { + return + } + if path, err := writeNetworkHealthReportFile(f.diagnosticsDir, f.health); err != nil { + f.debug("Failed to write network health report: %v", err) + } else { + f.debug("Network health report: %s", path) + } +} + +func (f *networkRollbackUIApplyFlow) waitForCommit() error { + remaining := f.handle.remaining(time.Now()) + if remaining <= 0 { + f.warning("Rollback window already expired; leaving rollback armed") + return nil + } + + logging.DebugStep(f.logger, "network safe apply (ui)", "Wait for COMMIT (rollback in %ds)", int(remaining.Seconds())) + committed, commitErr := f.ui.PromptNetworkCommit(f.ctx, remaining, f.health, f.nicRepair, f.diagnosticsDir) + if commitErr != nil { + f.warning("Commit prompt error: %v", commitErr) + return buildNetworkApplyNotCommittedError(f.ctx, f.logger, f.iface, f.handle) + } + logging.DebugStep(f.logger, "network safe apply (ui)", "User commit result: committed=%v", committed) + if committed { + return f.commitNetworkConfig() + } + return f.handleNetworkNotCommitted() +} + +func (f *networkRollbackUIApplyFlow) commitNetworkConfig() error { + if rollbackAlreadyRunning(f.ctx, f.logger, f.handle) { + f.warning("Commit received too late: rollback already running. Network configuration NOT committed.") + return buildNetworkApplyNotCommittedError(f.ctx, f.logger, f.iface, f.handle) + } + disarmNetworkRollback(f.ctx, f.logger, f.handle) + f.info("Network configuration committed successfully.") + return nil +} + +func (f *networkRollbackUIApplyFlow) handleNetworkNotCommitted() error { + notCommittedErr := buildNetworkApplyNotCommittedError(f.ctx, f.logger, f.iface, f.handle) + f.showNetworkNotCommittedMessage(notCommittedErr) + return notCommittedErr +} + +func (f *networkRollbackUIApplyFlow) showNetworkNotCommittedMessage(notCommittedErr *NetworkApplyNotCommittedError) { + if strings.TrimSpace(f.diagnosticsDir) == "" { + return + } + message := networkNotCommittedMessage(f.diagnosticsDir, notCommittedErr) + _ = f.ui.ShowMessage(f.ctx, "Network rollback", message) +} + +func networkNotCommittedMessage(diagnosticsDir string, notCommittedErr *NetworkApplyNotCommittedError) string { + rollbackState := "Rollback is ARMED and will run automatically." + if notCommittedErr != nil && !notCommittedErr.RollbackArmed { + rollbackState = "Rollback has executed (or marker cleared)." + } + observed, original := networkNotCommittedIPs(notCommittedErr) + reconnectHost := reconnectHostFromOriginalIP(original) + + var b strings.Builder + b.WriteString("Network configuration not committed.\n\n") + b.WriteString(rollbackState + "\n\n") + fmt.Fprintf(&b, "IP now (after apply): %s\n", observed) + if original != "unknown" { + fmt.Fprintf(&b, "Expected after rollback: %s\n", original) + } + if reconnectHost != "" && reconnectHost != "unknown" { + fmt.Fprintf(&b, "Reconnect using: %s\n", reconnectHost) + } + b.WriteString("\nDiagnostics saved under:\n") + b.WriteString(strings.TrimSpace(diagnosticsDir)) + return b.String() +} + +func networkNotCommittedIPs(notCommittedErr *NetworkApplyNotCommittedError) (string, string) { + observed := "unknown" + original := "unknown" + if notCommittedErr == nil { + return observed, original + } + if v := strings.TrimSpace(notCommittedErr.RestoredIP); v != "" { + observed = v + } + if v := strings.TrimSpace(notCommittedErr.OriginalIP); v != "" { + original = v + } + return observed, original +} + +func reconnectHostFromOriginalIP(original string) string { + if original == "" || original == "unknown" { + return "" + } + reconnectHost := original + if i := strings.Index(reconnectHost, ","); i >= 0 { + reconnectHost = reconnectHost[:i] + } + if i := strings.Index(reconnectHost, "/"); i >= 0 { + reconnectHost = reconnectHost[:i] + } + return strings.TrimSpace(reconnectHost) +} + +func (f *networkRollbackUIApplyFlow) debug(format string, args ...interface{}) { + if f.logger != nil { + f.logger.Debug(format, args...) + } +} + +func (f *networkRollbackUIApplyFlow) info(format string, args ...interface{}) { + if f.logger != nil { + f.logger.Info(format, args...) + } +} + +func (f *networkRollbackUIApplyFlow) warning(format string, args ...interface{}) { + if f.logger != nil { + f.logger.Warning(format, args...) + } +} + +func (f *networkRollbackUIApplyFlow) error(format string, args ...interface{}) { + if f.logger != nil { + f.logger.Error(format, args...) + } +} diff --git a/internal/orchestrator/network_diagnostics.go b/internal/orchestrator/network_diagnostics.go index 1509f178..ecd2f9a6 100644 --- a/internal/orchestrator/network_diagnostics.go +++ b/internal/orchestrator/network_diagnostics.go @@ -49,8 +49,8 @@ func writeNetworkSnapshot(ctx context.Context, logger *logging.Logger, diagnosti path = filepath.Join(diagnosticsDir, fmt.Sprintf("%s.txt", label)) var b strings.Builder - b.WriteString(fmt.Sprintf("GeneratedAt: %s\n", nowRestore().Format(time.RFC3339))) - b.WriteString(fmt.Sprintf("Label: %s\n\n", label)) + fmt.Fprintf(&b, "GeneratedAt: %s\n", nowRestore().Format(time.RFC3339)) + fmt.Fprintf(&b, "Label: %s\n\n", label) b.WriteString("=== LIVE NETWORK STATE ===\n\n") commands := [][]string{ @@ -78,7 +78,7 @@ func writeNetworkSnapshot(ctx context.Context, logger *logging.Logger, diagnosti } } if err != nil { - b.WriteString(fmt.Sprintf("ERROR: %v\n", err)) + fmt.Fprintf(&b, "ERROR: %v\n", err) if logger != nil { logger.Debug("Network snapshot command failed: %s: %v", strings.Join(cmd, " "), err) } @@ -128,7 +128,7 @@ func appendCommandSnapshot(ctx context.Context, logger *logging.Logger, b *strin } } if err != nil { - b.WriteString(fmt.Sprintf("ERROR: %v\n", err)) + fmt.Fprintf(b, "ERROR: %v\n", err) if logger != nil { logger.Debug("Network snapshot command failed: %s: %v", strings.Join(cmd, " "), err) } @@ -143,18 +143,18 @@ func appendFileSnapshot(logger *logging.Logger, label string, b *strings.Builder } info, err := restoreFS.Stat(path) if err != nil { - b.WriteString(fmt.Sprintf("ERROR: %v\n\n", err)) + fmt.Fprintf(b, "ERROR: %v\n\n", err) if logger != nil { logging.DebugStep(logger, "network snapshot", "Stat failed (%s): %s: %v", label, path, err) } return } - b.WriteString(fmt.Sprintf("Mode: %s\n", info.Mode().String())) - b.WriteString(fmt.Sprintf("Size: %d\n", info.Size())) - b.WriteString(fmt.Sprintf("ModTime: %s\n\n", info.ModTime().Format(time.RFC3339))) + fmt.Fprintf(b, "Mode: %s\n", info.Mode().String()) + fmt.Fprintf(b, "Size: %d\n", info.Size()) + fmt.Fprintf(b, "ModTime: %s\n\n", info.ModTime().Format(time.RFC3339)) data, err := restoreFS.ReadFile(path) if err != nil { - b.WriteString(fmt.Sprintf("ERROR: %v\n\n", err)) + fmt.Fprintf(b, "ERROR: %v\n\n", err) if logger != nil { logging.DebugStep(logger, "network snapshot", "Read failed (%s): %s: %v", label, path, err) } @@ -168,7 +168,7 @@ func appendFileSnapshot(logger *logging.Logger, label string, b *strings.Builder if maxBytes > 0 && (len(data) == 0 || data[maxBytes-1] != '\n') { b.WriteString("\n") } - b.WriteString(fmt.Sprintf("\n[truncated: %d of %d bytes]\n\n", maxBytes, len(data))) + fmt.Fprintf(b, "\n[truncated: %d of %d bytes]\n\n", maxBytes, len(data)) return } b.Write(data) @@ -185,7 +185,7 @@ func appendDirSnapshot(logger *logging.Logger, label string, b *strings.Builder, } entries, err := restoreFS.ReadDir(dir) if err != nil { - b.WriteString(fmt.Sprintf("ERROR: %v\n\n", err)) + fmt.Fprintf(b, "ERROR: %v\n\n", err) if logger != nil { logging.DebugStep(logger, "network snapshot", "ReadDir failed (%s): %s: %v", label, dir, err) } @@ -220,7 +220,7 @@ func appendDirSnapshot(logger *logging.Logger, label string, b *strings.Builder, logging.DebugStep(logger, "network snapshot", "Dir entries (%s): %s: %s", label, dir, strings.Join(names, ", ")) } for _, e := range list { - b.WriteString(fmt.Sprintf("- %s (%s)\n", e.name, e.mode.String())) + fmt.Fprintf(b, "- %s (%s)\n", e.name, e.mode.String()) } b.WriteString("\n") diff --git a/internal/orchestrator/network_health.go b/internal/orchestrator/network_health.go index 8e4de316..62d01a3f 100644 --- a/internal/orchestrator/network_health.go +++ b/internal/orchestrator/network_health.go @@ -73,7 +73,7 @@ func (r networkHealthReport) Details() string { b.WriteString(r.Summary()) b.WriteString("\n") for _, c := range r.Checks { - b.WriteString(fmt.Sprintf("- [%s] %s: %s\n", c.Severity.String(), c.Name, c.Message)) + fmt.Fprintf(&b, "- [%s] %s: %s\n", c.Severity.String(), c.Name, c.Message) } return strings.TrimRight(b.String(), "\n") } diff --git a/internal/orchestrator/network_plan.go b/internal/orchestrator/network_plan.go index 11c1eb9c..a5c73b9a 100644 --- a/internal/orchestrator/network_plan.go +++ b/internal/orchestrator/network_plan.go @@ -44,12 +44,12 @@ func buildNetworkPlanReport(ctx context.Context, iface, source string, timeout t var b strings.Builder b.WriteString("Network plan\n\n") - b.WriteString(fmt.Sprintf("- Management interface: %s\n", strings.TrimSpace(iface))) + fmt.Fprintf(&b, "- Management interface: %s\n", strings.TrimSpace(iface)) if strings.TrimSpace(source) != "" { - b.WriteString(fmt.Sprintf("- Detection source: %s\n", strings.TrimSpace(source))) + fmt.Fprintf(&b, "- Detection source: %s\n", strings.TrimSpace(source)) } - b.WriteString(fmt.Sprintf("- Current runtime: %s\n", current.summary())) - b.WriteString(fmt.Sprintf("- Target config: %s\n", target.summary())) + fmt.Fprintf(&b, "- Current runtime: %s\n", current.summary()) + fmt.Fprintf(&b, "- Target config: %s\n", target.summary()) return b.String(), nil } diff --git a/internal/orchestrator/network_staged_apply.go b/internal/orchestrator/network_staged_apply.go index da938440..3824acd3 100644 --- a/internal/orchestrator/network_staged_apply.go +++ b/internal/orchestrator/network_staged_apply.go @@ -178,10 +178,6 @@ func copyFileOverlayWithinRoot(src, dest, destRoot string) (bool, error) { return true, nil } -func copySymlinkOverlay(src, dest string) (bool, error) { - return copySymlinkOverlayWithinRoot(src, dest, filepath.Dir(dest)) -} - func copySymlinkOverlayWithinRoot(src, dest, destRoot string) (bool, error) { info, err := restoreFS.Lstat(src) if err != nil { diff --git a/internal/orchestrator/nic_mapping.go b/internal/orchestrator/nic_mapping.go index b77dc273..9726f46c 100644 --- a/internal/orchestrator/nic_mapping.go +++ b/internal/orchestrator/nic_mapping.go @@ -163,7 +163,7 @@ func (r nicRepairResult) Details() string { var b strings.Builder b.WriteString(r.Summary()) if r.BackupDir != "" { - b.WriteString(fmt.Sprintf("\nBackup of pre-repair files: %s", r.BackupDir)) + fmt.Fprintf(&b, "\nBackup of pre-repair files: %s", r.BackupDir) } if len(r.ChangedFiles) > 0 { b.WriteString("\nUpdated files:") @@ -317,7 +317,7 @@ func readArchiveEntry(ctx context.Context, archivePath string, candidates []stri if err != nil { return nil, "", err } - defer file.Close() + defer closeIntoErr(&err, file, "close archive") reader, err := createDecompressionReader(ctx, file, archivePath) if err != nil { diff --git a/internal/orchestrator/nic_mapping_additional_test.go b/internal/orchestrator/nic_mapping_additional_test.go index c4cb6817..057ab97b 100644 --- a/internal/orchestrator/nic_mapping_additional_test.go +++ b/internal/orchestrator/nic_mapping_additional_test.go @@ -552,7 +552,7 @@ func TestPlanAndApplyNICNameRepair_WithFakeInventory(t *testing.T) { t.Fatalf("write interfaces: %v", err) } - // includeConflicts=false: applies only safe mapping (ens2 -> eno1). + // includeConflicts=false: applies only safe mapping (ens20 -> eno1). res, err := applyNICNameRepair(logger, plan, false) if err != nil { t.Fatalf("apply: %v", err) @@ -588,6 +588,15 @@ func TestPlanAndApplyNICNameRepair_WithFakeInventory(t *testing.T) { if err != nil { t.Fatalf("apply conflicts: %v", err) } + if res == nil || res.SkippedReason != "" { + t.Fatalf("conflict result=%+v", res) + } + if len(res.ChangedFiles) != 1 || res.ChangedFiles[0] != "/etc/network/interfaces" { + t.Fatalf("conflict ChangedFiles=%v", res.ChangedFiles) + } + if len(res.AppliedNICMap) != 2 { + t.Fatalf("conflict AppliedNICMap=%+v", res.AppliedNICMap) + } data, err = fakeFS.ReadFile("/etc/network/interfaces") if err != nil { t.Fatalf("read: %v", err) diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index de9a9548..cf097c4e 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -726,10 +726,12 @@ func (o *Orchestrator) createBundle(ctx context.Context, archivePath string) (bu } if _, err := io.Copy(tw, &contextReader{ctx: ctx, r: file}); err != nil { - file.Close() + _ = file.Close() return "", fmt.Errorf("failed to write %s to tar: %w", filename, err) } - file.Close() + if err := file.Close(); err != nil { + return "", fmt.Errorf("failed to close %s: %w", filename, err) + } } // Close tar writer to flush @@ -806,7 +808,7 @@ func (o *Orchestrator) removeAssociatedFiles(archivePath string) error { // encryptArchive was replaced by streaming encryption inside the archiver. // SaveStatsReport writes a JSON report with backup statistics to the log directory. -func (o *Orchestrator) SaveStatsReport(stats *BackupStats) error { +func (o *Orchestrator) SaveStatsReport(stats *BackupStats) (err error) { if stats == nil { return fmt.Errorf("stats cannot be nil") } @@ -833,7 +835,7 @@ func (o *Orchestrator) SaveStatsReport(stats *BackupStats) error { if err != nil { return fmt.Errorf("create stats report: %w", err) } - defer file.Close() + defer closeIntoErr(&err, file, "close stats report") durationSeconds := stats.Duration.Seconds() compressionRatio := stats.CompressionRatio @@ -1060,21 +1062,21 @@ func (o *Orchestrator) writeBackupMetadata(tempDir string, stats *BackupStats) e builder := strings.Builder{} builder.WriteString("# ProxSave Metadata\n") builder.WriteString("# This file enables selective restore functionality in newer restore scripts\n") - builder.WriteString(fmt.Sprintf("VERSION=%s\n", version)) - builder.WriteString(fmt.Sprintf("BACKUP_TYPE=%s\n", stats.ProxmoxType.String())) + fmt.Fprintf(&builder, "VERSION=%s\n", version) + fmt.Fprintf(&builder, "BACKUP_TYPE=%s\n", stats.ProxmoxType.String()) if len(stats.ProxmoxTargets) > 0 { - builder.WriteString(fmt.Sprintf("BACKUP_TARGETS=%s\n", strings.Join(stats.ProxmoxTargets, ","))) + fmt.Fprintf(&builder, "BACKUP_TARGETS=%s\n", strings.Join(stats.ProxmoxTargets, ",")) } - builder.WriteString(fmt.Sprintf("TIMESTAMP=%s\n", stats.Timestamp)) - builder.WriteString(fmt.Sprintf("HOSTNAME=%s\n", stats.Hostname)) + fmt.Fprintf(&builder, "TIMESTAMP=%s\n", stats.Timestamp) + fmt.Fprintf(&builder, "HOSTNAME=%s\n", stats.Hostname) if strings.TrimSpace(stats.PVEVersion) != "" { - builder.WriteString(fmt.Sprintf("PVE_VERSION=%s\n", strings.TrimSpace(stats.PVEVersion))) + fmt.Fprintf(&builder, "PVE_VERSION=%s\n", strings.TrimSpace(stats.PVEVersion)) } if strings.TrimSpace(stats.PBSVersion) != "" { - builder.WriteString(fmt.Sprintf("PBS_VERSION=%s\n", strings.TrimSpace(stats.PBSVersion))) + fmt.Fprintf(&builder, "PBS_VERSION=%s\n", strings.TrimSpace(stats.PBSVersion)) } if stats.ClusterMode != "" { - builder.WriteString(fmt.Sprintf("PVE_CLUSTER_MODE=%s\n", stats.ClusterMode)) + fmt.Fprintf(&builder, "PVE_CLUSTER_MODE=%s\n", stats.ClusterMode) } builder.WriteString("SUPPORTS_SELECTIVE_RESTORE=true\n") builder.WriteString("BACKUP_FEATURES=selective_restore,category_mapping,version_detection,auto_directory_creation\n") @@ -1175,7 +1177,7 @@ func applyCollectorOverrides(cc *backup.CollectorConfig, cfg *config.Config) { cc.PBSPassword = cfg.PBSPassword cc.PBSFingerprint = cfg.PBSFingerprint } -func copyFile(fs FS, src, dest string) error { +func copyFile(fs FS, src, dest string) (err error) { if fs == nil { fs = osFS{} } @@ -1183,13 +1185,13 @@ func copyFile(fs FS, src, dest string) error { if err != nil { return err } - defer in.Close() + defer closeIntoErr(&err, in, "close source file") out, err := fs.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640) if err != nil { return err } - defer out.Close() + defer closeIntoErr(&err, out, "close destination file") if _, err := io.Copy(out, in); err != nil { return err diff --git a/internal/orchestrator/orchestrator_test.go b/internal/orchestrator/orchestrator_test.go index 558b80fb..06a33691 100644 --- a/internal/orchestrator/orchestrator_test.go +++ b/internal/orchestrator/orchestrator_test.go @@ -22,6 +22,20 @@ import ( const testAgeRecipient = "age1ql3z7hjy54pw3hyww5ayyfg7zqgvc7w3j2elw8zmrj2kg5sfn9aqmcac8p" +func setSmallBackupTestConfig(t *testing.T, orch *Orchestrator, dir string) { + t.Helper() + + configPath := filepath.Join(dir, "backup.env") + if err := os.WriteFile(configPath, []byte("BACKUP_CONFIG_FILE=true\n"), 0o600); err != nil { + t.Fatalf("write test config: %v", err) + } + + orch.SetConfig(&config.Config{ + ConfigPath: configPath, + BackupConfigFile: true, + }) +} + type testStorageTarget struct { err error calls int @@ -44,6 +58,7 @@ func TestRunGoBackupEndToEnd(t *testing.T) { orch := New(logger, false) orch.SetBackupConfig(backupDir, logDir, types.CompressionNone, 0, 0, "standard", nil) + setSmallBackupTestConfig(t, orch, backupDir) checkerConfig := &checks.CheckerConfig{ BackupPath: backupDir, @@ -62,7 +77,8 @@ func TestRunGoBackupEndToEnd(t *testing.T) { checker := checks.NewChecker(logger, checkerConfig) orch.SetChecker(checker) - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() stats, err := orch.RunGoBackup(ctx, &environment.EnvironmentInfo{Type: types.ProxmoxUnknown, Version: "unknown"}, "test-host") if err != nil { t.Fatalf("RunGoBackup failed: %v", err) @@ -172,6 +188,7 @@ func TestRunGoBackupFallbackCompression(t *testing.T) { orch := New(logger, false) orch.SetBackupConfig(backupDir, logDir, types.CompressionXZ, 6, 0, "ultra", nil) + setSmallBackupTestConfig(t, orch, backupDir) checkerConfig := &checks.CheckerConfig{ BackupPath: backupDir, @@ -198,7 +215,8 @@ func TestRunGoBackupFallbackCompression(t *testing.T) { }) t.Cleanup(restore) - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() stats, err := orch.RunGoBackup(ctx, &environment.EnvironmentInfo{Type: types.ProxmoxUnknown, Version: "unknown"}, "fallback-host") if err != nil { t.Fatalf("RunGoBackup failed: %v", err) diff --git a/internal/orchestrator/pbs_notifications_api_apply.go b/internal/orchestrator/pbs_notifications_api_apply.go index 032debba..a2b3651d 100644 --- a/internal/orchestrator/pbs_notifications_api_apply.go +++ b/internal/orchestrator/pbs_notifications_api_apply.go @@ -1,7 +1,9 @@ +// Package orchestrator coordinates backup, restore, decrypt, and notification workflows. package orchestrator import ( "context" + "errors" "fmt" "sort" "strings" @@ -9,235 +11,325 @@ import ( "github.com/tis24dev/proxsave/internal/logging" ) +type pbsNotificationEndpointSection struct { + section proxmoxNotificationSection + redactFlags []string + redactIndex []int + positional []string + sectionKey string + endpointType string +} + +type pbsNotificationDesiredState struct { + endpoints []pbsNotificationEndpointSection + matchers map[string]proxmoxNotificationSection + matcherNames []string +} + +// gotifyTokenRedactIndex is the token positional index in +// `notification endpoint gotify create ...`. +const gotifyTokenRedactIndex = 6 + func applyPBSNotificationsViaAPI(ctx context.Context, logger *logging.Logger, stageRoot string, strict bool) error { - cfgRaw, cfgPresent, err := readStageFileOptional(stageRoot, "etc/proxmox-backup/notifications.cfg") - if err != nil { + desired, present, err := loadPBSNotificationDesiredState(stageRoot, logger) + if err != nil || !present { + return err + } + + if strict { + if err := removeExtraPBSNotificationMatchers(ctx, logger, desired.matchers); err != nil { + return err + } + } + if err := syncPBSNotificationEndpoints(ctx, logger, desired.endpoints, strict); err != nil { return err } - if !cfgPresent { - return nil + return syncPBSNotificationMatchers(ctx, desired) +} + +func loadPBSNotificationDesiredState(stageRoot string, logger *logging.Logger) (pbsNotificationDesiredState, bool, error) { + cfgSections, privSections, present, err := readPBSNotificationStageSections(stageRoot) + if err != nil || !present { + return pbsNotificationDesiredState{}, present, err + } + + desired := buildPBSNotificationDesiredState(cfgSections, privSections, logger) + return desired, true, nil +} + +func readPBSNotificationStageSections(stageRoot string) ([]proxmoxNotificationSection, []proxmoxNotificationSection, bool, error) { + cfgRaw, cfgPresent, err := readStageFileOptional(stageRoot, "etc/proxmox-backup/notifications.cfg") + if err != nil || !cfgPresent { + return nil, nil, cfgPresent, err } privRaw, _, err := readStageFileOptional(stageRoot, "etc/proxmox-backup/notifications-priv.cfg") if err != nil { - return err + return nil, nil, true, err } cfgSections, err := parseProxmoxNotificationSections(cfgRaw) if err != nil { - return fmt.Errorf("parse staged notifications.cfg: %w", err) + return nil, nil, true, fmt.Errorf("parse staged notifications.cfg: %w", err) } privSections, err := parseProxmoxNotificationSections(privRaw) if err != nil { - return fmt.Errorf("parse staged notifications-priv.cfg: %w", err) - } - - privByKey := make(map[string][]proxmoxNotificationEntry) - privRedactFlagsByKey := make(map[string][]string) - for _, s := range privSections { - if strings.TrimSpace(s.Type) == "" || strings.TrimSpace(s.Name) == "" { - continue - } - key := fmt.Sprintf("%s:%s", strings.TrimSpace(s.Type), strings.TrimSpace(s.Name)) - privByKey[key] = append([]proxmoxNotificationEntry{}, s.Entries...) - privRedactFlagsByKey[key] = append([]string(nil), notificationRedactFlagsFromEntries(s.Entries)...) - } - - type endpointSection struct { - section proxmoxNotificationSection - redactFlags []string - redactIndex []int - positional []string - sectionKey string - endpointType string + return nil, nil, true, fmt.Errorf("parse staged notifications-priv.cfg: %w", err) } + return cfgSections, privSections, true, nil +} - var endpoints []endpointSection - var matchers []proxmoxNotificationSection +func buildPBSNotificationDesiredState(cfgSections, privSections []proxmoxNotificationSection, logger *logging.Logger) pbsNotificationDesiredState { + privByKey, privRedactFlagsByKey := pbsNotificationPrivMaps(privSections) + desired := pbsNotificationDesiredState{matchers: make(map[string]proxmoxNotificationSection)} - for _, s := range cfgSections { - typ := strings.TrimSpace(s.Type) - name := strings.TrimSpace(s.Name) + for _, section := range cfgSections { + typ := strings.TrimSpace(section.Type) + name := strings.TrimSpace(section.Name) if typ == "" || name == "" { continue } switch typ { case "smtp", "sendmail", "gotify", "webhook": - key := fmt.Sprintf("%s:%s", typ, name) - if priv, ok := privByKey[key]; ok && len(priv) > 0 { - s.Entries = append(s.Entries, priv...) + if endpoint, ok := buildPBSNotificationEndpoint(section, privByKey, privRedactFlagsByKey, logger); ok { + desired.endpoints = append(desired.endpoints, endpoint) } - redactFlags := notificationRedactFlags(s) - if extra := privRedactFlagsByKey[key]; len(extra) > 0 { - redactFlags = append(redactFlags, extra...) - } - - pos := []string{} - entries := s.Entries - - switch typ { - case "smtp": - recipients, remaining, ok := popEntryValue(entries, "recipients", "mailto", "mail-to") - if !ok || strings.TrimSpace(recipients) == "" { - logger.Warning("PBS notifications API apply: smtp endpoint %s missing recipients; skipping", name) - continue - } - pos = append(pos, recipients) - s.Entries = remaining - case "sendmail": - mailto, remaining, ok := popEntryValue(entries, "mailto", "mail-to", "recipients") - if !ok || strings.TrimSpace(mailto) == "" { - logger.Warning("PBS notifications API apply: sendmail endpoint %s missing mailto; skipping", name) - continue - } - pos = append(pos, mailto) - s.Entries = remaining - case "gotify": - server, remaining, ok := popEntryValue(entries, "server") - if !ok || strings.TrimSpace(server) == "" { - logger.Warning("PBS notifications API apply: gotify endpoint %s missing server; skipping", name) - continue - } - token, remaining2, ok := popEntryValue(remaining, "token") - if !ok || strings.TrimSpace(token) == "" { - logger.Warning("PBS notifications API apply: gotify endpoint %s missing token; skipping", name) - continue - } - pos = append(pos, server, token) - s.Entries = remaining2 - case "webhook": - url, remaining, ok := popEntryValue(entries, "url") - if !ok || strings.TrimSpace(url) == "" { - logger.Warning("PBS notifications API apply: webhook endpoint %s missing url; skipping", name) - continue - } - pos = append(pos, url) - s.Entries = remaining - } - - redactIndex := []int(nil) - if typ == "gotify" { - // proxmox-backup-manager notification endpoint gotify create/update - redactIndex = []int{6} - } - - endpoints = append(endpoints, endpointSection{ - section: s, - redactFlags: redactFlags, - redactIndex: redactIndex, - positional: pos, - sectionKey: key, - endpointType: typ, - }) case "matcher": - matchers = append(matchers, s) + desired.matchers[name] = section default: logger.Warning("PBS notifications API apply: unknown section %q (%s); skipping", typ, name) } } - // In strict mode, remove matchers first so endpoint cleanup isn't blocked by references. - desiredMatchers := make(map[string]proxmoxNotificationSection, len(matchers)) - for _, m := range matchers { - name := strings.TrimSpace(m.Name) - if name == "" { + desired.matcherNames = sortedPBSMatcherNames(desired.matchers) + return desired +} + +func pbsNotificationPrivMaps(sections []proxmoxNotificationSection) (map[string][]proxmoxNotificationEntry, map[string][]string) { + privByKey := make(map[string][]proxmoxNotificationEntry) + redactByKey := make(map[string][]string) + for _, section := range sections { + typ := strings.TrimSpace(section.Type) + name := strings.TrimSpace(section.Name) + if typ == "" || name == "" { continue } - desiredMatchers[name] = m + key := pbsNotificationSectionKey(typ, name) + privByKey[key] = append([]proxmoxNotificationEntry{}, section.Entries...) + redactByKey[key] = append([]string(nil), notificationRedactFlagsFromEntries(section.Entries)...) } + return privByKey, redactByKey +} + +func buildPBSNotificationEndpoint(section proxmoxNotificationSection, privByKey map[string][]proxmoxNotificationEntry, privRedactFlagsByKey map[string][]string, logger *logging.Logger) (pbsNotificationEndpointSection, bool) { + typ := strings.TrimSpace(section.Type) + name := strings.TrimSpace(section.Name) + key := pbsNotificationSectionKey(typ, name) - matcherNames := make([]string, 0, len(desiredMatchers)) - for name := range desiredMatchers { - matcherNames = append(matcherNames, name) + if priv := privByKey[key]; len(priv) > 0 { + section.Entries = append(section.Entries, priv...) } - sort.Strings(matcherNames) + positional, entries, ok := pbsEndpointPositionalArgs(typ, name, section.Entries, logger) + if !ok { + return pbsNotificationEndpointSection{}, false + } + section.Entries = entries - if strict { - out, err := runPBSManager(ctx, "notification", "matcher", "list", "--output-format=json") - if err != nil { - return err - } - current, err := parsePBSListIDs(out, "name", "id") - if err != nil { - return fmt.Errorf("parse matcher list: %w", err) + redactFlags := notificationRedactFlags(section) + if extra := privRedactFlagsByKey[key]; len(extra) > 0 { + redactFlags = append(redactFlags, extra...) + } + + return pbsNotificationEndpointSection{ + section: section, + redactFlags: redactFlags, + redactIndex: pbsEndpointRedactIndexes(typ), + positional: positional, + sectionKey: key, + endpointType: typ, + }, true +} + +func pbsEndpointPositionalArgs(typ, name string, entries []proxmoxNotificationEntry, logger *logging.Logger) ([]string, []proxmoxNotificationEntry, bool) { + switch typ { + case "smtp": + return pbsEndpointSinglePositional(typ, name, entries, logger, "recipients", "mailto", "mail-to") + case "sendmail": + return pbsEndpointSinglePositional(typ, name, entries, logger, "mailto", "mail-to", "recipients") + case "gotify": + return pbsGotifyEndpointPositionals(name, entries, logger) + case "webhook": + return pbsEndpointSinglePositional(typ, name, entries, logger, "url") + default: + return nil, entries, false + } +} + +func pbsEndpointSinglePositional(typ, name string, entries []proxmoxNotificationEntry, logger *logging.Logger, keys ...string) ([]string, []proxmoxNotificationEntry, bool) { + value, remaining, ok := popEntryValue(entries, keys...) + if !ok || strings.TrimSpace(value) == "" { + logger.Warning("PBS notifications API apply: %s endpoint %s missing %s; skipping", typ, name, keys[0]) + return nil, entries, false + } + return []string{value}, remaining, true +} + +func pbsGotifyEndpointPositionals(name string, entries []proxmoxNotificationEntry, logger *logging.Logger) ([]string, []proxmoxNotificationEntry, bool) { + server, remaining, ok := popEntryValue(entries, "server") + if !ok || strings.TrimSpace(server) == "" { + logger.Warning("PBS notifications API apply: gotify endpoint %s missing server; skipping", name) + return nil, entries, false + } + token, remaining, ok := popEntryValue(remaining, "token") + if !ok || strings.TrimSpace(token) == "" { + logger.Warning("PBS notifications API apply: gotify endpoint %s missing token; skipping", name) + return nil, entries, false + } + return []string{server, token}, remaining, true +} + +func pbsEndpointRedactIndexes(typ string) []int { + if typ == "gotify" { + return []int{gotifyTokenRedactIndex} + } + return nil +} + +func sortedPBSMatcherNames(matchers map[string]proxmoxNotificationSection) []string { + names := make([]string, 0, len(matchers)) + for name := range matchers { + names = append(names, name) + } + sort.Strings(names) + return names +} + +func removeExtraPBSNotificationMatchers(ctx context.Context, logger *logging.Logger, desired map[string]proxmoxNotificationSection) error { + current, err := listPBSNotificationIDs(ctx, "matcher", "list") + if err != nil { + return err + } + for _, name := range current { + if _, ok := desired[name]; ok { + continue } - for _, name := range current { - if _, ok := desiredMatchers[name]; ok { - continue - } - if _, err := runPBSManager(ctx, "notification", "matcher", "remove", name); err != nil { - // Built-in matchers may not be removable; keep going. - logger.Warning("PBS notifications API apply: matcher remove %s failed (continuing): %v", name, err) - } + if _, err := runPBSManager(ctx, "notification", "matcher", "remove", name); err != nil { + logger.Warning("PBS notifications API apply: matcher remove %s failed (continuing): %v", name, err) } } + return nil +} - // Endpoints first (matchers refer to targets/endpoints). +func syncPBSNotificationEndpoints(ctx context.Context, logger *logging.Logger, endpoints []pbsNotificationEndpointSection, strict bool) error { for _, typ := range []string{"smtp", "sendmail", "gotify", "webhook"} { - desiredNames := make(map[string]endpointSection) - for _, e := range endpoints { - if e.endpointType != typ { - continue - } - name := strings.TrimSpace(e.section.Name) - if name == "" { - continue + desired := pbsEndpointsByName(endpoints, typ) + if strict { + if err := removeExtraPBSNotificationEndpoints(ctx, logger, typ, desired); err != nil { + return err } - desiredNames[name] = e } + if err := upsertPBSNotificationEndpoints(ctx, typ, desired); err != nil { + return err + } + } + return nil +} - names := make([]string, 0, len(desiredNames)) - for name := range desiredNames { - names = append(names, name) +func pbsEndpointsByName(endpoints []pbsNotificationEndpointSection, typ string) map[string]pbsNotificationEndpointSection { + desired := make(map[string]pbsNotificationEndpointSection) + for _, endpoint := range endpoints { + if endpoint.endpointType != typ { + continue + } + name := strings.TrimSpace(endpoint.section.Name) + if name != "" { + desired[name] = endpoint } - sort.Strings(names) + } + return desired +} - if strict { - out, err := runPBSManager(ctx, "notification", "endpoint", typ, "list", "--output-format=json") - if err != nil { - return err - } - current, err := parsePBSListIDs(out, "name", "id") - if err != nil { - return fmt.Errorf("parse endpoint list (%s): %w", typ, err) - } - for _, name := range current { - if _, ok := desiredNames[name]; ok { - continue - } - if _, err := runPBSManager(ctx, "notification", "endpoint", typ, "remove", name); err != nil { - // Built-in endpoints may not be removable; keep going. - logger.Warning("PBS notifications API apply: endpoint remove %s:%s failed (continuing): %v", typ, name, err) - } - } +func removeExtraPBSNotificationEndpoints(ctx context.Context, logger *logging.Logger, typ string, desired map[string]pbsNotificationEndpointSection) error { + current, err := listPBSNotificationIDs(ctx, "endpoint", typ, "list") + if err != nil { + return err + } + for _, name := range current { + if _, ok := desired[name]; ok { + continue + } + if _, err := runPBSManager(ctx, "notification", "endpoint", typ, "remove", name); err != nil { + logger.Warning("PBS notifications API apply: endpoint remove %s:%s failed (continuing): %v", typ, name, err) } + } + return nil +} - for _, name := range names { - e := desiredNames[name] - flags := buildProxmoxManagerFlags(e.section.Entries) - createArgs := append([]string{"notification", "endpoint", typ, "create", name}, e.positional...) - createArgs = append(createArgs, flags...) - if _, err := runPBSManagerRedacted(ctx, createArgs, e.redactFlags, e.redactIndex); err != nil { - updateArgs := append([]string{"notification", "endpoint", typ, "update", name}, e.positional...) - updateArgs = append(updateArgs, flags...) - if _, upErr := runPBSManagerRedacted(ctx, updateArgs, e.redactFlags, e.redactIndex); upErr != nil { - return fmt.Errorf("endpoint %s:%s: %v (create) / %v (update)", typ, name, err, upErr) - } - } +func upsertPBSNotificationEndpoints(ctx context.Context, typ string, desired map[string]pbsNotificationEndpointSection) error { + names := sortedPBSEndpointNames(desired) + for _, name := range names { + if err := upsertPBSNotificationEndpoint(ctx, typ, name, desired[name]); err != nil { + return err } } + return nil +} + +func sortedPBSEndpointNames(desired map[string]pbsNotificationEndpointSection) []string { + names := make([]string, 0, len(desired)) + for name := range desired { + names = append(names, name) + } + sort.Strings(names) + return names +} - // Then matchers. - for _, name := range matcherNames { - m := desiredMatchers[name] - flags := buildProxmoxManagerFlags(m.Entries) - createArgs := append([]string{"notification", "matcher", "create", name}, flags...) - if _, err := runPBSManager(ctx, createArgs...); err != nil { - updateArgs := append([]string{"notification", "matcher", "update", name}, flags...) - if _, upErr := runPBSManager(ctx, updateArgs...); upErr != nil { - return fmt.Errorf("matcher %s: %v (create) / %v (update)", name, err, upErr) - } +func upsertPBSNotificationEndpoint(ctx context.Context, typ, name string, endpoint pbsNotificationEndpointSection) error { + flags := buildProxmoxManagerFlags(endpoint.section.Entries) + createArgs := append([]string{"notification", "endpoint", typ, "create", name}, endpoint.positional...) + createArgs = append(createArgs, flags...) + if _, err := runPBSManagerRedacted(ctx, createArgs, endpoint.redactFlags, endpoint.redactIndex); err != nil { + updateArgs := append([]string{"notification", "endpoint", typ, "update", name}, endpoint.positional...) + updateArgs = append(updateArgs, flags...) + if _, upErr := runPBSManagerRedacted(ctx, updateArgs, endpoint.redactFlags, endpoint.redactIndex); upErr != nil { + return fmt.Errorf("endpoint %s:%s: %w", typ, name, errors.Join(err, upErr)) } } + return nil +} +func syncPBSNotificationMatchers(ctx context.Context, desired pbsNotificationDesiredState) error { + for _, name := range desired.matcherNames { + if err := upsertPBSNotificationMatcher(ctx, name, desired.matchers[name]); err != nil { + return err + } + } + return nil +} + +func upsertPBSNotificationMatcher(ctx context.Context, name string, matcher proxmoxNotificationSection) error { + flags := buildProxmoxManagerFlags(matcher.Entries) + createArgs := append([]string{"notification", "matcher", "create", name}, flags...) + if _, err := runPBSManager(ctx, createArgs...); err != nil { + updateArgs := append([]string{"notification", "matcher", "update", name}, flags...) + if _, upErr := runPBSManager(ctx, updateArgs...); upErr != nil { + return fmt.Errorf("matcher %s: %w", name, errors.Join(err, upErr)) + } + } return nil } + +func listPBSNotificationIDs(ctx context.Context, args ...string) ([]string, error) { + out, err := runPBSManager(ctx, append([]string{"notification"}, args...)...) + if err != nil { + return nil, err + } + current, err := parsePBSListIDs(out, "name", "id") + if err != nil { + return nil, fmt.Errorf("parse %s: %w", strings.Join(args, " "), err) + } + return current, nil +} + +func pbsNotificationSectionKey(typ, name string) string { + return fmt.Sprintf("%s:%s", strings.TrimSpace(typ), strings.TrimSpace(name)) +} diff --git a/internal/orchestrator/pbs_staged_apply.go b/internal/orchestrator/pbs_staged_apply.go index 451bdd64..af2ee805 100644 --- a/internal/orchestrator/pbs_staged_apply.go +++ b/internal/orchestrator/pbs_staged_apply.go @@ -404,11 +404,11 @@ func loadPBSDatastoreCfgFromInventory(stageRoot string) (string, string, error) if out.Len() > 0 { out.WriteString("\n") } - out.WriteString(fmt.Sprintf("datastore: %s\n", name)) + fmt.Fprintf(&out, "datastore: %s\n", name) if comment := strings.TrimSpace(ds.Comment); comment != "" { - out.WriteString(fmt.Sprintf(" comment %s\n", comment)) + fmt.Fprintf(&out, " comment %s\n", comment) } - out.WriteString(fmt.Sprintf(" path %s\n", path)) + fmt.Fprintf(&out, " path %s\n", path) } generated := strings.TrimSpace(out.String()) diff --git a/internal/orchestrator/prompts_cli_test.go b/internal/orchestrator/prompts_cli_test.go index 0377ae5c..e23f5a30 100644 --- a/internal/orchestrator/prompts_cli_test.go +++ b/internal/orchestrator/prompts_cli_test.go @@ -85,8 +85,8 @@ func TestPromptYesNoWithCountdown_InputYes(test *testing.T) { func TestPromptYesNoWithCountdown_TimeoutReturnsNo(test *testing.T) { pipeReader, pipeWriter := io.Pipe() - defer pipeReader.Close() - defer pipeWriter.Close() + defer func() { _ = pipeReader.Close() }() + defer func() { _ = pipeWriter.Close() }() reader := bufio.NewReader(pipeReader) logger := logging.New(types.LogLevelInfo, false) diff --git a/internal/orchestrator/resolv_conf_repair.go b/internal/orchestrator/resolv_conf_repair.go index bce82238..6790c08c 100644 --- a/internal/orchestrator/resolv_conf_repair.go +++ b/internal/orchestrator/resolv_conf_repair.go @@ -153,7 +153,7 @@ func readTarEntry(ctx context.Context, archivePath, name string, maxBytes int64) if err != nil { return nil, fmt.Errorf("open archive: %w", err) } - defer file.Close() + defer closeIntoErr(&err, file, "close archive") reader, err := createDecompressionReader(ctx, file, archivePath) if err != nil { @@ -182,7 +182,7 @@ func readTarEntry(ctx context.Context, archivePath, name string, maxBytes int64) if header.Name != wantA && header.Name != wantB { continue } - if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeRegA { + if header.Typeflag != tar.TypeReg { return nil, fmt.Errorf("archive entry %s is not a regular file", header.Name) } diff --git a/internal/orchestrator/restore_archive.go b/internal/orchestrator/restore_archive.go index 363b7387..38759c6c 100644 --- a/internal/orchestrator/restore_archive.go +++ b/internal/orchestrator/restore_archive.go @@ -124,7 +124,11 @@ func runFullRestoreFstabMerge(ctx context.Context, reader *bufio.Reader, archive logger.Warning("Failed to create temp dir for fstab merge: %v", err) return nil } - defer restoreFS.RemoveAll(fsTempDir) + defer func() { + if err := restoreFS.RemoveAll(fsTempDir); err != nil { + logger.Debug("Failed to remove temporary fstab merge directory %s: %v", fsTempDir, err) + } + }() if err := extractFullRestoreFstab(ctx, archivePath, fsTempDir, logger); err != nil { logger.Warning("Failed to extract filesystem config for merge: %v", err) @@ -254,7 +258,7 @@ func extractSelectiveArchive(ctx context.Context, archivePath, destRoot string, // Create detailed log directory logDir := "/tmp/proxsave" - if err := restoreFS.MkdirAll(logDir, 0o755); err != nil { + if err := restoreFS.MkdirAll(logDir, 0o700); err != nil { logger.Warning("Could not create log directory: %v", err) } @@ -267,7 +271,11 @@ func extractSelectiveArchive(ctx context.Context, archivePath, destRoot string, logger.Warning("Could not create detailed log file: %v", err) logFile = nil } else { - defer logFile.Close() + defer func() { + if closeErr := logFile.Close(); closeErr != nil { + logger.Warning("close detailed restore log: %v", closeErr) + } + }() logger.Info("Detailed restore log: %s", logPath) logging.DebugStep(logger, "extract selective archive", "log file=%s", logPath) } diff --git a/internal/orchestrator/restore_archive_entries.go b/internal/orchestrator/restore_archive_entries.go index 0411c84b..44fb074e 100644 --- a/internal/orchestrator/restore_archive_entries.go +++ b/internal/orchestrator/restore_archive_entries.go @@ -209,12 +209,16 @@ func extractSymlink(target string, header *tar.Header, destRoot string, logger * // POST-CREATION VALIDATION: Verify the created symlink's target stays within destRoot actualTarget, err := restoreFS.Readlink(target) if err != nil { - restoreFS.Remove(target) // Clean up + if removeErr := restoreFS.Remove(target); removeErr != nil && !os.IsNotExist(removeErr) { + logger.Debug("Failed to remove symlink %s after readlink error: %v", target, removeErr) + } return fmt.Errorf("read created symlink %s: %w", target, err) } if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), actualTarget); err != nil { - restoreFS.Remove(target) + if removeErr := restoreFS.Remove(target); removeErr != nil && !os.IsNotExist(removeErr) { + logger.Debug("Failed to remove unsafe symlink %s: %v", target, removeErr) + } return fmt.Errorf("symlink target escapes root after creation: %s -> %s: %w", header.Name, actualTarget, err) } diff --git a/internal/orchestrator/restore_archive_extract.go b/internal/orchestrator/restore_archive_extract.go index ea16abb4..6544f15b 100644 --- a/internal/orchestrator/restore_archive_extract.go +++ b/internal/orchestrator/restore_archive_extract.go @@ -43,7 +43,7 @@ func extractArchiveNative(ctx context.Context, opts restoreArchiveOptions) (err if err != nil { return fmt.Errorf("open archive: %w", err) } - defer file.Close() + defer closeIntoErr(&err, file, "close archive") reader, err := createDecompressionReader(ctx, file, opts.archivePath) if err != nil { @@ -97,7 +97,7 @@ func closeAndRemoveRestoreTemp(file *os.File) { if file == nil { return } - file.Close() + _ = file.Close() _ = restoreFS.Remove(file.Name()) } @@ -105,19 +105,19 @@ func (log *restoreExtractionLog) writeHeader(opts restoreArchiveOptions) { if log.logFile == nil { return } - fmt.Fprintf(log.logFile, "=== PROXMOX RESTORE LOG ===\n") - fmt.Fprintf(log.logFile, "Date: %s\n", nowRestore().Format("2006-01-02 15:04:05")) - fmt.Fprintf(log.logFile, "Mode: %s\n", getModeName(opts.mode)) + _, _ = fmt.Fprintf(log.logFile, "=== PROXMOX RESTORE LOG ===\n") + _, _ = fmt.Fprintf(log.logFile, "Date: %s\n", nowRestore().Format("2006-01-02 15:04:05")) + _, _ = fmt.Fprintf(log.logFile, "Mode: %s\n", getModeName(opts.mode)) if len(opts.categories) > 0 { - fmt.Fprintf(log.logFile, "Selected categories: %d categories\n", len(opts.categories)) + _, _ = fmt.Fprintf(log.logFile, "Selected categories: %d categories\n", len(opts.categories)) for _, cat := range opts.categories { - fmt.Fprintf(log.logFile, " - %s (%s)\n", cat.Name, cat.ID) + _, _ = fmt.Fprintf(log.logFile, " - %s (%s)\n", cat.Name, cat.ID) } } else { - fmt.Fprintf(log.logFile, "Selected categories: ALL (full restore)\n") + _, _ = fmt.Fprintf(log.logFile, "Selected categories: ALL (full restore)\n") } - fmt.Fprintf(log.logFile, "Archive: %s\n", filepath.Base(opts.archivePath)) - fmt.Fprintf(log.logFile, "\n") + _, _ = fmt.Fprintf(log.logFile, "Archive: %s\n", filepath.Base(opts.archivePath)) + _, _ = fmt.Fprintf(log.logFile, "\n") } func processRestoreArchiveEntries(ctx context.Context, tarReader *tar.Reader, opts restoreArchiveOptions, extractionLog *restoreExtractionLog) (restoreExtractionStats, error) { @@ -179,13 +179,13 @@ func restoreEntryMatchesCategories(entryName string, categories []Category) bool func (log *restoreExtractionLog) recordSkipped(name, reason string) { if log.skippedTemp != nil { - fmt.Fprintf(log.skippedTemp, "SKIPPED: %s (%s)\n", name, reason) + _, _ = fmt.Fprintf(log.skippedTemp, "SKIPPED: %s (%s)\n", name, reason) } } func (log *restoreExtractionLog) recordRestored(name string) { if log.restoredTemp != nil { - fmt.Fprintf(log.restoredTemp, "RESTORED: %s\n", name) + _, _ = fmt.Fprintf(log.restoredTemp, "RESTORED: %s\n", name) } } @@ -193,19 +193,19 @@ func (log *restoreExtractionLog) writeSummary(stats restoreExtractionStats) { if log.logFile == nil { return } - fmt.Fprintf(log.logFile, "=== FILES RESTORED ===\n") + _, _ = fmt.Fprintf(log.logFile, "=== FILES RESTORED ===\n") log.copyTempEntries(log.restoredTemp, "restored") - fmt.Fprintf(log.logFile, "\n") + _, _ = fmt.Fprintf(log.logFile, "\n") - fmt.Fprintf(log.logFile, "=== FILES SKIPPED ===\n") + _, _ = fmt.Fprintf(log.logFile, "=== FILES SKIPPED ===\n") log.copyTempEntries(log.skippedTemp, "skipped") - fmt.Fprintf(log.logFile, "\n") + _, _ = fmt.Fprintf(log.logFile, "\n") - fmt.Fprintf(log.logFile, "=== SUMMARY ===\n") - fmt.Fprintf(log.logFile, "Total files extracted: %d\n", stats.filesExtracted) - fmt.Fprintf(log.logFile, "Total files skipped: %d\n", stats.filesSkipped) - fmt.Fprintf(log.logFile, "Total files failed: %d\n", stats.filesFailed) - fmt.Fprintf(log.logFile, "Total files in archive: %d\n", stats.filesExtracted+stats.filesSkipped+stats.filesFailed) + _, _ = fmt.Fprintf(log.logFile, "=== SUMMARY ===\n") + _, _ = fmt.Fprintf(log.logFile, "Total files extracted: %d\n", stats.filesExtracted) + _, _ = fmt.Fprintf(log.logFile, "Total files skipped: %d\n", stats.filesSkipped) + _, _ = fmt.Fprintf(log.logFile, "Total files failed: %d\n", stats.filesFailed) + _, _ = fmt.Fprintf(log.logFile, "Total files in archive: %d\n", stats.filesExtracted+stats.filesSkipped+stats.filesFailed) } func (log *restoreExtractionLog) copyTempEntries(tempFile *os.File, label string) { diff --git a/internal/orchestrator/restore_coverage_extra_test.go b/internal/orchestrator/restore_coverage_extra_test.go index e21566cc..d073de57 100644 --- a/internal/orchestrator/restore_coverage_extra_test.go +++ b/internal/orchestrator/restore_coverage_extra_test.go @@ -664,7 +664,7 @@ func TestRunRestoreCommandStream_FallsBackToExecCommand(t *testing.T) { if err != nil { t.Fatalf("runRestoreCommandStream error: %v", err) } - defer reader.Close() + defer func() { _ = reader.Close() }() out, err := io.ReadAll(reader) if err != nil { @@ -700,15 +700,15 @@ func TestExtractTarEntry_SkipsSensitiveSystemPathsOnRootRestore(t *testing.T) { } } -func writeTarFile(path string, files map[string]string) error { +func writeTarFile(path string, files map[string]string) (err error) { f, err := os.Create(path) if err != nil { return err } - defer f.Close() + defer closeIntoErr(&err, f, "close test tar file") tw := tar.NewWriter(f) - defer tw.Close() + defer closeIntoErr(&err, tw, "close test tar writer") for name, content := range files { b := []byte(content) diff --git a/internal/orchestrator/restore_decision.go b/internal/orchestrator/restore_decision.go index 903984f6..dc2208c8 100644 --- a/internal/orchestrator/restore_decision.go +++ b/internal/orchestrator/restore_decision.go @@ -45,6 +45,7 @@ type restoreArchiveInspection struct { const ( restoreDecisionMetadataPath = "var/lib/proxsave-info/backup_metadata.txt" restoreDecisionMetadataMaxBytes = 8 * 1024 + restoreDecisionNulTypeFlag = byte(0) ) // AnalyzeRestoreArchive inspects the archive once and derives trusted restore facts @@ -85,7 +86,7 @@ func inspectRestoreArchiveContents(archivePath string, logger *logging.Logger) ( if err != nil { return nil, fmt.Errorf("open archive: %w", err) } - defer file.Close() + defer closeIntoErr(&err, file, "close archive") reader, err := createDecompressionReader(context.Background(), file, archivePath) if err != nil { @@ -162,7 +163,7 @@ func readRestoreDecisionMetadata(tarReader *tar.Reader, header *tar.Header) ([]b if header == nil { return nil, fmt.Errorf("restore metadata entry is missing a tar header") } - if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeRegA { + if header.Typeflag != tar.TypeReg && header.Typeflag != restoreDecisionNulTypeFlag { return nil, fmt.Errorf("archive entry %s is not a regular file", header.Name) } diff --git a/internal/orchestrator/restore_decision_test.go b/internal/orchestrator/restore_decision_test.go index a6a0bdd5..6d527dc1 100644 --- a/internal/orchestrator/restore_decision_test.go +++ b/internal/orchestrator/restore_decision_test.go @@ -62,6 +62,41 @@ func tarBytes(t *testing.T, files map[string]string) []byte { return buf.Bytes() } +func TestReadRestoreDecisionMetadataAcceptsNulTypeFlag(t *testing.T) { + const nulTypeFlag byte = 0 + + data := []byte("BACKUP_TYPE=pbs\n") + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + if err := tw.WriteHeader(&tar.Header{ + Name: "var/lib/proxsave-info/backup_metadata.txt", + Typeflag: nulTypeFlag, + Mode: 0o640, + Size: int64(len(data)), + }); err != nil { + t.Fatalf("WriteHeader: %v", err) + } + if _, err := tw.Write(data); err != nil { + t.Fatalf("Write: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close tar writer: %v", err) + } + + tr := tar.NewReader(&buf) + header, err := tr.Next() + if err != nil { + t.Fatalf("Next: %v", err) + } + got, err := readRestoreDecisionMetadata(tr, header) + if err != nil { + t.Fatalf("readRestoreDecisionMetadata: %v", err) + } + if string(got) != string(data) { + t.Fatalf("metadata=%q; want %q", string(got), string(data)) + } +} + func TestAnalyzeRestoreArchive_UsesInternalMetadataWhenCategoriesAreCommonOnly(t *testing.T) { origRestoreFS := restoreFS t.Cleanup(func() { restoreFS = origRestoreFS }) @@ -184,7 +219,7 @@ func TestCollectRestoreArchiveFacts_RejectsOversizedMetadata(t *testing.T) { if err != nil { t.Fatalf("os.Open: %v", err) } - defer file.Close() + defer func() { _ = file.Close() }() archivePaths, metadata, metadataErr, err := collectRestoreArchiveFacts(tar.NewReader(file)) if err != nil { diff --git a/internal/orchestrator/restore_decompression.go b/internal/orchestrator/restore_decompression.go index 224fea52..507e2968 100644 --- a/internal/orchestrator/restore_decompression.go +++ b/internal/orchestrator/restore_decompression.go @@ -96,7 +96,7 @@ func runRestoreCommandStream(ctx context.Context, name string, stdin io.Reader, return nil, fmt.Errorf("create %s pipe: %w", name, err) } if err := cmd.Start(); err != nil { - stdout.Close() + _ = stdout.Close() return nil, fmt.Errorf("start %s: %w", name, err) } return &waitReadCloser{ReadCloser: stdout, wait: cmd.Wait}, nil diff --git a/internal/orchestrator/restore_errors_test.go b/internal/orchestrator/restore_errors_test.go index bfe6c39e..d83f0a47 100644 --- a/internal/orchestrator/restore_errors_test.go +++ b/internal/orchestrator/restore_errors_test.go @@ -48,14 +48,14 @@ func TestRunRestoreCommandStream_UsesStreamingRunner(t *testing.T) { if err != nil { t.Fatalf("CreateTemp: %v", err) } - defer os.Remove(tmp.Name()) - defer tmp.Close() + defer func() { _ = os.Remove(tmp.Name()) }() + defer func() { _ = tmp.Close() }() reader, err := createXZReader(context.Background(), tmp) if err != nil { t.Fatalf("createXZReader: %v", err) } - defer reader.Close() + defer func() { _ = reader.Close() }() buf, err := io.ReadAll(reader) if err != nil { @@ -1017,7 +1017,7 @@ func TestExtractRegularFile_CopyFails(t *testing.T) { tw := tar.NewWriter(&buf) _ = tw.WriteHeader(header) _, _ = tw.Write([]byte("short")) // Only 5 bytes but header says 100 - tw.Close() + _ = tw.Close() tr := tar.NewReader(&buf) _, _ = tr.Next() @@ -1066,7 +1066,7 @@ func TestExtractRegularFile_CopyFailsPreservesExistingTarget(t *testing.T) { tw := tar.NewWriter(&buf) _ = tw.WriteHeader(header) _, _ = tw.Write([]byte("short")) - tw.Close() + _ = tw.Close() tr := tar.NewReader(&buf) _, _ = tr.Next() @@ -1323,7 +1323,9 @@ func TestSleepWithContext_ContextCanceled(t *testing.T) { cancel() start := time.Now() - sleepWithContext(ctx, 10*time.Second) + if err := sleepWithContext(ctx, 10*time.Second); err == nil { + t.Fatalf("expected cancellation error") + } elapsed := time.Since(start) // Should return immediately due to canceled context @@ -1565,7 +1567,7 @@ func TestCreateDecompressionReader_UnknownExtension(t *testing.T) { if err != nil { t.Fatalf("open: %v", err) } - defer file.Close() + defer func() { _ = file.Close() }() // Should return error for unknown extension _, err = createDecompressionReader(context.Background(), file, filePath) @@ -1833,7 +1835,9 @@ func TestExtractRegularFile_Success(t *testing.T) { if _, err := tw.Write(content); err != nil { t.Fatalf("write content: %v", err) } - tw.Close() + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } tr := tar.NewReader(&buf) if _, err := tr.Next(); err != nil { diff --git a/internal/orchestrator/restore_firewall_additional_test.go b/internal/orchestrator/restore_firewall_additional_test.go index beef3796..0a094108 100644 --- a/internal/orchestrator/restore_firewall_additional_test.go +++ b/internal/orchestrator/restore_firewall_additional_test.go @@ -193,7 +193,7 @@ func (e staticDirEntry) Name() string { return e.name } func (e staticDirEntry) IsDir() bool { return e.mode.IsDir() } func (e staticDirEntry) Type() fs.FileMode { return e.mode } func (e staticDirEntry) Info() (fs.FileInfo, error) { - return staticFileInfo{name: e.name, mode: e.mode}, nil + return staticFileInfo(e), nil } type scriptedConfirmAction struct { diff --git a/internal/orchestrator/restore_test.go b/internal/orchestrator/restore_test.go index 622057ca..c98e133e 100644 --- a/internal/orchestrator/restore_test.go +++ b/internal/orchestrator/restore_test.go @@ -1235,7 +1235,9 @@ func TestMinDuration(t *testing.T) { func TestSleepWithContext_Normal(t *testing.T) { ctx := context.Background() start := time.Now() - sleepWithContext(ctx, 50*time.Millisecond) + if err := sleepWithContext(ctx, 50*time.Millisecond); err != nil { + t.Fatalf("sleepWithContext error: %v", err) + } elapsed := time.Since(start) if elapsed < 40*time.Millisecond { t.Fatalf("sleep too short: %v", elapsed) @@ -1246,7 +1248,9 @@ func TestSleepWithContext_Cancelled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() start := time.Now() - sleepWithContext(ctx, 1*time.Second) + if err := sleepWithContext(ctx, 1*time.Second); err == nil { + t.Fatalf("expected cancellation error") + } elapsed := time.Since(start) if elapsed > 100*time.Millisecond { t.Fatalf("sleep should have returned immediately: %v", elapsed) diff --git a/internal/orchestrator/restore_tui.go b/internal/orchestrator/restore_tui.go index 1bb3f5da..4ff09f30 100644 --- a/internal/orchestrator/restore_tui.go +++ b/internal/orchestrator/restore_tui.go @@ -108,8 +108,8 @@ func selectRestoreModeTUI(ctx context.Context, systemType SystemType, configPath listItem := components.NewListFormItem(list). SetLabel("Select restore mode"). SetFieldHeight(8) - form.Form.AddFormItem(listItem) - form.Form.SetFocus(0) + form.AddFormItem(listItem) + form.SetFocus(0) form.SetOnCancel(func() { aborted = true @@ -192,8 +192,8 @@ func selectPBSRestoreBehaviorTUI(ctx context.Context, configPath, buildSig, back listItem := components.NewListFormItem(list). SetLabel("Select PBS restore behavior"). SetFieldHeight(6) - form.Form.AddFormItem(listItem) - form.Form.SetFocus(0) + form.AddFormItem(listItem) + form.SetFocus(0) form.SetOnCancel(func() { aborted = true @@ -298,7 +298,7 @@ func selectCategoriesTUI(ctx context.Context, available []Category, systemType S return event }) - form.Form.AddFormItem(dropdown) + form.AddFormItem(dropdown) if strings.TrimSpace(cat.Description) != "" { desc := tview.NewInputField(). @@ -306,7 +306,7 @@ func selectCategoriesTUI(ctx context.Context, available []Category, systemType S SetFieldWidth(0). SetText(""). SetDisabled(true) - form.Form.AddFormItem(desc) + form.AddFormItem(desc) } } @@ -333,7 +333,7 @@ func selectCategoriesTUI(ctx context.Context, available []Category, systemType S }) // Buttons: Back, Continue, Cancel - form.Form.AddButton("Back", func() { + form.AddButton("Back", func() { goBack = true app.Stop() }) @@ -550,8 +550,8 @@ func promptClusterRestoreModeTUI(ctx context.Context, configPath, buildSig strin listItem := components.NewListFormItem(list). SetLabel("Cluster restore mode"). SetFieldHeight(6) - form.Form.AddFormItem(listItem) - form.Form.SetFocus(0) + form.AddFormItem(listItem) + form.SetFocus(0) form.SetOnCancel(func() { aborted = true @@ -904,13 +904,11 @@ func promptNetworkCommitTUI(ctx context.Context, timeout time.Duration, health n var b strings.Builder for _, check := range report.Checks { color := healthColor(check.Severity) - b.WriteString(fmt.Sprintf( - "- [%s]%s[white] %s: %s\n", + fmt.Fprintf(&b, "- [%s]%s[white] %s: %s\n", color, check.Severity.String(), tview.Escape(check.Name), - tview.Escape(check.Message), - )) + tview.Escape(check.Message)) } return strings.TrimRight(b.String(), "\n") } @@ -934,7 +932,7 @@ func promptNetworkCommitTUI(ctx context.Context, timeout time.Duration, health n } var b strings.Builder for _, m := range r.AppliedNICMap { - b.WriteString(fmt.Sprintf("- %s -> %s\n", tview.Escape(m.OldName), tview.Escape(m.NewName))) + fmt.Fprintf(&b, "- %s -> %s\n", tview.Escape(m.OldName), tview.Escape(m.NewName)) } return strings.TrimRight(b.String(), "\n") } diff --git a/internal/orchestrator/restore_tui_simulation_test.go b/internal/orchestrator/restore_tui_simulation_test.go index bb255362..afba171a 100644 --- a/internal/orchestrator/restore_tui_simulation_test.go +++ b/internal/orchestrator/restore_tui_simulation_test.go @@ -7,6 +7,8 @@ import ( "github.com/gdamore/tcell/v2" ) +type restoreTUITestContextKey struct{} + func TestPromptYesNoTUI_YesReturnsTrue(t *testing.T) { withSimApp(t, []tcell.Key{tcell.KeyEnter}) @@ -63,7 +65,7 @@ func TestShowRestorePlanTUI_CancelReturnsAborted(t *testing.T) { } func TestConfirmRestoreTUI_ConfirmedAndOverwriteReturnsTrue(t *testing.T) { - expectedCtx := context.WithValue(context.Background(), struct{}{}, "confirm-restore") + expectedCtx := context.WithValue(context.Background(), restoreTUITestContextKey{}, "confirm-restore") restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { if ctx != expectedCtx { t.Fatalf("stub received unexpected context: got %v want %v", ctx, expectedCtx) @@ -84,7 +86,7 @@ func TestConfirmRestoreTUI_ConfirmedAndOverwriteReturnsTrue(t *testing.T) { } func TestConfirmRestoreTUI_OverwriteDeclinedReturnsFalse(t *testing.T) { - expectedCtx := context.WithValue(context.Background(), struct{}{}, "overwrite-declined") + expectedCtx := context.WithValue(context.Background(), restoreTUITestContextKey{}, "overwrite-declined") restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { if ctx != expectedCtx { t.Fatalf("stub received unexpected context: got %v want %v", ctx, expectedCtx) diff --git a/internal/orchestrator/restore_workflow_decision_test.go b/internal/orchestrator/restore_workflow_decision_test.go index efa2f33f..03411b80 100644 --- a/internal/orchestrator/restore_workflow_decision_test.go +++ b/internal/orchestrator/restore_workflow_decision_test.go @@ -27,6 +27,52 @@ func stubPreparedRestoreBundle(archivePath string, manifest *backup.Manifest) fu } } +func TestRunRestoreWorkflow_CleansPreparedBundleWhenPlanningFails(t *testing.T) { + origRestoreSystem := restoreSystem + origPrepare := prepareRestoreBundleFunc + origAnalyze := analyzeRestoreArchiveFunc + t.Cleanup(func() { + restoreSystem = origRestoreSystem + prepareRestoreBundleFunc = origPrepare + analyzeRestoreArchiveFunc = origAnalyze + }) + + restoreSystem = fakeSystemDetector{systemType: SystemTypePVE} + cleanupCalls := 0 + prepareRestoreBundleFunc = func(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui RestoreWorkflowUI) (*backupCandidate, *preparedBundle, error) { + return &backupCandidate{ + DisplayBase: "test", + Manifest: &backup.Manifest{ + CreatedAt: time.Unix(1700000000, 0), + ProxmoxType: "pve", + ScriptVersion: "vtest", + }, + }, &preparedBundle{ + ArchivePath: "/bundle.tar", + Manifest: backup.Manifest{ArchivePath: "/bundle.tar"}, + cleanup: func() { + cleanupCalls++ + }, + }, nil + } + analyzeRestoreArchiveFunc = func(archivePath string, logger *logging.Logger) ([]Category, *RestoreDecisionInfo, error) { + return nil, &RestoreDecisionInfo{BackupType: SystemTypePVE}, nil + } + + wantErr := errors.New("select mode failed") + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{BaseDir: "/base"} + ui := &fakeRestoreWorkflowUI{modeErr: wantErr} + + err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui) + if !errors.Is(err, wantErr) { + t.Fatalf("err=%v; want %v", err, wantErr) + } + if cleanupCalls != 1 { + t.Fatalf("cleanupCalls=%d; want 1", cleanupCalls) + } +} + func TestRunRestoreWorkflow_ClusterPromptUsesArchivePayloadNotManifest(t *testing.T) { origRestoreFS := restoreFS origRestoreCmd := restoreCmd diff --git a/internal/orchestrator/restore_workflow_test.go b/internal/orchestrator/restore_workflow_test.go index e22bbbf1..fe84b4ba 100644 --- a/internal/orchestrator/restore_workflow_test.go +++ b/internal/orchestrator/restore_workflow_test.go @@ -27,10 +27,8 @@ func writeMinimalTar(t *testing.T, dir string) string { if err != nil { t.Fatalf("create tar: %v", err) } - defer f.Close() tw := tar.NewWriter(f) - defer tw.Close() body := []byte("hello\n") hdr := &tar.Header{ @@ -49,6 +47,12 @@ func writeMinimalTar(t *testing.T, dir string) string { if err := tw.Flush(); err != nil { t.Fatalf("flush tar: %v", err) } + if err := tw.Close(); err != nil { + t.Fatalf("close tar writer: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("close tar file: %v", err) + } return path } diff --git a/internal/orchestrator/restore_workflow_ui.go b/internal/orchestrator/restore_workflow_ui.go index 0fd91a27..97a2b20c 100644 --- a/internal/orchestrator/restore_workflow_ui.go +++ b/internal/orchestrator/restore_workflow_ui.go @@ -6,11 +6,7 @@ import ( "errors" "fmt" "io" - "os" - "os/exec" - "path/filepath" "strings" - "time" "github.com/tis24dev/proxsave/internal/backup" "github.com/tis24dev/proxsave/internal/config" @@ -60,1248 +56,30 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l done := logging.DebugStart(logger, "restore workflow (ui)", "version=%s", version) defer func() { done(err) }() + defer func() { err = normalizeRestoreWorkflowUIError(ctx, logger, err) }() - restoreHadWarnings := false - defer func() { - if err == nil { - return - } - if err == io.EOF { - logger.Warning("Restore input closed unexpectedly (EOF). This usually means the interactive UI lost access to stdin/TTY (e.g., SSH disconnect or non-interactive execution). Re-run with --restore --cli from an interactive shell.") - err = ErrRestoreAborted - return - } - if errors.Is(err, input.ErrInputAborted) || - errors.Is(err, ErrDecryptAborted) || - errors.Is(err, ErrAgeRecipientSetupAborted) || - errors.Is(err, context.Canceled) || - (ctx != nil && ctx.Err() != nil) { - err = ErrRestoreAborted - } - }() - - candidate, prepared, err := prepareRestoreBundleFunc(ctx, cfg, logger, version, ui) - if err != nil { - return err - } - defer prepared.Cleanup() - - destRoot := "/" - logger.Info("Restore target: system root (/) — files will be written back to their original paths") - - systemType := restoreSystem.DetectCurrentSystem() - logger.Info("Detected system type: %s", GetSystemTypeString(systemType)) - - availableCategories, decisionInfo, err := analyzeRestoreArchiveFunc(prepared.ArchivePath, logger) - fallbackToFullRestore := false - if err != nil { - logger.Warning("Could not analyze categories: %v", err) - availableCategories = nil - decisionInfo = fallbackRestoreDecisionInfoFromManifest(candidate.Manifest) - fallbackToFullRestore = true - } - if decisionInfo == nil { - decisionInfo = &RestoreDecisionInfo{} - } - - if warn := ValidateCompatibility(systemType, decisionInfo.BackupType); warn != nil { - logger.Warning("Compatibility check: %v", warn) - proceed, perr := ui.ConfirmCompatibility(ctx, warn) - if perr != nil { - return perr - } - if !proceed { - return ErrRestoreAborted - } - } - if fallbackToFullRestore { - logger.Info("Falling back to full restore mode") - return runFullRestoreWithUI(ctx, ui, candidate, prepared, destRoot, logger, cfg.DryRun) - } - - var ( - mode RestoreMode - selectedCategories []Category - ) - for { - mode, err = ui.SelectRestoreMode(ctx, systemType) - if err != nil { - return err - } - - if mode != RestoreModeCustom { - selectedCategories = GetCategoriesForMode(mode, systemType, availableCategories) - break - } - - selectedCategories, err = ui.SelectCategories(ctx, availableCategories, systemType) - if err != nil { - if errors.Is(err, errRestoreBackToMode) { - continue - } - return err - } - break - } - - if mode == RestoreModeCustom { - selectedCategories, err = maybeAddRecommendedCategoriesForTFA(ctx, ui, logger, selectedCategories, availableCategories) - if err != nil { - return err - } - } - - plan := PlanRestore(decisionInfo.ClusterPayload, selectedCategories, systemType, mode) - - if plan.SystemType.SupportsPBS() && - (plan.HasCategoryID("pbs_host") || - plan.HasCategoryID("datastore_pbs") || - plan.HasCategoryID("pbs_remotes") || - plan.HasCategoryID("pbs_jobs") || - plan.HasCategoryID("pbs_notifications") || - plan.HasCategoryID("pbs_access_control") || - plan.HasCategoryID("pbs_tape")) { - behavior, err := ui.SelectPBSRestoreBehavior(ctx) - if err != nil { - return err - } - plan.PBSRestoreBehavior = behavior - logger.Info("PBS restore behavior: %s", behavior.DisplayName()) - } - - if plan.NeedsClusterRestore && plan.ClusterBackup { - logger.Info("Cluster payload detected in backup; enabling guarded restore options for pve_cluster") - choice, promptErr := ui.SelectClusterRestoreMode(ctx) - if promptErr != nil { - return promptErr - } - switch choice { - case ClusterRestoreAbort: - return ErrRestoreAborted - case ClusterRestoreSafe: - plan.ApplyClusterSafeMode(true) - logger.Info("Selected SAFE cluster restore: /var/lib/pve-cluster will be exported only, not written to system") - case ClusterRestoreRecovery: - plan.ApplyClusterSafeMode(false) - logger.Warning("Selected RECOVERY cluster restore: full cluster database will be restored; ensure other nodes are isolated") - default: - return fmt.Errorf("invalid cluster restore mode selected") - } - } - - if plan.HasCategoryID("pve_access_control") || plan.HasCategoryID("pbs_access_control") { - currentHost, hostErr := os.Hostname() - if hostErr == nil && strings.TrimSpace(decisionInfo.BackupHostname) != "" && strings.TrimSpace(currentHost) != "" { - backupHost := strings.TrimSpace(decisionInfo.BackupHostname) - if !strings.EqualFold(strings.TrimSpace(currentHost), backupHost) { - logger.Warning("Access control/TFA: backup hostname=%s current hostname=%s; WebAuthn users may require re-enrollment if the UI origin (FQDN/port) changes", backupHost, currentHost) - } - } - } - - if destRoot != "/" || !isRealRestoreFS(restoreFS) { - if len(plan.StagedCategories) > 0 { - logging.DebugStep(logger, "restore", "Staging disabled (destRoot=%s realFS=%v): extracting %d staged category(ies) directly", destRoot, isRealRestoreFS(restoreFS), len(plan.StagedCategories)) - plan.NormalCategories = append(plan.NormalCategories, plan.StagedCategories...) - plan.StagedCategories = nil - } - } - - restoreConfig := &SelectiveRestoreConfig{ - Mode: mode, - SystemType: systemType, - Metadata: candidate.Manifest, - } - restoreConfig.SelectedCategories = append(restoreConfig.SelectedCategories, plan.NormalCategories...) - restoreConfig.SelectedCategories = append(restoreConfig.SelectedCategories, plan.StagedCategories...) - restoreConfig.SelectedCategories = append(restoreConfig.SelectedCategories, plan.ExportCategories...) - - if err := ui.ShowRestorePlan(ctx, restoreConfig); err != nil { - return err - } - - confirmed, err := ui.ConfirmRestore(ctx) - if err != nil { - return err - } - if !confirmed { - logger.Info("Restore operation cancelled by user") - return ErrRestoreAborted - } - - var safetyBackup *SafetyBackupResult - var networkRollbackBackup *SafetyBackupResult - var firewallRollbackBackup *SafetyBackupResult - var haRollbackBackup *SafetyBackupResult - var accessControlRollbackBackup *SafetyBackupResult - systemWriteCategories := append([]Category{}, plan.NormalCategories...) - systemWriteCategories = append(systemWriteCategories, plan.StagedCategories...) - if len(systemWriteCategories) > 0 { - logger.Info("") - safetyBackup, err = CreateSafetyBackup(logger, systemWriteCategories, destRoot) - if err != nil { - logger.Warning("Failed to create safety backup: %v", err) - cont, perr := ui.ConfirmContinueWithoutSafetyBackup(ctx, err) - if perr != nil { - return perr - } - if !cont { - return ErrRestoreAborted - } - } else { - logger.Info("Safety backup location: %s", safetyBackup.BackupPath) - logger.Info("You can restore from this backup if needed using: tar -xzf %s -C /", safetyBackup.BackupPath) - } - } - - if plan.HasCategoryID("network") { - logger.Info("") - logging.DebugStep(logger, "restore", "Create network-only rollback backup for transactional network apply") - networkRollbackBackup, err = CreateNetworkRollbackBackup(logger, systemWriteCategories, destRoot) - if err != nil { - logger.Warning("Failed to create network rollback backup: %v", err) - } else if networkRollbackBackup != nil && strings.TrimSpace(networkRollbackBackup.BackupPath) != "" { - logger.Info("Network rollback backup location: %s", networkRollbackBackup.BackupPath) - logger.Info("This backup is used for the %ds network rollback timer and only includes network paths.", int(defaultNetworkRollbackTimeout.Seconds())) - } - } - if plan.HasCategoryID("pve_firewall") { - logger.Info("") - logging.DebugStep(logger, "restore", "Create firewall-only rollback backup for transactional firewall apply") - firewallRollbackBackup, err = CreateFirewallRollbackBackup(logger, systemWriteCategories, destRoot) - if err != nil { - logger.Warning("Failed to create firewall rollback backup: %v", err) - } else if firewallRollbackBackup != nil && strings.TrimSpace(firewallRollbackBackup.BackupPath) != "" { - logger.Info("Firewall rollback backup location: %s", firewallRollbackBackup.BackupPath) - logger.Info("This backup is used for the %ds firewall rollback timer and only includes firewall paths.", int(defaultFirewallRollbackTimeout.Seconds())) - } - } - if plan.HasCategoryID("pve_ha") { - logger.Info("") - logging.DebugStep(logger, "restore", "Create HA-only rollback backup for transactional HA apply") - haRollbackBackup, err = CreateHARollbackBackup(logger, systemWriteCategories, destRoot) - if err != nil { - logger.Warning("Failed to create HA rollback backup: %v", err) - } else if haRollbackBackup != nil && strings.TrimSpace(haRollbackBackup.BackupPath) != "" { - logger.Info("HA rollback backup location: %s", haRollbackBackup.BackupPath) - logger.Info("This backup is used for the %ds HA rollback timer and only includes HA paths.", int(defaultHARollbackTimeout.Seconds())) - } - } - if plan.SystemType.SupportsPVE() && plan.ClusterBackup && !plan.NeedsClusterRestore && plan.HasCategoryID("pve_access_control") { - logger.Info("") - logging.DebugStep(logger, "restore", "Create access-control-only rollback backup for optional cluster-safe access control apply") - accessControlRollbackBackup, err = CreatePVEAccessControlRollbackBackup(logger, systemWriteCategories, destRoot) - if err != nil { - logger.Warning("Failed to create access control rollback backup: %v", err) - } else if accessControlRollbackBackup != nil && strings.TrimSpace(accessControlRollbackBackup.BackupPath) != "" { - logger.Info("Access control rollback backup location: %s", accessControlRollbackBackup.BackupPath) - logger.Info("This backup is used for the %ds access control rollback timer and only includes access control paths.", int(defaultAccessControlRollbackTimeout.Seconds())) - } - } - - stageLogPath := "" - stageRoot := "" - - needsClusterRestore := plan.NeedsClusterRestore - clusterServicesStopped := false - pbsServicesStopped := false - needsPBSServices := plan.NeedsPBSServices - - if needsClusterRestore { - logger.Info("") - logger.Info("Preparing system for cluster database restore: stopping PVE services and unmounting /etc/pve") - if err := stopPVEClusterServices(ctx, logger); err != nil { - return err - } - clusterServicesStopped = true - defer func() { - restartCtx, cancel := context.WithTimeout(context.Background(), 2*serviceStartTimeout+2*serviceVerifyTimeout+10*time.Second) - defer cancel() - if err := startPVEClusterServices(restartCtx, logger); err != nil { - logger.Warning("Failed to restart PVE services after restore: %v", err) - } - }() - - if err := unmountEtcPVE(ctx, logger); err != nil { - logger.Warning("Could not unmount /etc/pve: %v", err) - } - } - - if needsPBSServices { - logger.Info("") - logger.Info("Preparing PBS system for restore: stopping proxmox-backup services") - if err := stopPBSServices(ctx, logger); err != nil { - logger.Warning("Unable to stop PBS services automatically: %v", err) - cont, perr := ui.ConfirmContinueWithPBSServicesRunning(ctx) - if perr != nil { - return perr - } - if !cont { - return ErrRestoreAborted - } - logger.Warning("Continuing restore with PBS services still running") - } else { - pbsServicesStopped = true - defer func() { - restartCtx, cancel := context.WithTimeout(context.Background(), 2*serviceStartTimeout+2*serviceVerifyTimeout+10*time.Second) - defer cancel() - if err := startPBSServices(restartCtx, logger); err != nil { - logger.Warning("Failed to restart PBS services after restore: %v", err) - return - } - if err := maybeVerifyAndRepairPBSNotificationsAfterRestore(restartCtx, logger, plan, stageRoot, cfg.DryRun); err != nil { - logger.Warning("PBS notifications verification/repair: %v", err) - } - }() - } - } - - var detailedLogPath string - - needsFilesystemRestore := false - if plan.HasCategoryID("filesystem") { - needsFilesystemRestore = true - var filtered []Category - for _, cat := range plan.NormalCategories { - if cat.ID != "filesystem" { - filtered = append(filtered, cat) - } - } - plan.NormalCategories = filtered - logging.DebugStep(logger, "restore", "Filesystem category intercepted: enabling Smart Merge workflow (skipping generic extraction)") - } - - if len(plan.NormalCategories) > 0 { - logger.Info("") - categoriesForExtraction := plan.NormalCategories - if needsClusterRestore { - logging.DebugStep(logger, "restore", "Cluster RECOVERY shadow-guard: sanitize categories to avoid /etc/pve shadow writes") - sanitized, removed := sanitizeCategoriesForClusterRecovery(categoriesForExtraction) - removedPaths := 0 - for _, paths := range removed { - removedPaths += len(paths) - } - logging.DebugStep( - logger, - "restore", - "Cluster RECOVERY shadow-guard: categories_before=%d categories_after=%d removed_categories=%d removed_paths=%d", - len(categoriesForExtraction), - len(sanitized), - len(removed), - removedPaths, - ) - if len(removed) > 0 { - logger.Warning("Cluster RECOVERY restore: skipping direct restore of /etc/pve paths to prevent shadowing while pmxcfs is stopped/unmounted") - for _, cat := range categoriesForExtraction { - if paths, ok := removed[cat.ID]; ok && len(paths) > 0 { - logger.Warning(" - %s (%s): %s", cat.Name, cat.ID, strings.Join(paths, ", ")) - } - } - logger.Info("These paths are expected to be restored from config.db and become visible after /etc/pve is remounted.") - } else { - logging.DebugStep(logger, "restore", "Cluster RECOVERY shadow-guard: no /etc/pve paths detected in selected categories") - } - categoriesForExtraction = sanitized - } - - if len(categoriesForExtraction) == 0 { - logging.DebugStep(logger, "restore", "Skip system-path extraction: no categories remain after shadow-guard") - logger.Info("No system-path categories remain after cluster shadow-guard; skipping system-path extraction.") - } else { - detailedLogPath, err = extractSelectiveArchive(ctx, prepared.ArchivePath, destRoot, categoriesForExtraction, mode, logger) - if err != nil { - logger.Error("Restore failed: %v", err) - if safetyBackup != nil { - logger.Info("You can rollback using the safety backup at: %s", safetyBackup.BackupPath) - } - return err - } - } - } else { - logger.Info("") - logger.Info("No system-path categories selected for restore (only export categories will be processed).") - } - - // Mount-first: restore /etc/fstab (Smart Merge) before applying PBS datastore configs. - if needsFilesystemRestore { - logger.Info("") - fsTempDir, err := restoreFS.MkdirTemp("", "proxsave-fstab-") - if err != nil { - restoreHadWarnings = true - logger.Warning("Failed to create temp dir for fstab merge: %v", err) - } else { - defer restoreFS.RemoveAll(fsTempDir) - fsCat := GetCategoryByID("filesystem", availableCategories) - if fsCat == nil { - logger.Warning("Filesystem category not available in analyzed backup contents; skipping fstab merge") - } else { - fsCategory := []Category{*fsCat} - if _, err := extractSelectiveArchive(ctx, prepared.ArchivePath, fsTempDir, fsCategory, RestoreModeCustom, logger); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - logger.Warning("Failed to extract filesystem config for merge: %v", err) - } else { - // Best-effort: extract ProxSave inventory files used for stable fstab device remapping. - // (e.g., blkid/lsblk JSON from var/lib/proxsave-info). - invCategory := []Category{{ - ID: "fstab_inventory", - Name: "Fstab inventory (device mapping)", - Paths: []string{ - "./var/lib/proxsave-info/commands/system/blkid.txt", - "./var/lib/proxsave-info/commands/system/lsblk_json.json", - "./var/lib/proxsave-info/commands/system/lsblk.txt", - "./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json", - }, - }} - if err := extractArchiveNative(ctx, restoreArchiveOptions{ - archivePath: prepared.ArchivePath, - destRoot: fsTempDir, - logger: logger, - categories: invCategory, - mode: RestoreModeCustom, - }); err != nil { - logger.Debug("Failed to extract fstab inventory data (continuing): %v", err) - } - - currentFstab := filepath.Join(destRoot, "etc", "fstab") - backupFstab := filepath.Join(fsTempDir, "etc", "fstab") - if err := smartMergeFstabWithUI(ctx, logger, ui, currentFstab, backupFstab, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - logger.Info("Restore aborted by user during Smart Filesystem Configuration Merge.") - return err - } - restoreHadWarnings = true - logger.Warning("Smart Fstab Merge failed: %v", err) - } - } - } - } - } - - exportLogPath := "" - exportRoot := "" - if len(plan.ExportCategories) > 0 { - exportRoot = exportDestRoot(cfg.BaseDir) - logger.Info("") - logger.Info("Exporting %d export-only category(ies) to: %s", len(plan.ExportCategories), exportRoot) - if err := restoreFS.MkdirAll(exportRoot, 0o755); err != nil { - return fmt.Errorf("failed to create export directory %s: %w", exportRoot, err) - } - - if exportLog, exErr := extractSelectiveArchive(ctx, prepared.ArchivePath, exportRoot, plan.ExportCategories, RestoreModeCustom, logger); exErr != nil { - if errors.Is(exErr, ErrRestoreAborted) || input.IsAborted(exErr) { - return exErr - } - restoreHadWarnings = true - logger.Warning("Export completed with errors: %v", exErr) - } else { - exportLogPath = exportLog - } - } - - if plan.ClusterSafeMode { - if exportRoot == "" { - logger.Warning("Cluster SAFE mode selected but export directory not available; skipping automatic pvesh apply") - } else { - // Best-effort: extract extra SAFE apply inventory (pools/mappings) used by pvesh apply workflows. - // This keeps SAFE apply usable even when the user did not explicitly export proxsave_info or /etc/pve. - safeInvCategory := []Category{{ - ID: "safe_apply_inventory", - Name: "SAFE apply inventory (pools/mappings)", - Paths: []string{ - "./etc/pve/user.cfg", - "./var/lib/proxsave-info/commands/pve/mapping_pci.json", - "./var/lib/proxsave-info/commands/pve/mapping_usb.json", - "./var/lib/proxsave-info/commands/pve/mapping_dir.json", - }, - }} - if err := extractArchiveNative(ctx, restoreArchiveOptions{ - archivePath: prepared.ArchivePath, - destRoot: exportRoot, - logger: logger, - categories: safeInvCategory, - mode: RestoreModeCustom, - }); err != nil { - logger.Debug("Failed to extract SAFE apply inventory (continuing): %v", err) - } - - if err := runSafeClusterApplyWithUI(ctx, ui, exportRoot, logger, plan); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - logger.Warning("Cluster SAFE apply completed with errors: %v", err) - } - } - } - - stageLogPath = "" - stageRoot = "" - if len(plan.StagedCategories) > 0 { - stageRoot = stageDestRoot() - logger.Info("") - logger.Info("Staging %d sensitive category(ies) to: %s", len(plan.StagedCategories), stageRoot) - if err := restoreFS.MkdirAll(stageRoot, 0o755); err != nil { - return fmt.Errorf("failed to create staging directory %s: %w", stageRoot, err) - } - - if stageLog, err := extractSelectiveArchive(ctx, prepared.ArchivePath, stageRoot, plan.StagedCategories, RestoreModeCustom, logger); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - logger.Warning("Staging completed with errors: %v", err) - } else { - stageLogPath = stageLog - } - - if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, destRoot, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - logger.Warning("PBS mount guard: %v", err) - } - - logger.Info("") - if err := maybeApplyPBSConfigsFromStage(ctx, logger, plan, stageRoot, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - logger.Warning("PBS staged config apply: %v", err) - } - if err := maybeApplyPVEConfigsFromStage(ctx, logger, plan, stageRoot, destRoot, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - logger.Warning("PVE staged config apply: %v", err) - } - if err := maybeApplyPVESDNFromStage(ctx, logger, plan, stageRoot, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - logger.Warning("PVE SDN staged apply: %v", err) - } - if err := maybeApplyAccessControlWithUI(ctx, ui, logger, plan, safetyBackup, accessControlRollbackBackup, stageRoot, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - if errors.Is(err, ErrAccessControlApplyNotCommitted) { - var notCommitted *AccessControlApplyNotCommittedError - rollbackLog := "" - rollbackArmed := false - deadline := time.Time{} - if errors.As(err, ¬Committed) && notCommitted != nil { - rollbackLog = strings.TrimSpace(notCommitted.RollbackLog) - rollbackArmed = notCommitted.RollbackArmed - deadline = notCommitted.RollbackDeadline - } - if rollbackArmed { - logger.Warning("Access control apply not committed; rollback is ARMED and will run automatically.") - } else { - logger.Warning("Access control apply not committed; rollback has executed (or marker cleared).") - } - if !deadline.IsZero() { - logger.Info("Rollback deadline: %s", deadline.Format(time.RFC3339)) - } - if rollbackLog != "" { - logger.Info("Rollback log: %s", rollbackLog) - } - } else { - logger.Warning("Access control staged apply: %v", err) - } - } - if err := maybeApplyNotificationsFromStage(ctx, logger, plan, stageRoot, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - logger.Warning("Notifications staged apply: %v", err) - } - } - - if plan.SystemType.SupportsPBS() && plan.HasCategoryID("pbs_notifications") && !pbsServicesStopped { - if err := maybeVerifyAndRepairPBSNotificationsAfterRestore(ctx, logger, plan, stageRoot, cfg.DryRun); err != nil { - restoreHadWarnings = true - logger.Warning("PBS notifications verification/repair: %v", err) - } - } - - stageRootForNetworkApply := stageRoot - if installed, err := maybeInstallNetworkConfigFromStage(ctx, logger, plan, stageRoot, prepared.ArchivePath, networkRollbackBackup, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - logger.Warning("Network staged install: %v", err) - } else if installed { - stageRootForNetworkApply = "" - logging.DebugStep(logger, "restore", "Network staged install completed: configuration written to /etc (no reload); live apply will use system paths") - } - - logger.Info("") - categoriesForDirRecreate := append([]Category{}, plan.NormalCategories...) - categoriesForDirRecreate = append(categoriesForDirRecreate, plan.StagedCategories...) - if shouldRecreateDirectories(systemType, categoriesForDirRecreate) { - if err := RecreateDirectoriesFromConfig(systemType, logger); err != nil { - restoreHadWarnings = true - logger.Warning("Failed to recreate directory structures: %v", err) - logger.Warning("You may need to manually create storage/datastore directories") - } - } else { - logger.Debug("Skipping datastore/storage directory recreation (category not selected)") - } - - logger.Info("") - if plan.HasCategoryID("network") { - logger.Info("") - if err := maybeRepairResolvConfAfterRestore(ctx, logger, prepared.ArchivePath, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - return err - } - restoreHadWarnings = true - logger.Warning("DNS resolver repair: %v", err) - } - } - - logger.Info("") - if err := maybeApplyNetworkConfigWithUI(ctx, ui, logger, plan, safetyBackup, networkRollbackBackup, stageRootForNetworkApply, prepared.ArchivePath, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - logger.Info("Restore aborted by user during network apply prompt.") - return err - } - restoreHadWarnings = true - if errors.Is(err, ErrNetworkApplyNotCommitted) { - var notCommitted *NetworkApplyNotCommittedError - observedIP := "unknown" - originalIP := "unknown" - reconnectHost := "" - rollbackLog := "" - rollbackArmed := false - if errors.As(err, ¬Committed) && notCommitted != nil { - if strings.TrimSpace(notCommitted.RestoredIP) != "" { - observedIP = strings.TrimSpace(notCommitted.RestoredIP) - } - if strings.TrimSpace(notCommitted.OriginalIP) != "" { - originalIP = strings.TrimSpace(notCommitted.OriginalIP) - reconnectHost = originalIP - if i := strings.Index(reconnectHost, ","); i >= 0 { - reconnectHost = reconnectHost[:i] - } - if i := strings.Index(reconnectHost, "/"); i >= 0 { - reconnectHost = reconnectHost[:i] - } - reconnectHost = strings.TrimSpace(reconnectHost) - } - rollbackLog = strings.TrimSpace(notCommitted.RollbackLog) - rollbackArmed = notCommitted.RollbackArmed - lastRestoreAbortInfo = &RestoreAbortInfo{ - NetworkRollbackArmed: rollbackArmed, - NetworkRollbackLog: rollbackLog, - NetworkRollbackMarker: strings.TrimSpace(notCommitted.RollbackMarker), - OriginalIP: notCommitted.OriginalIP, - CurrentIP: observedIP, - RollbackDeadline: notCommitted.RollbackDeadline, - } - } - if rollbackArmed { - logger.Warning("Network apply not committed; rollback is ARMED and will run automatically.") - } else { - logger.Warning("Network apply not committed; rollback has executed (or marker cleared).") - } - if reconnectHost != "" && reconnectHost != "unknown" && originalIP != "unknown" { - logger.Warning("IP now (after apply): %s. Expected after rollback: %s. Reconnect using: %s", observedIP, originalIP, reconnectHost) - } else if originalIP != "unknown" { - logger.Warning("IP now (after apply): %s. Expected after rollback: %s", observedIP, originalIP) - } else { - logger.Warning("IP now (after apply): %s", observedIP) - } - if rollbackLog != "" { - logger.Info("Rollback log: %s", rollbackLog) - } - } else { - logger.Warning("Network apply step skipped or failed: %v", err) - } - } - - logger.Info("") - if err := maybeApplyPVEFirewallWithUI(ctx, ui, logger, plan, safetyBackup, firewallRollbackBackup, stageRoot, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - logger.Info("Restore aborted by user during firewall apply prompt.") - return err - } - restoreHadWarnings = true - if errors.Is(err, ErrFirewallApplyNotCommitted) { - var notCommitted *FirewallApplyNotCommittedError - rollbackLog := "" - rollbackArmed := false - deadline := time.Time{} - if errors.As(err, ¬Committed) && notCommitted != nil { - rollbackLog = strings.TrimSpace(notCommitted.RollbackLog) - rollbackArmed = notCommitted.RollbackArmed - deadline = notCommitted.RollbackDeadline - } - if rollbackArmed { - logger.Warning("Firewall apply not committed; rollback is ARMED and will run automatically.") - } else { - logger.Warning("Firewall apply not committed; rollback has executed (or marker cleared).") - } - if !deadline.IsZero() { - logger.Info("Rollback deadline: %s", deadline.Format(time.RFC3339)) - } - if rollbackLog != "" { - logger.Info("Rollback log: %s", rollbackLog) - } - } else { - logger.Warning("Firewall apply step skipped or failed: %v", err) - } - } - - logger.Info("") - if err := maybeApplyPVEHAWithUI(ctx, ui, logger, plan, safetyBackup, haRollbackBackup, stageRoot, cfg.DryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - logger.Info("Restore aborted by user during HA apply prompt.") - return err - } - restoreHadWarnings = true - if errors.Is(err, ErrHAApplyNotCommitted) { - var notCommitted *HAApplyNotCommittedError - rollbackLog := "" - rollbackArmed := false - deadline := time.Time{} - if errors.As(err, ¬Committed) && notCommitted != nil { - rollbackLog = strings.TrimSpace(notCommitted.RollbackLog) - rollbackArmed = notCommitted.RollbackArmed - deadline = notCommitted.RollbackDeadline - } - if rollbackArmed { - logger.Warning("HA apply not committed; rollback is ARMED and will run automatically.") - } else { - logger.Warning("HA apply not committed; rollback has executed (or marker cleared).") - } - if !deadline.IsZero() { - logger.Info("Rollback deadline: %s", deadline.Format(time.RFC3339)) - } - if rollbackLog != "" { - logger.Info("Rollback log: %s", rollbackLog) - } - } else { - logger.Warning("HA apply step skipped or failed: %v", err) - } - } - - logger.Info("") - if restoreHadWarnings { - logger.Warning("Restore completed with warnings.") - } else { - logger.Info("Restore completed successfully.") - } - logger.Info("Temporary decrypted bundle removed.") - - if detailedLogPath != "" { - logger.Info("Detailed restore log: %s", detailedLogPath) - } - if exportRoot != "" { - logger.Info("Export directory: %s", exportRoot) - } - if exportLogPath != "" { - logger.Info("Export detailed log: %s", exportLogPath) - } - if stageRoot != "" { - logger.Info("Staging directory: %s", stageRoot) - } - if stageLogPath != "" { - logger.Info("Staging detailed log: %s", stageLogPath) - } - - if safetyBackup != nil { - logger.Info("Safety backup preserved at: %s", safetyBackup.BackupPath) - logger.Info("Remove it manually if restore was successful: rm %s", safetyBackup.BackupPath) - } - - logger.Info("") - logger.Info("IMPORTANT: You may need to restart services for changes to take effect.") - switch systemType { - case SystemTypeDual: - if needsClusterRestore && clusterServicesStopped { - logger.Info(" PVE services were stopped/restarted during restore; verify status with: pvecm status") - } else { - logger.Info(" PVE services: systemctl restart pve-cluster pvedaemon pveproxy") - } - if pbsServicesStopped { - logger.Info(" PBS services were stopped/restarted during restore; verify status with: systemctl status proxmox-backup proxmox-backup-proxy") - } else { - logger.Info(" PBS services: systemctl restart proxmox-backup-proxy proxmox-backup") - } - case SystemTypePVE: - if needsClusterRestore && clusterServicesStopped { - logger.Info(" PVE services were stopped/restarted during restore; verify status with: pvecm status") - } else { - logger.Info(" PVE services: systemctl restart pve-cluster pvedaemon pveproxy") - } - case SystemTypePBS: - if pbsServicesStopped { - logger.Info(" PBS services were stopped/restarted during restore; verify status with: systemctl status proxmox-backup proxmox-backup-proxy") - } else { - logger.Info(" PBS services: systemctl restart proxmox-backup-proxy proxmox-backup") - } - } - - if hasCategoryID(plan.NormalCategories, "zfs") { - logger.Info("") - if err := checkZFSPoolsAfterRestore(ctx, logger); err != nil { - logger.Warning("ZFS pool check: %v", err) - } - } else { - logger.Debug("Skipping ZFS pool verification (ZFS category not selected)") - } - - logger.Info("") - logger.Warning("⚠ SYSTEM REBOOT RECOMMENDED") - logger.Info("Reboot the node (or at least restart networking and system services) to ensure all restored configurations take effect cleanly.") - - return nil -} - -func maybeAddRecommendedCategoriesForTFA(ctx context.Context, ui RestoreWorkflowUI, logger *logging.Logger, selected []Category, available []Category) ([]Category, error) { - if ui == nil || logger == nil { - return selected, nil - } - if !hasCategoryID(selected, "pve_access_control") && !hasCategoryID(selected, "pbs_access_control") { - return selected, nil - } - - var missing []string - if !hasCategoryID(selected, "network") { - missing = append(missing, "network") - } - if !hasCategoryID(selected, "ssl") { - missing = append(missing, "ssl") - } - if len(missing) == 0 { - return selected, nil - } - - var addCategories []Category - var addNames []string - for _, id := range missing { - cat := GetCategoryByID(id, available) - if cat == nil || !cat.IsAvailable || cat.ExportOnly { - continue - } - addCategories = append(addCategories, *cat) - addNames = append(addNames, cat.Name) - } - if len(addCategories) == 0 { - return selected, nil - } - - message := fmt.Sprintf( - "You selected Access Control without restoring: %s\n\n"+ - "If TFA includes WebAuthn/FIDO2, changing the UI origin (FQDN/hostname or port) may require re-enrollment.\n\n"+ - "For maximum 1:1 compatibility, ProxSave recommends restoring these categories too.\n\n"+ - "Add recommended categories now?", - strings.Join(addNames, ", "), - ) - addNow, err := ui.ConfirmAction(ctx, "TFA/WebAuthn compatibility", message, "Add recommended", "Keep current", 0, true) - if err != nil { - return nil, err - } - if !addNow { - logger.Warning("Access control selected without %s; WebAuthn users may require re-enrollment if the UI origin changes", strings.Join(addNames, ", ")) - return selected, nil - } - - selected = append(selected, addCategories...) - return dedupeCategoriesByID(selected), nil -} - -func dedupeCategoriesByID(categories []Category) []Category { - if len(categories) == 0 { - return categories - } - seen := make(map[string]struct{}, len(categories)) - out := make([]Category, 0, len(categories)) - for _, cat := range categories { - id := strings.TrimSpace(cat.ID) - if id == "" { - out = append(out, cat) - continue - } - if _, ok := seen[id]; ok { - continue - } - seen[id] = struct{}{} - out = append(out, cat) - } - return out -} - -func runFullRestoreWithUI(ctx context.Context, ui RestoreWorkflowUI, candidate *backupCandidate, prepared *preparedBundle, destRoot string, logger *logging.Logger, dryRun bool) error { - if candidate == nil || prepared == nil || prepared.Manifest.ArchivePath == "" { - return fmt.Errorf("invalid restore candidate") - } - - if err := ui.ShowMessage(ctx, "Full restore", "Backup category analysis failed; ProxSave will run a full restore (no selective modes)."); err != nil { - return err - } - - confirmed, err := ui.ConfirmRestore(ctx) - if err != nil { - return err - } - if !confirmed { - return ErrRestoreAborted - } - - safeFstabMerge := destRoot == "/" && isRealRestoreFS(restoreFS) - skipFn := func(name string) bool { - if !safeFstabMerge { - return false - } - clean := strings.TrimPrefix(strings.TrimSpace(name), "./") - clean = strings.TrimPrefix(clean, "/") - return clean == "etc/fstab" - } - - if safeFstabMerge { - logger.Warning("Full restore safety: /etc/fstab will not be overwritten; Smart Merge will be applied after extraction.") - } - - if err := extractPlainArchive(ctx, prepared.ArchivePath, destRoot, logger, skipFn); err != nil { - return err - } - - if safeFstabMerge { - logger.Info("") - fsTempDir, err := restoreFS.MkdirTemp("", "proxsave-fstab-") - if err != nil { - logger.Warning("Failed to create temp dir for fstab merge: %v", err) - } else { - defer restoreFS.RemoveAll(fsTempDir) - fsCategory := []Category{{ - ID: "filesystem", - Name: "Filesystem Configuration", - Paths: []string{ - "./etc/fstab", - }, - }} - if err := extractArchiveNative(ctx, restoreArchiveOptions{ - archivePath: prepared.ArchivePath, - destRoot: fsTempDir, - logger: logger, - categories: fsCategory, - mode: RestoreModeCustom, - }); err != nil { - logger.Warning("Failed to extract filesystem config for merge: %v", err) - } else { - currentFstab := filepath.Join(destRoot, "etc", "fstab") - backupFstab := filepath.Join(fsTempDir, "etc", "fstab") - if err := smartMergeFstabWithUI(ctx, logger, ui, currentFstab, backupFstab, dryRun); err != nil { - if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { - logger.Info("Restore aborted by user during Smart Filesystem Configuration Merge.") - return err - } - logger.Warning("Smart Fstab Merge failed: %v", err) - } - } - } - } - - logger.Info("Restore completed successfully.") - return nil + workflow := newRestoreUIWorkflowRun(ctx, cfg, logger, version, ui) + return workflow.run() } -func runSafeClusterApplyWithUI(ctx context.Context, ui RestoreWorkflowUI, exportRoot string, logger *logging.Logger, plan *RestorePlan) (err error) { - done := logging.DebugStart(logger, "safe cluster apply (ui)", "export_root=%s", exportRoot) - defer func() { done(err) }() - - if err := ctx.Err(); err != nil { - return err - } - - if ui == nil { - return fmt.Errorf("restore UI not available") - } - - pveshPath, lookErr := exec.LookPath("pvesh") - if lookErr != nil { - logger.Warning("pvesh not found in PATH; skipping SAFE cluster apply") +func normalizeRestoreWorkflowUIError(ctx context.Context, logger *logging.Logger, err error) error { + if err == nil { return nil } - logging.DebugStep(logger, "safe cluster apply (ui)", "pvesh=%s", pveshPath) - - currentNode, _ := os.Hostname() - currentNode = shortHost(currentNode) - if strings.TrimSpace(currentNode) == "" { - currentNode = "localhost" - } - logging.DebugStep(logger, "safe cluster apply (ui)", "current_node=%s", currentNode) - - logger.Info("") - logger.Info("SAFE cluster restore: applying configs via pvesh (node=%s)", currentNode) - - // Datacenter-wide objects (SAFE apply): - // - resource mappings (used by VM configs via mapping=) - // - resource pools (definitions + membership) - if mapErr := maybeApplyPVEClusterResourceMappingsWithUI(ctx, ui, logger, exportRoot); mapErr != nil { - logger.Warning("SAFE apply: resource mappings: %v", mapErr) - } - - pools, poolsErr := readPVEPoolsFromExportUserCfg(exportRoot) - if poolsErr != nil { - logger.Warning("SAFE apply: failed to parse pools from export: %v", poolsErr) - pools = nil - } - applyPools := false - allowPoolMove := false - if len(pools) > 0 { - poolNames := summarizePoolIDs(pools, 10) - message := fmt.Sprintf("Found %d pool(s) in exported user.cfg.\n\nPools: %s\n\nApply pool definitions now? (Membership will be applied later in this SAFE apply flow.)", len(pools), poolNames) - ok, promptErr := ui.ConfirmAction(ctx, "Apply PVE resource pools (merge)", message, "Apply now", "Skip apply", 0, false) - if promptErr != nil { - return promptErr - } - applyPools = ok - logging.DebugStep(logger, "safe cluster apply (ui)", "User choice: apply_pools=%v (pools=%d)", applyPools, len(pools)) - if applyPools { - if anyPoolHasVMs(pools) { - moveMsg := "Allow moving guests from other pools to match the backup? This may change the current pool assignment of existing VMs/CTs." - move, moveErr := ui.ConfirmAction(ctx, "Pools: allow move (VM/CT)", moveMsg, "Allow move", "Don't move", 0, false) - if moveErr != nil { - return moveErr - } - allowPoolMove = move - } - - applied, failed, applyErr := applyPVEPoolsDefinitions(ctx, logger, pools) - if applyErr != nil { - logger.Warning("Pools apply (definitions) encountered errors: %v", applyErr) - } - logger.Info("Pools apply (definitions) completed: ok=%d failed=%d", applied, failed) - } - } - - sourceNode := currentNode - logging.DebugStep(logger, "safe cluster apply (ui)", "List exported node directories under %s", filepath.Join(exportRoot, "etc/pve/nodes")) - exportNodes, nodesErr := listExportNodeDirs(exportRoot) - if nodesErr != nil { - logger.Warning("Failed to inspect exported node directories: %v", nodesErr) - } else if len(exportNodes) > 0 { - logging.DebugStep(logger, "safe cluster apply (ui)", "export_nodes=%s", strings.Join(exportNodes, ",")) - } else { - logging.DebugStep(logger, "safe cluster apply (ui)", "No exported node directories found") - } - - if len(exportNodes) > 0 && !stringSliceContains(exportNodes, sourceNode) { - logging.DebugStep(logger, "safe cluster apply (ui)", "Node mismatch: current_node=%s export_nodes=%s", currentNode, strings.Join(exportNodes, ",")) - logger.Warning("SAFE cluster restore: VM/CT configs not found for current node %s in export; available nodes: %s", currentNode, strings.Join(exportNodes, ", ")) - if len(exportNodes) == 1 { - sourceNode = exportNodes[0] - logging.DebugStep(logger, "safe cluster apply (ui)", "Auto-select source node: %s", sourceNode) - logger.Info("SAFE cluster restore: using exported node %s as VM/CT source, applying to current node %s", sourceNode, currentNode) - } else { - for _, node := range exportNodes { - qemuCount, lxcCount := countVMConfigsForNode(exportRoot, node) - logging.DebugStep(logger, "safe cluster apply (ui)", "Export node candidate: %s (qemu=%d, lxc=%d)", node, qemuCount, lxcCount) - } - selected, selErr := ui.SelectExportNode(ctx, exportRoot, currentNode, exportNodes) - if selErr != nil { - return selErr - } - if strings.TrimSpace(selected) == "" { - logging.DebugStep(logger, "safe cluster apply (ui)", "User selected: skip VM/CT apply (no source node)") - logger.Info("Skipping VM/CT apply (no source node selected)") - sourceNode = "" - } else { - sourceNode = selected - logging.DebugStep(logger, "safe cluster apply (ui)", "User selected source node: %s", sourceNode) - logger.Info("SAFE cluster restore: selected exported node %s as VM/CT source, applying to current node %s", sourceNode, currentNode) - } - } - } - logging.DebugStep(logger, "safe cluster apply (ui)", "Selected VM/CT source node: %q (current_node=%q)", sourceNode, currentNode) - - var vmEntries []vmEntry - if strings.TrimSpace(sourceNode) != "" { - logging.DebugStep(logger, "safe cluster apply (ui)", "Scan VM/CT configs in export (source_node=%s)", sourceNode) - vmEntries, err = scanVMConfigs(exportRoot, sourceNode) - if err != nil { - logger.Warning("Failed to scan VM configs: %v", err) - vmEntries = nil - } else { - logging.DebugStep(logger, "safe cluster apply (ui)", "VM/CT configs found=%d (source_node=%s)", len(vmEntries), sourceNode) - } - } - - if len(vmEntries) > 0 { - applyVMs, promptErr := ui.ConfirmApplyVMConfigs(ctx, sourceNode, currentNode, len(vmEntries)) - if promptErr != nil { - return promptErr - } - logging.DebugStep(logger, "safe cluster apply (ui)", "User choice: apply_vms=%v (entries=%d)", applyVMs, len(vmEntries)) - if applyVMs { - applied, failed := applyVMConfigs(ctx, vmEntries, logger) - logger.Info("VM/CT apply completed: ok=%d failed=%d", applied, failed) - } else { - logger.Info("Skipping VM/CT apply") - } - } else { - if strings.TrimSpace(sourceNode) == "" { - logger.Info("No VM/CT configs applied (no source node selected)") - } else { - logger.Info("No VM/CT configs found for node %s in export", sourceNode) - } - } - - skipStorageDatacenter := plan != nil && plan.HasCategoryID("storage_pve") - if skipStorageDatacenter { - logging.DebugStep(logger, "safe cluster apply (ui)", "Skip storage/datacenter apply: handled by storage_pve staged restore") - logger.Info("Skipping storage/datacenter apply (handled by storage_pve staged restore)") - } else { - storageCfg := filepath.Join(exportRoot, "etc/pve/storage.cfg") - logging.DebugStep(logger, "safe cluster apply (ui)", "Check export: storage.cfg (%s)", storageCfg) - storageInfo, storageErr := restoreFS.Stat(storageCfg) - if storageErr == nil && !storageInfo.IsDir() { - logging.DebugStep(logger, "safe cluster apply (ui)", "storage.cfg found (size=%d)", storageInfo.Size()) - applyStorage, promptErr := ui.ConfirmApplyStorageCfg(ctx, storageCfg) - if promptErr != nil { - return promptErr - } - logging.DebugStep(logger, "safe cluster apply (ui)", "User choice: apply_storage=%v", applyStorage) - if applyStorage { - applied, failed, err := applyStorageCfg(ctx, storageCfg, logger) - logging.DebugStep(logger, "safe cluster apply (ui)", "Storage apply result: ok=%d failed=%d err=%v", applied, failed, err) - if err != nil { - logger.Warning("Storage apply encountered errors: %v", err) - } - logger.Info("Storage apply completed: ok=%d failed=%d", applied, failed) - } else { - logger.Info("Skipping storage.cfg apply") - } - } else { - logging.DebugStep(logger, "safe cluster apply (ui)", "storage.cfg not found (err=%v)", storageErr) - logger.Info("No storage.cfg found in export") - } - - dcCfg := filepath.Join(exportRoot, "etc/pve/datacenter.cfg") - logging.DebugStep(logger, "safe cluster apply (ui)", "Check export: datacenter.cfg (%s)", dcCfg) - dcInfo, dcErr := restoreFS.Stat(dcCfg) - if dcErr == nil && !dcInfo.IsDir() { - logging.DebugStep(logger, "safe cluster apply (ui)", "datacenter.cfg found (size=%d)", dcInfo.Size()) - applyDC, promptErr := ui.ConfirmApplyDatacenterCfg(ctx, dcCfg) - if promptErr != nil { - return promptErr - } - logging.DebugStep(logger, "safe cluster apply (ui)", "User choice: apply_datacenter=%v", applyDC) - if applyDC { - logging.DebugStep(logger, "safe cluster apply (ui)", "Apply datacenter.cfg via pvesh") - if err := runPvesh(ctx, logger, []string{"set", "/cluster/config", "-conf", dcCfg}); err != nil { - logger.Warning("Failed to apply datacenter.cfg: %v", err) - } else { - logger.Info("datacenter.cfg applied successfully") - } - } else { - logger.Info("Skipping datacenter.cfg apply") - } - } else { - logging.DebugStep(logger, "safe cluster apply (ui)", "datacenter.cfg not found (err=%v)", dcErr) - logger.Info("No datacenter.cfg found in export") - } + if err == io.EOF { + logger.Warning("Restore input closed unexpectedly (EOF). This usually means the interactive UI lost access to stdin/TTY (e.g., SSH disconnect or non-interactive execution). Re-run with --restore --cli from an interactive shell.") + return ErrRestoreAborted } - - // Apply pool membership after VM configs and storage/datacenter apply. - if applyPools && len(pools) > 0 { - applied, failed, applyErr := applyPVEPoolsMembership(ctx, logger, pools, allowPoolMove) - if applyErr != nil { - logger.Warning("Pools apply (membership) encountered errors: %v", applyErr) - } - logger.Info("Pools apply (membership) completed: ok=%d failed=%d", applied, failed) + if restoreWorkflowInputAborted(ctx, err) { + return ErrRestoreAborted } - - return nil + return err } -func smartMergeFstabWithUI(ctx context.Context, logger *logging.Logger, ui RestoreWorkflowUI, currentFstabPath, backupFstabPath string, dryRun bool) error { - if logger == nil { - logger = logging.GetDefaultLogger() - } - logger.Info("") - logger.Step("Smart Filesystem Configuration Merge") - logger.Debug("[FSTAB_MERGE] Starting analysis of %s vs backup %s...", currentFstabPath, backupFstabPath) - - currentEntries, currentRaw, err := parseFstab(currentFstabPath) - if err != nil { - return fmt.Errorf("failed to parse current fstab: %w", err) - } - backupEntries, _, err := parseFstab(backupFstabPath) - if err != nil { - return fmt.Errorf("failed to parse backup fstab: %w", err) - } - - remappedCount := 0 - backupRoot := fstabBackupRootFromPath(backupFstabPath) - if backupRoot != "" { - if remapped, count := remapFstabDevicesFromInventory(logger, backupEntries, backupRoot); count > 0 { - backupEntries = remapped - remappedCount = count - logger.Info("Fstab device remap: converted %d entry(ies) from /dev/* to stable UUID/PARTUUID/LABEL based on ProxSave inventory", count) - } else { - backupEntries = remapped - } - } - - analysis := analyzeFstabMerge(logger, currentEntries, backupEntries) - if len(analysis.ProposedMounts) == 0 { - logger.Info("No new safe mounts found to restore. Keeping current fstab.") - return nil - } - - defaultYes := analysis.RootComparable && analysis.RootMatch && (!analysis.SwapComparable || analysis.SwapMatch) - - var msg strings.Builder - msg.WriteString("ProxSave found missing mounts in /etc/fstab.\n\n") - if analysis.RootComparable && !analysis.RootMatch { - msg.WriteString("⚠ Root UUID mismatch: the backup appears to come from a different machine.\n") - } - if analysis.SwapComparable && !analysis.SwapMatch { - msg.WriteString("⚠ Swap mismatch: the current swap configuration will be kept.\n") - } - if remappedCount > 0 { - fmt.Fprintf(&msg, "✓ Remapped %d fstab entry(ies) from /dev/* to stable UUID/PARTUUID/LABEL using ProxSave inventory.\n", remappedCount) - } - msg.WriteString("\nProposed mounts (safe):\n") - for _, mount := range analysis.ProposedMounts { - fmt.Fprintf(&msg, " - %s -> %s (%s)\n", mount.Device, mount.MountPoint, mount.Type) - } - if len(analysis.SkippedMounts) > 0 { - msg.WriteString("\nMounts found but not auto-proposed:\n") - for _, mount := range analysis.SkippedMounts { - fmt.Fprintf(&msg, " - %s -> %s (%s)\n", mount.Device, mount.MountPoint, mount.Type) - } - msg.WriteString("\nHint: verify disks/UUIDs and options (nofail/_netdev) before adding them.\n") - } - - confirmMsg := "Do you want to add the missing mounts (NFS/CIFS and data mounts with verified UUID/LABEL)?" - if strings.TrimSpace(confirmMsg) != "" { - msg.WriteString("\n") - msg.WriteString(confirmMsg) - } - - confirmed, err := ui.ConfirmFstabMerge(ctx, "Smart fstab merge", msg.String(), 90*time.Second, defaultYes) - if err != nil { - return err - } - if !confirmed { - logger.Info("Fstab merge skipped by user.") - return nil - } - - return applyFstabMerge(ctx, logger, currentRaw, currentFstabPath, analysis.ProposedMounts, dryRun) +func restoreWorkflowInputAborted(ctx context.Context, err error) bool { + return errors.Is(err, input.ErrInputAborted) || + errors.Is(err, ErrDecryptAborted) || + errors.Is(err, ErrAgeRecipientSetupAborted) || + errors.Is(err, context.Canceled) || + (ctx != nil && ctx.Err() != nil) } diff --git a/internal/orchestrator/restore_workflow_ui_apply.go b/internal/orchestrator/restore_workflow_ui_apply.go new file mode 100644 index 00000000..7aaeb31b --- /dev/null +++ b/internal/orchestrator/restore_workflow_ui_apply.go @@ -0,0 +1,313 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "errors" + "strings" + "time" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func (w *restoreUIWorkflowRun) applyAccessControlFromStage() error { + return maybeApplyAccessControlWithUI(w.ctx, w.ui, w.logger, w.plan, w.safetyBackup, w.accessControlRollbackBackup, w.stageRoot, w.cfg.DryRun) +} + +func (w *restoreUIWorkflowRun) logAccessControlNotCommitted(err error) { + var notCommitted *AccessControlApplyNotCommittedError + rollbackLog := "" + rollbackArmed := false + deadline := time.Time{} + if errors.As(err, ¬Committed) && notCommitted != nil { + rollbackLog = strings.TrimSpace(notCommitted.RollbackLog) + rollbackArmed = notCommitted.RollbackArmed + deadline = notCommitted.RollbackDeadline + } + w.logGenericRollbackNotCommitted("Access control apply", rollbackArmed, deadline, rollbackLog) +} + +func (w *restoreUIWorkflowRun) verifyPBSNotificationsAfterRestore() { + if !w.plan.SystemType.SupportsPBS() || !w.plan.HasCategoryID("pbs_notifications") || w.pbsServicesStopped { + return + } + if err := maybeVerifyAndRepairPBSNotificationsAfterRestore(w.ctx, w.logger, w.plan, w.stageRoot, w.cfg.DryRun); err != nil { + w.restoreHadWarnings = true + w.logger.Warning("PBS notifications verification/repair: %v", err) + } +} + +func (w *restoreUIWorkflowRun) installNetworkConfigFromStage() error { + w.stageRootForNetworkApply = w.stageRoot + installed, err := maybeInstallNetworkConfigFromStage(w.ctx, w.logger, w.plan, w.stageRoot, w.prepared.ArchivePath, w.networkRollbackBackup, w.cfg.DryRun) + if err != nil { + if restoreAbortOrInput(err) { + return err + } + w.restoreHadWarnings = true + w.logger.Warning("Network staged install: %v", err) + return nil + } + if installed { + w.stageRootForNetworkApply = "" + logging.DebugStep(w.logger, "restore", "Network staged install completed: configuration written to /etc (no reload); live apply will use system paths") + } + return nil +} + +func (w *restoreUIWorkflowRun) recreateStorageDirectories() { + w.logger.Info("") + categories := append([]Category{}, w.plan.NormalCategories...) + categories = append(categories, w.plan.StagedCategories...) + if !shouldRecreateDirectories(w.systemType, categories) { + w.logger.Debug("Skipping datastore/storage directory recreation (category not selected)") + return + } + if err := RecreateDirectoriesFromConfig(w.systemType, w.logger); err != nil { + w.restoreHadWarnings = true + w.logger.Warning("Failed to recreate directory structures: %v", err) + w.logger.Warning("You may need to manually create storage/datastore directories") + } +} + +func (w *restoreUIWorkflowRun) repairDNSAfterRestore() error { + w.logger.Info("") + if !w.plan.HasCategoryID("network") { + return nil + } + w.logger.Info("") + err := maybeRepairResolvConfAfterRestore(w.ctx, w.logger, w.prepared.ArchivePath, w.cfg.DryRun) + if err == nil { + return nil + } + if restoreAbortOrInput(err) { + return err + } + w.restoreHadWarnings = true + w.logger.Warning("DNS resolver repair: %v", err) + return nil +} + +func (w *restoreUIWorkflowRun) applyNetworkConfig() error { + w.logger.Info("") + err := maybeApplyNetworkConfigWithUI(w.ctx, w.ui, w.logger, networkConfigUIApplyRequest{ + plan: w.plan, + safetyBackup: w.safetyBackup, + networkRollbackBackup: w.networkRollbackBackup, + stageRoot: w.stageRootForNetworkApply, + archivePath: w.prepared.ArchivePath, + dryRun: w.cfg.DryRun, + }) + if err == nil { + return nil + } + if restoreAbortOrInput(err) { + w.logger.Info("Restore aborted by user during network apply prompt.") + return err + } + w.restoreHadWarnings = true + w.logNetworkApplyError(err) + return nil +} + +func (w *restoreUIWorkflowRun) logNetworkApplyError(err error) { + if !errors.Is(err, ErrNetworkApplyNotCommitted) { + w.logger.Warning("Network apply step skipped or failed: %v", err) + return + } + var notCommitted *NetworkApplyNotCommittedError + if !errors.As(err, ¬Committed) || notCommitted == nil { + w.logger.Warning("Network apply not committed; rollback state unknown.") + return + } + w.saveNetworkAbortInfo(notCommitted) + observedIP, originalIP := networkNotCommittedIPs(notCommitted) + reconnectHost := reconnectHostFromOriginalIP(originalIP) + w.logNetworkRollbackState(notCommitted.RollbackArmed, observedIP, originalIP, reconnectHost) + if rollbackLog := strings.TrimSpace(notCommitted.RollbackLog); rollbackLog != "" { + w.logger.Info("Rollback log: %s", rollbackLog) + } +} + +func (w *restoreUIWorkflowRun) saveNetworkAbortInfo(notCommitted *NetworkApplyNotCommittedError) { + lastRestoreAbortInfo = &RestoreAbortInfo{ + NetworkRollbackArmed: notCommitted.RollbackArmed, + NetworkRollbackLog: strings.TrimSpace(notCommitted.RollbackLog), + NetworkRollbackMarker: strings.TrimSpace(notCommitted.RollbackMarker), + OriginalIP: notCommitted.OriginalIP, + CurrentIP: strings.TrimSpace(notCommitted.RestoredIP), + RollbackDeadline: notCommitted.RollbackDeadline, + } +} + +func (w *restoreUIWorkflowRun) logNetworkRollbackState(armed bool, observedIP, originalIP, reconnectHost string) { + if armed { + w.logger.Warning("Network apply not committed; rollback is ARMED and will run automatically.") + } else { + w.logger.Warning("Network apply not committed; rollback has executed (or marker cleared).") + } + if reconnectHost != "" && reconnectHost != "unknown" && originalIP != "unknown" { + w.logger.Warning("IP now (after apply): %s. Expected after rollback: %s. Reconnect using: %s", observedIP, originalIP, reconnectHost) + } else if originalIP != "unknown" { + w.logger.Warning("IP now (after apply): %s. Expected after rollback: %s", observedIP, originalIP) + } else { + w.logger.Warning("IP now (after apply): %s", observedIP) + } +} + +func (w *restoreUIWorkflowRun) applyFirewallConfig() error { + w.logger.Info("") + err := maybeApplyPVEFirewallWithUI(w.ctx, w.ui, w.logger, w.plan, w.safetyBackup, w.firewallRollbackBackup, w.stageRoot, w.cfg.DryRun) + if err == nil { + return nil + } + if restoreAbortOrInput(err) { + w.logger.Info("Restore aborted by user during firewall apply prompt.") + return err + } + w.restoreHadWarnings = true + w.logFirewallApplyError(err) + return nil +} + +func (w *restoreUIWorkflowRun) logFirewallApplyError(err error) { + if !errors.Is(err, ErrFirewallApplyNotCommitted) { + w.logger.Warning("Firewall apply step skipped or failed: %v", err) + return + } + armed, deadline, rollbackLog := firewallRollbackSummary(err) + w.logGenericRollbackNotCommitted("Firewall apply", armed, deadline, rollbackLog) +} + +func firewallRollbackSummary(err error) (bool, time.Time, string) { + var notCommitted *FirewallApplyNotCommittedError + if errors.As(err, ¬Committed) && notCommitted != nil { + return notCommitted.RollbackArmed, notCommitted.RollbackDeadline, strings.TrimSpace(notCommitted.RollbackLog) + } + return false, time.Time{}, "" +} + +func (w *restoreUIWorkflowRun) applyHAConfig() error { + w.logger.Info("") + err := maybeApplyPVEHAWithUI(w.ctx, w.ui, w.logger, w.plan, w.safetyBackup, w.haRollbackBackup, w.stageRoot, w.cfg.DryRun) + if err == nil { + return nil + } + if restoreAbortOrInput(err) { + w.logger.Info("Restore aborted by user during HA apply prompt.") + return err + } + w.restoreHadWarnings = true + w.logHAApplyError(err) + return nil +} + +func (w *restoreUIWorkflowRun) logHAApplyError(err error) { + if !errors.Is(err, ErrHAApplyNotCommitted) { + w.logger.Warning("HA apply step skipped or failed: %v", err) + return + } + armed, deadline, rollbackLog := haRollbackSummary(err) + w.logGenericRollbackNotCommitted("HA apply", armed, deadline, rollbackLog) +} + +func haRollbackSummary(err error) (bool, time.Time, string) { + var notCommitted *HAApplyNotCommittedError + if errors.As(err, ¬Committed) && notCommitted != nil { + return notCommitted.RollbackArmed, notCommitted.RollbackDeadline, strings.TrimSpace(notCommitted.RollbackLog) + } + return false, time.Time{}, "" +} + +func (w *restoreUIWorkflowRun) logGenericRollbackNotCommitted(label string, armed bool, deadline time.Time, rollbackLog string) { + if armed { + w.logger.Warning("%s not committed; rollback is ARMED and will run automatically.", label) + } else { + w.logger.Warning("%s not committed; rollback has executed (or marker cleared).", label) + } + if !deadline.IsZero() { + w.logger.Info("Rollback deadline: %s", deadline.Format(time.RFC3339)) + } + if rollbackLog != "" { + w.logger.Info("Rollback log: %s", rollbackLog) + } +} + +func (w *restoreUIWorkflowRun) logRestoreCompletion() { + w.logger.Info("") + if w.restoreHadWarnings { + w.logger.Warning("Restore completed with warnings.") + } else { + w.logger.Info("Restore completed successfully.") + } + w.logger.Info("Temporary decrypted bundle removed.") + w.logRestoreArtifacts() +} + +func (w *restoreUIWorkflowRun) logRestoreArtifacts() { + if w.detailedLogPath != "" { + w.logger.Info("Detailed restore log: %s", w.detailedLogPath) + } + if w.exportRoot != "" { + w.logger.Info("Export directory: %s", w.exportRoot) + } + if w.exportLogPath != "" { + w.logger.Info("Export detailed log: %s", w.exportLogPath) + } + if w.stageRoot != "" { + w.logger.Info("Staging directory: %s", w.stageRoot) + } + if w.stageLogPath != "" { + w.logger.Info("Staging detailed log: %s", w.stageLogPath) + } + if w.safetyBackup != nil { + w.logger.Info("Safety backup preserved at: %s", w.safetyBackup.BackupPath) + w.logger.Info("Remove it manually if restore was successful: rm %s", w.safetyBackup.BackupPath) + } +} + +func (w *restoreUIWorkflowRun) logServiceRestartAdvice() { + w.logger.Info("") + w.logger.Info("IMPORTANT: You may need to restart services for changes to take effect.") + switch w.systemType { + case SystemTypeDual: + w.logPVERestartAdvice() + w.logPBSRestartAdvice() + case SystemTypePVE: + w.logPVERestartAdvice() + case SystemTypePBS: + w.logPBSRestartAdvice() + } +} + +func (w *restoreUIWorkflowRun) logPVERestartAdvice() { + if w.needsClusterRestore && w.clusterServicesStopped { + w.logger.Info(" PVE services were stopped/restarted during restore; verify status with: pvecm status") + return + } + w.logger.Info(" PVE services: systemctl restart pve-cluster pvedaemon pveproxy") +} + +func (w *restoreUIWorkflowRun) logPBSRestartAdvice() { + if w.pbsServicesStopped { + w.logger.Info(" PBS services were stopped/restarted during restore; verify status with: systemctl status proxmox-backup proxmox-backup-proxy") + return + } + w.logger.Info(" PBS services: systemctl restart proxmox-backup-proxy proxmox-backup") +} + +func (w *restoreUIWorkflowRun) checkZFSPoolsAfterRestore() { + if hasCategoryID(w.plan.NormalCategories, "zfs") { + w.logger.Info("") + if err := checkZFSPoolsAfterRestore(w.ctx, w.logger); err != nil { + w.logger.Warning("ZFS pool check: %v", err) + } + return + } + w.logger.Debug("Skipping ZFS pool verification (ZFS category not selected)") +} + +func (w *restoreUIWorkflowRun) logRebootRecommendation() { + w.logger.Info("") + w.logger.Warning("⚠ SYSTEM REBOOT RECOMMENDED") + w.logger.Info("Reboot the node (or at least restart networking and system services) to ensure all restored configurations take effect cleanly.") +} diff --git a/internal/orchestrator/restore_workflow_ui_backups_services.go b/internal/orchestrator/restore_workflow_ui_backups_services.go new file mode 100644 index 00000000..f0f2702c --- /dev/null +++ b/internal/orchestrator/restore_workflow_ui_backups_services.go @@ -0,0 +1,216 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "context" + "strings" + "time" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func (w *restoreUIWorkflowRun) systemWriteCategories() []Category { + categories := append([]Category{}, w.plan.NormalCategories...) + return append(categories, w.plan.StagedCategories...) +} + +func (w *restoreUIWorkflowRun) createRollbackBackups() error { + systemWriteCategories := w.systemWriteCategories() + if err := w.createSafetyBackup(systemWriteCategories); err != nil { + return err + } + w.createNetworkRollbackBackup(systemWriteCategories) + w.createFirewallRollbackBackup(systemWriteCategories) + w.createHARollbackBackup(systemWriteCategories) + w.createAccessControlRollbackBackup(systemWriteCategories) + return nil +} + +func (w *restoreUIWorkflowRun) createSafetyBackup(categories []Category) error { + if len(categories) == 0 { + return nil + } + w.logger.Info("") + backup, err := CreateSafetyBackup(w.logger, categories, w.destRoot) + if err != nil { + w.logger.Warning("Failed to create safety backup: %v", err) + cont, promptErr := w.ui.ConfirmContinueWithoutSafetyBackup(w.ctx, err) + if promptErr != nil { + return promptErr + } + if !cont { + return ErrRestoreAborted + } + return nil + } + w.safetyBackup = backup + w.logger.Info("Safety backup location: %s", backup.BackupPath) + w.logger.Info("You can restore from this backup if needed using: tar -xzf %s -C /", backup.BackupPath) + return nil +} + +func (w *restoreUIWorkflowRun) createNetworkRollbackBackup(categories []Category) { + if !w.plan.HasCategoryID("network") { + return + } + w.logger.Info("") + logging.DebugStep(w.logger, "restore", "Create network-only rollback backup for transactional network apply") + backup, err := CreateNetworkRollbackBackup(w.logger, categories, w.destRoot) + if err != nil { + w.logger.Warning("Failed to create network rollback backup: %v", err) + return + } + w.networkRollbackBackup = backup + if backup != nil && strings.TrimSpace(backup.BackupPath) != "" { + w.logger.Info("Network rollback backup location: %s", backup.BackupPath) + w.logger.Info("This backup is used for the %ds network rollback timer and only includes network paths.", int(defaultNetworkRollbackTimeout.Seconds())) + } +} + +func (w *restoreUIWorkflowRun) createFirewallRollbackBackup(categories []Category) { + if !w.plan.HasCategoryID("pve_firewall") { + return + } + w.logger.Info("") + logging.DebugStep(w.logger, "restore", "Create firewall-only rollback backup for transactional firewall apply") + backup, err := CreateFirewallRollbackBackup(w.logger, categories, w.destRoot) + if err != nil { + w.logger.Warning("Failed to create firewall rollback backup: %v", err) + return + } + w.firewallRollbackBackup = backup + if backup != nil && strings.TrimSpace(backup.BackupPath) != "" { + w.logger.Info("Firewall rollback backup location: %s", backup.BackupPath) + w.logger.Info("This backup is used for the %ds firewall rollback timer and only includes firewall paths.", int(defaultFirewallRollbackTimeout.Seconds())) + } +} + +func (w *restoreUIWorkflowRun) createHARollbackBackup(categories []Category) { + if !w.plan.HasCategoryID("pve_ha") { + return + } + w.logger.Info("") + logging.DebugStep(w.logger, "restore", "Create HA-only rollback backup for transactional HA apply") + backup, err := CreateHARollbackBackup(w.logger, categories, w.destRoot) + if err != nil { + w.logger.Warning("Failed to create HA rollback backup: %v", err) + return + } + w.haRollbackBackup = backup + if backup != nil && strings.TrimSpace(backup.BackupPath) != "" { + w.logger.Info("HA rollback backup location: %s", backup.BackupPath) + w.logger.Info("This backup is used for the %ds HA rollback timer and only includes HA paths.", int(defaultHARollbackTimeout.Seconds())) + } +} + +func (w *restoreUIWorkflowRun) createAccessControlRollbackBackup(categories []Category) { + if !w.shouldCreateAccessControlRollbackBackup() { + return + } + w.logger.Info("") + logging.DebugStep(w.logger, "restore", "Create access-control-only rollback backup for optional cluster-safe access control apply") + backup, err := CreatePVEAccessControlRollbackBackup(w.logger, categories, w.destRoot) + if err != nil { + w.logger.Warning("Failed to create access control rollback backup: %v", err) + return + } + w.accessControlRollbackBackup = backup + if backup != nil && strings.TrimSpace(backup.BackupPath) != "" { + w.logger.Info("Access control rollback backup location: %s", backup.BackupPath) + w.logger.Info("This backup is used for the %ds access control rollback timer and only includes access control paths.", int(defaultAccessControlRollbackTimeout.Seconds())) + } +} + +func (w *restoreUIWorkflowRun) shouldCreateAccessControlRollbackBackup() bool { + return w.plan.SystemType.SupportsPVE() && + w.plan.ClusterBackup && + !w.plan.NeedsClusterRestore && + w.plan.HasCategoryID("pve_access_control") +} + +func (w *restoreUIWorkflowRun) prepareRestoreServices() (func(), error) { + var cleanups []func() + if cleanup, err := w.preparePVEClusterRestore(); err != nil { + return nil, err + } else if cleanup != nil { + cleanups = append(cleanups, cleanup) + } + if cleanup, err := w.preparePBSServices(); err != nil { + return nil, err + } else if cleanup != nil { + cleanups = append(cleanups, cleanup) + } + return func() { + for i := len(cleanups) - 1; i >= 0; i-- { + cleanups[i]() + } + }, nil +} + +func (w *restoreUIWorkflowRun) preparePVEClusterRestore() (func(), error) { + w.needsClusterRestore = w.plan.NeedsClusterRestore + if !w.needsClusterRestore { + return nil, nil + } + w.logger.Info("") + w.logger.Info("Preparing system for cluster database restore: stopping PVE services and unmounting /etc/pve") + if err := stopPVEClusterServices(w.ctx, w.logger); err != nil { + return nil, err + } + w.clusterServicesStopped = true + if err := unmountEtcPVE(w.ctx, w.logger); err != nil { + w.logger.Warning("Could not unmount /etc/pve: %v", err) + } + return w.restartPVEClusterServicesCleanup(), nil +} + +func (w *restoreUIWorkflowRun) restartPVEClusterServicesCleanup() func() { + return func() { + restartCtx, cancel := context.WithTimeout(context.Background(), 2*serviceStartTimeout+2*serviceVerifyTimeout+10*time.Second) + defer cancel() + if err := startPVEClusterServices(restartCtx, w.logger); err != nil { + w.logger.Warning("Failed to restart PVE services after restore: %v", err) + } + } +} + +func (w *restoreUIWorkflowRun) preparePBSServices() (func(), error) { + w.needsPBSServices = w.plan.NeedsPBSServices + if !w.needsPBSServices { + return nil, nil + } + w.logger.Info("") + w.logger.Info("Preparing PBS system for restore: stopping proxmox-backup services") + if err := stopPBSServices(w.ctx, w.logger); err != nil { + return w.confirmContinueWithPBSServicesRunning(err) + } + w.pbsServicesStopped = true + return w.restartPBSServicesCleanup(), nil +} + +func (w *restoreUIWorkflowRun) confirmContinueWithPBSServicesRunning(stopErr error) (func(), error) { + w.logger.Warning("Unable to stop PBS services automatically: %v", stopErr) + cont, err := w.ui.ConfirmContinueWithPBSServicesRunning(w.ctx) + if err != nil { + return nil, err + } + if !cont { + return nil, ErrRestoreAborted + } + w.logger.Warning("Continuing restore with PBS services still running") + return nil, nil +} + +func (w *restoreUIWorkflowRun) restartPBSServicesCleanup() func() { + return func() { + restartCtx, cancel := context.WithTimeout(context.Background(), 2*serviceStartTimeout+2*serviceVerifyTimeout+10*time.Second) + defer cancel() + if err := startPBSServices(restartCtx, w.logger); err != nil { + w.logger.Warning("Failed to restart PBS services after restore: %v", err) + return + } + if err := maybeVerifyAndRepairPBSNotificationsAfterRestore(restartCtx, w.logger, w.plan, w.stageRoot, w.cfg.DryRun); err != nil { + w.logger.Warning("PBS notifications verification/repair: %v", err) + } + } +} diff --git a/internal/orchestrator/restore_workflow_ui_cluster_apply.go b/internal/orchestrator/restore_workflow_ui_cluster_apply.go new file mode 100644 index 00000000..f19658f8 --- /dev/null +++ b/internal/orchestrator/restore_workflow_ui_cluster_apply.go @@ -0,0 +1,352 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/tis24dev/proxsave/internal/logging" +) + +type safeClusterApplyUIFlow struct { + ctx context.Context + ui RestoreWorkflowUI + exportRoot string + logger *logging.Logger + plan *RestorePlan + currentNode string + sourceNode string + pools []pvePoolSpec + applyPools bool + allowPoolMove bool +} + +func runSafeClusterApplyWithUI(ctx context.Context, ui RestoreWorkflowUI, exportRoot string, logger *logging.Logger, plan *RestorePlan) (err error) { + done := logging.DebugStart(logger, "safe cluster apply (ui)", "export_root=%s", exportRoot) + defer func() { done(err) }() + + flow := &safeClusterApplyUIFlow{ + ctx: ctx, + ui: ui, + exportRoot: exportRoot, + logger: logger, + plan: plan, + } + err = flow.run() + if errors.Is(err, errSafeClusterApplySkipped) { + return nil + } + return err +} + +func (f *safeClusterApplyUIFlow) run() error { + if err := f.validate(); err != nil { + return err + } + f.detectCurrentNode() + f.logger.Info("") + f.logger.Info("SAFE cluster restore: applying configs via pvesh (node=%s)", f.currentNode) + f.applyResourceMappings() + if err := f.preparePools(); err != nil { + return err + } + if err := f.selectVMSourceNode(); err != nil { + return err + } + if err := f.applyVMConfigsFromExport(); err != nil { + return err + } + if err := f.applyStorageAndDatacenter(); err != nil { + return err + } + f.applyPoolMembership() + return nil +} + +func (f *safeClusterApplyUIFlow) validate() error { + if err := f.ctx.Err(); err != nil { + return err + } + if f.ui == nil { + return fmt.Errorf("restore UI not available") + } + pveshPath, err := exec.LookPath("pvesh") + if err != nil { + f.logger.Warning("pvesh not found in PATH; skipping SAFE cluster apply") + return errSafeClusterApplySkipped + } + logging.DebugStep(f.logger, "safe cluster apply (ui)", "pvesh=%s", pveshPath) + return nil +} + +var errSafeClusterApplySkipped = fmt.Errorf("safe cluster apply skipped") + +func (f *safeClusterApplyUIFlow) detectCurrentNode() { + currentNode, _ := os.Hostname() + f.currentNode = shortHost(currentNode) + if strings.TrimSpace(f.currentNode) == "" { + f.currentNode = "localhost" + } + f.sourceNode = f.currentNode + logging.DebugStep(f.logger, "safe cluster apply (ui)", "current_node=%s", f.currentNode) +} + +func (f *safeClusterApplyUIFlow) applyResourceMappings() { + if err := maybeApplyPVEClusterResourceMappingsWithUI(f.ctx, f.ui, f.logger, f.exportRoot); err != nil { + f.logger.Warning("SAFE apply: resource mappings: %v", err) + } +} + +func (f *safeClusterApplyUIFlow) preparePools() error { + if err := f.loadPools(); err != nil { + return nil + } + if len(f.pools) == 0 { + return nil + } + if err := f.confirmPoolDefinitions(); err != nil { + return err + } + if f.applyPools { + return f.applyPoolDefinitions() + } + return nil +} + +func (f *safeClusterApplyUIFlow) loadPools() error { + pools, err := readPVEPoolsFromExportUserCfg(f.exportRoot) + if err != nil { + f.logger.Warning("SAFE apply: failed to parse pools from export: %v", err) + f.pools = nil + return err + } + f.pools = pools + return nil +} + +func (f *safeClusterApplyUIFlow) confirmPoolDefinitions() error { + poolNames := summarizePoolIDs(f.pools, 10) + message := fmt.Sprintf("Found %d pool(s) in exported user.cfg.\n\nPools: %s\n\nApply pool definitions now? (Membership will be applied later in this SAFE apply flow.)", len(f.pools), poolNames) + ok, err := f.ui.ConfirmAction(f.ctx, "Apply PVE resource pools (merge)", message, "Apply now", "Skip apply", 0, false) + if err != nil { + return err + } + f.applyPools = ok + logging.DebugStep(f.logger, "safe cluster apply (ui)", "User choice: apply_pools=%v (pools=%d)", f.applyPools, len(f.pools)) + return nil +} + +func (f *safeClusterApplyUIFlow) applyPoolDefinitions() error { + if anyPoolHasVMs(f.pools) { + if err := f.confirmAllowPoolMove(); err != nil { + return err + } + } + applied, failed, err := applyPVEPoolsDefinitions(f.ctx, f.logger, f.pools) + if err != nil { + f.logger.Warning("Pools apply (definitions) encountered errors: %v", err) + } + f.logger.Info("Pools apply (definitions) completed: ok=%d failed=%d", applied, failed) + return nil +} + +func (f *safeClusterApplyUIFlow) confirmAllowPoolMove() error { + moveMsg := "Allow moving guests from other pools to match the backup? This may change the current pool assignment of existing VMs/CTs." + move, err := f.ui.ConfirmAction(f.ctx, "Pools: allow move (VM/CT)", moveMsg, "Allow move", "Don't move", 0, false) + if err != nil { + return err + } + f.allowPoolMove = move + return nil +} + +func (f *safeClusterApplyUIFlow) selectVMSourceNode() error { + exportNodes, err := f.listExportNodes() + if err != nil || len(exportNodes) == 0 || stringSliceContains(exportNodes, f.sourceNode) { + return nil + } + return f.resolveSourceNodeMismatch(exportNodes) +} + +func (f *safeClusterApplyUIFlow) listExportNodes() ([]string, error) { + logging.DebugStep(f.logger, "safe cluster apply (ui)", "List exported node directories under %s", filepath.Join(f.exportRoot, "etc/pve/nodes")) + exportNodes, err := listExportNodeDirs(f.exportRoot) + if err != nil { + f.logger.Warning("Failed to inspect exported node directories: %v", err) + return nil, err + } + if len(exportNodes) > 0 { + logging.DebugStep(f.logger, "safe cluster apply (ui)", "export_nodes=%s", strings.Join(exportNodes, ",")) + } else { + logging.DebugStep(f.logger, "safe cluster apply (ui)", "No exported node directories found") + } + return exportNodes, nil +} + +func (f *safeClusterApplyUIFlow) resolveSourceNodeMismatch(exportNodes []string) error { + logging.DebugStep(f.logger, "safe cluster apply (ui)", "Node mismatch: current_node=%s export_nodes=%s", f.currentNode, strings.Join(exportNodes, ",")) + f.logger.Warning("SAFE cluster restore: VM/CT configs not found for current node %s in export; available nodes: %s", f.currentNode, strings.Join(exportNodes, ", ")) + if len(exportNodes) == 1 { + f.sourceNode = exportNodes[0] + logging.DebugStep(f.logger, "safe cluster apply (ui)", "Auto-select source node: %s", f.sourceNode) + f.logger.Info("SAFE cluster restore: using exported node %s as VM/CT source, applying to current node %s", f.sourceNode, f.currentNode) + return nil + } + return f.promptSourceNode(exportNodes) +} + +func (f *safeClusterApplyUIFlow) promptSourceNode(exportNodes []string) error { + f.logExportNodeCandidates(exportNodes) + selected, err := f.ui.SelectExportNode(f.ctx, f.exportRoot, f.currentNode, exportNodes) + if err != nil { + return err + } + if strings.TrimSpace(selected) == "" { + logging.DebugStep(f.logger, "safe cluster apply (ui)", "User selected: skip VM/CT apply (no source node)") + f.logger.Info("Skipping VM/CT apply (no source node selected)") + f.sourceNode = "" + return nil + } + f.sourceNode = selected + logging.DebugStep(f.logger, "safe cluster apply (ui)", "User selected source node: %s", f.sourceNode) + f.logger.Info("SAFE cluster restore: selected exported node %s as VM/CT source, applying to current node %s", f.sourceNode, f.currentNode) + return nil +} + +func (f *safeClusterApplyUIFlow) logExportNodeCandidates(exportNodes []string) { + for _, node := range exportNodes { + qemuCount, lxcCount := countVMConfigsForNode(f.exportRoot, node) + logging.DebugStep(f.logger, "safe cluster apply (ui)", "Export node candidate: %s (qemu=%d, lxc=%d)", node, qemuCount, lxcCount) + } +} + +func (f *safeClusterApplyUIFlow) applyVMConfigsFromExport() error { + logging.DebugStep(f.logger, "safe cluster apply (ui)", "Selected VM/CT source node: %q (current_node=%q)", f.sourceNode, f.currentNode) + vmEntries := f.scanVMConfigs() + if len(vmEntries) == 0 { + f.logNoVMConfigs() + return nil + } + applyVMs, err := f.ui.ConfirmApplyVMConfigs(f.ctx, f.sourceNode, f.currentNode, len(vmEntries)) + if err != nil { + return err + } + logging.DebugStep(f.logger, "safe cluster apply (ui)", "User choice: apply_vms=%v (entries=%d)", applyVMs, len(vmEntries)) + if applyVMs { + applied, failed := applyVMConfigs(f.ctx, vmEntries, f.logger) + f.logger.Info("VM/CT apply completed: ok=%d failed=%d", applied, failed) + } else { + f.logger.Info("Skipping VM/CT apply") + } + return nil +} + +func (f *safeClusterApplyUIFlow) scanVMConfigs() []vmEntry { + if strings.TrimSpace(f.sourceNode) == "" { + return nil + } + logging.DebugStep(f.logger, "safe cluster apply (ui)", "Scan VM/CT configs in export (source_node=%s)", f.sourceNode) + vmEntries, err := scanVMConfigs(f.exportRoot, f.sourceNode) + if err != nil { + f.logger.Warning("Failed to scan VM configs: %v", err) + return nil + } + logging.DebugStep(f.logger, "safe cluster apply (ui)", "VM/CT configs found=%d (source_node=%s)", len(vmEntries), f.sourceNode) + return vmEntries +} + +func (f *safeClusterApplyUIFlow) logNoVMConfigs() { + if strings.TrimSpace(f.sourceNode) == "" { + f.logger.Info("No VM/CT configs applied (no source node selected)") + return + } + f.logger.Info("No VM/CT configs found for node %s in export", f.sourceNode) +} + +func (f *safeClusterApplyUIFlow) applyStorageAndDatacenter() error { + if f.plan != nil && f.plan.HasCategoryID("storage_pve") { + logging.DebugStep(f.logger, "safe cluster apply (ui)", "Skip storage/datacenter apply: handled by storage_pve staged restore") + f.logger.Info("Skipping storage/datacenter apply (handled by storage_pve staged restore)") + return nil + } + if err := f.maybeApplyStorageCfg(); err != nil { + return err + } + return f.maybeApplyDatacenterCfg() +} + +func (f *safeClusterApplyUIFlow) maybeApplyStorageCfg() error { + storageCfg := filepath.Join(f.exportRoot, "etc/pve/storage.cfg") + logging.DebugStep(f.logger, "safe cluster apply (ui)", "Check export: storage.cfg (%s)", storageCfg) + info, err := restoreFS.Stat(storageCfg) + if err != nil || info.IsDir() { + logging.DebugStep(f.logger, "safe cluster apply (ui)", "storage.cfg not found (err=%v)", err) + f.logger.Info("No storage.cfg found in export") + return nil + } + logging.DebugStep(f.logger, "safe cluster apply (ui)", "storage.cfg found (size=%d)", info.Size()) + applyStorage, err := f.ui.ConfirmApplyStorageCfg(f.ctx, storageCfg) + if err != nil { + return err + } + logging.DebugStep(f.logger, "safe cluster apply (ui)", "User choice: apply_storage=%v", applyStorage) + if applyStorage { + applied, failed, applyErr := applyStorageCfg(f.ctx, storageCfg, f.logger) + logging.DebugStep(f.logger, "safe cluster apply (ui)", "Storage apply result: ok=%d failed=%d err=%v", applied, failed, applyErr) + if applyErr != nil { + f.logger.Warning("Storage apply encountered errors: %v", applyErr) + } + f.logger.Info("Storage apply completed: ok=%d failed=%d", applied, failed) + } else { + f.logger.Info("Skipping storage.cfg apply") + } + return nil +} + +func (f *safeClusterApplyUIFlow) maybeApplyDatacenterCfg() error { + dcCfg := filepath.Join(f.exportRoot, "etc/pve/datacenter.cfg") + logging.DebugStep(f.logger, "safe cluster apply (ui)", "Check export: datacenter.cfg (%s)", dcCfg) + info, err := restoreFS.Stat(dcCfg) + if err != nil || info.IsDir() { + logging.DebugStep(f.logger, "safe cluster apply (ui)", "datacenter.cfg not found (err=%v)", err) + f.logger.Info("No datacenter.cfg found in export") + return nil + } + return f.confirmAndApplyDatacenterCfg(dcCfg, info.Size()) +} + +func (f *safeClusterApplyUIFlow) confirmAndApplyDatacenterCfg(dcCfg string, size int64) error { + logging.DebugStep(f.logger, "safe cluster apply (ui)", "datacenter.cfg found (size=%d)", size) + applyDC, err := f.ui.ConfirmApplyDatacenterCfg(f.ctx, dcCfg) + if err != nil { + return err + } + logging.DebugStep(f.logger, "safe cluster apply (ui)", "User choice: apply_datacenter=%v", applyDC) + if !applyDC { + f.logger.Info("Skipping datacenter.cfg apply") + return nil + } + logging.DebugStep(f.logger, "safe cluster apply (ui)", "Apply datacenter.cfg via pvesh") + if err := runPvesh(f.ctx, f.logger, []string{"set", "/cluster/config", "-conf", dcCfg}); err != nil { + f.logger.Warning("Failed to apply datacenter.cfg: %v", err) + } else { + f.logger.Info("datacenter.cfg applied successfully") + } + return nil +} + +func (f *safeClusterApplyUIFlow) applyPoolMembership() { + if !f.applyPools || len(f.pools) == 0 { + return + } + applied, failed, err := applyPVEPoolsMembership(f.ctx, f.logger, f.pools, f.allowPoolMove) + if err != nil { + f.logger.Warning("Pools apply (membership) encountered errors: %v", err) + } + f.logger.Info("Pools apply (membership) completed: ok=%d failed=%d", applied, failed) +} diff --git a/internal/orchestrator/restore_workflow_ui_extract.go b/internal/orchestrator/restore_workflow_ui_extract.go new file mode 100644 index 00000000..15b4ed0b --- /dev/null +++ b/internal/orchestrator/restore_workflow_ui_extract.go @@ -0,0 +1,355 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "errors" + "fmt" + "path/filepath" + "strings" + + "github.com/tis24dev/proxsave/internal/input" + "github.com/tis24dev/proxsave/internal/logging" +) + +func (w *restoreUIWorkflowRun) interceptFilesystemCategory() { + if !w.plan.HasCategoryID("filesystem") { + return + } + w.needsFilesystemRestore = true + w.plan.NormalCategories = categoriesWithoutID(w.plan.NormalCategories, "filesystem") + logging.DebugStep(w.logger, "restore", "Filesystem category intercepted: enabling Smart Merge workflow (skipping generic extraction)") +} + +func categoriesWithoutID(categories []Category, id string) []Category { + var filtered []Category + for _, cat := range categories { + if cat.ID != id { + filtered = append(filtered, cat) + } + } + return filtered +} + +func (w *restoreUIWorkflowRun) extractNormalCategories() error { + if len(w.plan.NormalCategories) == 0 { + w.logger.Info("") + w.logger.Info("No system-path categories selected for restore (only export categories will be processed).") + return nil + } + + w.logger.Info("") + categories := w.systemExtractionCategories() + if len(categories) == 0 { + logging.DebugStep(w.logger, "restore", "Skip system-path extraction: no categories remain after shadow-guard") + w.logger.Info("No system-path categories remain after cluster shadow-guard; skipping system-path extraction.") + return nil + } + + detailedLogPath, err := extractSelectiveArchive(w.ctx, w.prepared.ArchivePath, w.destRoot, categories, w.mode, w.logger) + if err != nil { + w.logger.Error("Restore failed: %v", err) + if w.safetyBackup != nil { + w.logger.Info("You can rollback using the safety backup at: %s", w.safetyBackup.BackupPath) + } + return err + } + w.detailedLogPath = detailedLogPath + return nil +} + +func (w *restoreUIWorkflowRun) systemExtractionCategories() []Category { + categories := w.plan.NormalCategories + if !w.needsClusterRestore { + return categories + } + logging.DebugStep(w.logger, "restore", "Cluster RECOVERY shadow-guard: sanitize categories to avoid /etc/pve shadow writes") + sanitized, removed := sanitizeCategoriesForClusterRecovery(categories) + w.logClusterShadowGuardResult(categories, sanitized, removed) + return sanitized +} + +func (w *restoreUIWorkflowRun) logClusterShadowGuardResult(before, after []Category, removed map[string][]string) { + removedPaths := 0 + for _, paths := range removed { + removedPaths += len(paths) + } + logging.DebugStep(w.logger, "restore", "Cluster RECOVERY shadow-guard: categories_before=%d categories_after=%d removed_categories=%d removed_paths=%d", len(before), len(after), len(removed), removedPaths) + if len(removed) == 0 { + logging.DebugStep(w.logger, "restore", "Cluster RECOVERY shadow-guard: no /etc/pve paths detected in selected categories") + return + } + + w.logger.Warning("Cluster RECOVERY restore: skipping direct restore of /etc/pve paths to prevent shadowing while pmxcfs is stopped/unmounted") + for _, cat := range before { + if paths, ok := removed[cat.ID]; ok && len(paths) > 0 { + w.logger.Warning(" - %s (%s): %s", cat.Name, cat.ID, strings.Join(paths, ", ")) + } + } + w.logger.Info("These paths are expected to be restored from config.db and become visible after /etc/pve is remounted.") +} + +func (w *restoreUIWorkflowRun) smartMergeFilesystemCategory() error { + if !w.needsFilesystemRestore { + return nil + } + w.logger.Info("") + fsTempDir, err := restoreFS.MkdirTemp("", "proxsave-fstab-") + if err != nil { + w.restoreHadWarnings = true + w.logger.Warning("Failed to create temp dir for fstab merge: %v", err) + return nil + } + defer func() { + if err := restoreFS.RemoveAll(fsTempDir); err != nil { + w.logger.Debug("Failed to remove temporary fstab merge directory %s: %v", fsTempDir, err) + } + }() + return w.extractAndMergeFstab(fsTempDir) +} + +func (w *restoreUIWorkflowRun) extractAndMergeFstab(fsTempDir string) error { + fsCat := GetCategoryByID("filesystem", w.availableCategories) + if fsCat == nil { + w.logger.Warning("Filesystem category not available in analyzed backup contents; skipping fstab merge") + return nil + } + if _, err := extractSelectiveArchive(w.ctx, w.prepared.ArchivePath, fsTempDir, []Category{*fsCat}, RestoreModeCustom, w.logger); err != nil { + return w.handleFstabExtractError(err) + } + w.extractFstabInventory(fsTempDir) + currentFstab := filepath.Join(w.destRoot, "etc", "fstab") + backupFstab := filepath.Join(fsTempDir, "etc", "fstab") + if err := smartMergeFstabWithUI(w.ctx, w.logger, w.ui, currentFstab, backupFstab, w.cfg.DryRun); err != nil { + return w.handleFstabMergeError(err) + } + return nil +} + +func (w *restoreUIWorkflowRun) handleFstabExtractError(err error) error { + if restoreAbortOrInput(err) { + return err + } + w.restoreHadWarnings = true + w.logger.Warning("Failed to extract filesystem config for merge: %v", err) + return nil +} + +func (w *restoreUIWorkflowRun) extractFstabInventory(fsTempDir string) { + inventoryCategory := []Category{{ + ID: "fstab_inventory", + Name: "Fstab inventory (device mapping)", + Paths: []string{ + "./var/lib/proxsave-info/commands/system/blkid.txt", + "./var/lib/proxsave-info/commands/system/lsblk_json.json", + "./var/lib/proxsave-info/commands/system/lsblk.txt", + "./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json", + }, + }} + err := extractArchiveNative(w.ctx, restoreArchiveOptions{ + archivePath: w.prepared.ArchivePath, + destRoot: fsTempDir, + logger: w.logger, + categories: inventoryCategory, + mode: RestoreModeCustom, + }) + if err != nil { + w.logger.Debug("Failed to extract fstab inventory data (continuing): %v", err) + } +} + +func (w *restoreUIWorkflowRun) handleFstabMergeError(err error) error { + if restoreAbortOrInput(err) { + w.logger.Info("Restore aborted by user during Smart Filesystem Configuration Merge.") + return err + } + w.restoreHadWarnings = true + w.logger.Warning("Smart Fstab Merge failed: %v", err) + return nil +} + +func (w *restoreUIWorkflowRun) exportCategories() error { + if len(w.plan.ExportCategories) == 0 { + return nil + } + w.exportRoot = exportDestRoot(w.cfg.BaseDir) + w.logger.Info("") + w.logger.Info("Exporting %d export-only category(ies) to: %s", len(w.plan.ExportCategories), w.exportRoot) + if err := restoreFS.MkdirAll(w.exportRoot, 0o700); err != nil { + return fmt.Errorf("failed to create export directory %s: %w", w.exportRoot, err) + } + + exportLog, err := extractSelectiveArchive(w.ctx, w.prepared.ArchivePath, w.exportRoot, w.plan.ExportCategories, RestoreModeCustom, w.logger) + if err != nil { + return w.handleExportError(err) + } + w.exportLogPath = exportLog + return nil +} + +func (w *restoreUIWorkflowRun) handleExportError(err error) error { + if restoreAbortOrInput(err) { + return err + } + w.restoreHadWarnings = true + w.logger.Warning("Export completed with errors: %v", err) + return nil +} + +func (w *restoreUIWorkflowRun) runClusterSafeApply() error { + if !w.plan.ClusterSafeMode { + return nil + } + if w.exportRoot == "" { + w.logger.Warning("Cluster SAFE mode selected but export directory not available; skipping automatic pvesh apply") + return nil + } + if w.exportLogPath == "" { + w.logger.Warning("Cluster SAFE mode selected but export extraction did not complete; skipping automatic pvesh apply") + return nil + } + w.extractSafeApplyInventory() + if err := runSafeClusterApplyWithUI(w.ctx, w.ui, w.exportRoot, w.logger, w.plan); err != nil { + return w.handleClusterSafeApplyError(err) + } + return nil +} + +func (w *restoreUIWorkflowRun) extractSafeApplyInventory() { + safeInvCategory := []Category{{ + ID: "safe_apply_inventory", + Name: "SAFE apply inventory (pools/mappings)", + Paths: []string{ + "./etc/pve/user.cfg", + "./var/lib/proxsave-info/commands/pve/mapping_pci.json", + "./var/lib/proxsave-info/commands/pve/mapping_usb.json", + "./var/lib/proxsave-info/commands/pve/mapping_dir.json", + }, + }} + err := extractArchiveNative(w.ctx, restoreArchiveOptions{ + archivePath: w.prepared.ArchivePath, + destRoot: w.exportRoot, + logger: w.logger, + categories: safeInvCategory, + mode: RestoreModeCustom, + }) + if err != nil { + w.logger.Debug("Failed to extract SAFE apply inventory (continuing): %v", err) + } +} + +func (w *restoreUIWorkflowRun) handleClusterSafeApplyError(err error) error { + if restoreAbortOrInput(err) { + return err + } + w.restoreHadWarnings = true + w.logger.Warning("Cluster SAFE apply completed with errors: %v", err) + return nil +} + +func (w *restoreUIWorkflowRun) stageAndApplySensitiveCategories() error { + if len(w.plan.StagedCategories) == 0 { + return nil + } + success, err := w.extractStagedCategories() + if err != nil { + return err + } + if !success { + w.logger.Warning("Skipping apply due to staged extraction errors") + return nil + } + return w.applyStagedCategories() +} + +func (w *restoreUIWorkflowRun) extractStagedCategories() (bool, error) { + w.stageRoot = stageDestRoot() + w.logger.Info("") + w.logger.Info("Staging %d sensitive category(ies) to: %s", len(w.plan.StagedCategories), w.stageRoot) + if err := restoreFS.MkdirAll(w.stageRoot, 0o700); err != nil { + return false, fmt.Errorf("failed to create staging directory %s: %w", w.stageRoot, err) + } + + stageLog, err := extractSelectiveArchive(w.ctx, w.prepared.ArchivePath, w.stageRoot, w.plan.StagedCategories, RestoreModeCustom, w.logger) + if err != nil { + if err := w.handleStageExtractError(err); err != nil { + return false, err + } + return false, nil + } + w.stageLogPath = stageLog + return true, nil +} + +func (w *restoreUIWorkflowRun) handleStageExtractError(err error) error { + if restoreAbortOrInput(err) { + return err + } + w.restoreHadWarnings = true + w.logger.Warning("Staging completed with errors: %v", err) + return nil +} + +func (w *restoreUIWorkflowRun) applyStagedCategories() error { + if err := w.applyPBSMountGuards(); err != nil { + return err + } + w.logger.Info("") + steps := []restoreStageApplyStep{ + {name: "PBS staged config apply", run: func() error { return maybeApplyPBSConfigsFromStage(w.ctx, w.logger, w.plan, w.stageRoot, w.cfg.DryRun) }}, + {name: "PVE staged config apply", run: func() error { + return maybeApplyPVEConfigsFromStage(w.ctx, w.logger, w.plan, w.stageRoot, w.destRoot, w.cfg.DryRun) + }}, + {name: "PVE SDN staged apply", run: func() error { return maybeApplyPVESDNFromStage(w.ctx, w.logger, w.plan, w.stageRoot, w.cfg.DryRun) }}, + {name: "Access control staged apply", run: w.applyAccessControlFromStage}, + {name: "Notifications staged apply", run: func() error { + return maybeApplyNotificationsFromStage(w.ctx, w.logger, w.plan, w.stageRoot, w.cfg.DryRun) + }}, + } + for _, step := range steps { + if err := w.runStageApplyStep(step); err != nil { + return err + } + } + return nil +} + +type restoreStageApplyStep struct { + name string + run func() error +} + +func (w *restoreUIWorkflowRun) applyPBSMountGuards() error { + err := maybeApplyPBSDatastoreMountGuards(w.ctx, w.logger, w.plan, w.stageRoot, w.destRoot, w.cfg.DryRun) + if err == nil { + return nil + } + if restoreAbortOrInput(err) { + return err + } + w.restoreHadWarnings = true + w.logger.Warning("PBS mount guard: %v", err) + return nil +} + +func (w *restoreUIWorkflowRun) runStageApplyStep(step restoreStageApplyStep) error { + if err := step.run(); err != nil { + if restoreAbortOrInput(err) { + return err + } + w.restoreHadWarnings = true + w.logStageApplyWarning(step.name, err) + } + return nil +} + +func (w *restoreUIWorkflowRun) logStageApplyWarning(name string, err error) { + if errors.Is(err, ErrAccessControlApplyNotCommitted) { + w.logAccessControlNotCommitted(err) + return + } + w.logger.Warning("%s: %v", name, err) +} + +func restoreAbortOrInput(err error) bool { + return errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) +} diff --git a/internal/orchestrator/restore_workflow_ui_extract_test.go b/internal/orchestrator/restore_workflow_ui_extract_test.go new file mode 100644 index 00000000..3e4141e3 --- /dev/null +++ b/internal/orchestrator/restore_workflow_ui_extract_test.go @@ -0,0 +1,100 @@ +package orchestrator + +import ( + "context" + "path/filepath" + "testing" + + "github.com/tis24dev/proxsave/internal/types" +) + +func TestRunClusterSafeApplySkipsWhenExportExtractionIncomplete(t *testing.T) { + logger := newTestLogger() + logger.SetLevel(types.LogLevelWarning) + w := &restoreUIWorkflowRun{ + ctx: context.Background(), + logger: logger, + ui: nil, + exportRoot: filepath.Join(t.TempDir(), "export"), + exportLogPath: "", + plan: &RestorePlan{ClusterSafeMode: true}, + prepared: &preparedBundle{ArchivePath: missingArchivePath(t)}, + } + + if err := w.runClusterSafeApply(); err != nil { + t.Fatalf("runClusterSafeApply error: %v", err) + } + if logger.WarningCount() != 1 { + t.Fatalf("expected skip warning for incomplete export extraction, got %d", logger.WarningCount()) + } +} + +func TestExtractStagedCategoriesReportsIncompleteOnNonAbortError(t *testing.T) { + origRestoreFS := restoreFS + fakeFS := NewFakeFS() + t.Cleanup(func() { + restoreFS = origRestoreFS + _ = fakeFS.Cleanup() + }) + restoreFS = fakeFS + + w := &restoreUIWorkflowRun{ + ctx: context.Background(), + logger: newTestLogger(), + plan: &RestorePlan{ + SystemType: SystemTypePBS, + StagedCategories: []Category{{ID: "pbs_notifications"}}, + }, + prepared: &preparedBundle{ArchivePath: missingArchivePath(t)}, + } + + success, err := w.extractStagedCategories() + if err != nil { + t.Fatalf("extractStagedCategories error: %v", err) + } + if success { + t.Fatalf("success=true; want false") + } + if !w.restoreHadWarnings { + t.Fatalf("restoreHadWarnings=false; want true") + } + if w.stageLogPath != "" { + t.Fatalf("stageLogPath=%q; want empty on incomplete staging", w.stageLogPath) + } +} + +func TestStageAndApplySensitiveCategoriesSkipsApplyWhenStagingIncomplete(t *testing.T) { + origRestoreFS := restoreFS + fakeFS := NewFakeFS() + t.Cleanup(func() { + restoreFS = origRestoreFS + _ = fakeFS.Cleanup() + }) + restoreFS = fakeFS + + w := &restoreUIWorkflowRun{ + ctx: context.Background(), + logger: newTestLogger(), + destRoot: "/", + plan: &RestorePlan{ + SystemType: SystemTypePBS, + StagedCategories: []Category{{ID: "pbs_notifications"}}, + }, + prepared: &preparedBundle{ArchivePath: missingArchivePath(t)}, + } + + if err := w.stageAndApplySensitiveCategories(); err != nil { + t.Fatalf("stageAndApplySensitiveCategories error: %v", err) + } + if !w.restoreHadWarnings { + t.Fatalf("restoreHadWarnings=false; want true") + } + if w.stageLogPath != "" { + t.Fatalf("stageLogPath=%q; want empty on incomplete staging", w.stageLogPath) + } +} + +func missingArchivePath(t *testing.T) string { + t.Helper() + return filepath.Join(t.TempDir(), "missing.tar") +} diff --git a/internal/orchestrator/restore_workflow_ui_fstab.go b/internal/orchestrator/restore_workflow_ui_fstab.go new file mode 100644 index 00000000..236fa1ae --- /dev/null +++ b/internal/orchestrator/restore_workflow_ui_fstab.go @@ -0,0 +1,124 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/tis24dev/proxsave/internal/logging" +) + +type fstabMergeUIPrompt struct { + analysis FstabAnalysisResult + remappedCount int + defaultYes bool +} + +// fstabMergeTimeout is the default time a UI has to confirm the Smart fstab merge. +const fstabMergeTimeout = 90 * time.Second + +func smartMergeFstabWithUI(ctx context.Context, logger *logging.Logger, ui RestoreWorkflowUI, currentFstabPath, backupFstabPath string, dryRun bool) error { + if logger == nil { + logger = logging.GetDefaultLogger() + } + logger.Info("") + logger.Step("Smart Filesystem Configuration Merge") + logger.Debug("[FSTAB_MERGE] Starting analysis of %s vs backup %s...", currentFstabPath, backupFstabPath) + + currentRaw, prompt, err := prepareFstabMergePrompt(logger, currentFstabPath, backupFstabPath) + if err != nil { + return err + } + if len(prompt.analysis.ProposedMounts) == 0 { + logger.Info("No new safe mounts found to restore. Keeping current fstab.") + return nil + } + + confirmed, err := ui.ConfirmFstabMerge(ctx, "Smart fstab merge", prompt.message(), fstabMergeTimeout, prompt.defaultYes) + if err != nil { + return err + } + if !confirmed { + logger.Info("Fstab merge skipped by user.") + return nil + } + return applyFstabMerge(ctx, logger, currentRaw, currentFstabPath, prompt.analysis.ProposedMounts, dryRun) +} + +func prepareFstabMergePrompt(logger *logging.Logger, currentFstabPath, backupFstabPath string) ([]string, fstabMergeUIPrompt, error) { + currentEntries, currentRaw, err := parseFstab(currentFstabPath) + if err != nil { + return nil, fstabMergeUIPrompt{}, fmt.Errorf("failed to parse current fstab: %w", err) + } + backupEntries, _, err := parseFstab(backupFstabPath) + if err != nil { + return nil, fstabMergeUIPrompt{}, fmt.Errorf("failed to parse backup fstab: %w", err) + } + + backupEntries, remappedCount := remapBackupFstabEntries(logger, backupEntries, backupFstabPath) + analysis := analyzeFstabMerge(logger, currentEntries, backupEntries) + prompt := fstabMergeUIPrompt{ + analysis: analysis, + remappedCount: remappedCount, + defaultYes: analysis.RootComparable && analysis.RootMatch && (!analysis.SwapComparable || analysis.SwapMatch), + } + return currentRaw, prompt, nil +} + +func remapBackupFstabEntries(logger *logging.Logger, entries []FstabEntry, backupFstabPath string) ([]FstabEntry, int) { + backupRoot := fstabBackupRootFromPath(backupFstabPath) + if backupRoot == "" { + return entries, 0 + } + remapped, count := remapFstabDevicesFromInventory(logger, entries, backupRoot) + if count > 0 { + logger.Info("Fstab device remap: converted %d entry(ies) from /dev/* to stable UUID/PARTUUID/LABEL based on ProxSave inventory", count) + } + return remapped, count +} + +func (p fstabMergeUIPrompt) message() string { + var msg strings.Builder + msg.WriteString("ProxSave found missing mounts in /etc/fstab.\n\n") + p.writeWarnings(&msg) + p.writeRemapSummary(&msg) + p.writeProposedMounts(&msg) + p.writeSkippedMounts(&msg) + msg.WriteString("\nDo you want to add the missing mounts (NFS/CIFS and data mounts with verified UUID/LABEL)?") + return msg.String() +} + +func (p fstabMergeUIPrompt) writeWarnings(msg *strings.Builder) { + if p.analysis.RootComparable && !p.analysis.RootMatch { + msg.WriteString("⚠ Root UUID mismatch: the backup appears to come from a different machine.\n") + } + if p.analysis.SwapComparable && !p.analysis.SwapMatch { + msg.WriteString("⚠ Swap mismatch: the current swap configuration will be kept.\n") + } +} + +func (p fstabMergeUIPrompt) writeRemapSummary(msg *strings.Builder) { + if p.remappedCount > 0 { + fmt.Fprintf(msg, "✓ Remapped %d fstab entry(ies) from /dev/* to stable UUID/PARTUUID/LABEL using ProxSave inventory.\n", p.remappedCount) + } +} + +func (p fstabMergeUIPrompt) writeProposedMounts(msg *strings.Builder) { + msg.WriteString("\nProposed mounts (safe):\n") + for _, mount := range p.analysis.ProposedMounts { + fmt.Fprintf(msg, " - %s -> %s (%s)\n", mount.Device, mount.MountPoint, mount.Type) + } +} + +func (p fstabMergeUIPrompt) writeSkippedMounts(msg *strings.Builder) { + if len(p.analysis.SkippedMounts) == 0 { + return + } + msg.WriteString("\nMounts found but not auto-proposed:\n") + for _, mount := range p.analysis.SkippedMounts { + fmt.Fprintf(msg, " - %s -> %s (%s)\n", mount.Device, mount.MountPoint, mount.Type) + } + msg.WriteString("\nHint: verify disks/UUIDs and options (nofail/_netdev) before adding them.\n") +} diff --git a/internal/orchestrator/restore_workflow_ui_full.go b/internal/orchestrator/restore_workflow_ui_full.go new file mode 100644 index 00000000..04a2137e --- /dev/null +++ b/internal/orchestrator/restore_workflow_ui_full.go @@ -0,0 +1,142 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "context" + "errors" + "fmt" + "path/filepath" + "strings" + + "github.com/tis24dev/proxsave/internal/input" + "github.com/tis24dev/proxsave/internal/logging" +) + +type fullRestoreUIFlow struct { + ctx context.Context + ui RestoreWorkflowUI + candidate *backupCandidate + prepared *preparedBundle + destRoot string + logger *logging.Logger + dryRun bool +} + +func runFullRestoreWithUI(ctx context.Context, ui RestoreWorkflowUI, candidate *backupCandidate, prepared *preparedBundle, destRoot string, logger *logging.Logger, dryRun bool) error { + flow := &fullRestoreUIFlow{ + ctx: ctx, + ui: ui, + candidate: candidate, + prepared: prepared, + destRoot: destRoot, + logger: logger, + dryRun: dryRun, + } + return flow.run() +} + +func (f *fullRestoreUIFlow) run() error { + if err := f.validate(); err != nil { + return err + } + if err := f.confirm(); err != nil { + return err + } + if f.safeFstabMerge() { + f.logger.Warning("Full restore safety: /etc/fstab will not be overwritten; Smart Merge will be applied after extraction.") + } + if err := extractPlainArchive(f.ctx, f.prepared.ArchivePath, f.destRoot, f.logger, f.skipPath); err != nil { + return err + } + if err := f.mergeFstabIfSafe(); err != nil { + return err + } + f.logger.Info("Restore completed successfully.") + return nil +} + +func (f *fullRestoreUIFlow) validate() error { + if f.candidate == nil || f.prepared == nil || f.prepared.ArchivePath == "" { + return fmt.Errorf("invalid restore candidate") + } + return nil +} + +func (f *fullRestoreUIFlow) confirm() error { + if err := f.ui.ShowMessage(f.ctx, "Full restore", "Backup category analysis failed; ProxSave will run a full restore (no selective modes)."); err != nil { + return err + } + confirmed, err := f.ui.ConfirmRestore(f.ctx) + if err != nil { + return err + } + if !confirmed { + return ErrRestoreAborted + } + return nil +} + +func (f *fullRestoreUIFlow) skipPath(name string) bool { + if !f.safeFstabMerge() { + return false + } + clean := strings.TrimPrefix(strings.TrimSpace(name), "./") + clean = strings.TrimPrefix(clean, "/") + return clean == "etc/fstab" +} + +func (f *fullRestoreUIFlow) safeFstabMerge() bool { + return f.destRoot == "/" && isRealRestoreFS(restoreFS) +} + +func (f *fullRestoreUIFlow) mergeFstabIfSafe() error { + if !f.safeFstabMerge() { + return nil + } + f.logger.Info("") + fsTempDir, err := restoreFS.MkdirTemp("", "proxsave-fstab-") + if err != nil { + f.logger.Warning("Failed to create temp dir for fstab merge: %v", err) + return nil + } + defer func() { + if err := restoreFS.RemoveAll(fsTempDir); err != nil { + f.logger.Debug("Failed to remove temporary fstab merge directory %s: %v", fsTempDir, err) + } + }() + return f.extractAndMergeFstab(fsTempDir) +} + +func (f *fullRestoreUIFlow) extractAndMergeFstab(fsTempDir string) error { + category := []Category{{ + ID: "filesystem", + Name: "Filesystem Configuration", + Paths: []string{"./etc/fstab"}, + }} + err := extractArchiveNative(f.ctx, restoreArchiveOptions{ + archivePath: f.prepared.ArchivePath, + destRoot: fsTempDir, + logger: f.logger, + categories: category, + mode: RestoreModeCustom, + }) + if err != nil { + f.logger.Warning("Failed to extract filesystem config for merge: %v", err) + return nil + } + currentFstab := filepath.Join(f.destRoot, "etc", "fstab") + backupFstab := filepath.Join(fsTempDir, "etc", "fstab") + if err := smartMergeFstabWithUI(f.ctx, f.logger, f.ui, currentFstab, backupFstab, f.dryRun); err != nil { + return f.handleFstabMergeError(err) + } + return nil +} + +func (f *fullRestoreUIFlow) handleFstabMergeError(err error) error { + if errors.Is(err, ErrRestoreAborted) || input.IsAborted(err) { + f.logger.Info("Restore aborted by user during Smart Filesystem Configuration Merge.") + return err + } + f.logger.Warning("Smart Fstab Merge failed: %v", err) + return nil +} diff --git a/internal/orchestrator/restore_workflow_ui_plan.go b/internal/orchestrator/restore_workflow_ui_plan.go new file mode 100644 index 00000000..6275d37d --- /dev/null +++ b/internal/orchestrator/restore_workflow_ui_plan.go @@ -0,0 +1,250 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "errors" + "fmt" + "os" + "strings" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func (w *restoreUIWorkflowRun) prepareBundleAndPlan() (fallbackToFullRestore bool, err error) { + if err := w.prepareBundle(); err != nil { + return false, err + } + cleanupOnFailure := true + defer func() { + if cleanupOnFailure && w.prepared != nil { + w.prepared.Cleanup() + } + }() + + fallbackToFullRestore, err = w.planPreparedBundle() + if err != nil { + return fallbackToFullRestore, err + } + cleanupOnFailure = false + return fallbackToFullRestore, nil +} + +func (w *restoreUIWorkflowRun) planPreparedBundle() (bool, error) { + w.detectTargetSystem() + fallbackToFullRestore, err := w.analyzeArchive() + if err != nil { + return false, err + } + if err := w.confirmCompatibility(); err != nil || fallbackToFullRestore { + return fallbackToFullRestore, err + } + if err := w.selectRestorePlan(); err != nil { + return false, err + } + return false, w.configurePlanForRuntime() +} + +func (w *restoreUIWorkflowRun) prepareBundle() error { + candidate, prepared, err := prepareRestoreBundleFunc(w.ctx, w.cfg, w.logger, w.version, w.ui) + if err != nil { + return err + } + w.candidate = candidate + w.prepared = prepared + w.logger.Info("Restore target: system root (/) — files will be written back to their original paths") + return nil +} + +func (w *restoreUIWorkflowRun) detectTargetSystem() { + w.systemType = restoreSystem.DetectCurrentSystem() + w.logger.Info("Detected system type: %s", GetSystemTypeString(w.systemType)) +} + +func (w *restoreUIWorkflowRun) analyzeArchive() (bool, error) { + available, decisionInfo, err := analyzeRestoreArchiveFunc(w.prepared.ArchivePath, w.logger) + if err == nil { + w.availableCategories = available + w.decisionInfo = ensureRestoreDecisionInfo(decisionInfo) + return false, nil + } + + w.logger.Warning("Could not analyze categories: %v", err) + w.availableCategories = nil + w.decisionInfo = fallbackRestoreDecisionInfoFromManifest(w.candidate.Manifest) + w.logger.Info("Falling back to full restore mode") + return true, nil +} + +func ensureRestoreDecisionInfo(info *RestoreDecisionInfo) *RestoreDecisionInfo { + if info != nil { + return info + } + return &RestoreDecisionInfo{} +} + +func (w *restoreUIWorkflowRun) confirmCompatibility() error { + warn := ValidateCompatibility(w.systemType, w.decisionInfo.BackupType) + if warn == nil { + return nil + } + w.logger.Warning("Compatibility check: %v", warn) + proceed, err := w.ui.ConfirmCompatibility(w.ctx, warn) + if err != nil { + return err + } + if !proceed { + return ErrRestoreAborted + } + return nil +} + +func (w *restoreUIWorkflowRun) selectRestorePlan() error { + categories, mode, err := w.selectModeAndCategories() + if err != nil { + return err + } + if mode == RestoreModeCustom { + categories, err = maybeAddRecommendedCategoriesForTFA(w.ctx, w.ui, w.logger, categories, w.availableCategories) + if err != nil { + return err + } + } + w.mode = mode + w.plan = PlanRestore(w.decisionInfo.ClusterPayload, categories, w.systemType, mode) + return nil +} + +func (w *restoreUIWorkflowRun) selectModeAndCategories() ([]Category, RestoreMode, error) { + for { + mode, err := w.ui.SelectRestoreMode(w.ctx, w.systemType) + if err != nil { + return nil, mode, err + } + if mode != RestoreModeCustom { + return GetCategoriesForMode(mode, w.systemType, w.availableCategories), mode, nil + } + + categories, err := w.ui.SelectCategories(w.ctx, w.availableCategories, w.systemType) + if errors.Is(err, errRestoreBackToMode) { + continue + } + return categories, mode, err + } +} + +func (w *restoreUIWorkflowRun) configurePlanForRuntime() error { + if err := w.selectPBSRestoreBehavior(); err != nil { + return err + } + if err := w.selectClusterRestoreMode(); err != nil { + return err + } + w.warnAccessControlHostnameMismatch() + w.collapseStagingWhenUnavailable() + return nil +} + +func (w *restoreUIWorkflowRun) selectPBSRestoreBehavior() error { + if !w.planNeedsPBSBehavior() { + return nil + } + behavior, err := w.ui.SelectPBSRestoreBehavior(w.ctx) + if err != nil { + return err + } + w.plan.PBSRestoreBehavior = behavior + w.logger.Info("PBS restore behavior: %s", behavior.DisplayName()) + return nil +} + +func (w *restoreUIWorkflowRun) planNeedsPBSBehavior() bool { + return w.plan.SystemType.SupportsPBS() && + (w.plan.HasCategoryID("pbs_host") || + w.plan.HasCategoryID("datastore_pbs") || + w.plan.HasCategoryID("pbs_remotes") || + w.plan.HasCategoryID("pbs_jobs") || + w.plan.HasCategoryID("pbs_notifications") || + w.plan.HasCategoryID("pbs_access_control") || + w.plan.HasCategoryID("pbs_tape")) +} + +func (w *restoreUIWorkflowRun) selectClusterRestoreMode() error { + if !w.plan.NeedsClusterRestore || !w.plan.ClusterBackup { + return nil + } + w.logger.Info("Cluster payload detected in backup; enabling guarded restore options for pve_cluster") + choice, err := w.ui.SelectClusterRestoreMode(w.ctx) + if err != nil { + return err + } + return w.applyClusterRestoreChoice(choice) +} + +func (w *restoreUIWorkflowRun) applyClusterRestoreChoice(choice ClusterRestoreMode) error { + switch choice { + case ClusterRestoreAbort: + return ErrRestoreAborted + case ClusterRestoreSafe: + w.plan.ApplyClusterSafeMode(true) + w.logger.Info("Selected SAFE cluster restore: /var/lib/pve-cluster will be exported only, not written to system") + case ClusterRestoreRecovery: + w.plan.ApplyClusterSafeMode(false) + w.logger.Warning("Selected RECOVERY cluster restore: full cluster database will be restored; ensure other nodes are isolated") + default: + return fmt.Errorf("invalid cluster restore mode selected") + } + return nil +} + +func (w *restoreUIWorkflowRun) warnAccessControlHostnameMismatch() { + if !w.plan.HasCategoryID("pve_access_control") && !w.plan.HasCategoryID("pbs_access_control") { + return + } + currentHost, err := os.Hostname() + backupHost := strings.TrimSpace(w.decisionInfo.BackupHostname) + if err != nil || backupHost == "" || strings.TrimSpace(currentHost) == "" { + return + } + if !strings.EqualFold(strings.TrimSpace(currentHost), backupHost) { + w.logger.Warning("Access control/TFA: backup hostname=%s current hostname=%s; WebAuthn users may require re-enrollment if the UI origin (FQDN/port) changes", backupHost, currentHost) + } +} + +func (w *restoreUIWorkflowRun) collapseStagingWhenUnavailable() { + if w.destRoot == "/" && isRealRestoreFS(restoreFS) { + return + } + if len(w.plan.StagedCategories) == 0 { + return + } + logging.DebugStep(w.logger, "restore", "Staging disabled (destRoot=%s realFS=%v): extracting %d staged category(ies) directly", w.destRoot, isRealRestoreFS(restoreFS), len(w.plan.StagedCategories)) + w.plan.NormalCategories = append(w.plan.NormalCategories, w.plan.StagedCategories...) + w.plan.StagedCategories = nil +} + +func (w *restoreUIWorkflowRun) confirmRestorePlan() error { + if w.plan == nil { + return ErrRestoreAborted + } + restoreConfig := &SelectiveRestoreConfig{ + Mode: w.mode, + SystemType: w.systemType, + Metadata: w.candidate.Manifest, + } + restoreConfig.SelectedCategories = append(restoreConfig.SelectedCategories, w.plan.NormalCategories...) + restoreConfig.SelectedCategories = append(restoreConfig.SelectedCategories, w.plan.StagedCategories...) + restoreConfig.SelectedCategories = append(restoreConfig.SelectedCategories, w.plan.ExportCategories...) + + if err := w.ui.ShowRestorePlan(w.ctx, restoreConfig); err != nil { + return err + } + confirmed, err := w.ui.ConfirmRestore(w.ctx) + if err != nil { + return err + } + if !confirmed { + w.logger.Info("Restore operation cancelled by user") + return ErrRestoreAborted + } + return nil +} diff --git a/internal/orchestrator/restore_workflow_ui_run.go b/internal/orchestrator/restore_workflow_ui_run.go new file mode 100644 index 00000000..837bf198 --- /dev/null +++ b/internal/orchestrator/restore_workflow_ui_run.go @@ -0,0 +1,133 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "context" + + "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/logging" +) + +type restoreUIWorkflowRun struct { + ctx context.Context + cfg *config.Config + logger *logging.Logger + version string + ui RestoreWorkflowUI + candidate *backupCandidate + prepared *preparedBundle + destRoot string + systemType SystemType + availableCategories []Category + decisionInfo *RestoreDecisionInfo + mode RestoreMode + plan *RestorePlan + restoreHadWarnings bool + safetyBackup *SafetyBackupResult + networkRollbackBackup *SafetyBackupResult + firewallRollbackBackup *SafetyBackupResult + haRollbackBackup *SafetyBackupResult + accessControlRollbackBackup *SafetyBackupResult + stageLogPath string + stageRoot string + stageRootForNetworkApply string + detailedLogPath string + exportLogPath string + exportRoot string + needsClusterRestore bool + clusterServicesStopped bool + pbsServicesStopped bool + needsPBSServices bool + needsFilesystemRestore bool +} + +func newRestoreUIWorkflowRun(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui RestoreWorkflowUI) *restoreUIWorkflowRun { + return &restoreUIWorkflowRun{ + ctx: ctx, + cfg: cfg, + logger: logger, + version: version, + ui: ui, + destRoot: "/", + } +} + +func (w *restoreUIWorkflowRun) run() error { + fallbackToFullRestore, err := w.prepareBundleAndPlan() + if err != nil { + return err + } + if w.prepared != nil { + defer w.prepared.Cleanup() + } + if fallbackToFullRestore { + return runFullRestoreWithUI(w.ctx, w.ui, w.candidate, w.prepared, w.destRoot, w.logger, w.cfg.DryRun) + } + return w.runSelectiveRestore() +} + +func (w *restoreUIWorkflowRun) runSelectiveRestore() error { + if err := w.confirmRestorePlan(); err != nil { + return err + } + if err := w.createRollbackBackups(); err != nil { + return err + } + cleanupServices, err := w.prepareRestoreServices() + if err != nil { + return err + } + defer cleanupServices() + if err := w.prepareAndRestoreSelectedPayloads(); err != nil { + return err + } + if err := w.runPostRestoreApplyWorkflows(); err != nil { + return err + } + w.logRestoreCompletion() + w.logServiceRestartAdvice() + w.checkZFSPoolsAfterRestore() + w.logRebootRecommendation() + return nil +} + +func (w *restoreUIWorkflowRun) prepareAndRestoreSelectedPayloads() error { + w.interceptFilesystemCategory() + if err := w.extractNormalCategories(); err != nil { + return err + } + if err := w.smartMergeFilesystemCategory(); err != nil { + return err + } + if err := w.exportCategories(); err != nil { + return err + } + if err := w.runClusterSafeApply(); err != nil { + return err + } + if err := w.stageAndApplySensitiveCategories(); err != nil { + return err + } + return nil +} + +func (w *restoreUIWorkflowRun) runPostRestoreApplyWorkflows() error { + w.verifyPBSNotificationsAfterRestore() + if err := w.installNetworkConfigFromStage(); err != nil { + return err + } + w.recreateStorageDirectories() + if err := w.repairDNSAfterRestore(); err != nil { + return err + } + if err := w.applyNetworkConfig(); err != nil { + return err + } + if err := w.applyFirewallConfig(); err != nil { + return err + } + if err := w.applyHAConfig(); err != nil { + return err + } + return nil +} diff --git a/internal/orchestrator/restore_workflow_ui_tfa.go b/internal/orchestrator/restore_workflow_ui_tfa.go new file mode 100644 index 00000000..24c6fac6 --- /dev/null +++ b/internal/orchestrator/restore_workflow_ui_tfa.go @@ -0,0 +1,99 @@ +// Package orchestrator coordinates backup, restore, decrypt, and related workflows. +package orchestrator + +import ( + "context" + "fmt" + "strings" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func maybeAddRecommendedCategoriesForTFA(ctx context.Context, ui RestoreWorkflowUI, logger *logging.Logger, selected []Category, available []Category) ([]Category, error) { + if !shouldPromptForTFARecommendations(ui, logger, selected) { + return selected, nil + } + addCategories, addNames := tfaRecommendedCategories(selected, available) + if len(addCategories) == 0 { + return selected, nil + } + + addNow, err := confirmTFARecommendedCategories(ctx, ui, addNames) + if err != nil { + return nil, err + } + if !addNow { + logger.Warning("Access control selected without %s; WebAuthn users may require re-enrollment if the UI origin changes", strings.Join(addNames, ", ")) + return selected, nil + } + return dedupeCategoriesByID(append(selected, addCategories...)), nil +} + +func shouldPromptForTFARecommendations(ui RestoreWorkflowUI, logger *logging.Logger, selected []Category) bool { + return ui != nil && + logger != nil && + (hasCategoryID(selected, "pve_access_control") || hasCategoryID(selected, "pbs_access_control")) +} + +func tfaRecommendedCategories(selected, available []Category) ([]Category, []string) { + var categories []Category + var names []string + for _, id := range missingTFARecommendedCategoryIDs(selected) { + cat := GetCategoryByID(id, available) + if cat == nil || !cat.IsAvailable || cat.ExportOnly { + continue + } + categories = append(categories, *cat) + names = append(names, cat.Name) + } + return categories, names +} + +func missingTFARecommendedCategoryIDs(selected []Category) []string { + var missing []string + if !hasCategoryID(selected, "network") { + missing = append(missing, "network") + } + if !hasCategoryID(selected, "ssl") { + missing = append(missing, "ssl") + } + return missing +} + +func confirmTFARecommendedCategories(ctx context.Context, ui RestoreWorkflowUI, addNames []string) (bool, error) { + message := fmt.Sprintf( + "You selected Access Control without restoring: %s\n\n"+ + "If TFA includes WebAuthn/FIDO2, changing the UI origin (FQDN/hostname or port) may require re-enrollment.\n\n"+ + "For maximum 1:1 compatibility, ProxSave recommends restoring these categories too.\n\n"+ + "Add recommended categories now?", + strings.Join(addNames, ", "), + ) + return ui.ConfirmAction(ctx, "TFA/WebAuthn compatibility", message, "Add recommended", "Keep current", 0, true) +} + +func dedupeCategoriesByID(categories []Category) []Category { + if len(categories) == 0 { + return categories + } + seen := make(map[string]struct{}, len(categories)) + out := make([]Category, 0, len(categories)) + for _, cat := range categories { + if dedupeCategorySeen(seen, cat.ID) { + continue + } + out = append(out, cat) + } + return out +} + +func dedupeCategorySeen(seen map[string]struct{}, id string) bool { + id = strings.TrimSpace(id) + if id == "" { + return false + } + if _, ok := seen[id]; ok { + return true + } + seen[id] = struct{}{} + return false +} diff --git a/internal/orchestrator/selective_additional_test.go b/internal/orchestrator/selective_additional_test.go index e6fac525..d661cbc8 100644 --- a/internal/orchestrator/selective_additional_test.go +++ b/internal/orchestrator/selective_additional_test.go @@ -84,7 +84,7 @@ func TestConfirmRestoreOperation(t *testing.T) { } _ = w.Close() os.Stdin = r - defer r.Close() + defer func() { _ = r.Close() }() got, err := ConfirmRestoreOperation(context.Background(), logger) if err != nil { diff --git a/internal/orchestrator/temp_registry.go b/internal/orchestrator/temp_registry.go index 915dc621..e3000bca 100644 --- a/internal/orchestrator/temp_registry.go +++ b/internal/orchestrator/temp_registry.go @@ -136,7 +136,7 @@ func (r *TempDirRegistry) updateEntries(mutator func([]tempDirRecord) ([]tempDir }) } -func (r *TempDirRegistry) withLock(mutator func([]tempDirRecord) ([]tempDirRecord, error)) error { +func (r *TempDirRegistry) withLock(mutator func([]tempDirRecord) ([]tempDirRecord, error)) (err error) { r.mu.Lock() defer r.mu.Unlock() @@ -144,12 +144,20 @@ func (r *TempDirRegistry) withLock(mutator func([]tempDirRecord) ([]tempDirRecor if err != nil { return fmt.Errorf("open registry lock: %w", err) } - defer lockFile.Close() + defer func() { + if closeErr := lockFile.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("close registry lock: %w", closeErr) + } + }() if err := syscall.Flock(int(lockFile.Fd()), syscall.LOCK_EX); err != nil { return fmt.Errorf("flock registry: %w", err) } - defer syscall.Flock(int(lockFile.Fd()), syscall.LOCK_UN) + defer func() { + if unlockErr := syscall.Flock(int(lockFile.Fd()), syscall.LOCK_UN); unlockErr != nil && err == nil { + err = fmt.Errorf("unlock registry: %w", unlockErr) + } + }() entries, err := r.loadEntries() if err != nil { diff --git a/internal/orchestrator/tui_simulation_test.go b/internal/orchestrator/tui_simulation_test.go index 20d2cfdd..7830d649 100644 --- a/internal/orchestrator/tui_simulation_test.go +++ b/internal/orchestrator/tui_simulation_test.go @@ -28,12 +28,6 @@ func withSimAppSequence(t *testing.T, keys []simKey) <-chan struct{} { t.Helper() orig := newTUIApp - screen := tcell.NewSimulationScreen("UTF-8") - if err := screen.Init(); err != nil { - t.Fatalf("screen.Init: %v", err) - } - screen.SetSize(120, 40) - drawCh := make(chan struct{}, 8) done := make(chan struct{}) var injectOnce sync.Once @@ -51,6 +45,12 @@ func withSimAppSequence(t *testing.T, keys []simKey) <-chan struct{} { } newTUIApp = func() *tui.App { + screen := tcell.NewSimulationScreen("UTF-8") + if err := screen.Init(); err != nil { + t.Fatalf("screen.Init: %v", err) + } + screen.SetSize(120, 40) + app := tui.NewApp() appMu.Lock() currentApp = app diff --git a/internal/orchestrator/workflow_ui_tui_decrypt.go b/internal/orchestrator/workflow_ui_tui_decrypt.go index fa0483a5..ab77d275 100644 --- a/internal/orchestrator/workflow_ui_tui_decrypt.go +++ b/internal/orchestrator/workflow_ui_tui_decrypt.go @@ -218,12 +218,12 @@ func (u *tuiWorkflowUI) SelectBackupSource(ctx context.Context, options []decryp if listHeight > 14 { listHeight = 14 } - form.Form.AddFormItem( + form.AddFormItem( components.NewListFormItem(list). SetLabel("Available backup sources"). SetFieldHeight(listHeight), ) - form.Form.SetFocus(0) + form.SetFocus(0) form.SetOnCancel(func() { aborted = true }) @@ -332,12 +332,12 @@ func (u *tuiWorkflowUI) SelectBackupCandidate(ctx context.Context, candidates [] if listHeight > 14 { listHeight = 14 } - form.Form.AddFormItem( + form.AddFormItem( components.NewListFormItem(list). SetLabel("Available backups"). SetFieldHeight(listHeight), ) - form.Form.SetFocus(0) + form.SetFocus(0) form.SetOnCancel(func() { aborted = true }) diff --git a/internal/orchestrator/workflow_ui_tui_restore.go b/internal/orchestrator/workflow_ui_tui_restore.go index b9a18d47..75270d49 100644 --- a/internal/orchestrator/workflow_ui_tui_restore.go +++ b/internal/orchestrator/workflow_ui_tui_restore.go @@ -109,8 +109,8 @@ func (u *tuiWorkflowUI) SelectExportNode(ctx context.Context, exportRoot, curren listItem := components.NewListFormItem(list). SetLabel(fmt.Sprintf("Current node: %s", strings.TrimSpace(currentNode))). SetFieldHeight(8) - form.Form.AddFormItem(listItem) - form.Form.SetFocus(0) + form.AddFormItem(listItem) + form.SetFocus(0) form.SetOnCancel(func() { cancelled = true diff --git a/internal/orchestrator/workflow_ui_tui_shared.go b/internal/orchestrator/workflow_ui_tui_shared.go index 2f8cd448..c07ed9e1 100644 --- a/internal/orchestrator/workflow_ui_tui_shared.go +++ b/internal/orchestrator/workflow_ui_tui_shared.go @@ -10,7 +10,7 @@ func enableFormNavigation(form *components.Form, dropdownOpen *bool) { if form == nil || form.Form == nil { return } - form.Form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event == nil { return event } @@ -18,7 +18,7 @@ func enableFormNavigation(form *components.Form, dropdownOpen *bool) { return event } - formItemIndex, buttonIndex := form.Form.GetFocusedItemIndex() + formItemIndex, buttonIndex := form.GetFocusedItemIndex() isOnButton := formItemIndex < 0 && buttonIndex >= 0 isOnField := formItemIndex >= 0 @@ -31,7 +31,7 @@ func enableFormNavigation(form *components.Form, dropdownOpen *bool) { } } else if isOnField { // If focused item is a ListFormItem, let it handle navigation internally. - if _, ok := form.Form.GetFormItem(formItemIndex).(*components.ListFormItem); ok { + if _, ok := form.GetFormItem(formItemIndex).(*components.ListFormItem); ok { return event } // For other form fields, convert arrows to tab navigation. diff --git a/internal/pbs/namespaces_test.go b/internal/pbs/namespaces_test.go index 22e7a473..66824468 100644 --- a/internal/pbs/namespaces_test.go +++ b/internal/pbs/namespaces_test.go @@ -207,13 +207,13 @@ func TestHelperProcess(t *testing.T) { switch os.Getenv("PBS_HELPER_SCENARIO") { case "cli-success": - fmt.Fprint(os.Stdout, `{"data":[{"ns":"","path":"/mnt/datastore","comment":"root namespace"},{"ns":"prod","path":"/mnt/datastore/prod","parent":"","ctime":1700000000}]}`) + _, _ = fmt.Fprint(os.Stdout, `{"data":[{"ns":"","path":"/mnt/datastore","comment":"root namespace"},{"ns":"prod","path":"/mnt/datastore/prod","parent":"","ctime":1700000000}]}`) os.Exit(0) case "cli-error": - fmt.Fprint(os.Stderr, "CLI exploded") + _, _ = fmt.Fprint(os.Stderr, "CLI exploded") os.Exit(1) default: - fmt.Fprint(os.Stderr, "unknown scenario") + _, _ = fmt.Fprint(os.Stderr, "unknown scenario") os.Exit(2) } } diff --git a/internal/safeexec/safeexec.go b/internal/safeexec/safeexec.go index b4960255..7d2fd455 100644 --- a/internal/safeexec/safeexec.go +++ b/internal/safeexec/safeexec.go @@ -1,3 +1,4 @@ +// Package safeexec centralizes constrained process execution helpers. package safeexec import ( @@ -12,172 +13,170 @@ import ( "unicode" ) +// ErrCommandNotAllowed reports that a command name is outside the allowlist. var ErrCommandNotAllowed = errors.New("command not allowed") +type commandFactory func(context.Context, ...string) *exec.Cmd + +var allowedCommandFactories = map[string]commandFactory{ + "apt-cache": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "apt-cache", args...) + }, + "blkid": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "blkid", args...) }, + "bridge": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "bridge", args...) + }, + "bzip2": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "bzip2", args...) }, + "cat": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "cat", args...) }, + "ceph": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "ceph", args...) }, + "chattr": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "chattr", args...) + }, + "crontab": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "crontab", args...) + }, + "df": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "df", args...) }, + "dmidecode": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "dmidecode", args...) + }, + "dpkg": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "dpkg", args...) }, + "dpkg-query": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "dpkg-query", args...) + }, + "echo": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "echo", args...) }, + "ethtool": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "ethtool", args...) + }, + "false": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "false", args...) }, + "firewall-cmd": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "firewall-cmd", args...) + }, + "free": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "free", args...) }, + "hostname": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "hostname", args...) + }, + "ifreload": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "ifreload", args...) + }, + "ifup": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "ifup", args...) }, + "ip": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "ip", args...) }, + "iptables": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "iptables", args...) + }, + "iptables-save": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "iptables-save", args...) + }, + "ip6tables": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "ip6tables", args...) + }, + "ip6tables-save": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "ip6tables-save", args...) + }, + "journalctl": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "journalctl", args...) + }, + "lsblk": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "lsblk", args...) }, + "lspci": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "lspci", args...) }, + "lscpu": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "lscpu", args...) }, + "lsmod": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "lsmod", args...) }, + "lsusb": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "lsusb", args...) }, + "lvs": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "lvs", args...) }, + "lzma": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "lzma", args...) }, + "mailq": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "mailq", args...) }, + "mount": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "mount", args...) }, + "mountpoint": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "mountpoint", args...) + }, + "nft": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "nft", args...) }, + "pbzip2": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "pbzip2", args...) + }, + "pgrep": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "pgrep", args...) }, + "pigz": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "pigz", args...) }, + "ping": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "ping", args...) }, + "pvs": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "pvs", args...) }, + "proxmox-backup-client": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "proxmox-backup-client", args...) + }, + "proxmox-backup-manager": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "proxmox-backup-manager", args...) + }, + "proxmox-mail-forward": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "proxmox-mail-forward", args...) + }, + "proxmox-tape": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "proxmox-tape", args...) + }, + "ps": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "ps", args...) }, + "pvecm": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "pvecm", args...) }, + "pve-firewall": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "pve-firewall", args...) + }, + "pvenode": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "pvenode", args...) + }, + "pvesh": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "pvesh", args...) }, + "pvesm": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "pvesm", args...) }, + "pveum": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "pveum", args...) }, + "pveversion": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "pveversion", args...) + }, + "rclone": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "rclone", args...) + }, + "sendmail": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "sendmail", args...) + }, + "sensors": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "sensors", args...) + }, + "sh": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "sh", args...) }, + "smartctl": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "smartctl", args...) + }, + "ss": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "ss", args...) }, + "systemctl": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "systemctl", args...) + }, + "systemd-run": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "systemd-run", args...) + }, + "sysctl": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "sysctl", args...) + }, + "tail": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "tail", args...) }, + "tar": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "tar", args...) }, + "udevadm": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "udevadm", args...) + }, + "umount": func(ctx context.Context, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, "umount", args...) + }, + "uname": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "uname", args...) }, + "ufw": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "ufw", args...) }, + "vgs": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "vgs", args...) }, + "which": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "which", args...) }, + "xz": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "xz", args...) }, + "zfs": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "zfs", args...) }, + "zpool": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "zpool", args...) }, + "zstd": func(ctx context.Context, args ...string) *exec.Cmd { return exec.CommandContext(ctx, "zstd", args...) }, +} + // CommandContext creates commands only for binaries that are intentionally -// allowed by the application. Keep exec.CommandContext calls in the switch so +// allowed by the application. Keep exec.CommandContext calls in the factory map so // static analyzers can see literal command names. func CommandContext(ctx context.Context, name string, args ...string) (*exec.Cmd, error) { if strings.TrimSpace(name) != name || name == "" || strings.ContainsAny(name, `/\`) { return nil, fmt.Errorf("%w: %q", ErrCommandNotAllowed, name) } - switch name { - case "apt-cache": - return exec.CommandContext(ctx, "apt-cache", args...), nil - case "blkid": - return exec.CommandContext(ctx, "blkid", args...), nil - case "bridge": - return exec.CommandContext(ctx, "bridge", args...), nil - case "bzip2": - return exec.CommandContext(ctx, "bzip2", args...), nil - case "cat": - return exec.CommandContext(ctx, "cat", args...), nil - case "ceph": - return exec.CommandContext(ctx, "ceph", args...), nil - case "chattr": - return exec.CommandContext(ctx, "chattr", args...), nil - case "crontab": - return exec.CommandContext(ctx, "crontab", args...), nil - case "df": - return exec.CommandContext(ctx, "df", args...), nil - case "dmidecode": - return exec.CommandContext(ctx, "dmidecode", args...), nil - case "dpkg": - return exec.CommandContext(ctx, "dpkg", args...), nil - case "dpkg-query": - return exec.CommandContext(ctx, "dpkg-query", args...), nil - case "echo": - return exec.CommandContext(ctx, "echo", args...), nil - case "ethtool": - return exec.CommandContext(ctx, "ethtool", args...), nil - case "false": - return exec.CommandContext(ctx, "false", args...), nil - case "firewall-cmd": - return exec.CommandContext(ctx, "firewall-cmd", args...), nil - case "free": - return exec.CommandContext(ctx, "free", args...), nil - case "hostname": - return exec.CommandContext(ctx, "hostname", args...), nil - case "ifreload": - return exec.CommandContext(ctx, "ifreload", args...), nil - case "ifup": - return exec.CommandContext(ctx, "ifup", args...), nil - case "ip": - return exec.CommandContext(ctx, "ip", args...), nil - case "iptables": - return exec.CommandContext(ctx, "iptables", args...), nil - case "iptables-save": - return exec.CommandContext(ctx, "iptables-save", args...), nil - case "ip6tables": - return exec.CommandContext(ctx, "ip6tables", args...), nil - case "ip6tables-save": - return exec.CommandContext(ctx, "ip6tables-save", args...), nil - case "journalctl": - return exec.CommandContext(ctx, "journalctl", args...), nil - case "lsblk": - return exec.CommandContext(ctx, "lsblk", args...), nil - case "lspci": - return exec.CommandContext(ctx, "lspci", args...), nil - case "lscpu": - return exec.CommandContext(ctx, "lscpu", args...), nil - case "lsmod": - return exec.CommandContext(ctx, "lsmod", args...), nil - case "lsusb": - return exec.CommandContext(ctx, "lsusb", args...), nil - case "lvs": - return exec.CommandContext(ctx, "lvs", args...), nil - case "lzma": - return exec.CommandContext(ctx, "lzma", args...), nil - case "mailq": - return exec.CommandContext(ctx, "mailq", args...), nil - case "mount": - return exec.CommandContext(ctx, "mount", args...), nil - case "mountpoint": - return exec.CommandContext(ctx, "mountpoint", args...), nil - case "nft": - return exec.CommandContext(ctx, "nft", args...), nil - case "pbzip2": - return exec.CommandContext(ctx, "pbzip2", args...), nil - case "pgrep": - return exec.CommandContext(ctx, "pgrep", args...), nil - case "pigz": - return exec.CommandContext(ctx, "pigz", args...), nil - case "ping": - return exec.CommandContext(ctx, "ping", args...), nil - case "pvs": - return exec.CommandContext(ctx, "pvs", args...), nil - case "proxmox-backup-client": - return exec.CommandContext(ctx, "proxmox-backup-client", args...), nil - case "proxmox-backup-manager": - return exec.CommandContext(ctx, "proxmox-backup-manager", args...), nil - case "proxmox-mail-forward": - return exec.CommandContext(ctx, "proxmox-mail-forward", args...), nil - case "proxmox-tape": - return exec.CommandContext(ctx, "proxmox-tape", args...), nil - case "ps": - return exec.CommandContext(ctx, "ps", args...), nil - case "pvecm": - return exec.CommandContext(ctx, "pvecm", args...), nil - case "pve-firewall": - return exec.CommandContext(ctx, "pve-firewall", args...), nil - case "pvenode": - return exec.CommandContext(ctx, "pvenode", args...), nil - case "pvesh": - return exec.CommandContext(ctx, "pvesh", args...), nil - case "pvesm": - return exec.CommandContext(ctx, "pvesm", args...), nil - case "pveum": - return exec.CommandContext(ctx, "pveum", args...), nil - case "pveversion": - return exec.CommandContext(ctx, "pveversion", args...), nil - case "rclone": - return exec.CommandContext(ctx, "rclone", args...), nil - case "sendmail": - return exec.CommandContext(ctx, "sendmail", args...), nil - case "sensors": - return exec.CommandContext(ctx, "sensors", args...), nil - case "sh": - return exec.CommandContext(ctx, "sh", args...), nil - case "smartctl": - return exec.CommandContext(ctx, "smartctl", args...), nil - case "ss": - return exec.CommandContext(ctx, "ss", args...), nil - case "systemctl": - return exec.CommandContext(ctx, "systemctl", args...), nil - case "systemd-run": - return exec.CommandContext(ctx, "systemd-run", args...), nil - case "sysctl": - return exec.CommandContext(ctx, "sysctl", args...), nil - case "tail": - return exec.CommandContext(ctx, "tail", args...), nil - case "tar": - return exec.CommandContext(ctx, "tar", args...), nil - case "udevadm": - return exec.CommandContext(ctx, "udevadm", args...), nil - case "umount": - return exec.CommandContext(ctx, "umount", args...), nil - case "uname": - return exec.CommandContext(ctx, "uname", args...), nil - case "ufw": - return exec.CommandContext(ctx, "ufw", args...), nil - case "vgs": - return exec.CommandContext(ctx, "vgs", args...), nil - case "which": - return exec.CommandContext(ctx, "which", args...), nil - case "xz": - return exec.CommandContext(ctx, "xz", args...), nil - case "zfs": - return exec.CommandContext(ctx, "zfs", args...), nil - case "zpool": - return exec.CommandContext(ctx, "zpool", args...), nil - case "zstd": - return exec.CommandContext(ctx, "zstd", args...), nil - default: - return nil, fmt.Errorf("%w: %q", ErrCommandNotAllowed, name) + if factory, ok := allowedCommandFactories[name]; ok { + return factory(ctx, args...), nil } + return nil, fmt.Errorf("%w: %q", ErrCommandNotAllowed, name) } +// CombinedOutput runs an allowed command and returns its combined stdout/stderr. func CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error) { cmd, err := CommandContext(ctx, name, args...) if err != nil { @@ -186,6 +185,7 @@ func CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, e return cmd.CombinedOutput() } +// Output runs an allowed command and returns stdout. func Output(ctx context.Context, name string, args ...string) ([]byte, error) { cmd, err := CommandContext(ctx, name, args...) if err != nil { @@ -194,6 +194,7 @@ func Output(ctx context.Context, name string, args ...string) ([]byte, error) { return cmd.Output() } +// TrustedCommandContext creates a command for a validated absolute executable path. func TrustedCommandContext(ctx context.Context, execPath string, args ...string) (*exec.Cmd, error) { if err := ValidateTrustedExecutablePath(execPath); err != nil { return nil, err @@ -202,6 +203,7 @@ func TrustedCommandContext(ctx context.Context, execPath string, args ...string) return exec.CommandContext(ctx, execPath, args...), nil // nosemgrep: go.lang.security.audit.dangerous-exec-command.dangerous-exec-command } +// ValidateTrustedExecutablePath verifies an executable path is absolute, regular, executable, and not world-writable. func ValidateTrustedExecutablePath(execPath string) error { clean := strings.TrimSpace(execPath) if clean == "" { @@ -226,6 +228,7 @@ func ValidateTrustedExecutablePath(execPath string) error { return nil } +// ValidateRcloneRemoteName validates a rclone remote name before it is used in command arguments. func ValidateRcloneRemoteName(remote string) error { if remote == "" { return fmt.Errorf("rclone remote name is empty") @@ -244,6 +247,7 @@ func ValidateRcloneRemoteName(remote string) error { return nil } +// ValidateRemoteRelativePath validates a remote-relative path segment for a named field. func ValidateRemoteRelativePath(value, field string) error { clean := strings.TrimSpace(value) if clean == "" { @@ -264,6 +268,7 @@ func ValidateRemoteRelativePath(value, field string) error { return nil } +// ProcPath returns a safe /proc path for a supported PID leaf. func ProcPath(pid int, leaf string) (string, error) { if pid <= 0 { return "", fmt.Errorf("pid must be positive") diff --git a/internal/security/security.go b/internal/security/security.go index 5104a469..5edf6f64 100644 --- a/internal/security/security.go +++ b/internal/security/security.go @@ -348,7 +348,7 @@ func (c *Checker) verifyBinaryIntegrity() { c.addError("Cannot open executable %s: %v", c.execPath, err) return } - defer f.Close() + defer func() { _ = f.Close() }() openedInfo, err := f.Stat() if err != nil { @@ -359,7 +359,7 @@ func (c *Checker) verifyBinaryIntegrity() { c.addError("Executable %s changed during integrity check; aborting", c.execPath) return } - openedInfo = c.ensureOwnershipAndPermFromFD(f, openedInfo, 0o700, fmt.Sprintf("Executable %s", c.execPath)) + c.ensureOwnershipAndPermFromFD(f, openedInfo, 0o700, fmt.Sprintf("Executable %s", c.execPath)) currentHash, err := checksumReader(f) if err != nil { @@ -569,7 +569,7 @@ func fileContainsMarker(path string, markers []string, limit int) (bool, error) if err != nil { return false, err } - defer f.Close() + defer func() { _ = f.Close() }() const bufSize = 4096 maxMarkerLen := 0 @@ -751,7 +751,11 @@ func (c *Checker) checkSuspiciousProcesses(ctx context.Context) { if args == "" { continue } - lowerArgs := strings.ToLower(args) + trimmed := strings.TrimSpace(args) + if isZombieProxmoxProcess(user, state, vsz, trimmed) { + continue + } + lowerArgs := strings.ToLower(trimmed) for _, signature := range c.cfg.SuspiciousProcesses { sig := strings.ToLower(strings.TrimSpace(signature)) @@ -759,12 +763,11 @@ func (c *Checker) checkSuspiciousProcesses(ctx context.Context) { continue } if strings.Contains(lowerArgs, sig) { - c.addWarning("Suspicious process detected: %s (PID %s, user %s)", strings.TrimSpace(args), pid, user) + c.addWarning("Suspicious process detected: %s (PID %s, user %s)", trimmed, pid, user) break } } - trimmed := strings.TrimSpace(args) if strings.HasPrefix(trimmed, "[") && strings.HasSuffix(trimmed, "]") { name := strings.TrimSuffix(strings.TrimPrefix(trimmed, "["), "]") if !c.isSafeBracketProcess(name) { @@ -776,11 +779,6 @@ func (c *Checker) checkSuspiciousProcesses(ctx context.Context) { c.addWarning("Suspicious kernel-style process: %s (PID %s, user %s)", name, pid, user) } } - - //lint:ignore SA4017 isZombieProxmoxProcess is intentionally used only for control flow - if isZombieProxmoxProcess(user, state, vsz, trimmed) { - continue - } } } @@ -803,7 +801,7 @@ func checksumFile(path string) (string, error) { if err != nil { return "", err } - defer f.Close() + defer func() { _ = f.Close() }() return checksumReader(f) } diff --git a/internal/security/security_test.go b/internal/security/security_test.go index 8876062f..d67a537c 100644 --- a/internal/security/security_test.go +++ b/internal/security/security_test.go @@ -879,6 +879,43 @@ func TestCheckSuspiciousProcesses(t *testing.T) { } } +func TestCheckSuspiciousProcessesSkipsProxmoxBackupZombie(t *testing.T) { + writeFakePS(t, "root Z 0 123 proxmox-backup-proxy\n") + checker := newChecker(t, &config.Config{ + SuspiciousProcesses: []string{"proxmox-backup"}, + }) + + checker.checkSuspiciousProcesses(context.Background()) + + if containsIssue(checker.result, "Suspicious process detected") { + t.Fatalf("expected Proxmox Backup zombie to be skipped, issues=%+v", checker.result.Issues) + } +} + +func TestCheckSuspiciousProcessesWarnsForNonZombieProxmoxBackupMatch(t *testing.T) { + writeFakePS(t, "root S 1234 124 proxmox-backup-proxy\n") + checker := newChecker(t, &config.Config{ + SuspiciousProcesses: []string{"proxmox-backup"}, + }) + + checker.checkSuspiciousProcesses(context.Background()) + + if !containsIssue(checker.result, "Suspicious process detected") { + t.Fatalf("expected non-zombie Proxmox Backup process match warning, issues=%+v", checker.result.Issues) + } +} + +func writeFakePS(t *testing.T, output string) { + t.Helper() + dir := t.TempDir() + scriptPath := filepath.Join(dir, "ps") + script := fmt.Sprintf("#!/bin/sh\nprintf '%%b' %q\n", output) + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("write fake ps: %v", err) + } + t.Setenv("PATH", dir) +} + // TestRunSecurityChecks tests the main Run function func TestRunSecurityChecks(t *testing.T) { tmpDir := t.TempDir() @@ -1262,7 +1299,7 @@ func TestEnsureOwnershipAndPermFromFDAutoFix(t *testing.T) { if err != nil { t.Fatal(err) } - defer f.Close() + defer func() { _ = f.Close() }() info, err := f.Stat() if err != nil { @@ -2148,7 +2185,7 @@ func TestRunWithMissingTarDependency(t *testing.T) { result, err := Run(context.Background(), logger, cfg, configPath, execPath, envInfo) if err != nil { - // Error is expected if tar is not found + t.Fatalf("Run() unexpected error: %v", err) } if result == nil { @@ -2172,7 +2209,7 @@ func TestDetectPrivateAgeKeysWithUnreadableFile(t *testing.T) { if err := os.WriteFile(unreadable, []byte("AGE-SECRET-KEY-TEST"), 0000); err != nil { t.Fatal(err) } - defer os.Chmod(unreadable, 0644) // Cleanup + defer func() { _ = os.Chmod(unreadable, 0644) }() checker := &Checker{ logger: newSecurityTestLogger(), @@ -2246,8 +2283,7 @@ func TestVerifyDirectoriesWithExistingDir(t *testing.T) { } } if !hasPermWarning { - // Permission or ownership warning depends on running context - // This is acceptable + t.Log("permission or ownership warning depends on the running context") } } @@ -2440,7 +2476,7 @@ func TestRunWithPBSEnvironment(t *testing.T) { result, err := Run(context.Background(), logger, cfg, configPath, execPath, envInfo) if err != nil { - // May get error if dependencies are missing + t.Fatalf("Run() unexpected error: %v", err) } if result == nil { @@ -2549,7 +2585,7 @@ func TestVerifyBinaryIntegrityCreateHashErrorReadOnly(t *testing.T) { if err := os.Chmod(tmpDir, 0555); err != nil { t.Fatal(err) } - defer os.Chmod(tmpDir, 0755) // Cleanup + defer func() { _ = os.Chmod(tmpDir, 0755) }() checker := &Checker{ logger: newSecurityTestLogger(), @@ -2589,7 +2625,7 @@ func TestVerifyBinaryIntegrityUpdateHashError(t *testing.T) { if err := os.Chmod(hashPath, 0444); err != nil { t.Fatal(err) } - defer os.Chmod(hashPath, 0644) // Cleanup + defer func() { _ = os.Chmod(hashPath, 0644) }() checker := &Checker{ logger: newSecurityTestLogger(), diff --git a/internal/storage/cloud.go b/internal/storage/cloud.go index b367ff50..5bf7c7f8 100644 --- a/internal/storage/cloud.go +++ b/internal/storage/cloud.go @@ -1241,7 +1241,7 @@ func (c *CloudStorage) isBackupEntry(filename string, snapshot map[string]struct // Only include backup files (legacy `proxmox-backup-*` or Go `*-backup-*`) isNewName := strings.Contains(filename, "-backup-") isLegacy := strings.HasPrefix(filename, "proxmox-backup-") - if !(isLegacy || isNewName) { + if !isLegacy && !isNewName { return false } diff --git a/internal/storage/filesystem.go b/internal/storage/filesystem.go index 228e6650..9fce63dd 100644 --- a/internal/storage/filesystem.go +++ b/internal/storage/filesystem.go @@ -202,8 +202,11 @@ func (d *FilesystemDetector) testOwnershipSupport(ctx context.Context, path stri d.logger.Debug("Cannot create test file for ownership check: %v", err) return false } - f.Close() - defer os.Remove(testFile) + if err := f.Close(); err != nil { + d.logger.Debug("Cannot close test file for ownership check: %v", err) + return false + } + defer func() { _ = os.Remove(testFile) }() // Try to change ownership to current user (should be safe) uid := os.Getuid() diff --git a/internal/storage/local.go b/internal/storage/local.go index b9cc03f9..8b79c3bf 100644 --- a/internal/storage/local.go +++ b/internal/storage/local.go @@ -287,7 +287,7 @@ func (l *LocalStorage) loadMetadataFromBundle(bundlePath string) (*types.BackupM l.logger.Debug("Local storage: failed to open bundle %s: %v", bundlePath, err) return nil, err } - defer file.Close() + defer func() { _ = file.Close() }() tr := tar.NewReader(file) expectedName := strings.TrimSuffix(filepath.Base(bundlePath), ".bundle.tar") + ".metadata" diff --git a/internal/storage/secondary.go b/internal/storage/secondary.go index b05143e2..6812de73 100644 --- a/internal/storage/secondary.go +++ b/internal/storage/secondary.go @@ -221,7 +221,7 @@ func (s *SecondaryStorage) countBackups(ctx context.Context) int { } // copyFile copies a file using Go's io.Copy -func (s *SecondaryStorage) copyFile(ctx context.Context, src, dest string) error { +func (s *SecondaryStorage) copyFile(ctx context.Context, src, dest string) (err error) { if err := ctx.Err(); err != nil { return err } @@ -242,9 +242,15 @@ func (s *SecondaryStorage) copyFile(ctx context.Context, src, dest string) error } tempName := tempFile.Name() defer func() { - tempFile.Close() + if tempFile != nil { + if closeErr := tempFile.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("failed to close temporary file %s: %w", tempName, closeErr) + } + } if tempName != "" { - os.Remove(tempName) + if removeErr := os.Remove(tempName); removeErr != nil && err == nil && !os.IsNotExist(removeErr) { + err = fmt.Errorf("failed to remove temporary file %s: %w", tempName, removeErr) + } } }() @@ -253,7 +259,11 @@ func (s *SecondaryStorage) copyFile(ctx context.Context, src, dest string) error if err != nil { return fmt.Errorf("failed to open source file %s: %w", src, err) } - defer sourceFile.Close() + defer func() { + if closeErr := sourceFile.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("failed to close source file %s: %w", src, closeErr) + } + }() buf := make([]byte, 1024*1024) // 1MB buffer var written int64 @@ -282,10 +292,11 @@ func (s *SecondaryStorage) copyFile(ctx context.Context, src, dest string) error if err := tempFile.Sync(); err != nil { return fmt.Errorf("failed to sync temporary file %s: %w", tempName, err) } - if err := tempFile.Close(); err != nil { - return fmt.Errorf("failed to close temporary file %s: %w", tempName, err) - } + closeErr := tempFile.Close() tempFile = nil + if closeErr != nil { + return fmt.Errorf("failed to close temporary file %s: %w", tempName, closeErr) + } if err := os.Chmod(tempName, sourceInfo.Mode()); err != nil { s.logger.Debug("Secondary storage: unable to mirror permissions on %s: %v", tempName, err) diff --git a/internal/tui/abort_context_test.go b/internal/tui/abort_context_test.go index 2654353f..e04aae33 100644 --- a/internal/tui/abort_context_test.go +++ b/internal/tui/abort_context_test.go @@ -31,8 +31,12 @@ func newSimulationApp(t *testing.T) (*App, tcell.SimulationScreen, <-chan struct return app, screen, started } +func clearAbortContextForTest() { + SetAbortContext(nil) //nolint:staticcheck // Verifies nil clears the process-wide abort context. +} + func TestSetAbortContext_GetAbortContextRoundTrip(t *testing.T) { - SetAbortContext(nil) + clearAbortContextForTest() if got := getAbortContext(); got != nil { t.Fatalf("expected nil abort context, got %v", got) } @@ -44,7 +48,7 @@ func TestSetAbortContext_GetAbortContextRoundTrip(t *testing.T) { t.Fatalf("expected stored context to match") } - SetAbortContext(nil) + clearAbortContextForTest() if got := getAbortContext(); got != nil { t.Fatalf("expected abort context to be cleared, got %v", got) } @@ -53,7 +57,7 @@ func TestSetAbortContext_GetAbortContextRoundTrip(t *testing.T) { func TestBindAbortContext_StopsAppOnCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) SetAbortContext(ctx) - t.Cleanup(func() { SetAbortContext(nil) }) + t.Cleanup(clearAbortContextForTest) stopped := make(chan struct{}) app := &App{ @@ -71,7 +75,7 @@ func TestBindAbortContext_StopsAppOnCancel(t *testing.T) { } func TestBindAbortContext_NoContextNoop(t *testing.T) { - SetAbortContext(nil) + clearAbortContextForTest() stopped := make(chan struct{}) app := &App{ @@ -91,7 +95,7 @@ func TestNewApp_SetsThemeAndReturnsApplication(t *testing.T) { oldTheme := tview.Styles t.Cleanup(func() { tview.Styles = oldTheme }) - SetAbortContext(nil) + clearAbortContextForTest() app := NewApp() if app == nil || app.Application == nil { @@ -138,7 +142,7 @@ func TestAppRunWithContext_NilContextRunsUntilStopped(t *testing.T) { done := make(chan error, 1) go func() { - done <- app.RunWithContext(nil) + done <- app.RunWithContext(nil) //nolint:staticcheck // Verifies nil context runs until the app stops. }() select { diff --git a/internal/tui/components/form.go b/internal/tui/components/form.go index aad3045d..395fd4c1 100644 --- a/internal/tui/components/form.go +++ b/internal/tui/components/form.go @@ -48,7 +48,7 @@ func NewForm(app *tui.App) *Form { // AddInputFieldWithValidation adds an input field with validation func (f *Form) AddInputFieldWithValidation(label, value string, fieldWidth int, validators ...ValidatorFunc) *Form { f.validators[label] = validators - f.Form.AddInputField(label, value, fieldWidth, nil, nil) + f.AddInputField(label, value, fieldWidth, nil, nil) return f } @@ -79,7 +79,7 @@ func (f *Form) SetParentView(parent tview.Primitive) *Form { // AddSubmitButton adds a styled submit button func (f *Form) AddSubmitButton(label string) *Form { - f.Form.AddButton(label, func() { + f.AddButton(label, func() { if f.onSubmit != nil { values := f.GetFormValues() if err := f.ValidateAll(values); err != nil { @@ -106,7 +106,7 @@ func (f *Form) AddSubmitButton(label string) *Form { // AddCancelButton adds a styled cancel button func (f *Form) AddCancelButton(label string) *Form { - f.Form.AddButton(label, func() { + f.AddButton(label, func() { if f.onCancel != nil { f.onCancel() } @@ -118,8 +118,8 @@ func (f *Form) AddCancelButton(label string) *Form { // GetFormValues extracts all form values func (f *Form) GetFormValues() map[string]string { values := make(map[string]string) - for i := 0; i < f.Form.GetFormItemCount(); i++ { - item := f.Form.GetFormItem(i) + for i := 0; i < f.GetFormItemCount(); i++ { + item := f.GetFormItem(i) if inputField, ok := item.(*tview.InputField); ok { label := inputField.GetLabel() value := inputField.GetText() diff --git a/internal/tui/components/form_test.go b/internal/tui/components/form_test.go index fb50bea0..ce459b40 100644 --- a/internal/tui/components/form_test.go +++ b/internal/tui/components/form_test.go @@ -35,14 +35,14 @@ func TestGetFormValuesCollectsWidgets(t *testing.T) { form := NewForm(tui.NewApp()) form.AddInputFieldWithValidation("Input", "", 10) - form.Form.AddCheckbox("Check", true, nil) - form.Form.AddDropDown("Drop", []string{"a", "b"}, 1, nil) + form.AddCheckbox("Check", true, nil) + form.AddDropDown("Drop", []string{"a", "b"}, 1, nil) // Set values - if input, ok := form.Form.GetFormItem(0).(*tview.InputField); ok { + if input, ok := form.GetFormItem(0).(*tview.InputField); ok { input.SetText("value") } - if dd, ok := form.Form.GetFormItem(2).(*tview.DropDown); ok { + if dd, ok := form.GetFormItem(2).(*tview.DropDown); ok { dd.SetCurrentOption(1) } @@ -71,8 +71,8 @@ func TestAddPasswordFieldRegistersValidators(t *testing.T) { if _, ok := form.validators["Password"]; !ok { t.Fatalf("expected validators to be registered for Password") } - if form.Form.GetFormItemCount() != 1 { - t.Fatalf("form item count=%d; want 1", form.Form.GetFormItemCount()) + if form.GetFormItemCount() != 1 { + t.Fatalf("form item count=%d; want 1", form.GetFormItemCount()) } if got := form.Form.GetFormItem(0).(*tview.InputField).GetLabel(); got != "Password" { t.Fatalf("label=%q; want %q", got, "Password") @@ -91,7 +91,7 @@ func TestAddSubmitButtonShowsValidationError(t *testing.T) { form.SetOnSubmit(func(values map[string]string) error { return nil }) form.AddSubmitButton("Continue") - btn := form.Form.GetButton(form.Form.GetButtonCount() - 1) + btn := form.GetButton(form.GetButtonCount() - 1) btn.InputHandler()(tcell.NewEventKey(tcell.KeyEnter, 0, tcell.ModNone), nil) }) @@ -110,7 +110,7 @@ func TestAddSubmitButtonShowsSubmitError(t *testing.T) { form.SetOnSubmit(func(values map[string]string) error { return errors.New("boom") }) form.AddSubmitButton("Continue") - btn := form.Form.GetButton(form.Form.GetButtonCount() - 1) + btn := form.GetButton(form.GetButtonCount() - 1) btn.InputHandler()(tcell.NewEventKey(tcell.KeyEnter, 0, tcell.ModNone), nil) }) @@ -133,7 +133,7 @@ func TestAddSubmitButtonUsesInlineErrorWhenParentViewSet(t *testing.T) { form.SetOnSubmit(func(values map[string]string) error { return nil }) form.AddSubmitButton("Continue") - btn := form.Form.GetButton(form.Form.GetButtonCount() - 1) + btn := form.GetButton(form.GetButtonCount() - 1) btn.InputHandler()(tcell.NewEventKey(tcell.KeyEnter, 0, tcell.ModNone), nil) }) @@ -149,7 +149,7 @@ func TestAddCancelButtonCallsHandler(t *testing.T) { form.SetOnCancel(func() { called = true }) form.AddCancelButton("Cancel") - btn := form.Form.GetButton(form.Form.GetButtonCount() - 1) + btn := form.GetButton(form.GetButtonCount() - 1) btn.InputHandler()(tcell.NewEventKey(tcell.KeyEnter, 0, tcell.ModNone), nil) if !called { @@ -160,7 +160,7 @@ func TestAddCancelButtonCallsHandler(t *testing.T) { func TestSetBorderWithTitleSetsTitle(t *testing.T) { form := NewForm(tui.NewApp()) form.SetBorderWithTitle("Wizard") - if form.Form.GetTitle() != " Wizard " { - t.Fatalf("title=%q; want %q", form.Form.GetTitle(), " Wizard ") + if form.GetTitle() != " Wizard " { + t.Fatalf("title=%q; want %q", form.GetTitle(), " Wizard ") } } diff --git a/internal/tui/components/list_form_item.go b/internal/tui/components/list_form_item.go index 660b56c6..ebb7d1d1 100644 --- a/internal/tui/components/list_form_item.go +++ b/internal/tui/components/list_form_item.go @@ -33,7 +33,7 @@ func NewListFormItem(list *tview.List) *ListFormItem { focusedSelectedBg: tui.ProxmoxOrange, blurredSelectedBg: tcell.ColorDarkSlateGray, } - item.List.SetInputCapture(item.inputCapture) + item.SetInputCapture(item.inputCapture) return item } @@ -67,7 +67,7 @@ func (i *ListFormItem) GetLabel() string { func (i *ListFormItem) SetFormAttributes(labelWidth int, labelColor, bgColor, fieldTextColor, fieldBgColor tcell.Color) tview.FormItem { i.bgColor = bgColor i.textColor = fieldTextColor - i.List. + i. SetMainTextColor(fieldTextColor). SetSecondaryTextColor(fieldTextColor). SetBackgroundColor(bgColor) @@ -111,13 +111,13 @@ func (i *ListFormItem) inputCapture(event *tcell.EventKey) *tcell.EventKey { } return nil case tcell.KeyUp: - if i.finished != nil && i.List.GetItemCount() > 0 && i.List.GetCurrentItem() == 0 { + if i.finished != nil && i.GetItemCount() > 0 && i.GetCurrentItem() == 0 { i.finished(tcell.KeyBacktab) return nil } case tcell.KeyDown: - count := i.List.GetItemCount() - if i.finished != nil && count > 0 && i.List.GetCurrentItem() == count-1 { + count := i.GetItemCount() + if i.finished != nil && count > 0 && i.GetCurrentItem() == count-1 { i.finished(tcell.KeyTab) return nil } @@ -129,12 +129,12 @@ func (i *ListFormItem) inputCapture(event *tcell.EventKey) *tcell.EventKey { // Focus is called when this primitive receives focus. func (i *ListFormItem) Focus(delegate func(p tview.Primitive)) { i.hasFocus = true - i.List.SetSelectedBackgroundColor(i.focusedSelectedBg) + i.SetSelectedBackgroundColor(i.focusedSelectedBg) col := i.textColor if col == 0 { col = tcell.ColorWhite } - i.List.SetSelectedTextColor(col) + i.SetSelectedTextColor(col) i.List.Focus(delegate) } @@ -145,11 +145,11 @@ func (i *ListFormItem) Blur() { if bg == 0 { bg = i.blurredSelectedBg } - i.List.SetSelectedBackgroundColor(bg) + i.SetSelectedBackgroundColor(bg) col := i.textColor if col == 0 { col = tcell.ColorWhite } - i.List.SetSelectedTextColor(col) + i.SetSelectedTextColor(col) i.List.Blur() } diff --git a/internal/tui/components/panel.go b/internal/tui/components/panel.go index adf3d52e..ece09591 100644 --- a/internal/tui/components/panel.go +++ b/internal/tui/components/panel.go @@ -33,7 +33,7 @@ func (p *Panel) SetTitle(title string) *Panel { func (p *Panel) SetStatus(status string) *Panel { symbol := tui.StatusSymbol(status) - title := p.Box.GetTitle() + title := p.GetTitle() p.Box.SetTitle(title + " " + symbol) return p } @@ -41,30 +41,30 @@ func (p *Panel) SetStatus(status string) *Panel { // InfoPanel creates a styled info panel func InfoPanel(title, message string) *Panel { panel := NewPanel().SetTitle(title) - panel.Box.SetBackgroundColor(tui.ProxmoxDark) + panel.SetBackgroundColor(tui.ProxmoxDark) return panel } // SuccessPanel creates a success-styled panel func SuccessPanel(title, message string) *Panel { panel := NewPanel().SetTitle(title) - panel.Box.SetBorderColor(tui.SuccessGreen). - SetTitleColor(tui.SuccessGreen) + panel.SetBorderColor(tui.SuccessGreen) + panel.SetTitleColor(tui.SuccessGreen) return panel } // ErrorPanel creates an error-styled panel func ErrorPanel(title, message string) *Panel { panel := NewPanel().SetTitle(title) - panel.Box.SetBorderColor(tui.ErrorRed). - SetTitleColor(tui.ErrorRed) + panel.SetBorderColor(tui.ErrorRed) + panel.SetTitleColor(tui.ErrorRed) return panel } // WarningPanel creates a warning-styled panel func WarningPanel(title, message string) *Panel { panel := NewPanel().SetTitle(title) - panel.Box.SetBorderColor(tui.WarningYellow). - SetTitleColor(tui.WarningYellow) + panel.SetBorderColor(tui.WarningYellow) + panel.SetTitleColor(tui.WarningYellow) return panel } diff --git a/internal/tui/wizard/age.go b/internal/tui/wizard/age.go index 356415fb..e43ec78b 100644 --- a/internal/tui/wizard/age.go +++ b/internal/tui/wizard/age.go @@ -244,14 +244,14 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig return event }) - form.Form.AddFormItem(setupTypeDropdown) + form.AddFormItem(setupTypeDropdown) // Public key field (for "existing" type) publicKeyField = tview.NewInputField(). SetLabel(" └─ AGE/SSH Recipient"). SetText(""). SetFieldWidth(70) - form.Form.AddFormItem(publicKeyField) + form.AddFormItem(publicKeyField) // Passphrase fields (for "passphrase" type) passphraseField = tview.NewInputField(). @@ -260,7 +260,7 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig SetFieldWidth(50). SetMaskCharacter('*') passphraseField.SetDisabled(true) - form.Form.AddFormItem(passphraseField) + form.AddFormItem(passphraseField) passphraseConfirmField = tview.NewInputField(). SetLabel(" └─ Confirm Passphrase"). @@ -268,7 +268,7 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig SetFieldWidth(50). SetMaskCharacter('*') passphraseConfirmField.SetDisabled(true) - form.Form.AddFormItem(passphraseConfirmField) + form.AddFormItem(passphraseConfirmField) // Private key field (for "privatekey" type) privateKeyField = tview.NewInputField(). @@ -277,7 +277,7 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig SetFieldWidth(70). SetMaskCharacter('*') privateKeyField.SetDisabled(true) - form.Form.AddFormItem(privateKeyField) + form.AddFormItem(privateKeyField) // Initialize with "existing" type selected setupType = ageSetupTypeExisting @@ -323,17 +323,17 @@ func RunAgeSetupWizard(ctx context.Context, recipientPath, configPath, buildSig // Style the form form.SetBorderWithTitle("AGE Encryption Setup") - form.Form.SetBackgroundColor(tcell.ColorBlack) + form.SetBackgroundColor(tcell.ColorBlack) // Add arrow key support for navigation - form.Form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { // If a dropdown is open, don't intercept arrow keys - let them work naturally if dropdownOpen { return event } // Check if focus is on a button (not on a form field) - formItemIndex, buttonIndex := form.Form.GetFocusedItemIndex() + formItemIndex, buttonIndex := form.GetFocusedItemIndex() isOnButton := (formItemIndex < 0 && buttonIndex >= 0) isOnFormField := (formItemIndex >= 0) diff --git a/internal/tui/wizard/install.go b/internal/tui/wizard/install.go index 19c661de..33dc9ae2 100644 --- a/internal/tui/wizard/install.go +++ b/internal/tui/wizard/install.go @@ -123,14 +123,14 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu return event }) - form.Form.AddFormItem(secondaryDropdown) + form.AddFormItem(secondaryDropdown) secondaryHint := tview.NewInputField(). SetLabel(" tip: SECONDARY_PATH needs a mounted path; for 192.168.0.10/folder use an rclone remote"). SetFieldWidth(0). SetText("") secondaryHint.SetDisabled(true) - form.Form.AddFormItem(secondaryHint) + form.AddFormItem(secondaryHint) secondaryPathField = tview.NewInputField(). SetLabel(" └─ Secondary Backup Path"). @@ -140,7 +140,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu secondaryPathField.SetText(prefill.SecondaryPath) } secondaryPathField.SetDisabled(!secondaryEnabled) - form.Form.AddFormItem(secondaryPathField) + form.AddFormItem(secondaryPathField) secondaryLogField = tview.NewInputField(). SetLabel(" └─ Secondary Log Path"). @@ -150,7 +150,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu secondaryLogField.SetText(prefill.SecondaryLogPath) } secondaryLogField.SetDisabled(!secondaryEnabled) - form.Form.AddFormItem(secondaryLogField) + form.AddFormItem(secondaryLogField) // Cloud Storage section cloudEnabled := prefill.CloudEnabled @@ -179,14 +179,14 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu return event }) - form.Form.AddFormItem(cloudDropdown) + form.AddFormItem(cloudDropdown) cloudHint := tview.NewInputField(). SetLabel(" Tip: remote name (via 'rclone config'), e.g. myremote (or myremote:path)"). SetFieldWidth(0). SetText("") cloudHint.SetDisabled(true) - form.Form.AddFormItem(cloudHint) + form.AddFormItem(cloudHint) rcloneBackupField = tview.NewInputField(). SetLabel(" └─ Rclone Backup Remote"). @@ -196,7 +196,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu rcloneBackupField.SetText(prefill.CloudRemote) } rcloneBackupField.SetDisabled(!cloudEnabled) - form.Form.AddFormItem(rcloneBackupField) + form.AddFormItem(rcloneBackupField) rcloneLogField = tview.NewInputField(). SetLabel(" └─ Rclone Log Path"). @@ -206,7 +206,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu rcloneLogField.SetText(prefill.CloudLogPath) } rcloneLogField.SetDisabled(!cloudEnabled) - form.Form.AddFormItem(rcloneLogField) + form.AddFormItem(rcloneLogField) // Firewall rules backup (system collection) firewallEnabled := prefill.FirewallEnabled @@ -227,7 +227,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu return event }) - form.Form.AddFormItem(firewallDropdown) + form.AddFormItem(firewallDropdown) // Notifications (header + two toggles) telegramEnabled := prefill.TelegramEnabled @@ -237,7 +237,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu SetFieldWidth(0). SetText(""). SetDisabled(true) - form.Form.AddFormItem(notificationHeader) + form.AddFormItem(notificationHeader) telegramDropdown := tview.NewDropDown(). SetLabel(" └─ Enable Telegram notifications"). @@ -254,7 +254,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu } return event }) - form.Form.AddFormItem(telegramDropdown) + form.AddFormItem(telegramDropdown) emailDropdown := tview.NewDropDown(). SetLabel(" └─ Enable Email notifications"). @@ -271,7 +271,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu } return event }) - form.Form.AddFormItem(emailDropdown) + form.AddFormItem(emailDropdown) // Encryption encryptionDropdown := tview.NewDropDown(). @@ -290,7 +290,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu return event }) - form.Form.AddFormItem(encryptionDropdown) + form.AddFormItem(encryptionDropdown) // Separator before scheduling cronSeparator := tview.NewInputField(). @@ -298,7 +298,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu SetFieldWidth(0). SetText(""). SetDisabled(true) - form.Form.AddFormItem(cronSeparator) + form.AddFormItem(cronSeparator) // Cron schedule (after encryption) cronField := tview.NewInputField(). @@ -306,7 +306,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu SetText(""). SetPlaceholder(data.CronTime). SetFieldWidth(7) - form.Form.AddFormItem(cronField) + form.AddFormItem(cronField) // Set up form submission form.SetOnSubmit(func(values map[string]string) error { @@ -371,17 +371,17 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu // Style the form form.SetBorderWithTitle("ProxSave Installation") - form.Form.SetBackgroundColor(tcell.ColorBlack) + form.SetBackgroundColor(tcell.ColorBlack) // Add arrow key support for navigation - form.Form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { // If a dropdown is open, don't intercept arrow keys - let them work naturally if dropdownOpen { return event } // Check if focus is on a button (not on a form field) - formItemIndex, buttonIndex := form.Form.GetFocusedItemIndex() + formItemIndex, buttonIndex := form.GetFocusedItemIndex() isOnButton := (formItemIndex < 0 && buttonIndex >= 0) isOnFormField := (formItemIndex >= 0) diff --git a/internal/tui/wizard/post_install_audit_tui.go b/internal/tui/wizard/post_install_audit_tui.go index e81105fa..534bd76e 100644 --- a/internal/tui/wizard/post_install_audit_tui.go +++ b/internal/tui/wizard/post_install_audit_tui.go @@ -195,7 +195,7 @@ func showAuditReview(app *tui.App, pages *tview.Pages, configPath string, sugges b.WriteString("\n") } b.WriteString("\n") - b.WriteString(fmt.Sprintf("If you don’t use this feature, set [yellow]%s=false[white] to disable.\n", tview.Escape(s.Key))) + fmt.Fprintf(&b, "If you don’t use this feature, set [yellow]%s=false[white] to disable.\n", tview.Escape(s.Key)) details.SetText(b.String()) }