diff --git a/TASK_QUEUE.json b/TASK_QUEUE.json new file mode 100644 index 000000000..05f17eca6 --- /dev/null +++ b/TASK_QUEUE.json @@ -0,0 +1,151 @@ +{ + "project": "temporal-spring-ai", + "tasks": [ + { + "id": "T1", + "title": "Add unit tests for type conversion", + "description": "Test ChatModelTypes <-> Spring AI types round-trip in ActivityChatModel and ChatModelActivityImpl. Cover messages (all roles), tool calls, media content, model options, embeddings, vector store types.", + "severity": "high", + "category": "tests", + "depends_on": [], + "status": "completed" + }, + { + "id": "T2", + "title": "Add unit tests for tool detection and conversion", + "description": "Test TemporalToolUtil.convertTools() with activity stubs, local activity stubs, @DeterministicTool, @SideEffectTool, Nexus stubs, and rejection of unknown types. Test TemporalStubUtil detection methods.", + "severity": "high", + "category": "tests", + "depends_on": [], + "status": "completed" + }, + { + "id": "T3", + "title": "Add replay test for determinism", + "description": "Create a workflow that uses ActivityChatModel with tools, run it once to produce history, then replay from that history to verify determinism. Cover activity tools, @DeterministicTool, and @SideEffectTool.", + "severity": "high", + "category": "tests", + "depends_on": [], + "status": "completed" + }, + { + "id": "T4", + "title": "Add unit tests for plugin registration", + "description": "Test SpringAiPlugin.initializeWorker() registers correct activities based on available beans. Test single model, multi-model, with/without VectorStore, with/without EmbeddingModel.", + "severity": "medium", + "category": "tests", + "depends_on": [], + "status": "completed" + }, + { + "id": "T5", + "title": "Fix UUID.randomUUID() in workflow context", + "description": "Replace UUID.randomUUID() with Workflow.randomUUID() in LocalActivityToolCallbackWrapper.call(). One-line fix.", + "severity": "high", + "category": "bugfix", + "depends_on": [ + "T3" + ], + "status": "completed", + "notes": "Do after replay test exists so we can verify the fix." + }, + { + "id": "T6", + "title": "Split SpringAiPlugin for optional deps", + "description": "Refactor so VectorStore, EmbeddingModel, and MCP are handled by separate @ConditionalOnClass auto-configuration classes. Core SpringAiPlugin only references ChatModel. compileOnly scope stays correct.", + "severity": "high", + "category": "refactor", + "depends_on": [ + "T4" + ], + "status": "completed", + "notes": "Do after plugin registration tests exist so we can verify the refactor doesn't break registration. Also resolves T10 (unnecessary MCP reflection)." + }, + { + "id": "T14", + "title": "Fix NPE when ChatResponse metadata is null", + "description": "ActivityChatModel.toResponse() passes null metadata to ChatResponse.builder().metadata(null), which causes an NPE in Spring AI's builder. Fix: skip .metadata() call when metadata is null, or pass an empty ChatResponseMetadata.", + "severity": "high", + "category": "bugfix", + "depends_on": [], + "status": "completed" + }, + { + "id": "T7", + "title": "Add max iteration limit to ActivityChatModel tool loop", + "description": "Add a configurable max iteration count (default ~10) to the recursive call() loop in ActivityChatModel. Throw after limit to prevent infinite recursion from misbehaving models.", + "severity": "medium", + "category": "bugfix", + "depends_on": [ + "T1" + ], + "status": "reverted", + "notes": "Reverted: Spring AI does not limit tool call iterations either. Temporal activity timeouts and workflow execution timeout provide the safety net." + }, + { + "id": "T8", + "title": "Replace fragile stub detection with SDK internals", + "description": "TemporalStubUtil string-matches on internal handler class names. Since the plugin is in the SDK repo, use internal APIs or instanceof checks. Add tests to catch breakage.", + "severity": "medium", + "category": "refactor", + "depends_on": [ + "T2" + ], + "status": "completed", + "notes": "Do after tool detection tests exist so we can verify the refactor." + }, + { + "id": "T9", + "title": "Document static CALLBACK_REGISTRY lifecycle", + "description": "Add javadoc to LocalActivityToolCallbackWrapper explaining the leak risk when workflows are evicted mid-execution. Consider adding a size metric or periodic cleanup.", + "severity": "medium", + "category": "improvement", + "depends_on": [], + "status": "completed" + }, + { + "id": "T10", + "title": "Remove unnecessary MCP reflection", + "description": "SpringAiPlugin uses Class.forName() for McpClientActivityImpl which is in the same module. Will be resolved by T6 (split into conditional configs).", + "severity": "low", + "category": "refactor", + "depends_on": [ + "T6" + ], + "status": "completed", + "notes": "Likely resolved automatically by T6." + }, + { + "id": "T11", + "title": "Add UnsupportedOperationException for stream()", + "description": "Override stream() in ActivityChatModel to throw UnsupportedOperationException with a clear message that streaming is not supported through activities.", + "severity": "low", + "category": "improvement", + "depends_on": [], + "status": "completed" + }, + { + "id": "T12", + "title": "Verify all 5 samples run end-to-end", + "description": "Run chat, MCP, multi-model, RAG, and sandboxing samples interactively against a dev server. Verify tool calling works for each.", + "severity": "medium", + "category": "testing", + "depends_on": [ + "T6" + ], + "status": "completed", + "notes": "All 5 samples boot successfully. MCP requires Node.js/npx for MCP server (environment prereq, not a code issue)." + }, + { + "id": "T13", + "title": "Remove includeBuild from samples-java", + "description": "Once temporal-spring-ai is published to Maven Central, remove the includeBuild('../sdk-java') block from samples-java/settings.gradle and the grpc-util workaround from core/build.gradle.", + "severity": "low", + "category": "cleanup", + "depends_on": [], + "status": "blocked", + "notes": "Blocked on SDK release. Not actionable yet." + } + ], + "execution_order_rationale": "Tests first (T1-T4) in parallel since they're independent. Then fixes that benefit from test coverage: T5 (UUID fix, verified by T3), T6 (plugin split, verified by T4), T7 (loop limit, verified by T1), T8 (stub detection, verified by T2). Then downstream: T10 (resolved by T6), T9/T11 (independent improvements). T12 after T6. T13 blocked on release." +} diff --git a/settings.gradle b/settings.gradle index 918ceaa28..9d3905698 100644 --- a/settings.gradle +++ b/settings.gradle @@ -6,6 +6,7 @@ include 'temporal-testing' include 'temporal-test-server' include 'temporal-opentracing' include 'temporal-kotlin' +include 'temporal-spring-ai' include 'temporal-spring-boot-autoconfigure' include 'temporal-spring-boot-starter' include 'temporal-remote-data-encoder' diff --git a/temporal-bom/build.gradle b/temporal-bom/build.gradle index 8f5a8971d..e73d0d300 100644 --- a/temporal-bom/build.gradle +++ b/temporal-bom/build.gradle @@ -12,6 +12,7 @@ dependencies { api project(':temporal-sdk') api project(':temporal-serviceclient') api project(':temporal-shaded') + api project(':temporal-spring-ai') api project(':temporal-spring-boot-autoconfigure') api project(':temporal-spring-boot-starter') api project(':temporal-test-server') diff --git a/temporal-spring-ai/build.gradle b/temporal-spring-ai/build.gradle new file mode 100644 index 000000000..c8593011b --- /dev/null +++ b/temporal-spring-ai/build.gradle @@ -0,0 +1,56 @@ +description = '''Temporal Java SDK Spring AI Plugin''' + +ext { + springAiVersion = '1.1.0' + // Spring AI requires Spring Boot 3.x / Java 17+ + springBootVersionForSpringAi = "$springBoot3Version" +} + +// Spring AI requires Java 17+, override the default Java 8 target from java.gradle +java { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 +} + +compileJava { + options.compilerArgs.removeAll(['--release', '8']) + options.compilerArgs.addAll(['--release', '17']) +} + +compileTestJava { + options.compilerArgs.removeAll(['--release', '8']) + options.compilerArgs.addAll(['--release', '17']) +} + +dependencies { + api(platform("org.springframework.boot:spring-boot-dependencies:$springBootVersionForSpringAi")) + api(platform("org.springframework.ai:spring-ai-bom:$springAiVersion")) + + // this module shouldn't carry temporal-sdk with it, especially for situations when users may be using a shaded artifact + compileOnly project(':temporal-sdk') + compileOnly project(':temporal-spring-boot-autoconfigure') + + api 'org.springframework.boot:spring-boot-autoconfigure' + api 'org.springframework.ai:spring-ai-client-chat' + + implementation 'org.springframework.boot:spring-boot-starter' + + // Optional: Vector store support + compileOnly 'org.springframework.ai:spring-ai-rag' + + // Optional: MCP (Model Context Protocol) support + compileOnly 'org.springframework.ai:spring-ai-mcp' + + testImplementation project(':temporal-sdk') + testImplementation project(':temporal-testing') + testImplementation "org.mockito:mockito-core:${mockitoVersion}" + testImplementation 'org.springframework.boot:spring-boot-starter-test' + testImplementation 'org.springframework.ai:spring-ai-rag' + + testRuntimeOnly group: 'ch.qos.logback', name: 'logback-classic', version: "${logbackVersion}" + testRuntimeOnly "org.junit.platform:junit-platform-launcher" +} + +tasks.test { + useJUnitPlatform() +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivity.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivity.java new file mode 100644 index 000000000..19caf9a54 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivity.java @@ -0,0 +1,25 @@ +package io.temporal.springai.activity; + +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.springai.model.ChatModelTypes; + +/** + * Temporal activity interface for calling Spring AI chat models. + * + *

This activity wraps a Spring AI {@link org.springframework.ai.chat.model.ChatModel} and makes + * it callable from within Temporal workflows. The activity handles serialization of prompts and + * responses, enabling durable AI conversations with automatic retries and timeout handling. + */ +@ActivityInterface +public interface ChatModelActivity { + + /** + * Calls the chat model with the given input. + * + * @param input the chat model input containing messages, options, and tool definitions + * @return the chat model output containing generated responses and metadata + */ + @ActivityMethod + ChatModelTypes.ChatModelActivityOutput callChatModel(ChatModelTypes.ChatModelActivityInput input); +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivityImpl.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivityImpl.java new file mode 100644 index 000000000..71e5b6e99 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivityImpl.java @@ -0,0 +1,276 @@ +package io.temporal.springai.activity; + +import io.temporal.springai.model.ChatModelTypes; +import io.temporal.springai.model.ChatModelTypes.Message; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.content.Media; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; + +/** + * Implementation of {@link ChatModelActivity} that delegates to a Spring AI {@link ChatModel}. + * + *

This implementation handles the conversion between Temporal-serializable types ({@link + * ChatModelTypes}) and Spring AI types. + * + *

Supports multiple chat models. The model to use is determined by the {@code modelName} field + * in the input. If no model name is specified, the default model is used. + */ +public class ChatModelActivityImpl implements ChatModelActivity { + + private final Map chatModels; + private final String defaultModelName; + + /** + * Creates an activity implementation with a single chat model. + * + * @param chatModel the chat model to use + */ + public ChatModelActivityImpl(ChatModel chatModel) { + this.chatModels = Map.of("default", chatModel); + this.defaultModelName = "default"; + } + + /** + * Creates an activity implementation with multiple chat models. + * + * @param chatModels map of model names to chat models + * @param defaultModelName the name of the default model to use when none is specified + */ + public ChatModelActivityImpl(Map chatModels, String defaultModelName) { + this.chatModels = chatModels; + this.defaultModelName = defaultModelName; + } + + @Override + public ChatModelTypes.ChatModelActivityOutput callChatModel( + ChatModelTypes.ChatModelActivityInput input) { + ChatModel chatModel = resolveChatModel(input.modelName()); + Prompt prompt = createPrompt(input); + ChatResponse response = chatModel.call(prompt); + return toOutput(response); + } + + private ChatModel resolveChatModel(String modelName) { + String name = (modelName != null && !modelName.isEmpty()) ? modelName : defaultModelName; + ChatModel model = chatModels.get(name); + if (model == null) { + throw new IllegalArgumentException( + "No chat model with name '" + name + "'. Available models: " + chatModels.keySet()); + } + return model; + } + + private Prompt createPrompt(ChatModelTypes.ChatModelActivityInput input) { + List messages = + input.messages().stream().map(this::toSpringMessage).collect(Collectors.toList()); + + ToolCallingChatOptions.Builder optionsBuilder = + ToolCallingChatOptions.builder() + .internalToolExecutionEnabled(false); // Let workflow handle tool execution + + if (input.modelOptions() != null) { + ChatModelTypes.ModelOptions opts = input.modelOptions(); + if (opts.model() != null) optionsBuilder.model(opts.model()); + if (opts.temperature() != null) optionsBuilder.temperature(opts.temperature()); + if (opts.maxTokens() != null) optionsBuilder.maxTokens(opts.maxTokens()); + if (opts.topP() != null) optionsBuilder.topP(opts.topP()); + if (opts.topK() != null) optionsBuilder.topK(opts.topK()); + if (opts.frequencyPenalty() != null) optionsBuilder.frequencyPenalty(opts.frequencyPenalty()); + if (opts.presencePenalty() != null) optionsBuilder.presencePenalty(opts.presencePenalty()); + if (opts.stopSequences() != null) optionsBuilder.stopSequences(opts.stopSequences()); + } + + // Add tool callbacks (stubs that provide definitions but won't be executed + // since internalToolExecutionEnabled is false) + if (!CollectionUtils.isEmpty(input.tools())) { + List toolCallbacks = + input.tools().stream() + .map( + tool -> + createStubToolCallback( + tool.function().name(), + tool.function().description(), + tool.function().jsonSchema())) + .collect(Collectors.toList()); + optionsBuilder.toolCallbacks(toolCallbacks); + } + + ToolCallingChatOptions chatOptions = optionsBuilder.build(); + + return Prompt.builder().messages(messages).chatOptions(chatOptions).build(); + } + + private org.springframework.ai.chat.messages.Message toSpringMessage(Message message) { + return switch (message.role()) { + case SYSTEM -> new SystemMessage((String) message.rawContent()); + case USER -> { + UserMessage.Builder builder = UserMessage.builder().text((String) message.rawContent()); + if (!CollectionUtils.isEmpty(message.mediaContents())) { + builder.media( + message.mediaContents().stream().map(this::toMedia).collect(Collectors.toList())); + } + yield builder.build(); + } + case ASSISTANT -> + AssistantMessage.builder() + .content((String) message.rawContent()) + .properties(Map.of()) + .toolCalls( + message.toolCalls() != null + ? message.toolCalls().stream() + .map( + tc -> + new AssistantMessage.ToolCall( + tc.id(), + tc.type(), + tc.function().name(), + tc.function().arguments())) + .collect(Collectors.toList()) + : List.of()) + .media( + message.mediaContents() != null + ? message.mediaContents().stream() + .map(this::toMedia) + .collect(Collectors.toList()) + : List.of()) + .build(); + case TOOL -> + ToolResponseMessage.builder() + .responses( + List.of( + new ToolResponseMessage.ToolResponse( + message.toolCallId(), message.name(), (String) message.rawContent()))) + .build(); + }; + } + + private Media toMedia(ChatModelTypes.MediaContent mediaContent) { + MimeType mimeType = MimeType.valueOf(mediaContent.mimeType()); + if (mediaContent.uri() != null) { + try { + return new Media(mimeType, new URI(mediaContent.uri())); + } catch (URISyntaxException e) { + throw new RuntimeException("Invalid media URI: " + mediaContent.uri(), e); + } + } else if (mediaContent.data() != null) { + return new Media(mimeType, new ByteArrayResource(mediaContent.data())); + } + throw new IllegalArgumentException("Media content must have either uri or data"); + } + + private ChatModelTypes.ChatModelActivityOutput toOutput(ChatResponse response) { + List generations = + response.getResults().stream() + .map( + gen -> + new ChatModelTypes.ChatModelActivityOutput.Generation( + fromAssistantMessage(gen.getOutput()))) + .collect(Collectors.toList()); + + ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata metadata = null; + if (response.getMetadata() != null) { + var rateLimit = response.getMetadata().getRateLimit(); + var usage = response.getMetadata().getUsage(); + + metadata = + new ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata( + response.getMetadata().getModel(), + rateLimit != null + ? new ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata.RateLimit( + rateLimit.getRequestsLimit(), + rateLimit.getRequestsRemaining(), + rateLimit.getRequestsReset(), + rateLimit.getTokensLimit(), + rateLimit.getTokensRemaining(), + rateLimit.getTokensReset()) + : null, + usage != null + ? new ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata.Usage( + usage.getPromptTokens() != null ? usage.getPromptTokens().intValue() : null, + usage.getCompletionTokens() != null + ? usage.getCompletionTokens().intValue() + : null, + usage.getTotalTokens() != null ? usage.getTotalTokens().intValue() : null) + : null); + } + + return new ChatModelTypes.ChatModelActivityOutput(generations, metadata); + } + + private Message fromAssistantMessage(AssistantMessage assistantMessage) { + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = + assistantMessage.getToolCalls().stream() + .map( + tc -> + new Message.ToolCall( + tc.id(), + tc.type(), + new Message.ChatCompletionFunction(tc.name(), tc.arguments()))) + .collect(Collectors.toList()); + } + + List mediaContents = null; + if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) { + mediaContents = + assistantMessage.getMedia().stream().map(this::fromMedia).collect(Collectors.toList()); + } + + return new Message( + assistantMessage.getText(), Message.Role.ASSISTANT, null, null, toolCalls, mediaContents); + } + + private ChatModelTypes.MediaContent fromMedia(Media media) { + String mimeType = media.getMimeType().toString(); + if (media.getData() instanceof String uri) { + return new ChatModelTypes.MediaContent(mimeType, uri); + } else if (media.getData() instanceof byte[] data) { + return new ChatModelTypes.MediaContent(mimeType, data); + } + throw new IllegalArgumentException( + "Unsupported media data type: " + media.getData().getClass()); + } + + /** + * Creates a stub ToolCallback that provides a tool definition but throws if called. This is used + * because Spring AI's ChatModel API requires ToolCallbacks, but we only need to inform the model + * about available tools - actual execution happens in the workflow (since + * internalToolExecutionEnabled is false). + */ + private ToolCallback createStubToolCallback(String name, String description, String inputSchema) { + ToolDefinition toolDefinition = + ToolDefinition.builder() + .name(name) + .description(description) + .inputSchema(inputSchema) + .build(); + + return new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + throw new UnsupportedOperationException( + "Tool execution should be handled by the workflow, not the activity. " + + "Ensure internalToolExecutionEnabled is set to false."); + } + }; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivity.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivity.java new file mode 100644 index 000000000..8deed81f2 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivity.java @@ -0,0 +1,62 @@ +package io.temporal.springai.activity; + +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.springai.model.EmbeddingModelTypes; + +/** + * Temporal activity interface for Spring AI EmbeddingModel operations. + * + *

This activity wraps Spring AI's {@link org.springframework.ai.embedding.EmbeddingModel}, + * making embedding generation durable and retriable within Temporal workflows. + * + *

Example usage in a workflow: + * + *

{@code
+ * EmbeddingModelActivity embeddingModel = Workflow.newActivityStub(
+ *     EmbeddingModelActivity.class,
+ *     ActivityOptions.newBuilder()
+ *         .setStartToCloseTimeout(Duration.ofMinutes(2))
+ *         .build());
+ *
+ * // Embed single text
+ * EmbedOutput result = embeddingModel.embed(new EmbedTextInput("Hello world"));
+ * List vector = result.embedding();
+ *
+ * // Embed batch
+ * EmbedBatchOutput batchResult = embeddingModel.embedBatch(
+ *     new EmbedBatchInput(List.of("text1", "text2", "text3")));
+ * }
+ */ +@ActivityInterface +public interface EmbeddingModelActivity { + + /** + * Generates an embedding for a single text. + * + * @param input the text to embed + * @return the embedding vector + */ + @ActivityMethod + EmbeddingModelTypes.EmbedOutput embed(EmbeddingModelTypes.EmbedTextInput input); + + /** + * Generates embeddings for multiple texts in a single request. + * + *

This is more efficient than calling {@link #embed} multiple times when you have multiple + * texts to embed. + * + * @param input the texts to embed + * @return the embedding vectors with metadata + */ + @ActivityMethod + EmbeddingModelTypes.EmbedBatchOutput embedBatch(EmbeddingModelTypes.EmbedBatchInput input); + + /** + * Returns the dimensionality of the embedding vectors produced by this model. + * + * @return the number of dimensions + */ + @ActivityMethod + EmbeddingModelTypes.DimensionsOutput dimensions(); +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivityImpl.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivityImpl.java new file mode 100644 index 000000000..b9c6d8266 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivityImpl.java @@ -0,0 +1,72 @@ +package io.temporal.springai.activity; + +import io.temporal.springai.model.EmbeddingModelTypes; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingResponse; + +/** + * Implementation of {@link EmbeddingModelActivity} that delegates to a Spring AI {@link + * EmbeddingModel}. + * + *

This implementation handles the conversion between Temporal-serializable types ({@link + * EmbeddingModelTypes}) and Spring AI types. + */ +public class EmbeddingModelActivityImpl implements EmbeddingModelActivity { + + private final EmbeddingModel embeddingModel; + + public EmbeddingModelActivityImpl(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + } + + @Override + public EmbeddingModelTypes.EmbedOutput embed(EmbeddingModelTypes.EmbedTextInput input) { + float[] embedding = embeddingModel.embed(input.text()); + return new EmbeddingModelTypes.EmbedOutput(toDoubleList(embedding)); + } + + @Override + public EmbeddingModelTypes.EmbedBatchOutput embedBatch( + EmbeddingModelTypes.EmbedBatchInput input) { + EmbeddingResponse response = embeddingModel.embedForResponse(input.texts()); + + List results = + IntStream.range(0, response.getResults().size()) + .mapToObj( + i -> { + var embedding = response.getResults().get(i); + return new EmbeddingModelTypes.EmbeddingResult( + i, toDoubleList(embedding.getOutput())); + }) + .collect(Collectors.toList()); + + EmbeddingModelTypes.EmbeddingMetadata metadata = null; + if (response.getMetadata() != null) { + var usage = response.getMetadata().getUsage(); + metadata = + new EmbeddingModelTypes.EmbeddingMetadata( + response.getMetadata().getModel(), + usage != null && usage.getTotalTokens() != null + ? usage.getTotalTokens().intValue() + : null, + embeddingModel.dimensions()); + } + + return new EmbeddingModelTypes.EmbedBatchOutput(results, metadata); + } + + @Override + public EmbeddingModelTypes.DimensionsOutput dimensions() { + return new EmbeddingModelTypes.DimensionsOutput(embeddingModel.dimensions()); + } + + private List toDoubleList(float[] floats) { + return IntStream.range(0, floats.length) + .mapToDouble(i -> floats[i]) + .boxed() + .collect(Collectors.toList()); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivity.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivity.java new file mode 100644 index 000000000..51747e645 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivity.java @@ -0,0 +1,59 @@ +package io.temporal.springai.activity; + +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.springai.model.VectorStoreTypes; + +/** + * Temporal activity interface for Spring AI VectorStore operations. + * + *

This activity wraps Spring AI's {@link org.springframework.ai.vectorstore.VectorStore}, making + * vector database operations durable and retriable within Temporal workflows. + * + *

Example usage in a workflow: + * + *

{@code
+ * VectorStoreActivity vectorStore = Workflow.newActivityStub(
+ *     VectorStoreActivity.class,
+ *     ActivityOptions.newBuilder()
+ *         .setStartToCloseTimeout(Duration.ofMinutes(5))
+ *         .build());
+ *
+ * // Add documents
+ * vectorStore.addDocuments(new AddDocumentsInput(documents));
+ *
+ * // Search
+ * SearchOutput results = vectorStore.similaritySearch(new SearchInput("query", 10));
+ * }
+ */ +@ActivityInterface +public interface VectorStoreActivity { + + /** + * Adds documents to the vector store. + * + *

If the documents don't have pre-computed embeddings, the vector store will use its + * configured EmbeddingModel to generate them. + * + * @param input the documents to add + */ + @ActivityMethod + void addDocuments(VectorStoreTypes.AddDocumentsInput input); + + /** + * Deletes documents from the vector store by their IDs. + * + * @param input the IDs of documents to delete + */ + @ActivityMethod + void deleteByIds(VectorStoreTypes.DeleteByIdsInput input); + + /** + * Performs a similarity search in the vector store. + * + * @param input the search parameters + * @return the search results with similarity scores + */ + @ActivityMethod + VectorStoreTypes.SearchOutput similaritySearch(VectorStoreTypes.SearchInput input); +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivityImpl.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivityImpl.java new file mode 100644 index 000000000..80ce75518 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivityImpl.java @@ -0,0 +1,98 @@ +package io.temporal.springai.activity; + +import io.temporal.springai.model.VectorStoreTypes; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; + +/** + * Implementation of {@link VectorStoreActivity} that delegates to a Spring AI {@link VectorStore}. + * + *

This implementation handles the conversion between Temporal-serializable types ({@link + * VectorStoreTypes}) and Spring AI types. + */ +public class VectorStoreActivityImpl implements VectorStoreActivity { + + private final VectorStore vectorStore; + private final FilterExpressionTextParser filterParser = new FilterExpressionTextParser(); + + public VectorStoreActivityImpl(VectorStore vectorStore) { + this.vectorStore = vectorStore; + } + + @Override + public void addDocuments(VectorStoreTypes.AddDocumentsInput input) { + List documents = + input.documents().stream().map(this::toSpringDocument).collect(Collectors.toList()); + vectorStore.add(documents); + } + + @Override + public void deleteByIds(VectorStoreTypes.DeleteByIdsInput input) { + vectorStore.delete(input.ids()); + } + + @Override + public VectorStoreTypes.SearchOutput similaritySearch(VectorStoreTypes.SearchInput input) { + SearchRequest.Builder requestBuilder = + SearchRequest.builder().query(input.query()).topK(input.topK()); + + if (input.similarityThreshold() != null) { + requestBuilder.similarityThreshold(input.similarityThreshold()); + } + + if (input.filterExpression() != null && !input.filterExpression().isBlank()) { + requestBuilder.filterExpression(filterParser.parse(input.filterExpression())); + } + + List results = vectorStore.similaritySearch(requestBuilder.build()); + + List searchResults = + results.stream() + .map(doc -> new VectorStoreTypes.SearchResult(fromSpringDocument(doc), doc.getScore())) + .collect(Collectors.toList()); + + return new VectorStoreTypes.SearchOutput(searchResults); + } + + private Document toSpringDocument(VectorStoreTypes.Document doc) { + Document.Builder builder = Document.builder().id(doc.id()).text(doc.text()); + + if (doc.metadata() != null && !doc.metadata().isEmpty()) { + builder.metadata(new HashMap<>(doc.metadata())); + } + + return builder.build(); + } + + private VectorStoreTypes.Document fromSpringDocument(Document doc) { + // Convert metadata, handling potential non-serializable values + Map metadata = new HashMap<>(); + if (doc.getMetadata() != null) { + for (Map.Entry entry : doc.getMetadata().entrySet()) { + Object value = entry.getValue(); + // Only include serializable primitive types + if (value == null + || value instanceof String + || value instanceof Number + || value instanceof Boolean) { + metadata.put(entry.getKey(), value); + } else { + metadata.put(entry.getKey(), value.toString()); + } + } + } + + return new VectorStoreTypes.Document( + doc.getId(), + doc.getText(), + metadata, + null // Don't include embedding in results to reduce payload size + ); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/advisor/SandboxingAdvisor.java b/temporal-spring-ai/src/main/java/io/temporal/springai/advisor/SandboxingAdvisor.java new file mode 100644 index 000000000..d6042afeb --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/advisor/SandboxingAdvisor.java @@ -0,0 +1,119 @@ +package io.temporal.springai.advisor; + +import io.temporal.springai.tool.ActivityToolCallback; +import io.temporal.springai.tool.LocalActivityToolCallbackWrapper; +import io.temporal.springai.tool.NexusToolCallback; +import io.temporal.springai.tool.SideEffectToolCallback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.Advisor; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.model.tool.ToolCallingChatOptions; + +/** + * An advisor that automatically wraps unsafe tool callbacks in local activities. + * + *

This advisor inspects all tool callbacks in a chat request and ensures they are safe for + * workflow execution: + * + *

+ * + *

This provides a safety net for users who pass arbitrary Spring AI tools that may not be + * workflow-safe. A warning is logged for each wrapped tool to help users understand how to properly + * annotate their tools. + * + *

Usage

+ * + *
{@code
+ * this.chatClient = TemporalChatClient.builder(activityChatModel)
+ *         .defaultAdvisors(new SandboxingAdvisor())
+ *         .defaultTools(new UnsafeTools())  // Will be wrapped with warning
+ *         .build();
+ * }
+ * + *

When to Use

+ * + * + * + *

Performance Considerations

+ * + *

Wrapping tools in local activities adds overhead compared to properly annotated tools. For + * production, annotate your tools with {@code @DeterministicTool} or {@code @SideEffectTool}, or + * use activity stubs. + * + * @see io.temporal.springai.tool.DeterministicTool + * @see io.temporal.springai.tool.SideEffectTool + * @see LocalActivityToolCallbackWrapper + */ +public class SandboxingAdvisor implements CallAdvisor { + + private static final Logger logger = LoggerFactory.getLogger(SandboxingAdvisor.class); + + @Override + public ChatClientResponse adviseCall( + ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { + var prompt = chatClientRequest.prompt(); + + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + var toolCallbacks = toolCallingChatOptions.getToolCallbacks(); + + if (toolCallbacks != null && !toolCallbacks.isEmpty()) { + var wrappedCallbacks = + toolCallbacks.stream() + .map( + tc -> { + if (tc instanceof ActivityToolCallback + || tc instanceof NexusToolCallback + || tc instanceof SideEffectToolCallback) { + // Already safe for workflow execution + return tc; + } else if (tc instanceof LocalActivityToolCallbackWrapper) { + // Already wrapped + return tc; + } else { + // Wrap in local activity for safety + String toolName = + tc.getToolDefinition() != null + ? tc.getToolDefinition().name() + : tc.getClass().getSimpleName(); + logger.warn( + "Tool '{}' ({}) is not guaranteed to be deterministic. " + + "Wrapping in local activity for workflow safety. " + + "Consider using @DeterministicTool, @SideEffectTool, or an activity stub.", + toolName, + tc.getClass().getName()); + return new LocalActivityToolCallbackWrapper(tc); + } + }) + .toList(); + + toolCallingChatOptions.setToolCallbacks(wrappedCallbacks); + } + } + + return callAdvisorChain.nextCall(chatClientRequest); + } + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + // Run early to wrap tools before other advisors see them + return Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiEmbeddingAutoConfiguration.java b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiEmbeddingAutoConfiguration.java new file mode 100644 index 000000000..286392ed7 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiEmbeddingAutoConfiguration.java @@ -0,0 +1,25 @@ +package io.temporal.springai.autoconfigure; + +import io.temporal.springai.plugin.EmbeddingModelPlugin; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.context.annotation.Bean; + +/** + * Auto-configuration for EmbeddingModel integration with Temporal. + * + *

Conditionally creates an {@link EmbeddingModelPlugin} when {@code spring-ai-rag} is on the + * classpath and an {@link EmbeddingModel} bean is available. + */ +@AutoConfiguration(after = SpringAiTemporalAutoConfiguration.class) +@ConditionalOnClass(name = "org.springframework.ai.embedding.EmbeddingModel") +@ConditionalOnBean(EmbeddingModel.class) +public class SpringAiEmbeddingAutoConfiguration { + + @Bean + public EmbeddingModelPlugin embeddingModelPlugin(EmbeddingModel embeddingModel) { + return new EmbeddingModelPlugin(embeddingModel); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiMcpAutoConfiguration.java b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiMcpAutoConfiguration.java new file mode 100644 index 000000000..0fa299f85 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiMcpAutoConfiguration.java @@ -0,0 +1,22 @@ +package io.temporal.springai.autoconfigure; + +import io.temporal.springai.plugin.McpPlugin; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.context.annotation.Bean; + +/** + * Auto-configuration for MCP (Model Context Protocol) integration with Temporal. + * + *

Conditionally creates a {@link McpPlugin} when {@code spring-ai-mcp} and the MCP client + * library are on the classpath. + */ +@AutoConfiguration(after = SpringAiTemporalAutoConfiguration.class) +@ConditionalOnClass(name = "io.modelcontextprotocol.client.McpSyncClient") +public class SpringAiMcpAutoConfiguration { + + @Bean + public McpPlugin mcpPlugin() { + return new McpPlugin(); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiTemporalAutoConfiguration.java b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiTemporalAutoConfiguration.java new file mode 100644 index 000000000..f403208d9 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiTemporalAutoConfiguration.java @@ -0,0 +1,38 @@ +package io.temporal.springai.autoconfigure; + +import io.temporal.springai.plugin.SpringAiPlugin; +import java.util.Map; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.context.annotation.Bean; +import org.springframework.lang.Nullable; + +/** + * Core auto-configuration for the Spring AI Temporal plugin. + * + *

Creates the {@link SpringAiPlugin} bean which registers {@link + * io.temporal.springai.activity.ChatModelActivity} and {@link + * io.temporal.springai.tool.ExecuteToolLocalActivity} with all Temporal workers. + * + *

Optional integrations are handled by separate auto-configuration classes: + * + *

+ */ +@AutoConfiguration +@ConditionalOnClass( + name = {"org.springframework.ai.chat.model.ChatModel", "io.temporal.worker.Worker"}) +public class SpringAiTemporalAutoConfiguration { + + @Bean + public SpringAiPlugin springAiPlugin( + @Autowired Map chatModels, + @Autowired(required = false) @Nullable ChatModel primaryChatModel) { + return new SpringAiPlugin(chatModels, primaryChatModel); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiVectorStoreAutoConfiguration.java b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiVectorStoreAutoConfiguration.java new file mode 100644 index 000000000..bf2cf1ff8 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiVectorStoreAutoConfiguration.java @@ -0,0 +1,25 @@ +package io.temporal.springai.autoconfigure; + +import io.temporal.springai.plugin.VectorStorePlugin; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.context.annotation.Bean; + +/** + * Auto-configuration for VectorStore integration with Temporal. + * + *

Conditionally creates a {@link VectorStorePlugin} when {@code spring-ai-rag} is on the + * classpath and a {@link VectorStore} bean is available. + */ +@AutoConfiguration(after = SpringAiTemporalAutoConfiguration.class) +@ConditionalOnClass(name = "org.springframework.ai.vectorstore.VectorStore") +@ConditionalOnBean(VectorStore.class) +public class SpringAiVectorStoreAutoConfiguration { + + @Bean + public VectorStorePlugin vectorStorePlugin(VectorStore vectorStore) { + return new VectorStorePlugin(vectorStore); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/chat/TemporalChatClient.java b/temporal-spring-ai/src/main/java/io/temporal/springai/chat/TemporalChatClient.java new file mode 100644 index 000000000..847a10053 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/chat/TemporalChatClient.java @@ -0,0 +1,186 @@ +package io.temporal.springai.chat; + +import io.micrometer.observation.ObservationRegistry; +import io.temporal.springai.util.TemporalToolUtil; +import java.util.Map; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.DefaultChatClient; +import org.springframework.ai.chat.client.DefaultChatClientBuilder; +import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * A Temporal-aware implementation of Spring AI's {@link ChatClient} that understands Temporal + * primitives like activity stubs and deterministic tools. + * + *

This client extends Spring AI's {@link DefaultChatClient} to add support for Temporal-specific + * features: + * + *

+ * + *

Example usage in a workflow: + * + *

{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     // Create the activity-backed chat model
+ *     ChatModelActivity chatModelActivity = Workflow.newActivityStub(
+ *             ChatModelActivity.class, activityOptions);
+ *     ActivityChatModel activityChatModel = new ActivityChatModel(chatModelActivity);
+ *
+ *     // Create tools
+ *     WeatherActivity weatherTool = Workflow.newActivityStub(WeatherActivity.class, opts);
+ *     MathTools mathTools = new MathTools(); // @DeterministicTool
+ *
+ *     // Build the Temporal-aware chat client
+ *     this.chatClient = TemporalChatClient.builder(activityChatModel)
+ *             .defaultSystem("You are a helpful assistant.")
+ *             .defaultTools(weatherTool, mathTools)
+ *             .build();
+ * }
+ *
+ * @Override
+ * public String chat(String message) {
+ *     return chatClient.prompt()
+ *             .user(message)
+ *             .call()
+ *             .content();
+ * }
+ * }
+ * + * @see Builder + * @see io.temporal.springai.model.ActivityChatModel + */ +public class TemporalChatClient extends DefaultChatClient { + + /** + * Creates a new TemporalChatClient with the given request specification. + * + * @param defaultChatClientRequest the default request specification + */ + public TemporalChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) { + super(defaultChatClientRequest); + } + + /** + * Creates a builder for constructing a TemporalChatClient. + * + * @param chatModel the chat model to use (typically an {@code ActivityChatModel}) + * @return a new builder + */ + public static Builder builder(ChatModel chatModel) { + return builder(chatModel, ObservationRegistry.NOOP, null); + } + + /** + * Creates a builder with observation support. + * + * @param chatModel the chat model to use + * @param observationRegistry the observation registry for metrics + * @param customObservationConvention optional custom observation convention + * @return a new builder + */ + public static Builder builder( + ChatModel chatModel, + ObservationRegistry observationRegistry, + @Nullable ChatClientObservationConvention customObservationConvention) { + Assert.notNull(chatModel, "chatModel cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + return new Builder(chatModel, observationRegistry, customObservationConvention); + } + + /** + * A builder for creating {@link TemporalChatClient} instances that understand Temporal + * primitives. + * + *

This builder extends Spring AI's {@link DefaultChatClientBuilder} to add support for + * Temporal-specific tool types. When you call {@link #defaultTools(Object...)}, the builder + * automatically detects and converts: + * + *

+ * + * @see TemporalToolUtil + */ + public static class Builder extends DefaultChatClientBuilder { + + /** + * Creates a new builder for the given chat model. + * + * @param chatModel the chat model to use + */ + public Builder(ChatModel chatModel) { + super(chatModel, ObservationRegistry.NOOP, null, null); + } + + /** + * Creates a new builder with observation support. + * + * @param chatModel the chat model to use + * @param observationRegistry the observation registry for metrics + * @param customObservationConvention optional custom observation convention + */ + public Builder( + ChatModel chatModel, + ObservationRegistry observationRegistry, + @Nullable ChatClientObservationConvention customObservationConvention) { + super(chatModel, observationRegistry, customObservationConvention, null); + } + + /** + * Sets the default tools for all requests. + * + *

This method automatically detects and converts Temporal-specific tool types: + * + *

+ * + *

Unrecognized tool types will throw an {@link IllegalArgumentException}. For tools that + * aren't properly annotated, use {@code defaultToolCallbacks()} with {@link + * io.temporal.springai.advisor.SandboxingAdvisor} to wrap them safely. + * + * @param toolObjects the tool objects (activity stubs, deterministic tool instances, etc.) + * @return this builder + * @throws IllegalArgumentException if a tool object is not a recognized type + */ + @Override + public ChatClient.Builder defaultTools(Object... toolObjects) { + Assert.notNull(toolObjects, "toolObjects cannot be null"); + Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements"); + this.defaultRequest.toolCallbacks(TemporalToolUtil.convertTools(toolObjects)); + return this; + } + + /** + * Tool context is not supported in Temporal workflows. + * + *

Tool context requires mutable state that cannot be safely passed through Temporal's + * serialization boundaries. Use activity parameters or workflow state instead. + * + * @param toolContext ignored + * @return never returns + * @throws UnsupportedOperationException always + */ + @Override + public ChatClient.Builder defaultToolContext(Map toolContext) { + throw new UnsupportedOperationException( + "defaultToolContext is not supported in TemporalChatClient. " + + "Tool context cannot be safely serialized through Temporal activities. " + + "Consider passing required context as activity parameters or workflow state."); + } + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/ActivityMcpClient.java b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/ActivityMcpClient.java new file mode 100644 index 000000000..360412a83 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/ActivityMcpClient.java @@ -0,0 +1,141 @@ +package io.temporal.springai.mcp; + +import io.modelcontextprotocol.spec.McpSchema; +import io.temporal.activity.ActivityOptions; +import io.temporal.common.RetryOptions; +import io.temporal.workflow.Workflow; +import java.time.Duration; +import java.util.Map; + +/** + * A workflow-safe wrapper for MCP (Model Context Protocol) client operations. + * + *

This class provides access to MCP tools within Temporal workflows. All MCP operations are + * executed as activities, providing durability, automatic retries, and timeout handling. + * + *

Usage in Workflows

+ * + *
{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     // Create an MCP client with default options
+ *     ActivityMcpClient mcpClient = ActivityMcpClient.create();
+ *
+ *     // Get tools from all connected MCP servers
+ *     List mcpTools = McpToolCallback.fromMcpClient(mcpClient);
+ *
+ *     // Use with TemporalChatClient
+ *     this.chatClient = TemporalChatClient.builder(chatModel)
+ *             .defaultToolCallbacks(mcpTools)
+ *             .build();
+ * }
+ * }
+ * + *

MCP Server Configuration

+ * + *

MCP servers are configured in the worker's Spring context using Spring AI's MCP client + * configuration. See the Spring AI MCP documentation for details. + * + * @see McpClientActivity + * @see McpToolCallback + */ +public class ActivityMcpClient { + + /** Default timeout for MCP activity calls (30 seconds). */ + public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(30); + + /** Default maximum retry attempts for MCP activity calls. */ + public static final int DEFAULT_MAX_ATTEMPTS = 3; + + private final McpClientActivity activity; + private Map serverCapabilities; + private Map clientInfo; + + /** + * Creates a new ActivityMcpClient with the given activity stub. + * + * @param activity the activity stub for MCP operations + */ + public ActivityMcpClient(McpClientActivity activity) { + this.activity = activity; + } + + /** + * Creates an ActivityMcpClient with default options. + * + *

Must be called from workflow code. + * + * @return a new ActivityMcpClient + */ + public static ActivityMcpClient create() { + return create(DEFAULT_TIMEOUT, DEFAULT_MAX_ATTEMPTS); + } + + /** + * Creates an ActivityMcpClient with custom options. + * + *

Must be called from workflow code. + * + * @param timeout the activity start-to-close timeout + * @param maxAttempts the maximum number of retry attempts + * @return a new ActivityMcpClient + */ + public static ActivityMcpClient create(Duration timeout, int maxAttempts) { + McpClientActivity activity = + Workflow.newActivityStub( + McpClientActivity.class, + ActivityOptions.newBuilder() + .setStartToCloseTimeout(timeout) + .setRetryOptions(RetryOptions.newBuilder().setMaximumAttempts(maxAttempts).build()) + .build()); + return new ActivityMcpClient(activity); + } + + /** + * Gets the server capabilities for all connected MCP clients. + * + *

Results are cached after the first call. + * + * @return map of client name to server capabilities + */ + public Map getServerCapabilities() { + if (serverCapabilities == null) { + serverCapabilities = activity.getServerCapabilities(); + } + return serverCapabilities; + } + + /** + * Gets client info for all connected MCP clients. + * + *

Results are cached after the first call. + * + * @return map of client name to client implementation info + */ + public Map getClientInfo() { + if (clientInfo == null) { + clientInfo = activity.getClientInfo(); + } + return clientInfo; + } + + /** + * Calls a tool on a specific MCP client. + * + * @param clientName the name of the MCP client + * @param request the tool call request + * @return the tool call result + */ + public McpSchema.CallToolResult callTool(String clientName, McpSchema.CallToolRequest request) { + return activity.callTool(clientName, request); + } + + /** + * Lists all available tools from all connected MCP clients. + * + * @return map of client name to list of tools + */ + public Map listTools() { + return activity.listTools(); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivity.java b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivity.java new file mode 100644 index 000000000..5c17ce7d3 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivity.java @@ -0,0 +1,56 @@ +package io.temporal.springai.mcp; + +import io.modelcontextprotocol.spec.McpSchema; +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import java.util.Map; + +/** + * Activity interface for interacting with MCP (Model Context Protocol) clients. + * + *

This activity provides durable access to MCP servers, allowing workflows to discover and call + * MCP tools as Temporal activities with full retry and timeout support. + * + *

The activity implementation ({@link McpClientActivityImpl}) is automatically registered by the + * plugin when MCP clients are available in the Spring context. + * + * @see ActivityMcpClient + * @see McpToolCallback + */ +@ActivityInterface(namePrefix = "MCP-Client-") +public interface McpClientActivity { + + /** + * Gets the server capabilities for all connected MCP clients. + * + * @return map of client name to server capabilities + */ + @ActivityMethod + Map getServerCapabilities(); + + /** + * Gets client info for all connected MCP clients. + * + * @return map of client name to client implementation info + */ + @ActivityMethod + Map getClientInfo(); + + /** + * Calls a tool on a specific MCP client. + * + * @param clientName the name of the MCP client + * @param request the tool call request + * @return the tool call result + */ + @ActivityMethod + McpSchema.CallToolResult callTool(String clientName, McpSchema.CallToolRequest request); + + /** + * Lists all available tools from all connected MCP clients. + * + * @return map of client name to list of tools + */ + @ActivityMethod + Map listTools(); +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivityImpl.java b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivityImpl.java new file mode 100644 index 000000000..b7f031759 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivityImpl.java @@ -0,0 +1,64 @@ +package io.temporal.springai.mcp; + +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; +import io.temporal.failure.ApplicationFailure; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Implementation of {@link McpClientActivity} that delegates to Spring AI MCP clients. + * + *

This activity provides durable access to MCP servers. It is automatically registered by the + * plugin when MCP clients are available in the Spring context. + */ +public class McpClientActivityImpl implements McpClientActivity { + + private final Map mcpClients; + + /** + * Creates an activity implementation with the given MCP clients. + * + * @param mcpClients list of MCP sync clients from Spring context + */ + public McpClientActivityImpl(List mcpClients) { + this.mcpClients = + mcpClients.stream().collect(Collectors.toMap(c -> c.getClientInfo().name(), c -> c)); + } + + @Override + public Map getServerCapabilities() { + return mcpClients.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getServerCapabilities())); + } + + @Override + public Map getClientInfo() { + return mcpClients.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getClientInfo())); + } + + @Override + public McpSchema.CallToolResult callTool(String clientName, McpSchema.CallToolRequest request) { + McpSyncClient client = mcpClients.get(clientName); + if (client == null) { + throw ApplicationFailure.newBuilder() + .setType("ClientNotFound") + .setMessage( + "MCP client '" + + clientName + + "' not found. Available clients: " + + mcpClients.keySet()) + .setNonRetryable(true) + .build(); + } + return client.callTool(request); + } + + @Override + public Map listTools() { + return mcpClients.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().listTools())); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpToolCallback.java b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpToolCallback.java new file mode 100644 index 000000000..9cf821aae --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpToolCallback.java @@ -0,0 +1,133 @@ +package io.temporal.springai.mcp; + +import io.modelcontextprotocol.spec.McpSchema; +import java.util.List; +import java.util.Map; +import org.springframework.ai.mcp.McpToolUtils; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.springframework.ai.tool.definition.ToolDefinition; + +/** + * A {@link ToolCallback} implementation that executes MCP tools via Temporal activities. + * + *

This class bridges MCP tools with Spring AI's tool calling system, allowing AI models to call + * MCP server tools through durable Temporal activities. + * + *

Usage in Workflows

+ * + *
{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     // Create an MCP client
+ *     ActivityMcpClient mcpClient = ActivityMcpClient.create();
+ *
+ *     // Convert MCP tools to ToolCallbacks
+ *     List mcpTools = McpToolCallback.fromMcpClient(mcpClient);
+ *
+ *     // Use with TemporalChatClient
+ *     this.chatClient = TemporalChatClient.builder(chatModel)
+ *             .defaultToolCallbacks(mcpTools)
+ *             .build();
+ * }
+ * }
+ * + * @see ActivityMcpClient + * @see McpClientActivity + */ +public class McpToolCallback implements ToolCallback { + + private final ActivityMcpClient client; + private final String clientName; + private final McpSchema.Tool tool; + private final ToolDefinition toolDefinition; + + /** + * Creates a new McpToolCallback for a specific MCP tool. + * + * @param client the MCP client to use for tool calls + * @param clientName the name of the MCP client that provides this tool + * @param tool the tool definition + * @param toolNamePrefix the prefix to use for the tool name (usually the MCP server name) + */ + public McpToolCallback( + ActivityMcpClient client, String clientName, McpSchema.Tool tool, String toolNamePrefix) { + this.client = client; + this.clientName = clientName; + this.tool = tool; + + // Cache the tool definition at construction time to avoid activity calls in queries + String prefixedName = McpToolUtils.prefixedToolName(toolNamePrefix, tool.name()); + this.toolDefinition = + DefaultToolDefinition.builder() + .name(prefixedName) + .description(tool.description()) + .inputSchema(ModelOptionsUtils.toJsonString(tool.inputSchema())) + .build(); + } + + /** + * Creates ToolCallbacks for all tools from all MCP clients. + * + *

This method discovers all available tools from the MCP clients and wraps them as + * ToolCallbacks that execute through Temporal activities. + * + * @param client the MCP client + * @return list of ToolCallbacks for all discovered tools + */ + public static List fromMcpClient(ActivityMcpClient client) { + // Get client info upfront for tool name prefixes + Map clientInfo = client.getClientInfo(); + + Map toolsMap = client.listTools(); + return toolsMap.entrySet().stream() + .flatMap( + entry -> { + String clientName = entry.getKey(); + McpSchema.Implementation impl = clientInfo.get(clientName); + String prefix = impl != null ? impl.name() : clientName; + + return entry.getValue().tools().stream() + .map( + tool -> (ToolCallback) new McpToolCallback(client, clientName, tool, prefix)); + }) + .toList(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + Map arguments = ModelOptionsUtils.jsonToMap(toolInput); + + // Use the original tool name (not prefixed) when calling the MCP server + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(tool.name(), arguments); + McpSchema.CallToolResult result = client.callTool(clientName, request); + + // Return the result as-is (including errors) so the AI can handle them. + // For example, an "access denied" error lets the AI suggest a valid path. + return ModelOptionsUtils.toJsonString(result.content()); + } + + /** + * Returns the name of the MCP client that provides this tool. + * + * @return the client name + */ + public String getClientName() { + return clientName; + } + + /** + * Returns the original tool definition from the MCP server. + * + * @return the tool definition + */ + public McpSchema.Tool getMcpTool() { + return tool; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java new file mode 100644 index 000000000..9a16d4db1 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java @@ -0,0 +1,389 @@ +package io.temporal.springai.model; + +import io.temporal.activity.ActivityOptions; +import io.temporal.common.RetryOptions; +import io.temporal.springai.activity.ChatModelActivity; +import io.temporal.workflow.Workflow; +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.content.Media; +import org.springframework.ai.model.tool.*; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; +import reactor.core.publisher.Flux; + +/** + * A {@link ChatModel} implementation that delegates to a Temporal activity. + * + *

This class enables Spring AI chat clients to be used within Temporal workflows. AI model calls + * are executed as activities, providing durability, automatic retries, and timeout handling. + * + *

Tool execution is handled locally in the workflow (not in the activity), allowing tools to be + * implemented as activities, local activities, or other Temporal primitives. + * + *

Usage

+ * + *

For a single chat model, use the constructor directly: + * + *

{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     ChatModelActivity chatModelActivity = Workflow.newActivityStub(
+ *         ChatModelActivity.class,
+ *         ActivityOptions.newBuilder()
+ *             .setStartToCloseTimeout(Duration.ofMinutes(2))
+ *             .build());
+ *
+ *     ActivityChatModel chatModel = new ActivityChatModel(chatModelActivity);
+ *     this.chatClient = ChatClient.builder(chatModel).build();
+ * }
+ * }
+ * + *

Multiple Chat Models

+ * + *

For applications with multiple chat models, use the static factory methods: + * + *

{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     // Use the default model (first or @Primary bean)
+ *     ActivityChatModel defaultModel = ActivityChatModel.forDefault();
+ *
+ *     // Use a specific model by bean name
+ *     ActivityChatModel openAiModel = ActivityChatModel.forModel("openAiChatModel");
+ *     ActivityChatModel anthropicModel = ActivityChatModel.forModel("anthropicChatModel");
+ *
+ *     // Use different models for different purposes
+ *     this.fastClient = TemporalChatClient.builder(openAiModel).build();
+ *     this.smartClient = TemporalChatClient.builder(anthropicModel).build();
+ * }
+ * }
+ * + * @see #forDefault() + * @see #forModel(String) + */ +public class ActivityChatModel implements ChatModel { + + /** Default timeout for chat model activity calls (2 minutes). */ + public static final Duration DEFAULT_TIMEOUT = Duration.ofMinutes(2); + + /** Default maximum retry attempts for chat model activity calls. */ + public static final int DEFAULT_MAX_ATTEMPTS = 3; + + private final ChatModelActivity chatModelActivity; + private final String modelName; + private final ToolCallingManager toolCallingManager; + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + + /** + * Creates a new ActivityChatModel that uses the default chat model. + * + * @param chatModelActivity the activity stub for calling the chat model + */ + public ActivityChatModel(ChatModelActivity chatModelActivity) { + this(chatModelActivity, null); + } + + /** + * Creates a new ActivityChatModel that uses a specific chat model. + * + * @param chatModelActivity the activity stub for calling the chat model + * @param modelName the name of the chat model to use, or null for default + */ + public ActivityChatModel(ChatModelActivity chatModelActivity, String modelName) { + this.chatModelActivity = chatModelActivity; + this.modelName = modelName; + this.toolCallingManager = ToolCallingManager.builder().build(); + this.toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + } + + /** + * Creates an ActivityChatModel for the default chat model. + * + *

This factory method creates the activity stub internally with default timeout and retry + * options. + * + *

Must be called from workflow code. + * + * @return an ActivityChatModel for the default chat model + */ + public static ActivityChatModel forDefault() { + return forModel(null, DEFAULT_TIMEOUT, DEFAULT_MAX_ATTEMPTS); + } + + /** + * Creates an ActivityChatModel for a specific chat model by bean name. + * + *

This factory method creates the activity stub internally with default timeout and retry + * options. + * + *

Must be called from workflow code. + * + * @param modelName the bean name of the chat model + * @return an ActivityChatModel for the specified chat model + * @throws IllegalArgumentException if no model with that name exists (at activity runtime) + */ + public static ActivityChatModel forModel(String modelName) { + return forModel(modelName, DEFAULT_TIMEOUT, DEFAULT_MAX_ATTEMPTS); + } + + /** + * Creates an ActivityChatModel for a specific chat model with custom options. + * + *

Must be called from workflow code. + * + * @param modelName the bean name of the chat model, or null for default + * @param timeout the activity start-to-close timeout + * @param maxAttempts the maximum number of retry attempts + * @return an ActivityChatModel for the specified chat model + */ + public static ActivityChatModel forModel(String modelName, Duration timeout, int maxAttempts) { + ChatModelActivity activity = + Workflow.newActivityStub( + ChatModelActivity.class, + ActivityOptions.newBuilder() + .setStartToCloseTimeout(timeout) + .setRetryOptions(RetryOptions.newBuilder().setMaximumAttempts(maxAttempts).build()) + .build()); + return new ActivityChatModel(activity, modelName); + } + + /** + * Returns the name of the chat model this instance uses. + * + * @return the model name, or null if using the default model + */ + public String getModelName() { + return modelName; + } + + /** + * Streaming is not supported through Temporal activities. + * + * @throws UnsupportedOperationException always + */ + @Override + public Flux stream(Prompt prompt) { + throw new UnsupportedOperationException("Streaming is not supported in ActivityChatModel."); + } + + @Override + public ChatOptions getDefaultOptions() { + return ToolCallingChatOptions.builder().build(); + } + + @Override + public ChatResponse call(Prompt prompt) { + return internalCall(prompt); + } + + private ChatResponse internalCall(Prompt prompt) { + // Convert prompt to activity input and call the activity + ChatModelTypes.ChatModelActivityInput input = createActivityInput(prompt); + ChatModelTypes.ChatModelActivityOutput output = chatModelActivity.callChatModel(input); + + // Convert activity output to ChatResponse + ChatResponse response = toResponse(output); + + // Handle tool calls if the model requested them + if (prompt.getOptions() != null + && toolExecutionEligibilityPredicate.isToolExecutionRequired( + prompt.getOptions(), response)) { + var toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response); + + if (toolExecutionResult.returnDirect()) { + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + + // Send tool results back to the model + return internalCall( + new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions())); + } + + return response; + } + + private ChatModelTypes.ChatModelActivityInput createActivityInput(Prompt prompt) { + // Convert messages + List messages = + prompt.getInstructions().stream() + .flatMap(msg -> toActivityMessages(msg).stream()) + .collect(Collectors.toList()); + + // Convert options + ChatModelTypes.ModelOptions modelOptions = null; + if (prompt.getOptions() != null) { + ChatOptions opts = prompt.getOptions(); + modelOptions = + new ChatModelTypes.ModelOptions( + opts.getModel(), + opts.getFrequencyPenalty(), + opts.getMaxTokens(), + opts.getPresencePenalty(), + opts.getStopSequences(), + opts.getTemperature(), + opts.getTopK(), + opts.getTopP()); + } + + // Convert tool definitions + List tools = List.of(); + if (prompt.getOptions() instanceof ToolCallingChatOptions toolOptions) { + List toolDefinitions = toolCallingManager.resolveToolDefinitions(toolOptions); + if (!CollectionUtils.isEmpty(toolDefinitions)) { + tools = + toolDefinitions.stream() + .map( + td -> + new ChatModelTypes.FunctionTool( + new ChatModelTypes.FunctionTool.Function( + td.name(), td.description(), td.inputSchema()))) + .collect(Collectors.toList()); + } + } + + return new ChatModelTypes.ChatModelActivityInput(modelName, messages, modelOptions, tools); + } + + private List toActivityMessages(Message message) { + return switch (message.getMessageType()) { + case SYSTEM -> + List.of( + new ChatModelTypes.Message(message.getText(), ChatModelTypes.Message.Role.SYSTEM)); + case USER -> { + List mediaContents = null; + if (message instanceof UserMessage userMessage + && !CollectionUtils.isEmpty(userMessage.getMedia())) { + mediaContents = + userMessage.getMedia().stream() + .map(this::toMediaContent) + .collect(Collectors.toList()); + } + yield List.of( + new ChatModelTypes.Message( + message.getText(), mediaContents, ChatModelTypes.Message.Role.USER)); + } + case ASSISTANT -> { + AssistantMessage assistantMessage = (AssistantMessage) message; + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = + assistantMessage.getToolCalls().stream() + .map( + tc -> + new ChatModelTypes.Message.ToolCall( + tc.id(), + tc.type(), + new ChatModelTypes.Message.ChatCompletionFunction( + tc.name(), tc.arguments()))) + .collect(Collectors.toList()); + } + List mediaContents = + assistantMessage.getMedia().stream() + .map(this::toMediaContent) + .collect(Collectors.toList()); + yield List.of( + new ChatModelTypes.Message( + assistantMessage.getText(), + ChatModelTypes.Message.Role.ASSISTANT, + null, + null, + toolCalls, + mediaContents.isEmpty() ? null : mediaContents)); + } + case TOOL -> { + ToolResponseMessage toolMessage = (ToolResponseMessage) message; + yield toolMessage.getResponses().stream() + .map( + tr -> + new ChatModelTypes.Message( + tr.responseData(), + ChatModelTypes.Message.Role.TOOL, + tr.name(), + tr.id(), + null, + null)) + .collect(Collectors.toList()); + } + }; + } + + private ChatModelTypes.MediaContent toMediaContent(Media media) { + String mimeType = media.getMimeType().toString(); + if (media.getData() instanceof String uri) { + return new ChatModelTypes.MediaContent(mimeType, uri); + } else if (media.getData() instanceof byte[] data) { + return new ChatModelTypes.MediaContent(mimeType, data); + } + throw new IllegalArgumentException( + "Unsupported media data type: " + media.getData().getClass()); + } + + private ChatResponse toResponse(ChatModelTypes.ChatModelActivityOutput output) { + List generations = + output.generations().stream() + .map(gen -> new Generation(toAssistantMessage(gen.message()))) + .collect(Collectors.toList()); + + var builder = ChatResponse.builder().generations(generations); + if (output.metadata() != null) { + builder.metadata(ChatResponseMetadata.builder().model(output.metadata().model()).build()); + } + return builder.build(); + } + + private AssistantMessage toAssistantMessage(ChatModelTypes.Message message) { + List toolCalls = List.of(); + if (!CollectionUtils.isEmpty(message.toolCalls())) { + toolCalls = + message.toolCalls().stream() + .map( + tc -> + new AssistantMessage.ToolCall( + tc.id(), tc.type(), tc.function().name(), tc.function().arguments())) + .collect(Collectors.toList()); + } + + List media = List.of(); + if (!CollectionUtils.isEmpty(message.mediaContents())) { + media = message.mediaContents().stream().map(this::toMedia).collect(Collectors.toList()); + } + + return AssistantMessage.builder() + .content((String) message.rawContent()) + .properties(Map.of()) + .toolCalls(toolCalls) + .media(media) + .build(); + } + + private Media toMedia(ChatModelTypes.MediaContent mediaContent) { + MimeType mimeType = MimeType.valueOf(mediaContent.mimeType()); + if (mediaContent.uri() != null) { + try { + return new Media(mimeType, new URI(mediaContent.uri())); + } catch (URISyntaxException e) { + throw new RuntimeException("Invalid media URI: " + mediaContent.uri(), e); + } + } else if (mediaContent.data() != null) { + return new Media(mimeType, new ByteArrayResource(mediaContent.data())); + } + throw new IllegalArgumentException("Media content must have either uri or data"); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/model/ChatModelTypes.java b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ChatModelTypes.java new file mode 100644 index 000000000..f929e2cb2 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ChatModelTypes.java @@ -0,0 +1,192 @@ +package io.temporal.springai.model; + +import com.fasterxml.jackson.annotation.JsonFormat; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.time.Duration; +import java.util.List; + +/** + * Serializable types for chat model activity requests and responses. + * + *

These records are designed to be serialized by Temporal's data converter and passed between + * workflows and activities. + */ +public final class ChatModelTypes { + + private ChatModelTypes() {} + + /** + * Input to the chat model activity. + * + * @param modelName the name of the chat model bean to use (null for default) + * @param messages the conversation messages + * @param modelOptions options for the chat model (temperature, max tokens, etc.) + * @param tools tool definitions the model may call + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatModelActivityInput( + @JsonProperty("model_name") String modelName, + @JsonProperty("messages") List messages, + @JsonProperty("model_options") ModelOptions modelOptions, + @JsonProperty("tools") List tools) { + /** Creates input for the default chat model. */ + public ChatModelActivityInput( + List messages, ModelOptions modelOptions, List tools) { + this(null, messages, modelOptions, tools); + } + } + + /** + * Output from the chat model activity. + * + * @param generations the generated responses + * @param metadata response metadata (model, usage, rate limits) + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatModelActivityOutput( + @JsonProperty("generations") List generations, + @JsonProperty("metadata") ChatResponseMetadata metadata) { + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Generation(@JsonProperty("message") Message message) {} + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatResponseMetadata( + @JsonProperty("model") String model, + @JsonProperty("rate_limit") RateLimit rateLimit, + @JsonProperty("usage") Usage usage) { + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record RateLimit( + @JsonProperty("request_limit") Long requestLimit, + @JsonProperty("request_remaining") Long requestRemaining, + @JsonProperty("request_reset") Duration requestReset, + @JsonProperty("token_limit") Long tokenLimit, + @JsonProperty("token_remaining") Long tokenRemaining, + @JsonProperty("token_reset") Duration tokenReset) {} + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Usage( + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("completion_tokens") Integer completionTokens, + @JsonProperty("total_tokens") Integer totalTokens) {} + } + } + + /** + * A message in the conversation. + * + * @param rawContent the message content (typically a String) + * @param role the role of the message author + * @param name optional name for the participant + * @param toolCallId tool call ID this message responds to (for TOOL role) + * @param toolCalls tool calls requested by the model (for ASSISTANT role) + * @param mediaContents optional media attachments + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Message( + @JsonProperty("content") Object rawContent, + @JsonProperty("role") Role role, + @JsonProperty("name") String name, + @JsonProperty("tool_call_id") String toolCallId, + @JsonProperty("tool_calls") + @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + List toolCalls, + @JsonProperty("media") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + List mediaContents) { + public Message(Object content, Role role) { + this(content, role, null, null, null, null); + } + + public Message(Object content, List mediaContents, Role role) { + this(content, role, null, null, null, mediaContents); + } + + public enum Role { + @JsonProperty("system") + SYSTEM, + @JsonProperty("user") + USER, + @JsonProperty("assistant") + ASSISTANT, + @JsonProperty("tool") + TOOL + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ToolCall( + @JsonProperty("index") Integer index, + @JsonProperty("id") String id, + @JsonProperty("type") String type, + @JsonProperty("function") ChatCompletionFunction function) { + public ToolCall(String id, String type, ChatCompletionFunction function) { + this(null, id, type, function); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatCompletionFunction( + @JsonProperty("name") String name, @JsonProperty("arguments") String arguments) {} + } + + /** + * Media content within a message. + * + * @param mimeType the MIME type (e.g., "image/png") + * @param uri optional URI to the content + * @param data optional raw data bytes + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record MediaContent( + @JsonProperty("mime_type") String mimeType, + @JsonProperty("uri") String uri, + @JsonProperty("data") byte[] data) { + public MediaContent(String mimeType, String uri) { + this(mimeType, uri, null); + } + + public MediaContent(String mimeType, byte[] data) { + this(mimeType, null, data); + } + } + + /** A tool the model may call. */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record FunctionTool( + @JsonProperty("type") String type, @JsonProperty("function") Function function) { + public FunctionTool(Function function) { + this("function", function); + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Function( + @JsonProperty("name") String name, + @JsonProperty("description") String description, + @JsonProperty("json_schema") String jsonSchema) {} + } + + /** Model options for the chat request. */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ModelOptions( + @JsonProperty("model") String model, + @JsonProperty("frequency_penalty") Double frequencyPenalty, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("presence_penalty") Double presencePenalty, + @JsonProperty("stop_sequences") List stopSequences, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_k") Integer topK, + @JsonProperty("top_p") Double topP) {} +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/model/EmbeddingModelTypes.java b/temporal-spring-ai/src/main/java/io/temporal/springai/model/EmbeddingModelTypes.java new file mode 100644 index 000000000..c24c4f95e --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/model/EmbeddingModelTypes.java @@ -0,0 +1,67 @@ +package io.temporal.springai.model; + +import java.util.List; + +/** + * Serializable types for EmbeddingModel activity communication. + * + *

These records are used to pass data between workflows and the EmbeddingModelActivity, ensuring + * all data can be serialized by Temporal's data converter. + */ +public final class EmbeddingModelTypes { + + private EmbeddingModelTypes() {} + + /** + * Input for embedding a single text. + * + * @param text the text to embed + */ + public record EmbedTextInput(String text) {} + + /** + * Input for embedding multiple texts. + * + * @param texts the texts to embed + */ + public record EmbedBatchInput(List texts) {} + + /** + * Output containing a single embedding vector. + * + * @param embedding the embedding vector + */ + public record EmbedOutput(List embedding) {} + + /** + * Output containing multiple embedding vectors. + * + * @param embeddings the embedding vectors, one per input text + * @param metadata additional metadata about the embeddings + */ + public record EmbedBatchOutput(List embeddings, EmbeddingMetadata metadata) {} + + /** + * A single embedding result. + * + * @param index the index in the original input list + * @param embedding the embedding vector + */ + public record EmbeddingResult(int index, List embedding) {} + + /** + * Metadata about the embedding operation. + * + * @param model the model used for embedding + * @param totalTokens total tokens processed + * @param dimensions the dimensionality of the embeddings + */ + public record EmbeddingMetadata(String model, Integer totalTokens, Integer dimensions) {} + + /** + * Output containing embedding model dimensions. + * + * @param dimensions the number of dimensions in the embedding vectors + */ + public record DimensionsOutput(int dimensions) {} +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/model/VectorStoreTypes.java b/temporal-spring-ai/src/main/java/io/temporal/springai/model/VectorStoreTypes.java new file mode 100644 index 000000000..0eadd932e --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/model/VectorStoreTypes.java @@ -0,0 +1,82 @@ +package io.temporal.springai.model; + +import java.util.List; +import java.util.Map; + +/** + * Serializable types for VectorStore activity communication. + * + *

These records are used to pass data between workflows and the VectorStoreActivity, ensuring + * all data can be serialized by Temporal's data converter. + */ +public final class VectorStoreTypes { + + private VectorStoreTypes() {} + + /** + * Serializable representation of a document for vector storage. + * + * @param id unique identifier for the document + * @param text the text content of the document + * @param metadata additional metadata associated with the document + * @param embedding pre-computed embedding vector (optional, may be computed by the store) + */ + public record Document( + String id, String text, Map metadata, List embedding) { + public Document(String id, String text, Map metadata) { + this(id, text, metadata, null); + } + + public Document(String id, String text) { + this(id, text, Map.of(), null); + } + } + + /** + * Input for adding documents to the vector store. + * + * @param documents the documents to add + */ + public record AddDocumentsInput(List documents) {} + + /** + * Input for deleting documents by ID. + * + * @param ids the document IDs to delete + */ + public record DeleteByIdsInput(List ids) {} + + /** + * Input for similarity search. + * + * @param query the search query text + * @param topK maximum number of results to return + * @param similarityThreshold minimum similarity score (0.0 to 1.0) + * @param filterExpression optional filter expression for metadata filtering + */ + public record SearchInput( + String query, int topK, Double similarityThreshold, String filterExpression) { + public SearchInput(String query, int topK) { + this(query, topK, null, null); + } + + public SearchInput(String query) { + this(query, 4, null, null); + } + } + + /** + * Output from similarity search. + * + * @param documents the matching documents with their similarity scores + */ + public record SearchOutput(List documents) {} + + /** + * A single search result with similarity score. + * + * @param document the matched document + * @param score the similarity score + */ + public record SearchResult(Document document, Double score) {} +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/EmbeddingModelPlugin.java b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/EmbeddingModelPlugin.java new file mode 100644 index 000000000..d2993b36f --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/EmbeddingModelPlugin.java @@ -0,0 +1,34 @@ +package io.temporal.springai.plugin; + +import io.temporal.common.SimplePlugin; +import io.temporal.springai.activity.EmbeddingModelActivityImpl; +import io.temporal.worker.Worker; +import javax.annotation.Nonnull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.embedding.EmbeddingModel; + +/** + * Temporal plugin that registers {@link io.temporal.springai.activity.EmbeddingModelActivity} with + * workers. + * + *

This plugin is conditionally created by auto-configuration when Spring AI's {@link + * EmbeddingModel} is on the classpath and an EmbeddingModel bean is available. + */ +public class EmbeddingModelPlugin extends SimplePlugin { + + private static final Logger log = LoggerFactory.getLogger(EmbeddingModelPlugin.class); + + private final EmbeddingModel embeddingModel; + + public EmbeddingModelPlugin(EmbeddingModel embeddingModel) { + super("io.temporal.spring-ai-embedding"); + this.embeddingModel = embeddingModel; + } + + @Override + public void initializeWorker(@Nonnull String taskQueue, @Nonnull Worker worker) { + worker.registerActivitiesImplementations(new EmbeddingModelActivityImpl(embeddingModel)); + log.info("Registered EmbeddingModelActivity for task queue {}", taskQueue); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java new file mode 100644 index 000000000..2f3635cfd --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java @@ -0,0 +1,95 @@ +package io.temporal.springai.plugin; + +import io.modelcontextprotocol.client.McpSyncClient; +import io.temporal.common.SimplePlugin; +import io.temporal.springai.mcp.McpClientActivityImpl; +import io.temporal.worker.Worker; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nonnull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; + +/** + * Temporal plugin that registers {@link io.temporal.springai.mcp.McpClientActivity} with workers. + * + *

This plugin is conditionally created by auto-configuration when MCP classes are on the + * classpath. MCP clients may be created late by Spring AI's auto-configuration, so this plugin + * supports deferred registration via {@link SmartInitializingSingleton}. + */ +public class McpPlugin extends SimplePlugin + implements ApplicationContextAware, SmartInitializingSingleton { + + private static final Logger log = LoggerFactory.getLogger(McpPlugin.class); + + private List mcpClients = List.of(); + private ApplicationContext applicationContext; + private final List pendingWorkers = new ArrayList<>(); + + public McpPlugin() { + super("io.temporal.spring-ai-mcp"); + } + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.applicationContext = applicationContext; + } + + @SuppressWarnings("unchecked") + private List getMcpClients() { + if (!mcpClients.isEmpty()) { + return mcpClients; + } + + if (applicationContext != null && applicationContext.containsBean("mcpSyncClients")) { + try { + Object bean = applicationContext.getBean("mcpSyncClients"); + if (bean instanceof List clientList && !clientList.isEmpty()) { + mcpClients = (List) clientList; + log.info("Found {} MCP client(s) in ApplicationContext", mcpClients.size()); + } + } catch (Exception e) { + log.debug("Failed to get mcpSyncClients bean: {}", e.getMessage()); + } + } + + return mcpClients; + } + + @Override + public void initializeWorker(@Nonnull String taskQueue, @Nonnull Worker worker) { + List clients = getMcpClients(); + if (!clients.isEmpty()) { + worker.registerActivitiesImplementations(new McpClientActivityImpl(clients)); + log.info( + "Registered McpClientActivity ({} clients) for task queue {}", clients.size(), taskQueue); + } else { + pendingWorkers.add(worker); + log.debug("MCP clients not yet available; will attempt registration after initialization"); + } + } + + @Override + public void afterSingletonsInstantiated() { + if (pendingWorkers.isEmpty()) { + return; + } + + List clients = getMcpClients(); + if (clients.isEmpty()) { + log.debug("No MCP clients found after all beans initialized"); + pendingWorkers.clear(); + return; + } + + for (Worker worker : pendingWorkers) { + worker.registerActivitiesImplementations(new McpClientActivityImpl(clients)); + log.info("Registered deferred McpClientActivity ({} clients)", clients.size()); + } + pendingWorkers.clear(); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/SpringAiPlugin.java b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/SpringAiPlugin.java new file mode 100644 index 000000000..552fa0ba4 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/SpringAiPlugin.java @@ -0,0 +1,160 @@ +package io.temporal.springai.plugin; + +import io.temporal.common.SimplePlugin; +import io.temporal.springai.activity.ChatModelActivityImpl; +import io.temporal.springai.tool.ExecuteToolLocalActivityImpl; +import io.temporal.worker.Worker; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import javax.annotation.Nonnull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.lang.Nullable; + +/** + * Core Temporal plugin that registers {@link io.temporal.springai.activity.ChatModelActivity} and + * {@link io.temporal.springai.tool.ExecuteToolLocalActivity} with Temporal workers. + * + *

This plugin handles the required ChatModel integration. Optional integrations (VectorStore, + * EmbeddingModel, MCP) are handled by separate plugins that are conditionally created by + * auto-configuration: + * + *

+ * + *

In Workflows

+ * + *
{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     ActivityChatModel chatModel = ActivityChatModel.forDefault();
+ *     this.chatClient = TemporalChatClient.builder(chatModel).build();
+ * }
+ * }
+ * + * @see io.temporal.springai.activity.ChatModelActivity + * @see io.temporal.springai.model.ActivityChatModel + */ +public class SpringAiPlugin extends SimplePlugin { + + private static final Logger log = LoggerFactory.getLogger(SpringAiPlugin.class); + + /** The name used for the default chat model when none is specified. */ + public static final String DEFAULT_MODEL_NAME = "default"; + + private final Map chatModels; + private final String defaultModelName; + + /** + * Creates a new SpringAiPlugin with the given ChatModel. + * + * @param chatModel the Spring AI chat model to wrap as an activity + */ + public SpringAiPlugin(ChatModel chatModel) { + super("io.temporal.spring-ai"); + this.chatModels = Map.of(DEFAULT_MODEL_NAME, chatModel); + this.defaultModelName = DEFAULT_MODEL_NAME; + } + + /** + * Creates a new SpringAiPlugin with multiple ChatModels. + * + * @param chatModels map of bean names to ChatModel instances + * @param primaryChatModel the primary chat model (used to determine default), or null + */ + public SpringAiPlugin(Map chatModels, @Nullable ChatModel primaryChatModel) { + super("io.temporal.spring-ai"); + + if (chatModels == null || chatModels.isEmpty()) { + throw new IllegalArgumentException("At least one ChatModel bean is required"); + } + + this.chatModels = new LinkedHashMap<>(chatModels); + + if (primaryChatModel != null) { + String primaryName = + chatModels.entrySet().stream() + .filter(e -> e.getValue() == primaryChatModel) + .map(Map.Entry::getKey) + .findFirst() + .orElse(chatModels.keySet().iterator().next()); + this.defaultModelName = primaryName; + } else { + this.defaultModelName = chatModels.keySet().iterator().next(); + } + + if (chatModels.size() > 1) { + log.info( + "Registered {} chat models: {} (default: {})", + chatModels.size(), + chatModels.keySet(), + defaultModelName); + } + } + + @Override + public void initializeWorker(@Nonnull String taskQueue, @Nonnull Worker worker) { + // Register the ChatModelActivity implementation with all chat models + ChatModelActivityImpl chatModelActivityImpl = + new ChatModelActivityImpl(chatModels, defaultModelName); + worker.registerActivitiesImplementations(chatModelActivityImpl); + + // Register ExecuteToolLocalActivity for LocalActivityToolCallbackWrapper support + ExecuteToolLocalActivityImpl executeToolLocalActivity = new ExecuteToolLocalActivityImpl(); + worker.registerActivitiesImplementations(executeToolLocalActivity); + + String modelInfo = chatModels.size() > 1 ? " (" + chatModels.size() + " models)" : ""; + log.info( + "Registered ChatModelActivity{} and ExecuteToolLocalActivity for task queue {}", + modelInfo, + taskQueue); + } + + /** + * Returns the default ChatModel wrapped by this plugin. + * + * @return the default chat model + */ + public ChatModel getChatModel() { + return chatModels.get(defaultModelName); + } + + /** + * Returns a specific ChatModel by bean name. + * + * @param modelName the bean name of the chat model + * @return the chat model + * @throws IllegalArgumentException if no model with that name exists + */ + public ChatModel getChatModel(String modelName) { + ChatModel model = chatModels.get(modelName); + if (model == null) { + throw new IllegalArgumentException( + "No chat model with name '" + modelName + "'. Available models: " + chatModels.keySet()); + } + return model; + } + + /** + * Returns all ChatModels wrapped by this plugin, keyed by bean name. + * + * @return unmodifiable map of chat models + */ + public Map getChatModels() { + return Collections.unmodifiableMap(chatModels); + } + + /** + * Returns the name of the default chat model. + * + * @return the default model name + */ + public String getDefaultModelName() { + return defaultModelName; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/VectorStorePlugin.java b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/VectorStorePlugin.java new file mode 100644 index 000000000..e454e9d60 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/VectorStorePlugin.java @@ -0,0 +1,34 @@ +package io.temporal.springai.plugin; + +import io.temporal.common.SimplePlugin; +import io.temporal.springai.activity.VectorStoreActivityImpl; +import io.temporal.worker.Worker; +import javax.annotation.Nonnull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.vectorstore.VectorStore; + +/** + * Temporal plugin that registers {@link io.temporal.springai.activity.VectorStoreActivity} with + * workers. + * + *

This plugin is conditionally created by auto-configuration when Spring AI's {@link + * VectorStore} is on the classpath and a VectorStore bean is available. + */ +public class VectorStorePlugin extends SimplePlugin { + + private static final Logger log = LoggerFactory.getLogger(VectorStorePlugin.class); + + private final VectorStore vectorStore; + + public VectorStorePlugin(VectorStore vectorStore) { + super("io.temporal.spring-ai-vectorstore"); + this.vectorStore = vectorStore; + } + + @Override + public void initializeWorker(@Nonnull String taskQueue, @Nonnull Worker worker) { + worker.registerActivitiesImplementations(new VectorStoreActivityImpl(vectorStore)); + log.info("Registered VectorStoreActivity for task queue {}", taskQueue); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolCallback.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolCallback.java new file mode 100644 index 000000000..6f2dfe21b --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolCallback.java @@ -0,0 +1,61 @@ +package io.temporal.springai.tool; + +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; + +/** + * A wrapper for {@link ToolCallback} that indicates the underlying tool is backed by a Temporal + * activity stub. + * + *

This wrapper delegates all operations to the underlying callback while serving as a marker to + * indicate that tool invocations will execute as Temporal activities, providing durability, + * automatic retries, and timeout handling. + * + *

This class is primarily used internally by {@link ActivityToolUtil} when converting activity + * stubs to tool callbacks. Users typically don't need to create instances directly. + * + * @see ActivityToolUtil#fromActivityStub(Object...) + */ +public class ActivityToolCallback implements ToolCallback { + private final ToolCallback delegate; + + /** + * Creates a new ActivityToolCallback wrapping the given callback. + * + * @param delegate the underlying tool callback to wrap + */ + public ActivityToolCallback(ToolCallback delegate) { + this.delegate = delegate; + } + + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public ToolMetadata getToolMetadata() { + return delegate.getToolMetadata(); + } + + @Override + public String call(String toolInput) { + return delegate.call(toolInput); + } + + @Override + public String call(String toolInput, ToolContext toolContext) { + return delegate.call(toolInput, toolContext); + } + + /** + * Returns the underlying delegate callback. + * + * @return the wrapped callback + */ + public ToolCallback getDelegate() { + return delegate; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolUtil.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolUtil.java new file mode 100644 index 000000000..e168bcd86 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolUtil.java @@ -0,0 +1,135 @@ +package io.temporal.springai.tool; + +import io.temporal.activity.ActivityInterface; +import io.temporal.common.metadata.POJOActivityInterfaceMetadata; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.ai.tool.method.MethodToolCallback; +import org.springframework.ai.tool.support.ToolDefinitions; +import org.springframework.ai.tool.support.ToolUtils; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +/** + * Utility class for extracting tool definitions from Temporal activity interfaces. + * + *

This class bridges Spring AI's {@link Tool} annotation with Temporal's {@link + * ActivityInterface} annotation, allowing activity methods to be used as AI tools within workflows. + * + *

Example: + * + *

{@code
+ * @ActivityInterface
+ * public interface WeatherActivity {
+ *     @Tool(description = "Get the current weather for a city")
+ *     String getWeather(String city);
+ * }
+ *
+ * // In workflow:
+ * WeatherActivity weatherTool = Workflow.newActivityStub(WeatherActivity.class, opts);
+ * ToolCallback[] callbacks = ActivityToolUtil.fromActivityStub(weatherTool);
+ * }
+ */ +public final class ActivityToolUtil { + + private ActivityToolUtil() { + // Utility class + } + + /** + * Extracts {@link Tool} annotations from the given activity stub object. + * + *

Scans all interfaces implemented by the stub that are annotated with {@link + * ActivityInterface}, and returns a map of activity type names to their {@link Tool} annotations. + * + * @param activityStub the activity stub to extract annotations from + * @return a map of activity type names to Tool annotations + */ + public static Map getToolAnnotations(Object activityStub) { + return Stream.of(activityStub.getClass().getInterfaces()) + .filter(iface -> iface.isAnnotationPresent(ActivityInterface.class)) + .map(POJOActivityInterfaceMetadata::newInstance) + .flatMap(metadata -> metadata.getMethodsMetadata().stream()) + .filter(methodMetadata -> methodMetadata.getMethod().isAnnotationPresent(Tool.class)) + .collect( + Collectors.toMap( + methodMetadata -> methodMetadata.getActivityTypeName(), + methodMetadata -> methodMetadata.getMethod().getAnnotation(Tool.class))); + } + + /** + * Creates {@link ToolCallback} instances from activity stub objects. + * + *

For each activity stub, this method: + * + *

    + *
  1. Finds all interfaces annotated with {@link ActivityInterface} + *
  2. Extracts methods annotated with {@link Tool} + *
  3. Creates {@link MethodToolCallback} instances for each method + *
  4. Wraps them in {@link ActivityToolCallback} to mark their origin + *
+ * + *

Methods that return functional types (Function, Supplier, Consumer) are excluded as they are + * not supported as tools. + * + * @param toolObjects the activity stub objects to convert + * @return an array of ToolCallback instances + */ + public static ToolCallback[] fromActivityStub(Object... toolObjects) { + List callbacks = new ArrayList<>(); + + for (Object toolObject : toolObjects) { + Stream.of(toolObject.getClass().getInterfaces()) + .filter(iface -> iface.isAnnotationPresent(ActivityInterface.class)) + .flatMap(iface -> Stream.of(ReflectionUtils.getDeclaredMethods(iface))) + .filter(method -> method.isAnnotationPresent(Tool.class)) + .filter(method -> !isFunctionalType(method)) + .map(method -> createToolCallback(method, toolObject)) + .map(ActivityToolCallback::new) + .forEach(callbacks::add); + } + + return callbacks.toArray(new ToolCallback[0]); + } + + /** + * Checks if any interfaces implemented by the object are annotated with {@link ActivityInterface} + * and contain methods annotated with {@link Tool}. + * + * @param object the object to check + * @return true if the object has tool-annotated activity methods + */ + public static boolean hasToolAnnotations(Object object) { + return Stream.of(object.getClass().getInterfaces()) + .filter(iface -> iface.isAnnotationPresent(ActivityInterface.class)) + .flatMap(iface -> Stream.of(ReflectionUtils.getDeclaredMethods(iface))) + .anyMatch(method -> method.isAnnotationPresent(Tool.class)); + } + + private static MethodToolCallback createToolCallback(Method method, Object toolObject) { + return MethodToolCallback.builder() + .toolDefinition(ToolDefinitions.from(method)) + .toolMetadata(ToolMetadata.from(method)) + .toolMethod(method) + .toolObject(toolObject) + .toolCallResultConverter(ToolUtils.getToolCallResultConverter(method)) + .build(); + } + + private static boolean isFunctionalType(Method method) { + Class returnType = method.getReturnType(); + return ClassUtils.isAssignable(returnType, Function.class) + || ClassUtils.isAssignable(returnType, Supplier.class) + || ClassUtils.isAssignable(returnType, Consumer.class); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/DeterministicTool.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/DeterministicTool.java new file mode 100644 index 000000000..04a52c88c --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/DeterministicTool.java @@ -0,0 +1,49 @@ +package io.temporal.springai.tool; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Marks a tool class as deterministic, meaning it is safe to execute directly in a Temporal + * workflow without wrapping in an activity or side effect. + * + *

Deterministic tools must: + * + *

    + *
  • Always produce the same output for the same input + *
  • Have no side effects (no I/O, no random numbers, no system time) + *
  • Not call any non-deterministic APIs + *
+ * + *

Example usage: + * + *

{@code
+ * @DeterministicTool
+ * public class MathTools {
+ *     @Tool(description = "Add two numbers")
+ *     public int add(int a, int b) {
+ *         return a + b;
+ *     }
+ *
+ *     @Tool(description = "Multiply two numbers")
+ *     public int multiply(int a, int b) {
+ *         return a * b;
+ *     }
+ * }
+ *
+ * // In workflow:
+ * this.chatClient = TemporalChatClient.builder(activityChatModel)
+ *         .defaultTools(new MathTools())  // Safe to use directly
+ *         .build();
+ * }
+ * + *

Warning: Using this annotation on a class that performs non-deterministic operations + * will break workflow replay. Only use this for truly deterministic computations. + * + * @see org.springframework.ai.tool.annotation.Tool + */ +@Target({ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface DeterministicTool {} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ExecuteToolLocalActivity.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ExecuteToolLocalActivity.java new file mode 100644 index 000000000..3fef94e4e --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ExecuteToolLocalActivity.java @@ -0,0 +1,29 @@ +package io.temporal.springai.tool; + +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; + +/** + * Activity interface for executing tool callbacks via local activities. + * + *

This activity is used internally by {@link LocalActivityToolCallbackWrapper} to execute + * arbitrary {@link org.springframework.ai.tool.ToolCallback}s in a deterministic manner. Since + * callbacks cannot be serialized, they are stored in a static map and referenced by a unique ID. + * + *

This activity is automatically registered by the Spring AI plugin. + * + * @see LocalActivityToolCallbackWrapper + */ +@ActivityInterface +public interface ExecuteToolLocalActivity { + + /** + * Executes a tool callback identified by the given ID. + * + * @param toolCallbackId the unique ID of the tool callback in the static map + * @param toolInput the JSON input for the tool + * @return the tool's output as a string + */ + @ActivityMethod + String call(String toolCallbackId, String toolInput); +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ExecuteToolLocalActivityImpl.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ExecuteToolLocalActivityImpl.java new file mode 100644 index 000000000..5f9e76b8c --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ExecuteToolLocalActivityImpl.java @@ -0,0 +1,27 @@ +package io.temporal.springai.tool; + +import org.springframework.ai.tool.ToolCallback; +import org.springframework.stereotype.Component; + +/** + * Implementation of {@link ExecuteToolLocalActivity} that executes tool callbacks stored in the + * {@link LocalActivityToolCallbackWrapper#getCallback(String)} registry. + * + *

This activity is automatically registered by the Spring AI plugin. + */ +@Component +public class ExecuteToolLocalActivityImpl implements ExecuteToolLocalActivity { + + @Override + public String call(String toolCallbackId, String toolInput) { + ToolCallback callback = LocalActivityToolCallbackWrapper.getCallback(toolCallbackId); + if (callback == null) { + throw new IllegalStateException( + "Tool callback not found for ID: " + + toolCallbackId + + ". " + + "This may indicate the callback was not properly registered or was already cleaned up."); + } + return callback.call(toolInput); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/LocalActivityToolCallbackWrapper.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/LocalActivityToolCallbackWrapper.java new file mode 100644 index 000000000..6fdca60bd --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/LocalActivityToolCallbackWrapper.java @@ -0,0 +1,132 @@ +package io.temporal.springai.tool; + +import io.temporal.activity.LocalActivityOptions; +import io.temporal.workflow.Workflow; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; + +/** + * A wrapper that executes a {@link ToolCallback} via a local activity for deterministic replay. + * + *

This wrapper is used to make arbitrary (potentially non-deterministic) tool callbacks safe for + * workflow execution. The actual callback execution happens in a local activity, ensuring the + * result is recorded in workflow history. + * + *

Since {@link ToolCallback}s cannot be serialized, they are stored in a static map and + * referenced by a unique ID. The ID is passed to the local activity, which looks up the callback + * and executes it. + * + *

Memory Management: Callbacks are automatically removed from the map after execution to + * prevent memory leaks. However, if a workflow is evicted from the worker's cache mid-execution + * (between registering a callback and the {@code finally} block that removes it), the callback + * reference will leak until the worker is restarted. This is bounded by the number of concurrent + * in-flight tool calls and is unlikely to be a practical issue, but callers should be aware that + * the registry size ({@link #getRegisteredCallbackCount()}) may drift above zero under heavy + * eviction pressure. + * + *

This class is primarily used by {@code SandboxingAdvisor} to wrap unsafe tools. + * + * @see ExecuteToolLocalActivity + */ +public class LocalActivityToolCallbackWrapper implements ToolCallback { + + private static final Map CALLBACK_REGISTRY = new ConcurrentHashMap<>(); + + private final ToolCallback delegate; + private final ExecuteToolLocalActivity stub; + private final LocalActivityOptions options; + + /** + * Creates a new wrapper with default local activity options. + * + *

Default options: + * + *

    + *
  • Start-to-close timeout: 30 seconds + *
  • Arguments not included in marker (for smaller history) + *
+ * + * @param delegate the tool callback to wrap + */ + public LocalActivityToolCallbackWrapper(ToolCallback delegate) { + this( + delegate, + LocalActivityOptions.newBuilder() + .setStartToCloseTimeout(Duration.ofSeconds(30)) + .setDoNotIncludeArgumentsIntoMarker(true) + .build()); + } + + /** + * Creates a new wrapper with custom local activity options. + * + * @param delegate the tool callback to wrap + * @param options the local activity options to use + */ + public LocalActivityToolCallbackWrapper(ToolCallback delegate, LocalActivityOptions options) { + this.delegate = delegate; + this.options = options; + this.stub = Workflow.newLocalActivityStub(ExecuteToolLocalActivity.class, options); + } + + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public ToolMetadata getToolMetadata() { + return delegate.getToolMetadata(); + } + + @Override + public String call(String toolInput) { + String callbackId = Workflow.randomUUID().toString(); + try { + CALLBACK_REGISTRY.put(callbackId, delegate); + return stub.call(callbackId, toolInput); + } finally { + CALLBACK_REGISTRY.remove(callbackId); + } + } + + @Override + public String call(String toolInput, ToolContext toolContext) { + // Note: ToolContext cannot be passed through the activity, so we ignore it here. + // If context is needed, consider using activity parameters or workflow state. + return call(toolInput); + } + + /** + * Returns the underlying delegate callback. + * + * @return the wrapped callback + */ + public ToolCallback getDelegate() { + return delegate; + } + + /** + * Looks up a callback by its ID. Used by {@link ExecuteToolLocalActivityImpl}. + * + * @param callbackId the callback ID + * @return the callback, or null if not found + */ + public static ToolCallback getCallback(String callbackId) { + return CALLBACK_REGISTRY.get(callbackId); + } + + /** + * Returns the number of currently registered callbacks. Useful for testing and monitoring. + * + * @return the number of registered callbacks + */ + public static int getRegisteredCallbackCount() { + return CALLBACK_REGISTRY.size(); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolCallback.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolCallback.java new file mode 100644 index 000000000..a010dcd2d --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolCallback.java @@ -0,0 +1,61 @@ +package io.temporal.springai.tool; + +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; + +/** + * A wrapper for {@link ToolCallback} that indicates the underlying tool is backed by a Temporal + * Nexus service stub. + * + *

This wrapper delegates all operations to the underlying callback while serving as a marker to + * indicate that tool invocations will execute as Nexus operations, providing cross-namespace + * communication and durability. + * + *

This class is primarily used internally by {@link NexusToolUtil} when converting Nexus service + * stubs to tool callbacks. Users typically don't need to create instances directly. + * + * @see NexusToolUtil#fromNexusServiceStub(Object...) + */ +public class NexusToolCallback implements ToolCallback { + private final ToolCallback delegate; + + /** + * Creates a new NexusToolCallback wrapping the given callback. + * + * @param delegate the underlying tool callback to wrap + */ + public NexusToolCallback(ToolCallback delegate) { + this.delegate = delegate; + } + + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public ToolMetadata getToolMetadata() { + return delegate.getToolMetadata(); + } + + @Override + public String call(String toolInput) { + return delegate.call(toolInput); + } + + @Override + public String call(String toolInput, ToolContext toolContext) { + return delegate.call(toolInput, toolContext); + } + + /** + * Returns the underlying delegate callback. + * + * @return the wrapped callback + */ + public ToolCallback getDelegate() { + return delegate; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolUtil.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolUtil.java new file mode 100644 index 000000000..b2aa4a6a2 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolUtil.java @@ -0,0 +1,111 @@ +package io.temporal.springai.tool; + +import io.nexusrpc.Service; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.ai.tool.method.MethodToolCallback; +import org.springframework.ai.tool.support.ToolDefinitions; +import org.springframework.ai.tool.support.ToolUtils; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +/** + * Utility class for extracting tool definitions from Temporal Nexus service interfaces. + * + *

This class bridges Spring AI's {@link Tool} annotation with Nexus RPC's {@link Service} + * annotation, allowing Nexus service methods to be used as AI tools within workflows. + * + *

Example: + * + *

{@code
+ * @Service
+ * public interface WeatherService {
+ *     @Tool(description = "Get the current weather for a city")
+ *     String getWeather(String city);
+ * }
+ *
+ * // In workflow:
+ * WeatherService weatherTool = Workflow.newNexusServiceStub(WeatherService.class, opts);
+ * ToolCallback[] callbacks = NexusToolUtil.fromNexusServiceStub(weatherTool);
+ * }
+ */ +public final class NexusToolUtil { + + private NexusToolUtil() { + // Utility class + } + + /** + * Creates {@link ToolCallback} instances from Nexus service stub objects. + * + *

For each Nexus service stub, this method: + * + *

    + *
  1. Finds all interfaces annotated with {@link Service} + *
  2. Extracts methods annotated with {@link Tool} + *
  3. Creates {@link MethodToolCallback} instances for each method + *
  4. Wraps them in {@link NexusToolCallback} to mark their origin + *
+ * + *

Methods that return functional types (Function, Supplier, Consumer) are excluded as they are + * not supported as tools. + * + * @param toolObjects the Nexus service stub objects to convert + * @return an array of ToolCallback instances + */ + public static ToolCallback[] fromNexusServiceStub(Object... toolObjects) { + List callbacks = new ArrayList<>(); + + for (Object toolObject : toolObjects) { + Stream.of(toolObject.getClass().getInterfaces()) + .filter(iface -> iface.isAnnotationPresent(Service.class)) + .flatMap(iface -> Stream.of(ReflectionUtils.getDeclaredMethods(iface))) + .filter(method -> method.isAnnotationPresent(Tool.class)) + .filter(method -> !isFunctionalType(method)) + .map(method -> createToolCallback(method, toolObject)) + .map(NexusToolCallback::new) + .forEach(callbacks::add); + } + + return callbacks.toArray(new ToolCallback[0]); + } + + /** + * Checks if any interfaces implemented by the object are annotated with {@link Service} and + * contain methods annotated with {@link Tool}. + * + * @param object the object to check + * @return true if the object has tool-annotated Nexus service methods + */ + public static boolean hasToolAnnotations(Object object) { + return Stream.of(object.getClass().getInterfaces()) + .filter(iface -> iface.isAnnotationPresent(Service.class)) + .flatMap(iface -> Stream.of(ReflectionUtils.getDeclaredMethods(iface))) + .anyMatch(method -> method.isAnnotationPresent(Tool.class)); + } + + private static MethodToolCallback createToolCallback(Method method, Object toolObject) { + return MethodToolCallback.builder() + .toolDefinition(ToolDefinitions.from(method)) + .toolMetadata(ToolMetadata.from(method)) + .toolMethod(method) + .toolObject(toolObject) + .toolCallResultConverter(ToolUtils.getToolCallResultConverter(method)) + .build(); + } + + private static boolean isFunctionalType(Method method) { + Class returnType = method.getReturnType(); + return ClassUtils.isAssignable(returnType, Function.class) + || ClassUtils.isAssignable(returnType, Supplier.class) + || ClassUtils.isAssignable(returnType, Consumer.class); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectTool.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectTool.java new file mode 100644 index 000000000..f0ae6c5a0 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectTool.java @@ -0,0 +1,59 @@ +package io.temporal.springai.tool; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Marks a tool class as a side-effect tool, meaning its methods will be wrapped in {@code + * Workflow.sideEffect()} for safe execution in a Temporal workflow. + * + *

Side-effect tools are useful for operations that: + * + *

    + *
  • Are non-deterministic (e.g., reading current time, generating UUIDs) + *
  • Are cheap and don't need the full durability of an activity + *
  • Don't have external side effects that need to be retried on failure + *
+ * + *

The result of a side-effect tool is recorded in the workflow history, so on replay the same + * result is returned without re-executing the tool. + * + *

Example usage: + * + *

{@code
+ * @SideEffectTool
+ * public class TimestampTools {
+ *     @Tool(description = "Get the current timestamp")
+ *     public long currentTimeMillis() {
+ *         return System.currentTimeMillis();  // Non-deterministic, but recorded
+ *     }
+ *
+ *     @Tool(description = "Generate a random UUID")
+ *     public String randomUuid() {
+ *         return UUID.randomUUID().toString();
+ *     }
+ * }
+ *
+ * // In workflow:
+ * this.chatClient = TemporalChatClient.builder(activityChatModel)
+ *         .defaultTools(new TimestampTools())  // Wrapped in sideEffect()
+ *         .build();
+ * }
+ * + *

When to use which annotation: + * + *

    + *
  • {@link DeterministicTool} - Pure functions with no side effects (math, string manipulation) + *
  • {@code @SideEffectTool} - Non-deterministic but cheap operations (timestamps, random + * values) + *
  • Activity stub - Operations with external side effects or that need retry/durability + *
+ * + * @see DeterministicTool + * @see io.temporal.workflow.Workflow#sideEffect(Class, io.temporal.workflow.Functions.Func) + */ +@Target({ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface SideEffectTool {} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectToolCallback.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectToolCallback.java new file mode 100644 index 000000000..561b5b057 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectToolCallback.java @@ -0,0 +1,66 @@ +package io.temporal.springai.tool; + +import io.temporal.workflow.Workflow; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; + +/** + * A wrapper for {@link ToolCallback} that executes the tool within {@code Workflow.sideEffect()}, + * making it safe for non-deterministic operations. + * + *

When a tool is wrapped in this callback: + * + *

    + *
  • The first execution records the result in workflow history + *
  • On replay, the recorded result is returned without re-execution + *
  • This ensures deterministic replay even for non-deterministic tools + *
+ * + *

This is used internally when processing tools marked with {@link SideEffectTool}. + * + * @see SideEffectTool + * @see io.temporal.workflow.Workflow#sideEffect(Class, io.temporal.workflow.Functions.Func) + */ +public class SideEffectToolCallback implements ToolCallback { + private final ToolCallback delegate; + + /** + * Creates a new SideEffectToolCallback wrapping the given callback. + * + * @param delegate the underlying tool callback to wrap + */ + public SideEffectToolCallback(ToolCallback delegate) { + this.delegate = delegate; + } + + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public ToolMetadata getToolMetadata() { + return delegate.getToolMetadata(); + } + + @Override + public String call(String toolInput) { + return Workflow.sideEffect(String.class, () -> delegate.call(toolInput)); + } + + @Override + public String call(String toolInput, ToolContext toolContext) { + return Workflow.sideEffect(String.class, () -> delegate.call(toolInput, toolContext)); + } + + /** + * Returns the underlying delegate callback. + * + * @return the wrapped callback + */ + public ToolCallback getDelegate() { + return delegate; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalStubUtil.java b/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalStubUtil.java new file mode 100644 index 000000000..573cccb3b --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalStubUtil.java @@ -0,0 +1,87 @@ +package io.temporal.springai.util; + +import io.temporal.internal.sync.ActivityInvocationHandler; +import io.temporal.internal.sync.LocalActivityInvocationHandler; +import io.temporal.internal.sync.NexusServiceInvocationHandler; +import java.lang.reflect.Proxy; + +/** + * Utility class for detecting Temporal stub types. + * + *

Temporal creates dynamic proxies for various stub types (activities, local activities, child + * workflows, Nexus services). This utility provides methods to detect what type of stub an object + * is, which is useful for determining how to handle tool calls. + * + *

This class uses direct {@code instanceof} checks against the SDK's internal invocation handler + * classes. Since the {@code temporal-spring-ai} module lives in the SDK repo, this coupling is + * intentional and will be caught by compilation if the handler classes are renamed or moved. + */ +public final class TemporalStubUtil { + + private TemporalStubUtil() {} + + /** + * Checks if the given object is an activity stub created by {@code Workflow.newActivityStub()}. + * + * @param object the object to check + * @return true if the object is an activity stub (but not a local activity stub) + */ + public static boolean isActivityStub(Object object) { + if (object == null || !Proxy.isProxyClass(object.getClass())) { + return false; + } + var handler = Proxy.getInvocationHandler(object); + return handler instanceof ActivityInvocationHandler; + } + + /** + * Checks if the given object is a local activity stub created by {@code + * Workflow.newLocalActivityStub()}. + * + * @param object the object to check + * @return true if the object is a local activity stub + */ + public static boolean isLocalActivityStub(Object object) { + if (object == null || !Proxy.isProxyClass(object.getClass())) { + return false; + } + var handler = Proxy.getInvocationHandler(object); + return handler instanceof LocalActivityInvocationHandler; + } + + /** + * Checks if the given object is a child workflow stub created by {@code + * Workflow.newChildWorkflowStub()}. + * + *

Note: {@code ChildWorkflowInvocationHandler} is package-private in the SDK, so we check via + * the class name. This is safe because the module lives in the SDK repo — any rename would break + * compilation of this module's tests. + * + * @param object the object to check + * @return true if the object is a child workflow stub + */ + public static boolean isChildWorkflowStub(Object object) { + if (object == null || !Proxy.isProxyClass(object.getClass())) { + return false; + } + var handler = Proxy.getInvocationHandler(object); + // ChildWorkflowInvocationHandler is package-private, so we use class name check. + // This is the only handler where instanceof is not possible. + return handler.getClass().getName().endsWith("ChildWorkflowInvocationHandler"); + } + + /** + * Checks if the given object is a Nexus service stub created by {@code + * Workflow.newNexusServiceStub()}. + * + * @param object the object to check + * @return true if the object is a Nexus service stub + */ + public static boolean isNexusServiceStub(Object object) { + if (object == null || !Proxy.isProxyClass(object.getClass())) { + return false; + } + var handler = Proxy.getInvocationHandler(object); + return handler instanceof NexusServiceInvocationHandler; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalToolUtil.java b/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalToolUtil.java new file mode 100644 index 000000000..b725af6bf --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalToolUtil.java @@ -0,0 +1,159 @@ +package io.temporal.springai.util; + +import io.temporal.springai.tool.ActivityToolCallback; +import io.temporal.springai.tool.ActivityToolUtil; +import io.temporal.springai.tool.DeterministicTool; +import io.temporal.springai.tool.NexusToolCallback; +import io.temporal.springai.tool.NexusToolUtil; +import io.temporal.springai.tool.SideEffectTool; +import io.temporal.springai.tool.SideEffectToolCallback; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.springframework.ai.support.ToolCallbacks; +import org.springframework.ai.tool.ToolCallback; + +/** + * Utility class for converting tool objects to appropriate {@link ToolCallback} instances based on + * their type. + * + *

This class detects the type of each tool object and converts it appropriately: + * + *

    + *
  • Activity stubs - Converted to {@link ActivityToolCallback} for durable execution + *
  • Local activity stubs - Converted to tool callbacks for fast, local execution + *
  • Nexus service stubs - Converted to {@link NexusToolCallback} for cross-namespace + * operations + *
  • {@link DeterministicTool} classes - Converted to standard tool callbacks for direct + * execution + *
  • {@link SideEffectTool} classes - Wrapped in {@code Workflow.sideEffect()} for + * recorded execution + *
  • Child workflow stubs - Not supported + *
+ * + *

Example usage: + * + *

{@code
+ * WeatherActivity weatherTool = Workflow.newActivityStub(WeatherActivity.class, opts);
+ * MathTools mathTools = new MathTools(); // @DeterministicTool annotated
+ * TimestampTools timestamps = new TimestampTools(); // @SideEffectTool annotated
+ *
+ * List callbacks = TemporalToolUtil.convertTools(weatherTool, mathTools, timestamps);
+ * }
+ * + * @see DeterministicTool + * @see SideEffectTool + * @see ActivityToolCallback + * @see SideEffectToolCallback + */ +public final class TemporalToolUtil { + + private TemporalToolUtil() { + // Utility class + } + + /** + * Converts an array of tool objects to appropriate {@link ToolCallback} instances. + * + *

Each tool object is inspected to determine its type: + * + *

    + *
  • Activity stubs are converted using {@link ActivityToolUtil#fromActivityStub(Object...)} + *
  • Local activity stubs are converted the same way (both execute as activities) + *
  • Nexus service stubs are converted using {@link + * NexusToolUtil#fromNexusServiceStub(Object...)} + *
  • Child workflow stubs throw {@link UnsupportedOperationException} + *
  • Classes annotated with {@link DeterministicTool} are converted using Spring AI's standard + * {@code ToolCallbacks.from(Object)} + *
  • Classes annotated with {@link SideEffectTool} are wrapped in {@code + * Workflow.sideEffect()} + *
  • Other objects throw {@link IllegalArgumentException} + *
+ * + *

For tools that aren't properly annotated, use {@code defaultToolCallbacks()} with {@link + * io.temporal.springai.advisor.SandboxingAdvisor} to wrap them safely at call time. + * + * @param toolObjects the tool objects to convert + * @return a list of ToolCallback instances + * @throws IllegalArgumentException if a tool object is not a recognized type + * @throws UnsupportedOperationException if a tool type is not supported (child workflow) + */ + public static List convertTools(Object... toolObjects) { + List toolCallbacks = new ArrayList<>(); + + for (Object toolObject : toolObjects) { + if (toolObject == null) { + throw new IllegalArgumentException("Tool object cannot be null"); + } + + if (TemporalStubUtil.isActivityStub(toolObject)) { + // Activity stub - execute as durable activity + ToolCallback[] callbacks = ActivityToolUtil.fromActivityStub(toolObject); + toolCallbacks.addAll(List.of(callbacks)); + + } else if (TemporalStubUtil.isLocalActivityStub(toolObject)) { + // Local activity stub - execute as local activity (faster, less durable) + ToolCallback[] callbacks = ActivityToolUtil.fromActivityStub(toolObject); + toolCallbacks.addAll(List.of(callbacks)); + + } else if (TemporalStubUtil.isNexusServiceStub(toolObject)) { + // Nexus service stub - execute as Nexus operation + ToolCallback[] callbacks = NexusToolUtil.fromNexusServiceStub(toolObject); + toolCallbacks.addAll(List.of(callbacks)); + + } else if (TemporalStubUtil.isChildWorkflowStub(toolObject)) { + // Child workflow stubs are not supported + throw new UnsupportedOperationException( + "Child workflow stubs are not supported as tools. " + + "Consider using an activity to wrap the child workflow call."); + + } else if (toolObject.getClass().isAnnotationPresent(DeterministicTool.class)) { + // Deterministic tool - safe to execute directly in workflow + toolCallbacks.addAll(List.of(ToolCallbacks.from(toolObject))); + + } else if (toolObject.getClass().isAnnotationPresent(SideEffectTool.class)) { + // Side-effect tool - wrap in Workflow.sideEffect() for recorded execution + ToolCallback[] rawCallbacks = ToolCallbacks.from(toolObject); + List wrappedCallbacks = + Arrays.stream(rawCallbacks) + .map(SideEffectToolCallback::new) + .map(tc -> (ToolCallback) tc) + .toList(); + toolCallbacks.addAll(wrappedCallbacks); + + } else { + // Unknown type - reject to prevent non-deterministic behavior + throw new IllegalArgumentException( + "Tool object of type '" + + toolObject.getClass().getName() + + "' is not a " + + "recognized Temporal primitive (activity stub, local activity stub) or " + + "a class annotated with @DeterministicTool or @SideEffectTool. " + + "To use a plain object as a tool, either: " + + "(1) annotate its class with @DeterministicTool if it's truly deterministic, " + + "(2) annotate with @SideEffectTool if it's non-deterministic but cheap, " + + "(3) wrap it in an activity for durable execution, or " + + "(4) use defaultToolCallbacks() with SandboxingAdvisor to wrap unsafe tools."); + } + } + + return toolCallbacks; + } + + /** + * Checks if the given object is a recognized tool type that can be converted. + * + * @param toolObject the object to check + * @return true if the object can be converted to tool callbacks + */ + public static boolean isRecognizedToolType(Object toolObject) { + if (toolObject == null) { + return false; + } + return TemporalStubUtil.isActivityStub(toolObject) + || TemporalStubUtil.isLocalActivityStub(toolObject) + || TemporalStubUtil.isNexusServiceStub(toolObject) + || toolObject.getClass().isAnnotationPresent(DeterministicTool.class) + || toolObject.getClass().isAnnotationPresent(SideEffectTool.class); + } +} diff --git a/temporal-spring-ai/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/temporal-spring-ai/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 000000000..7f86436f4 --- /dev/null +++ b/temporal-spring-ai/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1,4 @@ +io.temporal.springai.autoconfigure.SpringAiTemporalAutoConfiguration +io.temporal.springai.autoconfigure.SpringAiVectorStoreAutoConfiguration +io.temporal.springai.autoconfigure.SpringAiEmbeddingAutoConfiguration +io.temporal.springai.autoconfigure.SpringAiMcpAutoConfiguration diff --git a/temporal-spring-ai/src/test/java/io/temporal/springai/WorkflowDeterminismTest.java b/temporal-spring-ai/src/test/java/io/temporal/springai/WorkflowDeterminismTest.java new file mode 100644 index 000000000..562058ad0 --- /dev/null +++ b/temporal-spring-ai/src/test/java/io/temporal/springai/WorkflowDeterminismTest.java @@ -0,0 +1,167 @@ +package io.temporal.springai; + +import static org.junit.jupiter.api.Assertions.*; + +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowOptions; +import io.temporal.client.WorkflowStub; +import io.temporal.common.WorkflowExecutionHistory; +import io.temporal.springai.activity.ChatModelActivityImpl; +import io.temporal.springai.chat.TemporalChatClient; +import io.temporal.springai.model.ActivityChatModel; +import io.temporal.springai.tool.DeterministicTool; +import io.temporal.springai.tool.SideEffectTool; +import io.temporal.testing.TestWorkflowEnvironment; +import io.temporal.testing.WorkflowReplayer; +import io.temporal.worker.Worker; +import io.temporal.workflow.WorkflowInterface; +import io.temporal.workflow.WorkflowMethod; +import java.util.List; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.tool.annotation.Tool; + +/** + * Verifies that workflows using ActivityChatModel with tools are deterministic by running them to + * completion and then replaying from the captured history. + */ +class WorkflowDeterminismTest { + + private static final String TASK_QUEUE = "test-spring-ai"; + + private TestWorkflowEnvironment testEnv; + private WorkflowClient client; + + @BeforeEach + void setUp() { + testEnv = TestWorkflowEnvironment.newInstance(); + client = testEnv.getWorkflowClient(); + } + + @AfterEach + void tearDown() { + testEnv.close(); + } + + @Test + void workflowWithChatModel_replaysDeterministically() throws Exception { + Worker worker = testEnv.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(ChatWorkflowImpl.class); + worker.registerActivitiesImplementations( + new ChatModelActivityImpl(new StubChatModel("Hello from the model!"))); + + testEnv.start(); + + TestChatWorkflow workflow = + client.newWorkflowStub( + TestChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); + + String result = workflow.chat("Hi"); + assertEquals("Hello from the model!", result); + + // Capture history and replay — any non-determinism throws here + WorkflowExecutionHistory history = + client.fetchHistory(WorkflowStub.fromTyped(workflow).getExecution().getWorkflowId()); + WorkflowReplayer.replayWorkflowExecution(history, ChatWorkflowImpl.class); + } + + @Test + void workflowWithTools_replaysDeterministically() throws Exception { + Worker worker = testEnv.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(ChatWithToolsWorkflowImpl.class); + worker.registerActivitiesImplementations( + new ChatModelActivityImpl(new StubChatModel("I used the tools!"))); + + testEnv.start(); + + TestChatWorkflow workflow = + client.newWorkflowStub( + TestChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); + + String result = workflow.chat("Use tools"); + assertEquals("I used the tools!", result); + + // Capture history and replay + WorkflowExecutionHistory history = + client.fetchHistory(WorkflowStub.fromTyped(workflow).getExecution().getWorkflowId()); + WorkflowReplayer.replayWorkflowExecution(history, ChatWithToolsWorkflowImpl.class); + } + + // --- Workflow interfaces and implementations --- + + @WorkflowInterface + public interface TestChatWorkflow { + @WorkflowMethod + String chat(String message); + } + + public static class ChatWorkflowImpl implements TestChatWorkflow { + @Override + public String chat(String message) { + ActivityChatModel chatModel = ActivityChatModel.forDefault(); + ChatClient chatClient = TemporalChatClient.builder(chatModel).build(); + return chatClient.prompt().user(message).call().content(); + } + } + + public static class ChatWithToolsWorkflowImpl implements TestChatWorkflow { + @Override + public String chat(String message) { + ActivityChatModel chatModel = ActivityChatModel.forDefault(); + TestDeterministicTools deterministicTools = new TestDeterministicTools(); + TestSideEffectTools sideEffectTools = new TestSideEffectTools(); + ChatClient chatClient = + TemporalChatClient.builder(chatModel) + .defaultTools(deterministicTools, sideEffectTools) + .build(); + return chatClient.prompt().user(message).call().content(); + } + } + + // --- Test tool classes --- + + @DeterministicTool + public static class TestDeterministicTools { + @Tool(description = "Add two numbers") + public int add(int a, int b) { + return a + b; + } + } + + @SideEffectTool + public static class TestSideEffectTools { + @Tool(description = "Get a timestamp") + public String timestamp() { + return "2025-01-01T00:00:00Z"; + } + } + + // --- Stub ChatModel that returns a canned response --- + + private static class StubChatModel implements ChatModel { + private final String response; + + StubChatModel(String response) { + this.response = response; + } + + @Override + public ChatResponse call(Prompt prompt) { + return ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage(response)))) + .build(); + } + + @Override + public reactor.core.publisher.Flux stream(Prompt prompt) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/temporal-spring-ai/src/test/java/io/temporal/springai/activity/ChatModelActivityImplTest.java b/temporal-spring-ai/src/test/java/io/temporal/springai/activity/ChatModelActivityImplTest.java new file mode 100644 index 000000000..300fe7dd7 --- /dev/null +++ b/temporal-spring-ai/src/test/java/io/temporal/springai/activity/ChatModelActivityImplTest.java @@ -0,0 +1,297 @@ +package io.temporal.springai.activity; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import io.temporal.springai.model.ChatModelTypes.*; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.ToolCallingChatOptions; + +class ChatModelActivityImplTest { + + @Test + void systemMessage_roundTrip() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("reply")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("You are helpful", Message.Role.SYSTEM)), null, List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + + assertNotNull(output); + assertEquals(1, output.generations().size()); + assertEquals("reply", output.generations().get(0).message().rawContent()); + assertEquals(Message.Role.ASSISTANT, output.generations().get(0).message().role()); + + // Verify the prompt was constructed with a SystemMessage + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + Prompt prompt = captor.getValue(); + assertEquals(1, prompt.getInstructions().size()); + assertInstanceOf( + org.springframework.ai.chat.messages.SystemMessage.class, prompt.getInstructions().get(0)); + assertEquals("You are helpful", prompt.getInstructions().get(0).getText()); + } + + @Test + void userMessage_roundTrip() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("hi")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("hello", Message.Role.USER)), null, List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + Prompt prompt = captor.getValue(); + assertInstanceOf( + org.springframework.ai.chat.messages.UserMessage.class, prompt.getInstructions().get(0)); + } + + @Test + void assistantMessageWithToolCalls_roundTrip() { + ChatModel mockModel = mock(ChatModel.class); + + // Model returns a response with tool calls + AssistantMessage assistantWithTools = + AssistantMessage.builder() + .content("I'll check the weather") + .toolCalls( + List.of( + new AssistantMessage.ToolCall( + "call_123", "function", "getWeather", "{\"city\":\"Seattle\"}"))) + .build(); + + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(assistantWithTools))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("What's the weather?", Message.Role.USER)), null, List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + + // Verify tool calls are preserved in output + Message outputMsg = output.generations().get(0).message(); + assertNotNull(outputMsg.toolCalls()); + assertEquals(1, outputMsg.toolCalls().size()); + assertEquals("call_123", outputMsg.toolCalls().get(0).id()); + assertEquals("function", outputMsg.toolCalls().get(0).type()); + assertEquals("getWeather", outputMsg.toolCalls().get(0).function().name()); + assertEquals("{\"city\":\"Seattle\"}", outputMsg.toolCalls().get(0).function().arguments()); + } + + @Test + void toolResponseMessage_roundTrip() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("It's 55F")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, + List.of( + new Message( + "Weather: 55F", Message.Role.TOOL, "getWeather", "call_123", null, null)), + null, + List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + + // Verify tool response was passed to model + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + Prompt prompt = captor.getValue(); + assertInstanceOf( + org.springframework.ai.chat.messages.ToolResponseMessage.class, + prompt.getInstructions().get(0)); + } + + @Test + void modelOptions_passedThrough() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("ok")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ModelOptions opts = new ModelOptions("gpt-4", null, 100, null, null, 0.5, null, 0.9); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("hi", Message.Role.USER)), opts, List.of()); + + impl.callChatModel(input); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + Prompt prompt = captor.getValue(); + assertNotNull(prompt.getOptions()); + assertEquals("gpt-4", prompt.getOptions().getModel()); + assertEquals(0.5, prompt.getOptions().getTemperature()); + assertEquals(0.9, prompt.getOptions().getTopP()); + assertEquals(100, prompt.getOptions().getMaxTokens()); + } + + @Test + void toolDefinitions_passedAsStubs() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("ok")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + FunctionTool tool = + new FunctionTool( + new FunctionTool.Function( + "getWeather", "Get weather for a city", "{\"type\":\"object\"}")); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("hi", Message.Role.USER)), null, List.of(tool)); + + impl.callChatModel(input); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + Prompt prompt = captor.getValue(); + + // Verify tool execution is disabled (workflow handles it) + assertInstanceOf(ToolCallingChatOptions.class, prompt.getOptions()); + assertFalse(ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions())); + } + + @Test + void multipleModels_resolvedByName() { + ChatModel openAi = mock(ChatModel.class); + ChatModel anthropic = mock(ChatModel.class); + when(openAi.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("openai")))) + .build()); + when(anthropic.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("anthropic")))) + .build()); + + ChatModelActivityImpl impl = + new ChatModelActivityImpl(Map.of("openai", openAi, "anthropic", anthropic), "openai"); + + // Call with specific model + ChatModelActivityInput input = + new ChatModelActivityInput( + "anthropic", List.of(new Message("hi", Message.Role.USER)), null, List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + assertEquals("anthropic", output.generations().get(0).message().rawContent()); + verify(anthropic).call(any(Prompt.class)); + verify(openAi, never()).call(any(Prompt.class)); + } + + @Test + void multipleModels_defaultUsedWhenNameNull() { + ChatModel openAi = mock(ChatModel.class); + ChatModel anthropic = mock(ChatModel.class); + when(openAi.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("openai")))) + .build()); + + ChatModelActivityImpl impl = + new ChatModelActivityImpl(Map.of("openai", openAi, "anthropic", anthropic), "openai"); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("hi", Message.Role.USER)), null, List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + assertEquals("openai", output.generations().get(0).message().rawContent()); + verify(openAi).call(any(Prompt.class)); + } + + @Test + void unknownModelName_throwsIllegalArgument() { + ChatModel model = mock(ChatModel.class); + ChatModelActivityImpl impl = new ChatModelActivityImpl(model); + + ChatModelActivityInput input = + new ChatModelActivityInput( + "nonexistent", List.of(new Message("hi", Message.Role.USER)), null, List.of()); + + assertThrows(IllegalArgumentException.class, () -> impl.callChatModel(input)); + } + + @Test + void multipleMessages_allConverted() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("ok")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, + List.of( + new Message("You are helpful", Message.Role.SYSTEM), + new Message("Hello", Message.Role.USER), + new Message("Hi there", Message.Role.ASSISTANT, null, null, null, null), + new Message("What's up?", Message.Role.USER)), + null, + List.of()); + + impl.callChatModel(input); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + assertEquals(4, captor.getValue().getInstructions().size()); + } +} diff --git a/temporal-spring-ai/src/test/java/io/temporal/springai/plugin/SpringAiPluginTest.java b/temporal-spring-ai/src/test/java/io/temporal/springai/plugin/SpringAiPluginTest.java new file mode 100644 index 000000000..3d1a20f35 --- /dev/null +++ b/temporal-spring-ai/src/test/java/io/temporal/springai/plugin/SpringAiPluginTest.java @@ -0,0 +1,139 @@ +package io.temporal.springai.plugin; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import io.temporal.springai.activity.ChatModelActivityImpl; +import io.temporal.springai.activity.EmbeddingModelActivityImpl; +import io.temporal.springai.activity.VectorStoreActivityImpl; +import io.temporal.springai.tool.ExecuteToolLocalActivityImpl; +import io.temporal.worker.Worker; +import java.util.*; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.VectorStore; + +class SpringAiPluginTest { + + private List captureRegisteredActivities(Worker worker) { + ArgumentCaptor captor = ArgumentCaptor.forClass(Object.class); + verify(worker, atLeastOnce()).registerActivitiesImplementations(captor.capture()); + return captor.getAllValues(); + } + + private Set> activityTypes(List activities) { + return activities.stream().map(Object::getClass).collect(Collectors.toSet()); + } + + // --- Core SpringAiPlugin tests --- + + @Test + void singleModel_registersChatModelAndExecuteToolLocal() { + ChatModel chatModel = mock(ChatModel.class); + Worker worker = mock(Worker.class); + + SpringAiPlugin plugin = new SpringAiPlugin(chatModel); + plugin.initializeWorker("test-queue", worker); + + Set> types = activityTypes(captureRegisteredActivities(worker)); + assertTrue(types.contains(ChatModelActivityImpl.class)); + assertTrue(types.contains(ExecuteToolLocalActivityImpl.class)); + // No VectorStore or EmbeddingModel — those are separate plugins now + assertFalse(types.contains(VectorStoreActivityImpl.class)); + assertFalse(types.contains(EmbeddingModelActivityImpl.class)); + } + + @Test + void multipleModels_allExposed() { + ChatModel model1 = mock(ChatModel.class); + ChatModel model2 = mock(ChatModel.class); + Map models = new LinkedHashMap<>(); + models.put("openai", model1); + models.put("anthropic", model2); + + SpringAiPlugin plugin = new SpringAiPlugin(models, model1); + + assertEquals(2, plugin.getChatModels().size()); + assertSame(model1, plugin.getChatModel("openai")); + assertSame(model2, plugin.getChatModel("anthropic")); + } + + @Test + void primaryModel_usedAsDefault() { + ChatModel model1 = mock(ChatModel.class); + ChatModel model2 = mock(ChatModel.class); + Map models = new LinkedHashMap<>(); + models.put("openai", model1); + models.put("anthropic", model2); + + SpringAiPlugin plugin = new SpringAiPlugin(models, model2); + + assertEquals("anthropic", plugin.getDefaultModelName()); + assertSame(model2, plugin.getChatModel()); + } + + @Test + void noPrimaryModel_firstEntryIsDefault() { + ChatModel model1 = mock(ChatModel.class); + ChatModel model2 = mock(ChatModel.class); + Map models = new LinkedHashMap<>(); + models.put("openai", model1); + models.put("anthropic", model2); + + SpringAiPlugin plugin = new SpringAiPlugin(models, null); + + assertEquals("openai", plugin.getDefaultModelName()); + assertSame(model1, plugin.getChatModel()); + } + + @Test + void singleModelConstructor_usesDefaultModelName() { + ChatModel chatModel = mock(ChatModel.class); + SpringAiPlugin plugin = new SpringAiPlugin(chatModel); + + assertEquals(SpringAiPlugin.DEFAULT_MODEL_NAME, plugin.getDefaultModelName()); + assertSame(chatModel, plugin.getChatModel()); + } + + @Test + void nullChatModelsMap_throwsIllegalArgument() { + assertThrows(IllegalArgumentException.class, () -> new SpringAiPlugin(null, null)); + } + + @Test + void emptyChatModelsMap_throwsIllegalArgument() { + assertThrows( + IllegalArgumentException.class, () -> new SpringAiPlugin(new LinkedHashMap<>(), null)); + } + + // --- VectorStorePlugin tests --- + + @Test + void vectorStorePlugin_registersActivity() { + VectorStore vectorStore = mock(VectorStore.class); + Worker worker = mock(Worker.class); + + VectorStorePlugin plugin = new VectorStorePlugin(vectorStore); + plugin.initializeWorker("test-queue", worker); + + Set> types = activityTypes(captureRegisteredActivities(worker)); + assertTrue(types.contains(VectorStoreActivityImpl.class)); + } + + // --- EmbeddingModelPlugin tests --- + + @Test + void embeddingModelPlugin_registersActivity() { + EmbeddingModel embeddingModel = mock(EmbeddingModel.class); + Worker worker = mock(Worker.class); + + EmbeddingModelPlugin plugin = new EmbeddingModelPlugin(embeddingModel); + plugin.initializeWorker("test-queue", worker); + + Set> types = activityTypes(captureRegisteredActivities(worker)); + assertTrue(types.contains(EmbeddingModelActivityImpl.class)); + } +} diff --git a/temporal-spring-ai/src/test/java/io/temporal/springai/util/TemporalToolUtilTest.java b/temporal-spring-ai/src/test/java/io/temporal/springai/util/TemporalToolUtilTest.java new file mode 100644 index 000000000..3742d2355 --- /dev/null +++ b/temporal-spring-ai/src/test/java/io/temporal/springai/util/TemporalToolUtilTest.java @@ -0,0 +1,254 @@ +package io.temporal.springai.util; + +import static org.junit.jupiter.api.Assertions.*; + +import io.temporal.springai.tool.DeterministicTool; +import io.temporal.springai.tool.SideEffectTool; +import io.temporal.springai.tool.SideEffectToolCallback; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.annotation.Tool; + +class TemporalToolUtilTest { + + // --- Test fixture classes --- + + @DeterministicTool + static class MathTools { + @Tool(description = "Add two numbers") + public int add(int a, int b) { + return a + b; + } + + @Tool(description = "Multiply two numbers") + public int multiply(int a, int b) { + return a * b; + } + } + + @SideEffectTool + static class TimestampTools { + @Tool(description = "Get the current timestamp") + public long currentTimeMillis() { + return System.currentTimeMillis(); + } + } + + @SideEffectTool + static class RandomTools { + @Tool(description = "Generate a random number") + public double random() { + return Math.random(); + } + } + + // No annotation + static class UnannotatedTools { + @Tool(description = "Some tool") + public String doSomething() { + return "result"; + } + } + + // --- Tests for convertTools with @DeterministicTool --- + + @Test + void convertTools_deterministicTool_producesStandardCallbacks() { + List callbacks = TemporalToolUtil.convertTools(new MathTools()); + + assertEquals(2, callbacks.size()); + // DeterministicTool callbacks should NOT be wrapped in SideEffectToolCallback + for (ToolCallback cb : callbacks) { + assertFalse( + cb instanceof SideEffectToolCallback, + "DeterministicTool should not produce SideEffectToolCallback"); + } + } + + @Test + void convertTools_deterministicTool_hasCorrectToolNames() { + List callbacks = TemporalToolUtil.convertTools(new MathTools()); + + List toolNames = + callbacks.stream().map(cb -> cb.getToolDefinition().name()).sorted().toList(); + assertEquals(List.of("add", "multiply"), toolNames); + } + + // --- Tests for convertTools with @SideEffectTool --- + + @Test + void convertTools_sideEffectTool_producesSideEffectCallbackWrappers() { + List callbacks = TemporalToolUtil.convertTools(new TimestampTools()); + + assertEquals(1, callbacks.size()); + assertInstanceOf(SideEffectToolCallback.class, callbacks.get(0)); + } + + @Test + void convertTools_sideEffectTool_hasCorrectToolName() { + List callbacks = TemporalToolUtil.convertTools(new TimestampTools()); + + assertEquals("currentTimeMillis", callbacks.get(0).getToolDefinition().name()); + } + + @Test + void convertTools_sideEffectTool_delegateIsPreserved() { + List callbacks = TemporalToolUtil.convertTools(new TimestampTools()); + + SideEffectToolCallback wrapper = (SideEffectToolCallback) callbacks.get(0); + assertNotNull(wrapper.getDelegate()); + assertEquals("currentTimeMillis", wrapper.getDelegate().getToolDefinition().name()); + } + + // --- Tests for unknown/unannotated objects --- + + @Test + void convertTools_unannotatedObject_throwsIllegalArgumentException() { + UnannotatedTools unannotated = new UnannotatedTools(); + + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, () -> TemporalToolUtil.convertTools(unannotated)); + assertTrue(ex.getMessage().contains("not a recognized Temporal primitive")); + assertTrue(ex.getMessage().contains("@DeterministicTool")); + assertTrue(ex.getMessage().contains("@SideEffectTool")); + assertTrue(ex.getMessage().contains(UnannotatedTools.class.getName())); + } + + @Test + void convertTools_plainString_throwsIllegalArgumentException() { + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, () -> TemporalToolUtil.convertTools("not a tool")); + assertTrue(ex.getMessage().contains("java.lang.String")); + } + + // --- Tests for null handling --- + + @Test + void convertTools_nullObject_throwsIllegalArgumentException() { + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, () -> TemporalToolUtil.convertTools((Object) null)); + assertTrue(ex.getMessage().contains("null")); + } + + @Test + void convertTools_nullInArray_throwsIllegalArgumentException() { + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, + () -> TemporalToolUtil.convertTools(new MathTools(), null)); + assertTrue(ex.getMessage().contains("null")); + } + + // --- Tests for empty input --- + + @Test + void convertTools_emptyArray_returnsEmptyList() { + List callbacks = TemporalToolUtil.convertTools(); + assertTrue(callbacks.isEmpty()); + } + + // --- Tests for mixed tool types --- + + @Test + void convertTools_mixedDeterministicAndSideEffect_allConvertCorrectly() { + List callbacks = + TemporalToolUtil.convertTools(new MathTools(), new TimestampTools(), new RandomTools()); + + // MathTools has 2 methods, TimestampTools has 1, RandomTools has 1 + assertEquals(4, callbacks.size()); + + long sideEffectCount = + callbacks.stream().filter(cb -> cb instanceof SideEffectToolCallback).count(); + long standardCount = + callbacks.stream().filter(cb -> !(cb instanceof SideEffectToolCallback)).count(); + + // 2 from TimestampTools + RandomTools are SideEffectToolCallback + assertEquals(2, sideEffectCount); + // 2 from MathTools are standard + assertEquals(2, standardCount); + } + + @Test + void convertTools_mixedWithUnannotated_throwsOnFirstUnannotated() { + assertThrows( + IllegalArgumentException.class, + () -> TemporalToolUtil.convertTools(new MathTools(), new UnannotatedTools())); + } + + // --- Tests for isRecognizedToolType --- + + @Test + void isRecognizedToolType_deterministicTool_returnsTrue() { + assertTrue(TemporalToolUtil.isRecognizedToolType(new MathTools())); + } + + @Test + void isRecognizedToolType_sideEffectTool_returnsTrue() { + assertTrue(TemporalToolUtil.isRecognizedToolType(new TimestampTools())); + } + + @Test + void isRecognizedToolType_unannotatedObject_returnsFalse() { + assertFalse(TemporalToolUtil.isRecognizedToolType(new UnannotatedTools())); + } + + @Test + void isRecognizedToolType_plainObject_returnsFalse() { + assertFalse(TemporalToolUtil.isRecognizedToolType("a string")); + assertFalse(TemporalToolUtil.isRecognizedToolType(42)); + } + + @Test + void isRecognizedToolType_null_returnsFalse() { + assertFalse(TemporalToolUtil.isRecognizedToolType(null)); + } + + // --- Tests for TemporalStubUtil negative cases --- + + @Test + void stubUtil_isActivityStub_nonProxy_returnsFalse() { + assertFalse(TemporalStubUtil.isActivityStub(new MathTools())); + assertFalse(TemporalStubUtil.isActivityStub("not a stub")); + assertFalse(TemporalStubUtil.isActivityStub(null)); + } + + @Test + void stubUtil_isLocalActivityStub_nonProxy_returnsFalse() { + assertFalse(TemporalStubUtil.isLocalActivityStub(new MathTools())); + assertFalse(TemporalStubUtil.isLocalActivityStub("not a stub")); + assertFalse(TemporalStubUtil.isLocalActivityStub(null)); + } + + @Test + void stubUtil_isChildWorkflowStub_nonProxy_returnsFalse() { + assertFalse(TemporalStubUtil.isChildWorkflowStub(new MathTools())); + assertFalse(TemporalStubUtil.isChildWorkflowStub("not a stub")); + assertFalse(TemporalStubUtil.isChildWorkflowStub(null)); + } + + @Test + void stubUtil_isNexusServiceStub_nonProxy_returnsFalse() { + assertFalse(TemporalStubUtil.isNexusServiceStub(new MathTools())); + assertFalse(TemporalStubUtil.isNexusServiceStub("not a stub")); + assertFalse(TemporalStubUtil.isNexusServiceStub(null)); + } + + @Test + void stubUtil_nonTemporalProxy_returnsFalse() { + // A JDK dynamic proxy that is NOT a Temporal stub should return false for all checks + Object proxy = + java.lang.reflect.Proxy.newProxyInstance( + getClass().getClassLoader(), + new Class[] {Runnable.class}, + (p, method, args) -> null); + + assertFalse(TemporalStubUtil.isActivityStub(proxy)); + assertFalse(TemporalStubUtil.isLocalActivityStub(proxy)); + assertFalse(TemporalStubUtil.isChildWorkflowStub(proxy)); + assertFalse(TemporalStubUtil.isNexusServiceStub(proxy)); + } +}