Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Made Google Gemini Image node functional
  • Loading branch information
Kosinkadink committed Aug 26, 2025
commit eabff3bed71e3e667950ca04de31d78cdb9e31ea
19 changes: 19 additions & 0 deletions comfy_api_nodes/apis/gemini_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing import List, Optional

from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
from pydantic import BaseModel


class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[List[str]] = None


class GeminiImageGenerateContentRequest(BaseModel):
contents: List[GeminiContent]
generationConfig: Optional[GeminiImageGenerationConfig] = None
safetySettings: Optional[List[GeminiSafetySetting]] = None
systemInstruction: Optional[GeminiSystemInstructionContent] = None
tools: Optional[List[GeminiTool]] = None
videoMetadata: Optional[GeminiVideoMetadata] = None
178 changes: 122 additions & 56 deletions comfy_api_nodes/nodes_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import time
import os
import uuid
import base64
from io import BytesIO
from enum import Enum
from typing import Optional, Literal

Expand All @@ -24,6 +26,7 @@
GeminiPart,
GeminiMimeType,
)
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
Expand All @@ -34,6 +37,7 @@
audio_to_base64_string,
video_to_base64_string,
tensor_to_base64_string,
bytesio_to_image_tensor,
)


Expand All @@ -52,6 +56,14 @@ class GeminiModel(str, Enum):
gemini_2_5_flash = "gemini-2.5-flash"


class GeminiImageModel(str, Enum):
"""
Gemini Image Model Names allowed by comfy-api
"""

gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"


def get_gemini_endpoint(
model: GeminiModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
Expand All @@ -74,6 +86,28 @@ def get_gemini_endpoint(
)


def get_gemini_image_endpoint(
model: GeminiImageModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
"""
Get the API endpoint for a given Gemini model.

Args:
model: The Gemini model to use, either as enum or string value.

Returns:
ApiEndpoint configured for the specific Gemini model.
"""
if isinstance(model, str):
model = GeminiImageModel(model)
return ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
method=HttpMethod.POST,
request_model=GeminiImageGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
)


def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.
Expand Down Expand Up @@ -171,32 +205,14 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:

def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
image_tensors: list[torch.Tensor] = []
# TODO:
"""
TODO something like this but without download but getting it from response:

# Process each image in the data array
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload – neither URL nor base64 data present.")

pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))

"""
return torch.stack(image_tensors, dim=0)
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1,1024,1024,4))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's pretty large

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the res of the image that Gemini returns normally, I tried to keep it the same so that previews don't get resized too much

return torch.cat(image_tensors, dim=0)


class GeminiNode(ComfyNodeABC):
Expand Down Expand Up @@ -497,7 +513,14 @@ def prepare_files(


class GeminiImage(ComfyNodeABC):
"""
Node to generate text and image responses from a Gemini model.

This node allows users to interact with Google's Gemini AI models, providing
multimodal inputs (text, images, files) to generate coherent
text and image responses. The node works with the latest Gemini models, handling the
API communication and response parsing.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
Expand All @@ -510,18 +533,38 @@ def INPUT_TYPES(cls) -> InputTypeDict:
"tooltip": "Text prompt for generation",
},
),
},
"optional": {
"model": (
IO.COMBO,
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiImageModel],
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
},
),
"seed": (
IO.INT,
{
"default": 0,
"default": 42,
"min": 0,
"max": 2**31 - 1,
"step": 1,
"display": "number",
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "not implemented yet in backend",
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
},
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"files": (
"GEMINI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
},
),
# TODO: later we can add this parameter later
Expand All @@ -536,13 +579,6 @@ def INPUT_TYPES(cls) -> InputTypeDict:
# "tooltip": "How many images to generate",
# },
# ),
"image": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional reference images (Max 3).",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
Expand All @@ -559,49 +595,79 @@ def INPUT_TYPES(cls) -> InputTypeDict:

async def api_call(
self,
prompt,
prompt: str,
model: GeminiImageModel,
images: Optional[IO.IMAGE] = None,
files: Optional[list[GeminiPart]] = None,
n=1,
image=None,
unique_id: Optional[str] = None,
**kwargs,
):
# Validate inputs
validate_string(prompt, strip_whitespace=True, min_length=1)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [create_text_part(prompt)]

if image is not None:
image_parts = create_image_parts(image)
# Add other modal parts
if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
if files is not None:
parts.extend(files)

response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/gemini-2.5-flash-image-preview",
method=HttpMethod.POST,
request_model=GeminiGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
),
request=GeminiGenerateContentRequest(
endpoint=get_gemini_image_endpoint(model),
request=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(
role="user",
parts=parts,
)
]
),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=["TEXT","IMAGE"]
)
),
auth_kwargs=kwargs,
).execute()

output_image = get_image_from_response(response)
output_text = get_text_from_response(response)
return output_image, output_text
if unique_id and output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)

output_text = output_text or "Empty response from Gemini model..."
return (output_image, output_text,)


NODE_CLASS_MAPPINGS = {
"GeminiNode": GeminiNode,
"GeminiImageNode": GeminiImage,
"GeminiInputFiles": GeminiInputFiles,
"GeminiImage": GeminiImage,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"GeminiNode": "Google Gemini",
"GeminiImageNode": "Google Gemini Image",
"GeminiInputFiles": "Gemini Input Files",
"GeminiImage": "Gemini Image",
}
Loading