Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions packages/task-graph/src/checkpoint/InMemoryCheckpointSaver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,27 @@ import type { CheckpointData, CheckpointId, ThreadId } from "./CheckpointTypes";
export class InMemoryCheckpointSaver extends CheckpointSaver {
private checkpoints: Map<CheckpointId, CheckpointData> = new Map();
private threadIndex: Map<ThreadId, CheckpointId[]> = new Map();
private readonly maxCheckpointsPerThread: number;

constructor(maxCheckpointsPerThread: number = 1000) {
super();
this.maxCheckpointsPerThread = maxCheckpointsPerThread;
}

async saveCheckpoint(data: CheckpointData): Promise<void> {
this.checkpoints.set(data.checkpointId, data);

const threadCheckpoints = this.threadIndex.get(data.threadId) ?? [];
threadCheckpoints.push(data.checkpointId);

if (threadCheckpoints.length > this.maxCheckpointsPerThread) {
const excess = threadCheckpoints.length - this.maxCheckpointsPerThread;
const removedIds = threadCheckpoints.splice(0, excess);
for (const id of removedIds) {
this.checkpoints.delete(id);
}
}

this.threadIndex.set(data.threadId, threadCheckpoints);
}

Expand Down
5 changes: 4 additions & 1 deletion packages/task-graph/src/checkpoint/TabularCheckpointSaver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ export class TabularCheckpointSaver extends CheckpointSaver {
return {
checkpointId: row.checkpoint_id as string,
threadId: row.thread_id as string,
parentCheckpointId: (row.parent_checkpoint_id as string) || undefined,
parentCheckpointId:
(row.parent_checkpoint_id as string) === ""
? undefined
: (row.parent_checkpoint_id as string),
graphJson,
taskStates,
dataflowStates,
Expand Down
58 changes: 48 additions & 10 deletions packages/task-graph/src/task-graph/TaskGraphRunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,23 @@ export class TaskGraphRunner {
this.pushStatusFromNodeToEdges(this.graph, task);
this.pushErrorFromNodeToEdges(this.graph, task);

// Capture checkpoint after task completion
// Capture checkpoint after successful task completion
if (
this.checkpointSaver &&
this.checkpointGranularity === "every-task" &&
(task.status === TaskStatus.COMPLETED || task.status === TaskStatus.FAILED)
task.status === TaskStatus.COMPLETED
) {
await this.captureCheckpoint(task.config.id);
try {
await this.captureCheckpoint(task.config.id);
} catch (checkpointError) {
// Do not interrupt task completion tracking if checkpoint capture fails
// eslint-disable-next-line no-console
console.error(
"Failed to capture checkpoint for task",
task.config.id,
checkpointError
);
}
}

this.processScheduler.onTaskCompleted(task.config.id);
Expand Down Expand Up @@ -208,7 +218,14 @@ export class TaskGraphRunner {

// Capture a final checkpoint for top-level-only granularity
if (this.checkpointSaver && this.checkpointGranularity === "top-level-only") {
await this.captureCheckpoint();
try {
await this.captureCheckpoint();
} catch (checkpointError) {
// Log checkpoint errors without failing the entire graph execution
// so that handleComplete still runs and the graph can finalize cleanly.
// eslint-disable-next-line no-console
console.error("Failed to capture final checkpoint:", checkpointError);
}
}

await this.handleComplete();
Expand Down Expand Up @@ -751,7 +768,7 @@ export class TaskGraphRunner {
inputData: { ...task.runInputData },
outputData: { ...task.runOutputData },
progress: task.progress,
error: task.error?.message,
error: task.error ? `${task.error.name}: ${task.error.message}` : undefined,
startedAt: task.startedAt?.toISOString(),
completedAt: task.completedAt?.toISOString(),
}));
Expand All @@ -761,7 +778,7 @@ export class TaskGraphRunner {
sourceTaskId: df.sourceTaskId,
targetTaskId: df.targetTaskId,
status: df.status,
portData: df.value !== undefined ? { _value: df.value } : undefined,
portData: df.value !== undefined ? (df.value as TaskOutput) : undefined,
}));

const checkpointId = uuid4();
Expand Down Expand Up @@ -819,19 +836,40 @@ export class TaskGraphRunner {

task.emit("status", task.status);
this.processScheduler.onTaskCompleted(task.config.id);
} else if (taskState.status === TaskStatus.FAILED) {
// Ensure FAILED tasks are fully reset for re-execution.
// resetGraph() already set them to PENDING, but restore their input data
// so the task can be retried with the same data.
task.runInputData = taskState.inputData ?? {};
}
// Leave PENDING/FAILED tasks in PENDING state so they get re-run
// Leave PENDING tasks in PENDING state so they get re-run
}

