Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 150 additions & 2 deletions src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ import { launchNotifyScript } from "./notify";
import { buildThinkingRequestOptions } from "./openai-thinking";
import { DEEPSEEK_V4_MODELS } from "./model-capabilities";
import { getCompactPrompt, getSystemPrompt, getTools, AGENT_DRIFT_GUARD_SKILL } from "./prompt";
import { ToolExecutor, type CreateOpenAIClient } from "./tools/executor";
import { ToolExecutor, type CreateOpenAIClient, type ToolCall } from "./tools/executor";
import {
getSafetyApprovalLabels,
recordProjectAllowedApproval,
type SafetyApprovalRequest,
} from "./tools/safety-hooks";
import { logApiError } from "./error-logger";
import { logOpenAIChatCompletionDebug, normalizeDebugError } from "./debug-logger";

Expand Down Expand Up @@ -169,6 +174,11 @@ export type LlmStreamProgress = {
phase: "start" | "update" | "end";
};

type PendingSafetyApproval = {
request: SafetyApprovalRequest;
toolCall: ToolCall;
};

export class SessionManager {
private readonly projectRoot: string;
private readonly createOpenAIClient: CreateOpenAIClient;
Expand All @@ -179,6 +189,7 @@ export class SessionManager {
private activeSessionId: string | null = null;
private activePromptController: AbortController | null = null;
private readonly sessionControllers = new Map<string, AbortController>();
private readonly pendingSafetyApprovals = new Map<string, PendingSafetyApproval>();
private readonly toolExecutor: ToolExecutor;

constructor(options: SessionManagerOptions) {
Expand Down Expand Up @@ -982,6 +993,10 @@ ${skillMd}
return;
}

if (await this.executeApprovedSafetyToolCall(sessionId)) {
continue;
}

const compactPromptTokenThreshold = getCompactPromptTokenThreshold(model);
if (session.activeTokens > compactPromptTokenThreshold) {
const message = this.buildAssistantMessage(
Expand Down Expand Up @@ -1624,6 +1639,12 @@ ${skillMd}
onProcessStart: (pid, command) => this.addSessionProcess(sessionId, pid, command),
onProcessExit: (pid) => this.removeSessionProcess(sessionId, pid),
shouldStop: () => this.isInterrupted(sessionId),
onSafetyApprovalRequested: (request, toolCall) =>
this.pendingSafetyApprovals.set(sessionId, {
request,
toolCall,
}),
consumeSafetyApproval: (request) => this.consumeSafetyApproval(sessionId, request),
});
if (this.isInterrupted(sessionId)) {
return { waitingForUser: false };
Expand Down Expand Up @@ -1655,6 +1676,121 @@ ${skillMd}
return { waitingForUser };
}

private consumeSafetyApproval(sessionId: string, request: SafetyApprovalRequest): "approved" | "denied" | "missing" {
const pending = this.pendingSafetyApprovals.get(sessionId);
if (!pending || !this.safetyApprovalRequestsMatch(pending.request, request)) {
return "missing";
}

const answer = this.findLatestUserAnswerForQuestion(sessionId, pending.request.question);
if (answer === "missing") {
return "missing";
}

this.pendingSafetyApprovals.delete(sessionId);
if (answer === "always_approved") {
recordProjectAllowedApproval(this.projectRoot, pending.request);
return "approved";
}
return answer;
}

private async executeApprovedSafetyToolCall(sessionId: string): Promise<boolean> {
const pending = this.pendingSafetyApprovals.get(sessionId);
if (!pending) {
return false;
}

const answer = this.findLatestUserAnswerForQuestion(sessionId, pending.request.question);
if (answer === "missing") {
return false;
}

const executions = await this.toolExecutor.executeToolCalls(sessionId, [pending.toolCall], {
onProcessStart: (pid, command) => this.addSessionProcess(sessionId, pid, command),
onProcessExit: (pid) => this.removeSessionProcess(sessionId, pid),
shouldStop: () => this.isInterrupted(sessionId),
consumeSafetyApproval: (request) => this.consumeSafetyApproval(sessionId, request),
});

if (this.isInterrupted(sessionId)) {
return true;
}

const followUpMessages: SessionMessage[] = [];
for (const execution of executions) {
const toolFunction = this.findToolFunction([pending.toolCall], execution.toolCallId);
const toolMessage = this.buildToolMessage(sessionId, execution.toolCallId, execution.content, toolFunction);
this.appendSessionMessage(sessionId, toolMessage);
this.onAssistantMessage(toolMessage, true);

for (const followUpMessage of execution.result.followUpMessages ?? []) {
if (followUpMessage.role !== "system") {
continue;
}
followUpMessages.push(
this.buildSystemMessage(sessionId, followUpMessage.content, followUpMessage.contentParams ?? null)
);
}
}

for (const followUpMessage of followUpMessages) {
this.appendSessionMessage(sessionId, followUpMessage);
}

return executions.length > 0;
}

private safetyApprovalRequestsMatch(current: SafetyApprovalRequest, incoming: SafetyApprovalRequest): boolean {
return (
current.id === incoming.id ||
(current.toolName === incoming.toolName &&
current.reason === incoming.reason &&
current.command === incoming.command &&
current.filePath === incoming.filePath)
);
}

private findLatestUserAnswerForQuestion(
sessionId: string,
question: string
): "approved" | "always_approved" | "denied" | "missing" {
const labels = getSafetyApprovalLabels();
const escapedQuestion = this.escapeAskUserQuestionAnswerPart(question);
const messages = this.listSessionMessages(sessionId);
for (let index = messages.length - 1; index >= 0; index -= 1) {
const message = messages[index];
if (message.role !== "user" || typeof message.content !== "string") {
continue;
}

const content = message.content;
if (!content.includes(escapedQuestion)) {
continue;
}

if (content.includes(`="${this.escapeAskUserQuestionAnswerPart(labels.allow)}"`)) {
return "approved";
}

if (content.includes(`="${this.escapeAskUserQuestionAnswerPart(labels.alwaysAllow)}"`)) {
return "always_approved";
}

if (content.includes(`="${this.escapeAskUserQuestionAnswerPart(labels.deny)}"`)) {
return "denied";
}

return "missing";
}

return "missing";
}

private escapeAskUserQuestionAnswerPart(value: string): string {
return value.replace(/\\/g, "\\\\").replace(/"/g, '\\"').replace(/\s+/g, " ").trim();
}

private buildOpenAIMessages(messages: SessionMessage[], thinkingEnabled: boolean): ChatCompletionMessageParam[] {
const activeMessages = messages.filter((message) => !message.compacted);
const toolPairings = this.pairToolMessages(activeMessages);
Expand Down Expand Up @@ -1786,13 +1922,25 @@ ${skillMd}
if (firstMatchingIndex == null) {
firstMatchingIndex = index;
}
if (!this.isInterruptedToolMessage(message)) {
if (!this.isInterruptedToolMessage(message) && !this.isSafetyApprovalToolMessage(message)) {
return index;
}
}
return firstMatchingIndex;
}

private isSafetyApprovalToolMessage(message: SessionMessage): boolean {
if (typeof message.content !== "string" || !message.content.trim()) {
return false;
}
try {
const parsed = JSON.parse(message.content) as { name?: unknown; awaitUserResponse?: unknown };
return parsed.name === "SafetyApproval" && parsed.awaitUserResponse === true;
} catch {
return false;
}
}

private getAssistantToolCalls(message: SessionMessage): unknown[] {
if (message.role !== "assistant") {
return [];
Expand Down
8 changes: 8 additions & 0 deletions src/tools/ask-user-question-handler.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ToolExecutionContext, ToolExecutionResult } from "./executor";
import { evaluateGenericToolSafety, type PermissionContext, type SafetyDecision } from "./safety-hooks";

type AskUserQuestionOption = {
label: string;
Expand All @@ -16,6 +17,13 @@ type AskUserQuestionMetadata = {
questions: AskUserQuestionItem[];
};

export function canExecuteAskUserQuestionTool(
args: Record<string, unknown>,
context: PermissionContext
): SafetyDecision {
return evaluateGenericToolSafety("AskUserQuestion", args, context);
}

export async function handleAskUserQuestionTool(
args: Record<string, unknown>,
_context: ToolExecutionContext
Expand Down
5 changes: 5 additions & 0 deletions src/tools/bash-handler.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { spawn } from "child_process";
import type { ToolExecutionContext, ToolExecutionResult } from "./executor";
import { evaluateBashToolSafety, type PermissionContext, type SafetyDecision } from "./safety-hooks";
import {
buildDisableExtglobCommand,
buildShellEnv,
Expand All @@ -24,6 +25,10 @@ type ToolCommandResult = {
startCwd?: string;
};

export function canExecuteBashTool(args: Record<string, unknown>, context: PermissionContext): SafetyDecision {
return evaluateBashToolSafety(args, context);
}

export async function handleBashTool(
args: Record<string, unknown>,
context: ToolExecutionContext
Expand Down
5 changes: 5 additions & 0 deletions src/tools/edit-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import * as fs from "fs";
import { z } from "zod";
import { buildThinkingRequestOptions } from "../openai-thinking";
import type { ToolExecutionContext, ToolExecutionResult } from "./executor";
import { evaluateEditToolSafety, type PermissionContext, type SafetyDecision } from "./safety-hooks";
import { buildDiffPreview, hasFileChangedSinceState, readTextFileWithMetadata, writeTextFile } from "./file-utils";
import { executeValidatedTool, semanticBoolean } from "./runtime";
import {
Expand Down Expand Up @@ -75,6 +76,10 @@ const editSchema = z.strictObject({
}, z.number().int().min(1, "expected_occurrences must be >= 1.").optional()),
});

export function canExecuteEditTool(args: Record<string, unknown>, context: PermissionContext): SafetyDecision {
return evaluateEditToolSafety(args, context);
}

export async function handleEditTool(
args: Record<string, unknown>,
context: ToolExecutionContext
Expand Down
Loading