diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py
index ad484c4d5..a3b3cd590 100644
--- a/llama_cpp/llama.py
+++ b/llama_cpp/llama.py
@@ -1973,6 +1973,9 @@ def create_chat_completion(
logit_bias: Optional[Dict[int, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
+ reasoning_effort: Optional[
+ Literal["none", "minimal", "low", "medium", "high", "xhigh"]
+ ] = None,
) -> Union[
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
]:
@@ -2005,6 +2008,8 @@ def create_chat_completion(
logits_processor: A list of logits processors to use.
grammar: A grammar to use.
logit_bias: A logit bias to use.
+ reasoning_effort: Optional reasoning hint forwarded to chat handlers as a
+ chat-template keyword argument.
Returns:
Generated chat completion or a stream of chat completion chunks.
@@ -2044,6 +2049,7 @@ def create_chat_completion(
logits_processor=logits_processor,
grammar=grammar,
logit_bias=logit_bias,
+ reasoning_effort=reasoning_effort,
)
def create_chat_completion_openai_v1(
diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py
index d7910e984..1024fb85b 100644
--- a/llama_cpp/llama_chat_format.py
+++ b/llama_cpp/llama_chat_format.py
@@ -243,6 +243,7 @@ def raise_exception(message: str):
tools=tools,
tool_choice=tool_choice,
strftime_now=self.strftime_now,
+ **kwargs,
)
stopping_criteria = None
@@ -617,6 +618,7 @@ def chat_completion_handler(
function_call=function_call,
tools=tools,
tool_choice=tool_choice,
+ **kwargs,
)
prompt = llama.tokenize(
result.prompt.encode("utf-8"),
@@ -734,7 +736,9 @@ def format_autotokenizer(
**kwargs: Any,
) -> ChatFormatterResponse:
tokenizer.use_default_system_prompt = False # type: ignore
- prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
+ prompt: str = tokenizer.apply_chat_template( # type: ignore
+ messages, tokenize=False, **kwargs
+ )
assert isinstance(prompt, str)
# Return formatted prompt and eos token by default
return ChatFormatterResponse(
@@ -791,6 +795,7 @@ def format_tokenizer_config(
messages=messages,
bos_token=bos_token,
eos_token=eos_token,
+ **kwargs,
)
return ChatFormatterResponse(
prompt=prompt, stop=[eos_token, bos_token], added_special=True
diff --git a/llama_cpp/server/types.py b/llama_cpp/server/types.py
index fdd164456..df5de1b4e 100644
--- a/llama_cpp/server/types.py
+++ b/llama_cpp/server/types.py
@@ -235,6 +235,12 @@ class CreateChatCompletionRequest(BaseModel):
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
default=None,
)
+ reasoning_effort: Optional[
+ Literal["none", "minimal", "low", "medium", "high", "xhigh"]
+ ] = Field(
+ default=None,
+ description="Optional reasoning-effort hint exposed to chat templates as the `reasoning_effort` keyword argument.",
+ )
# ignored or currently unsupported
model: Optional[str] = model_field
diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py
index 18c7279cf..b682c99c7 100644
--- a/tests/test_llama_chat_format.py
+++ b/tests/test_llama_chat_format.py
@@ -1,12 +1,13 @@
import json
+import inspect
import jinja2
-from llama_cpp import (
- ChatCompletionRequestUserMessage,
-)
+import llama_cpp
+from llama_cpp import ChatCompletionRequestUserMessage
import llama_cpp.llama_types as llama_types
import llama_cpp.llama_chat_format as llama_chat_format
+import llama_cpp.server.types as server_types
from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
@@ -92,3 +93,108 @@ def test_hf_tokenizer_config_str_to_chat_formatter():
)
assert chat_formatter_respoonse.prompt == ("[INST] Hello, world! [/INST]")
+
+
+def test_jinja2_chat_formatter_passes_template_kwargs():
+ chat_formatter = llama_chat_format.Jinja2ChatFormatter(
+ template="{{ reasoning_effort | default('unset') }} {{ messages[0]['content'] }}",
+ bos_token="",
+ eos_token="",
+ )
+ response = chat_formatter(
+ messages=[
+ ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
+ ],
+ reasoning_effort="low",
+ )
+
+ assert response.prompt == "low Hello, world!"
+
+
+def test_hf_tokenizer_config_chat_formatter_passes_template_kwargs():
+ tokenizer_config = {
+ "chat_template": "{{ bos_token }}{{ reasoning_effort | default('unset') }} {{ messages[0]['content'] }}",
+ "bos_token": "",
+ "eos_token": "",
+ }
+ chat_formatter = hf_tokenizer_config_to_chat_formatter(
+ tokenizer_config, add_generation_prompt=False
+ )
+ response = chat_formatter(
+ messages=[
+ ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
+ ],
+ reasoning_effort="medium",
+ )
+
+ assert response.prompt == "medium Hello, world!"
+
+
+def test_chat_completion_handler_passes_template_kwargs():
+ captured = {}
+
+ def chat_formatter(*, messages, **kwargs):
+ captured["messages"] = messages
+ captured["kwargs"] = kwargs
+ return llama_chat_format.ChatFormatterResponse(prompt="Hello")
+
+ handler = llama_chat_format.chat_formatter_to_chat_completion_handler(
+ chat_formatter
+ )
+
+ class DummyLlama:
+ verbose = False
+
+ def tokenize(self, data, add_bos, special):
+ return [1]
+
+ def create_completion(self, **kwargs):
+ return {
+ "id": "cmpl-test",
+ "object": "text_completion",
+ "created": 0,
+ "model": "dummy",
+ "choices": [
+ {
+ "text": "world",
+ "index": 0,
+ "logprobs": None,
+ "finish_reason": "stop",
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 1,
+ "completion_tokens": 1,
+ "total_tokens": 2,
+ },
+ }
+
+ response = handler(
+ llama=DummyLlama(),
+ messages=[
+ ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
+ ],
+ reasoning_effort="high",
+ )
+
+ assert response["choices"][0]["message"]["content"] == "world"
+ assert captured["kwargs"]["reasoning_effort"] == "high"
+
+
+def test_create_chat_completion_exposes_reasoning_effort_parameter():
+ parameter = inspect.signature(llama_cpp.Llama.create_chat_completion).parameters[
+ "reasoning_effort"
+ ]
+
+ assert parameter.default is None
+
+
+def test_server_chat_completion_request_accepts_reasoning_effort():
+ request = server_types.CreateChatCompletionRequest(
+ messages=[
+ ChatCompletionRequestUserMessage(role="user", content="Hello, world!")
+ ],
+ reasoning_effort="minimal",
+ )
+
+ assert request.reasoning_effort == "minimal"