// Restore dataflow states
for (const dfState of checkpointData.dataflowStates) {
if (typeof dfState.id !== "string") {
// eslint-disable-next-line no-console
console.warn(
"TaskGraphRunner.restoreFromCheckpoint: Skipping dataflow with non-string id",
{ id: dfState.id }
);
continue;
}

const df = this.graph.getDataflow(dfState.id as any);
if (!df) continue;
if (!df) {
// eslint-disable-next-line no-console
console.warn(
"TaskGraphRunner.restoreFromCheckpoint: Dataflow not found for id",
{ id: dfState.id }
);
continue;
}

if (dfState.status === TaskStatus.COMPLETED || dfState.status === TaskStatus.DISABLED) {
df.setStatus(dfState.status);
if (dfState.portData?._value !== undefined) {
df.value = dfState.portData._value;
if (dfState.portData !== undefined) {
df.value = dfState.portData as any;
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions packages/task-graph/src/task-graph/TaskGraphScheduler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,12 @@ export class DependencyBasedScheduler implements ITaskGraphScheduler {
onTaskCompleted(taskId: unknown): void {
this.completedTasks.add(taskId);

// Remove any disabled tasks from pending
// Remove the completed task and any disabled tasks from pending.
// This handles both normal completion (task was already removed when picked up,
// so this is a no-op) and checkpoint-restore completion (task is still pending
// and must be removed so it isn't re-scheduled).
for (const task of Array.from(this.pendingTasks)) {
if (task.status === TaskStatus.DISABLED) {
if (task.config.id === taskId || task.status === TaskStatus.DISABLED) {
this.pendingTasks.delete(task);
}
}
Expand Down
21 changes: 16 additions & 5 deletions packages/task-graph/src/task/IteratorTaskRunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,23 @@ export class IteratorTaskRunner<
this.task.compoundMerge
) as TaskOutput;

// Capture iteration checkpoint if checkpoint saver is available
// Capture iteration checkpoint if checkpoint saver is available.
// This is best-effort: failures here should not break iteration processing.
if (this.checkpointSaver && this.threadId && iterationIndex !== undefined) {
await this.task.subGraph.runner.captureCheckpoint(this.task.config.id, {
iterationIndex,
iterationParentTaskId: this.task.config.id,
});
try {
await this.task.subGraph.runner.captureCheckpoint(this.task.config.id, {
iterationIndex,
iterationParentTaskId: this.task.config.id,
});
} catch (error) {
// Checkpointing is best-effort; log the error but do not interrupt iteration execution.
// eslint-disable-next-line no-console
console.error("Failed to capture iterator-task iteration checkpoint", {
taskId: this.task.config.id,
iterationIndex,
error,
});
}
}

return merged;
Expand Down
20 changes: 16 additions & 4 deletions packages/task-graph/src/task/WhileTask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ export class WhileTask<
parentSignal: context.signal,
checkpointSaver: context.checkpointSaver,
threadId: context.threadId,
// Disable subgraph top-level checkpoints; iteration checkpoints are handled separately.
checkpointGranularity: "none",
});

// Merge results
Expand All @@ -388,10 +390,20 @@ export class WhileTask<

// Capture iteration checkpoint if checkpoint saver is available
if (context.checkpointSaver && context.threadId) {
await this.subGraph.runner.captureCheckpoint(this.config.id, {
iterationIndex: this._currentIteration,
iterationParentTaskId: this.config.id,
});
try {
await this.subGraph.runner.captureCheckpoint(this.config.id, {
iterationIndex: this._currentIteration,
iterationParentTaskId: this.config.id,
});
} catch (error) {
// Checkpointing is best-effort; log the error but do not interrupt the loop.
// eslint-disable-next-line no-console
console.error("Failed to capture while-task iteration checkpoint", {
taskId: this.config.id,
iterationIndex: this._currentIteration,
error,
});
}
}

// Check condition
Expand Down
37 changes: 32 additions & 5 deletions packages/test/src/test/task/Checkpoint.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import { CheckpointData, Dataflow, InMemoryCheckpointSaver, TaskGraph } from "@workglow/task-graph";
import { beforeEach, describe, expect, it } from "vitest";
import { FailingTask, NumberTask, TestSimpleTask } from "./TestTasks";
import { FailingTask, NumberTask, TestSimpleTask, TrackingTask } from "./TestTasks";

describe("Checkpoint", () => {
let saver: InMemoryCheckpointSaver;
Expand Down Expand Up @@ -306,10 +306,11 @@ describe("Checkpoint", () => {
const history = await saver.getCheckpointHistory("resume-thread");
expect(history.length).toBeGreaterThanOrEqual(1);

// Now create a new graph with the same structure and resume
// Now create a new graph with the same structure and resume.
// Use TrackingTask to verify that task-1 is actually skipped (not re-executed).
const graph2 = new TaskGraph();
const task1b = new TestSimpleTask({ input: "first" }, { id: "task-1" });
const task2b = new TestSimpleTask({ input: "second" }, { id: "task-2" });
const task1b = new TrackingTask({ input: "first" }, { id: "task-1" });
const task2b = new TrackingTask({ input: "second" }, { id: "task-2" });

graph2.addTask(task1b);
graph2.addTask(task2b);
Expand All @@ -326,7 +327,10 @@ describe("Checkpoint", () => {
}
);

// Should complete successfully
// All tasks were already COMPLETED in the checkpoint, so both should be skipped
expect(task1b.executed).toBe(false);
expect(task2b.executed).toBe(false);
// Should complete successfully with no new leaf results (all skipped)
expect(results.length).toBeGreaterThanOrEqual(0);
});

Expand Down Expand Up @@ -361,6 +365,29 @@ describe("Checkpoint", () => {
cp.taskStates.some((ts) => ts.taskId === "task-1" && ts.status === "COMPLETED")
);
expect(resumeCheckpoint).toBeDefined();

// Now resume from this checkpoint with a non-failing version of task-2
const resumeGraph = new TaskGraph();
const task1Resume = new TrackingTask({ input: 42 }, { id: "task-1" });
const task2Resume = new TrackingTask({}, { id: "task-2" });

resumeGraph.addTask(task1Resume);
resumeGraph.addTask(task2Resume);
resumeGraph.addDataflow(new Dataflow("task-1", "output", "task-2", "in"));

await resumeGraph.run(
{},
{
checkpointSaver: saver,
threadId: "fail-thread-resumed",
resumeFromCheckpoint: resumeCheckpoint!.checkpointId,
}
);

// task-1 was COMPLETED in the checkpoint, so it should be skipped (not re-executed)
expect(task1Resume.executed).toBe(false);
// task-2 was not completed in the checkpoint, so it should have been re-run
expect(task2Resume.executed).toBe(true);
});
});

Expand Down