Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
)
from ldai_langchain.langchain_model_runner import LangChainModelRunner
from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner

__version__ = "0.1.0"

__all__ = [
'__version__',
'LangChainRunnerFactory',
'LangGraphAgentGraphRunner',
'LangChainModelRunner',
'convert_messages_to_langchain',
'create_langchain_model',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

from ldai.models import AIConfigKind
from ldai.providers import AIProvider
from ldai.providers import AIProvider, ToolRegistry

from ldai_langchain.langchain_helper import create_langchain_model
from ldai_langchain.langchain_model_runner import LangChainModelRunner
Expand All @@ -8,6 +10,19 @@
class LangChainRunnerFactory(AIProvider):
"""LangChain ``AIProvider`` implementation for the LaunchDarkly AI SDK."""

def create_agent_graph(self, graph_def: Any, tools: ToolRegistry) -> Any:
"""
Create a configured LangGraphAgentGraphRunner for the given graph definition.

:param graph_def: The AgentGraphDefinition to execute
:param tools: Registry mapping tool names to callables (langchain-compatible)
:return: LangGraphAgentGraphRunner ready to execute the graph
"""
from ldai_langchain.langgraph_agent_graph_runner import (
LangGraphAgentGraphRunner,
)
return LangGraphAgentGraphRunner(graph_def, tools)

def create_model(self, config: AIConfigKind) -> LangChainModelRunner:
"""
Create a configured LangChainModelRunner for the given AI config.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""

import operator
import time
from typing import Annotated, Any, List

from ldai import log
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry
from ldai.providers.types import LDAIMetrics

from ldai_langchain.langchain_helper import (
create_langchain_model,
get_ai_metrics_from_response,
get_ai_usage_from_response,
get_tool_calls_from_response,
sum_token_usage_from_messages,
)


class LangGraphAgentGraphRunner(AgentGraphRunner):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is a little ridiuclous.

"""
AgentGraphRunner implementation for LangGraph.

Compiles and runs the agent graph with LangGraph and automatically records
graph- and node-level AI metric data to the LaunchDarkly trackers on the
graph definition and each node.

Requires ``langgraph`` to be installed.
"""

def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry):
"""
Initialize the runner.

:param graph: The AgentGraphDefinition to execute
:param tools: Registry mapping tool names to callables (langchain-compatible)
"""
self._graph = graph
self._tools = tools

async def run(self, input: Any) -> AgentGraphResult:
"""
Run the agent graph with the given input.

Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles
it, and invokes it. Tracks latency and invocation success/failure.

:param input: The string prompt to send to the agent graph
:return: AgentGraphResult with the final output and metrics
"""
tracker = self._graph.get_tracker()
start_ns = time.perf_counter_ns()
try:
from langchain_core.messages import AnyMessage, HumanMessage
from langgraph.graph import END, START, StateGraph
from typing_extensions import TypedDict

class WorkflowState(TypedDict):
messages: Annotated[List[Any], operator.add]

agent_builder: StateGraph = StateGraph(WorkflowState)
root_node = self._graph.root()
root_key = root_node.get_key() if root_node else None
tools_ref = self._tools
exec_path: List[str] = []

def handle_traversal(node: AgentGraphNode, ctx: dict) -> None:
node_config = node.get_config()
node_key = node.get_key()
node_tracker = node_config.tracker

model = None
if node_config.model:
lc_model = create_langchain_model(node_config)
tool_defs = node_config.model.get_parameter('tools') or []
tool_fns = [
tools_ref[t.get('name', '')]
for t in tool_defs
if t.get('name', '') in tools_ref
]
model = lc_model.bind_tools(tool_fns) if tool_fns else lc_model

def invoke(state: WorkflowState) -> WorkflowState:
exec_path.append(node_key)
if not model:
return {'messages': []}
gk = tracker.graph_key if tracker is not None else None
if node_tracker:
response = node_tracker.track_metrics_of(
lambda: model.invoke(state['messages']),
get_ai_metrics_from_response,
graph_key=gk,
)
node_tracker.track_tool_calls(
get_tool_calls_from_response(response),
graph_key=tracker.graph_key if tracker is not None else None,
)
else:
response = model.invoke(state['messages'])

return {'messages': [response]}

invoke.__name__ = node_key

agent_builder.add_node(node_key, invoke)

if node_key == root_key:
agent_builder.add_edge(START, node_key)

if node.is_terminal():
agent_builder.add_edge(node_key, END)

for edge in node.get_edges():
agent_builder.add_edge(node_key, edge.target_config)

return None

self._graph.traverse(fn=handle_traversal)
compiled = agent_builder.compile()

result = await compiled.ainvoke( # type: ignore[call-overload]
{'messages': [HumanMessage(content=str(input))]}
)
duration = (time.perf_counter_ns() - start_ns) // 1_000_000

output = ''
messages = result.get('messages', [])
if messages:
last = messages[-1]
if hasattr(last, 'content'):
output = str(last.content)

if tracker:
tracker.track_path(exec_path)
tracker.track_latency(duration)
tracker.track_invocation_success()
tracker.track_total_tokens(
sum_token_usage_from_messages(messages)
)

return AgentGraphResult(
output=output,
raw=result,
metrics=LDAIMetrics(success=True),
)
except Exception as exc:
if isinstance(exc, ImportError):
log.warning(
"langgraph is required for LangGraphAgentGraphRunner. "
"Install it with: pip install langgraph"
)
else:
log.warning(f'LangGraphAgentGraphRunner run failed: {exc}')
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
if tracker:
tracker.track_latency(duration)
tracker.track_invocation_failure()
return AgentGraphResult(
output='',
raw=None,
metrics=LDAIMetrics(success=False),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Tests for LangGraphAgentGraphRunner and LangChainRunnerFactory.create_agent_graph()."""

import pytest
from unittest.mock import AsyncMock, MagicMock, patch

from ldai.agent_graph import AgentGraphDefinition
from ldai.models import AIAgentGraphConfig, AIAgentConfig, ModelConfig, ProviderConfig
from ldai.providers import AgentGraphResult, ToolRegistry
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory


def _make_graph(enabled: bool = True) -> AgentGraphDefinition:
root_config = AIAgentConfig(
key='root-agent',
enabled=enabled,
model=ModelConfig(name='gpt-4'),
provider=ProviderConfig(name='openai'),
instructions='You are a helpful assistant.',
tracker=MagicMock(),
)
graph_config = AIAgentGraphConfig(
key='test-graph',
root_config_key='root-agent',
edges=[],
enabled=enabled,
)
nodes = AgentGraphDefinition.build_nodes(graph_config, {'root-agent': root_config})
return AgentGraphDefinition(
agent_graph=graph_config,
nodes=nodes,
context=MagicMock(),
enabled=enabled,
tracker=MagicMock(),
)


# --- Factory ---

def test_langchain_runner_factory_create_agent_graph_returns_runner():
graph = _make_graph()
tools: ToolRegistry = {'fetch_weather': lambda loc: f'weather in {loc}'}
factory = LangChainRunnerFactory()
runner = factory.create_agent_graph(graph, tools)
assert isinstance(runner, LangGraphAgentGraphRunner)


def test_langchain_runner_factory_create_agent_graph_wires_graph_and_tools():
graph = _make_graph()
tools: ToolRegistry = {}
factory = LangChainRunnerFactory()
runner = factory.create_agent_graph(graph, tools)
assert runner._graph is graph
assert runner._tools is tools


# --- LangGraphAgentGraphRunner ---

def test_langgraph_runner_stores_graph_and_tools():
graph = _make_graph()
tools: ToolRegistry = {}
runner = LangGraphAgentGraphRunner(graph, tools)
assert runner._graph is graph
assert runner._tools is tools


@pytest.mark.asyncio
async def test_langgraph_runner_run_raises_when_langgraph_not_installed():
graph = _make_graph()
runner = LangGraphAgentGraphRunner(graph, {})

with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
result = await runner.run("test")
assert isinstance(result, AgentGraphResult)
assert result.metrics.success is False


@pytest.mark.asyncio
async def test_langgraph_runner_run_tracks_failure_on_exception():
graph = _make_graph()
tracker = graph.get_tracker()
runner = LangGraphAgentGraphRunner(graph, {})

with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
result = await runner.run("fail")

assert result.metrics.success is False
tracker.track_invocation_failure.assert_called_once()
tracker.track_latency.assert_called_once()


@pytest.mark.asyncio
async def test_langgraph_runner_run_success():
graph = _make_graph()
tracker = graph.get_tracker()

mock_message = MagicMock()
mock_message.content = "langgraph answer"
mock_message.usage_metadata = None
mock_message.response_metadata = None

mock_compiled = MagicMock()
mock_compiled.ainvoke = AsyncMock(return_value={'messages': [mock_message]})

mock_state_graph_instance = MagicMock()
mock_state_graph_instance.add_node = MagicMock()
mock_state_graph_instance.add_edge = MagicMock()
mock_state_graph_instance.compile = MagicMock(return_value=mock_compiled)

mock_langgraph_graph = MagicMock()
mock_langgraph_graph.END = 'END'
mock_langgraph_graph.START = 'START'
mock_langgraph_graph.StateGraph = MagicMock(return_value=mock_state_graph_instance)

mock_human_message = MagicMock()
mock_lc_core_messages = MagicMock()
mock_lc_core_messages.HumanMessage = MagicMock(return_value=mock_human_message)
mock_lc_core_messages.AnyMessage = MagicMock()

mock_model_response = MagicMock()
mock_model_response.content = 'langgraph answer'
mock_model_response.usage_metadata = None
mock_model_response.response_metadata = None
mock_model_response.tool_calls = None

mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_model_response)

