Skip to content

Commit f5c2671

Browse files
committed
sync load_prompt and aload_prompt
1 parent 16e28a6 commit f5c2671

1 file changed

Lines changed: 91 additions & 35 deletions

File tree

py/src/braintrust/logger.py

Lines changed: 91 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def post_json(self, object_type: str, args: Mapping[str, Any] | None = None) ->
786786

787787
async def aget_json(
788788
self, object_type: str, args: Optional[Mapping[str, Any]] = None, retries: int = 0
789-
) -> Mapping[str, Any]:
789+
) -> Mapping[str, Any] | None:
790790
"""
791791
Async version of get_json. Makes a true async HTTP GET request and returns JSON response.
792792
"""
@@ -833,15 +833,16 @@ async def _make_aiohttp_request(self, url: str) -> Mapping[str, Any]:
833833

834834
async def _make_asyncio_request(self, url: str) -> Mapping[str, Any]:
835835
"""Make async HTTP request using asyncio and urllib (fallback)"""
836-
loop = asyncio.get_event_loop()
836+
loop = asyncio.get_running_loop()
837+
timeout_secs = parse_env_var_float("BRAINTRUST_HTTP_TIMEOUT", 60.0)
837838

838839
def sync_request():
839840
request = Request(url)
840841
if self.token:
841842
request.add_header("Authorization", f"Bearer {self.token}")
842843

