Skip to content

Commit e569567

Browse files
committed
Add image generation support
Add experimental parameters for image generation models: - width: width of generated image in pixels - height: height of generated image in pixels - steps: number of diffusion steps Add response fields for image generation: - image: base64-encoded generated image data - completed/total: streaming progress indicators
1 parent 60e7b2f commit e569567

6 files changed

Lines changed: 188 additions & 2 deletions

File tree

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,6 @@ ollama.embed(model='gemma3', input=['The sky is blue because of rayleigh scatter
250250
ollama.ps()
251251
```
252252

253-
254253
## Errors
255254

256255
Errors are raised if requests return an error status or if an error is detected while streaming.

examples/generate-image-stream.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Image generation is experimental and currently only available on macOS
2+
3+
import base64
4+
5+
from ollama import generate
6+
7+
for response in generate(model='x/z-image-turbo', prompt='a sunset over mountains', stream=True):
8+
if response.get('done'):
9+
# Final response contains the image
10+
with open('output.png', 'wb') as f:
11+
f.write(base64.b64decode(response['image']))
12+
print('\nImage saved to output.png')
13+
else:
14+
# Progress update
15+
print(f"Progress: {response.get('completed')}/{response.get('total')}", end='\r')

examples/generate-image.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Image generation is experimental and currently only available on macOS
2+
3+
import base64
4+
5+
from ollama import generate
6+
7+
response = generate(
8+
model='x/z-image-turbo',
9+
prompt='a sunset over mountains',
10+
width=1024,
11+
height=768,
12+
)
13+
14+
# Save the generated image
15+
with open('output.png', 'wb') as f:
16+
f.write(base64.b64decode(response['image']))
17+
18+
print('Image saved to output.png')

ollama/_client.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def generate(
217217
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
218218
options: Optional[Union[Mapping[str, Any], Options]] = None,
219219
keep_alive: Optional[Union[float, str]] = None,
220+
width: Optional[int] = None,
221+
height: Optional[int] = None,
222+
steps: Optional[int] = None,
220223
) -> GenerateResponse: ...
221224

222225
@overload
@@ -238,6 +241,9 @@ def generate(
238241
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
239242
options: Optional[Union[Mapping[str, Any], Options]] = None,
240243
keep_alive: Optional[Union[float, str]] = None,
244+
width: Optional[int] = None,
245+
height: Optional[int] = None,
246+
steps: Optional[int] = None,
241247
) -> Iterator[GenerateResponse]: ...
242248

243249
def generate(
@@ -258,6 +264,9 @@ def generate(
258264
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
259265
options: Optional[Union[Mapping[str, Any], Options]] = None,
260266
keep_alive: Optional[Union[float, str]] = None,
267+
width: Optional[int] = None,
268+
height: Optional[int] = None,
269+
steps: Optional[int] = None,
261270
) -> Union[GenerateResponse, Iterator[GenerateResponse]]:
262271
"""
263272
Create a response using the requested model.
@@ -289,6 +298,9 @@ def generate(
289298
images=list(_copy_images(images)) if images else None,
290299
options=options,
291300
keep_alive=keep_alive,
301+
width=width,
302+
height=height,
303+
steps=steps,
292304
).model_dump(exclude_none=True),
293305
stream=stream,
294306
)
@@ -838,6 +850,9 @@ async def generate(
838850
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
839851
options: Optional[Union[Mapping[str, Any], Options]] = None,
840852
keep_alive: Optional[Union[float, str]] = None,
853+
width: Optional[int] = None,
854+
height: Optional[int] = None,
855+
steps: Optional[int] = None,
841856
) -> GenerateResponse: ...
842857

843858
@overload
@@ -859,6 +874,9 @@ async def generate(
859874
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
860875
options: Optional[Union[Mapping[str, Any], Options]] = None,
861876
keep_alive: Optional[Union[float, str]] = None,
877+
width: Optional[int] = None,
878+
height: Optional[int] = None,
879+
steps: Optional[int] = None,
862880
) -> AsyncIterator[GenerateResponse]: ...
863881

864882
async def generate(
@@ -879,6 +897,9 @@ async def generate(
879897
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
880898
options: Optional[Union[Mapping[str, Any], Options]] = None,
881899
keep_alive: Optional[Union[float, str]] = None,
900+
width: Optional[int] = None,
901+
height: Optional[int] = None,
902+
steps: Optional[int] = None,
882903
) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]:
883904
"""
884905
Create a response using the requested model.
@@ -909,6 +930,9 @@ async def generate(
909930
images=list(_copy_images(images)) if images else None,
910931
options=options,
911932
keep_alive=keep_alive,
933+
width=width,
934+
height=height,
935+
steps=steps,
912936
).model_dump(exclude_none=True),
913937
stream=stream,
914938
)

ollama/_types.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,16 @@ class GenerateRequest(BaseGenerateRequest):
216216
top_logprobs: Optional[int] = None
217217
'Number of alternative tokens and log probabilities to include per position (0-20).'
218218

219+
# Experimental image generation parameters
220+
width: Optional[int] = None
221+
'Width of the generated image in pixels (for image generation models).'
222+
223+
height: Optional[int] = None
224+
'Height of the generated image in pixels (for image generation models).'
225+
226+
steps: Optional[int] = None
227+
'Number of diffusion steps (for image generation models).'
228+
219229

220230
class BaseGenerateResponse(SubscriptableBaseModel):
221231
model: Optional[str] = None
@@ -267,7 +277,7 @@ class GenerateResponse(BaseGenerateResponse):
267277
Response returned by generate requests.
268278
"""
269279

270-
response: str
280+
response: Optional[str] = None
271281
'Response content. When streaming, this contains a fragment of the response.'
272282

273283
thinking: Optional[str] = None
@@ -279,6 +289,17 @@ class GenerateResponse(BaseGenerateResponse):
279289
logprobs: Optional[Sequence[Logprob]] = None
280290
'Log probabilities for generated tokens.'
281291

292+
# Image generation response fields
293+
image: Optional[str] = None
294+
'Base64-encoded generated image data (for image generation models).'
295+
296+
# Streaming progress fields (for image generation)
297+
completed: Optional[int] = None
298+
'Number of completed steps (for image generation streaming).'
299+
300+
total: Optional[int] = None
301+
'Total number of steps (for image generation streaming).'
302+
282303

283304
class Message(SubscriptableBaseModel):
284305
"""

tests/test_client.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,115 @@ class ResponseFormat(BaseModel):
568568
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
569569

570570

571+
def test_client_generate_image(httpserver: HTTPServer):
572+
httpserver.expect_ordered_request(
573+
'/api/generate',
574+
method='POST',
575+
json={
576+
'model': 'dummy-image',
577+
'prompt': 'a sunset over mountains',
578+
'stream': False,
579+
'width': 1024,
580+
'height': 768,
581+
'steps': 20,
582+
},
583+
).respond_with_json(
584+
{
585+
'model': 'dummy-image',
586+
'image': PNG_BASE64,
587+
'done': True,
588+
'done_reason': 'stop',
589+
}
590+
)
591+
592+
client = Client(httpserver.url_for('/'))
593+
response = client.generate('dummy-image', 'a sunset over mountains', width=1024, height=768, steps=20)
594+
assert response['model'] == 'dummy-image'
595+
assert response['image'] == PNG_BASE64
596+
assert response['done'] is True
597+
598+
599+
def test_client_generate_image_stream(httpserver: HTTPServer):
600+
def stream_handler(_: Request):
601+
def generate():
602+
# Progress updates
603+
for i in range(1, 4):
604+
yield (
605+
json.dumps(
606+
{
607+
'model': 'dummy-image',
608+
'completed': i,
609+
'total': 3,
610+
'done': False,
611+
}
612+
)
613+
+ '\n'
614+
)
615+
# Final response with image
616+
yield (
617+
json.dumps(
618+
{
619+
'model': 'dummy-image',
620+
'image': PNG_BASE64,
621+
'done': True,
622+
'done_reason': 'stop',
623+
}
624+
)
625+
+ '\n'
626+
)
627+
628+
return Response(generate())
629+
630+
httpserver.expect_ordered_request(
631+
'/api/generate',
632+
method='POST',
633+
json={
634+
'model': 'dummy-image',
635+
'prompt': 'a sunset over mountains',
636+
'stream': True,
637+
'width': 512,
638+
'height': 512,
639+
},
640+
).respond_with_handler(stream_handler)
641+
642+
client = Client(httpserver.url_for('/'))
643+
response = client.generate('dummy-image', 'a sunset over mountains', stream=True, width=512, height=512)
644+
645+
parts = list(response)
646+
# Check progress updates
647+
assert parts[0]['completed'] == 1
648+
assert parts[0]['total'] == 3
649+
assert parts[0]['done'] is False
650+
# Check final response
651+
assert parts[-1]['image'] == PNG_BASE64
652+
assert parts[-1]['done'] is True
653+
654+
655+
async def test_async_client_generate_image(httpserver: HTTPServer):
656+
httpserver.expect_ordered_request(
657+
'/api/generate',
658+
method='POST',
659+
json={
660+
'model': 'dummy-image',
661+
'prompt': 'a robot painting',
662+
'stream': False,
663+
'width': 1024,
664+
'height': 1024,
665+
},
666+
).respond_with_json(
667+
{
668+
'model': 'dummy-image',
669+
'image': PNG_BASE64,
670+
'done': True,
671+
}
672+
)
673+
674+
client = AsyncClient(httpserver.url_for('/'))
675+
response = await client.generate('dummy-image', 'a robot painting', width=1024, height=1024)
676+
assert response['model'] == 'dummy-image'
677+
assert response['image'] == PNG_BASE64
678+
679+
571680
def test_client_pull(httpserver: HTTPServer):
572681
httpserver.expect_ordered_request(
573682
'/api/pull',

0 commit comments

Comments
 (0)