mock_init_model = MagicMock()
mock_init_model.return_value = mock_llm
mock_langchain_chat = MagicMock()
mock_langchain_chat.init_chat_model = mock_init_model

with patch.dict('sys.modules', {
'langgraph': MagicMock(),
'langgraph.graph': mock_langgraph_graph,
'langchain_core': MagicMock(),
'langchain_core.messages': mock_lc_core_messages,
'langchain': MagicMock(),
'langchain.chat_models': mock_langchain_chat,
'typing_extensions': __import__('typing_extensions'),
}):
runner = LangGraphAgentGraphRunner(graph, {})
result = await runner.run("find restaurants")

assert isinstance(result, AgentGraphResult)
assert result.output == "langgraph answer"
assert result.metrics.success is True
tracker.track_path.assert_called_once_with([])
tracker.track_invocation_success.assert_called_once()
tracker.track_latency.assert_called_once()
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner
from ldai_openai.openai_helper import (
convert_messages_to_openai,
get_ai_metrics_from_response,
Expand All @@ -8,6 +9,7 @@

__all__ = [
'OpenAIRunnerFactory',
'OpenAIAgentGraphRunner',
'OpenAIModelRunner',
'convert_messages_to_openai',
'get_ai_metrics_from_response',
Expand Down
Loading
Loading