843844
try:
844-
response_obj = urlopen(request)
845+
response_obj = urlopen(request, timeout=timeout_secs)
845846
response_data = response_obj.read()
846847
return json.loads(response_data.decode("utf-8"))
847848
except HTTPError as e:
@@ -2079,8 +2080,10 @@ async def aload_prompt(
20792080
slug: Optional[str] = None,
20802081
version: Optional[Union[str, int]] = None,
20812082
project_id: Optional[str] = None,
2083+
prompt_id: str | None = None,
20822084
defaults: Optional[Mapping[str, Any]] = None,
20832085
no_trace: bool = False,
2086+
environment: str | None = None,
20842087
app_url: Optional[str] = None,
20852088
api_key: Optional[str] = None,
20862089
org_name: Optional[str] = None,
@@ -2092,81 +2095,134 @@ async def aload_prompt(
20922095
:param slug: The slug of the prompt to load.
20932096
:param version: An optional version of the prompt (to read). If not specified, the latest version will be used.
20942097
:param project_id: The id of the project to load the prompt from. This takes precedence over `project` if specified.
2098+
:param prompt_id: The id of a specific prompt to load. If specified, this takes precedence over all other parameters (project, slug, version).
20952099
:param defaults: (Optional) A dictionary of default values to use when rendering the prompt. Prompt values will override these defaults.
20962100
:param no_trace: If true, do not include logging metadata for this prompt when build() is called.
2101+
:param environment: The environment to load the prompt from. Cannot be used together with version.
20972102
:param app_url: The URL of the Braintrust App. Defaults to https://www.braintrust.dev.
20982103
:param api_key: The API key to use. If the parameter is not specified, will try to use the `BRAINTRUST_API_KEY` environment variable. If no API
20992104
key is specified, will prompt the user to login.
21002105
:param org_name: (Optional) The name of a specific organization to connect to. This is useful if you belong to multiple.
21012106
:returns: The prompt object.
21022107
"""
21032108

2104-
if not project and not project_id:
2109+
if version is not None and environment is not None:
2110+
raise ValueError(
2111+
"Cannot specify both 'version' and 'environment' parameters. Please use only one (remove the other)."
2112+
)
2113+
2114+
if prompt_id:
2115+
pass
2116+
elif not project and not project_id:
21052117
raise ValueError("Must specify at least one of project or project_id")
2106-
if not slug:
2118+
elif not slug:
21072119
raise ValueError("Must specify slug")
21082120

2109-
loop = asyncio.get_event_loop()
2121+
loop = asyncio.get_running_loop()
2122+
response = None
21102123

21112124
try:
21122125
# Run login in thread pool since it's synchronous
21132126
await loop.run_in_executor(HTTP_REQUEST_THREAD_POOL, login, app_url, api_key, org_name)
2114-
2115-
# Make async HTTP request
2116-
args = _populate_args(
2117-
{
2118-
"project_name": project,
2119-
"project_id": project_id,
2120-
"slug": slug,
2127+
if prompt_id:
2128+
args = _populate_args({
21212129
"version": version,
2122-
},
2123-
)
2130+
"environment": environment
2131+
})
2132+
2133+
response = await _state.api_conn().aget_json(f"/v1/prompt/{prompt_id}", args)
2134+
2135+
if response:
2136+
response = {"objects": [response]}
2137+
2138+
else:
2139+
args = _populate_args(
2140+
{
2141+
"project_name": project,
2142+
"project_id": project_id,
2143+
"slug": slug,
2144+
"version": version,
2145+
"environment": environment
2146+
},
2147+
)
21242148

2125-
response = await _state.api_conn().aget_json("/v1/prompt", args)
2149+
response = await _state.api_conn().aget_json("/v1/prompt", args)
21262150

21272151
except Exception as server_error:
2152+
# If environment was specified, don't fall back to cache
2153+
if environment is not None:
2154+
raise ValueError(f"Prompt not found for specified environment {environment}") from server_error
2155+
21282156
eprint(f"Failed to load prompt, attempting to fall back to cache: {server_error}")
21292157
try:
2130-
cache_result = await loop.run_in_executor(
2131-
HTTP_REQUEST_THREAD_POOL,
2132-
lambda: _state._prompt_cache.get(
2133-
slug,
2134-
version=str(version) if version else "latest",
2135-
project_id=project_id,
2136-
project_name=project,
2137-
),
2138-
)
2158+
if prompt_id:
2159+
cache_result = await loop.run_in_executor(
2160+
HTTP_REQUEST_THREAD_POOL,
2161+
lambda: _state._prompt_cache.get(
2162+
id=prompt_id
2163+
),
2164+
)
2165+
else:
2166+
cache_result = await loop.run_in_executor(
2167+
HTTP_REQUEST_THREAD_POOL,
2168+
lambda: _state._prompt_cache.get(
2169+
slug,
2170+
version=str(version) if version else "latest",
2171+
project_id=project_id,
2172+
project_name=project,
2173+
),
2174+
)
21392175
# Return Prompt with pre-computed metadata from cache
21402176
return Prompt(
21412177
lazy_metadata=LazyValue(lambda: cache_result, use_mutex=True),
21422178
defaults=defaults or {},
21432179
no_trace=no_trace,
21442180
)
21452181
except Exception as cache_error:
2182+
if prompt_id:
2183+
raise ValueError(
2184+
f"Prompt with id {prompt_id} not found (not found on server or in local cache): {cache_error}"
2185+
) from server_error
21462186
raise ValueError(
21472187
f"Prompt {slug} (version {version or 'latest'}) not found in {project or project_id} (not found on server or in local cache): {cache_error}"
21482188
) from server_error
21492189

21502190
if response is None or "objects" not in response or len(response["objects"]) == 0:
2191+
if prompt_id:
2192+
raise ValueError(f"Prompt with id {prompt_id} not found.")
2193+
21512194
raise ValueError(f"Prompt {slug} not found in project {project or project_id}.")
21522195
elif len(response["objects"]) > 1:
2196+
if prompt_id:
2197+
raise ValueError(f"Multiple prompts found with id {prompt_id}. This should never happen.")
2198+
21532199
raise ValueError(
21542200
f"Multiple prompts found with slug {slug} in project {project or project_id}. This should never happen."
21552201
)
21562202

21572203
resp_prompt = response["objects"][0]
21582204
prompt_metadata = PromptSchema.from_dict_deep(resp_prompt)
21592205
try:
2160-
await loop.run_in_executor(
2161-
HTTP_REQUEST_THREAD_POOL,
2162-
lambda: _state._prompt_cache.set(
2163-
slug,
2164-
str(version) if version else "latest",
2165-
prompt_metadata,
2166-
project_id=project_id,
2167-
project_name=project,
2168-
),
2169-
)
2206+
# save prompt to cache
2207+
if prompt_id:
2208+
await loop.run_in_executor(
2209+
HTTP_REQUEST_THREAD_POOL,
2210+
lambda: _state._prompt_cache.set(
2211+
prompt_metadata,
2212+
id=prompt_id
2213+
),
2214+
)
2215+
else:
2216+
await loop.run_in_executor(
2217+
HTTP_REQUEST_THREAD_POOL,
2218+
lambda: _state._prompt_cache.set(
2219+
prompt_metadata,
2220+
slug=slug,
2221+
version=str(version) if version else "latest",
2222+
project_id=project_id,
2223+
project_name=project,
2224+
),
2225+
)
21702226
except Exception as e:
21712227
eprint(f"Failed to store prompt in cache: {e}")
21722228

0 commit comments

Comments
 (0)