Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
bigcat88's progress on adding Google Gemini Image node
  • Loading branch information
Kosinkadink committed Aug 26, 2025
commit 50ccac9b28aaed319d2be1eb22a0b2fb861ef5c2
322 changes: 228 additions & 94 deletions comfy_api_nodes/nodes_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""
from __future__ import annotations


import json
import time
import os
Expand Down Expand Up @@ -75,6 +74,131 @@ def get_gemini_endpoint(
)


def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.

Args:
image_input: Batch of image tensors from ComfyUI.

Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=image_as_b64,
)
)
)
return image_parts


def create_text_part(text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.

Args:
text: The text content to include in the request.

Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)


def get_parts_from_response(
response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.

Args:
response: The API response from Gemini.

Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts


def get_parts_by_type(
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
"""
Filter response parts by their type.

Args:
response: The API response from Gemini.
part_type: Type of parts to extract ("text" or a MIME type).

Returns:
List of response parts matching the requested type.
"""
parts = []
for part in get_parts_from_response(response):
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
parts.append(part)
# Skip parts that don't match the requested type
return parts


def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
"""
Extract and concatenate all text parts from the response.

Args:
response: The API response from Gemini.

Returns:
Combined text from all text parts in the response.
"""
parts = get_parts_by_type(response, "text")
return "\n".join([part.text for part in parts])


def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
image_tensors: list[torch.Tensor] = []
# TODO:
"""
TODO something like this but without download but getting it from response:

# Process each image in the data array
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload – neither URL nor base64 data present.")

pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))

"""
return torch.stack(image_tensors, dim=0)


class GeminiNode(ComfyNodeABC):
"""
Node to generate text responses from a Gemini model.
Expand Down Expand Up @@ -159,59 +283,6 @@ def INPUT_TYPES(cls) -> InputTypeDict:
CATEGORY = "api node/text/Gemini"
API_NODE = True

def get_parts_from_response(
self, response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.

Args:
response: The API response from Gemini.

Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts

def get_parts_by_type(
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
"""
Filter response parts by their type.

Args:
response: The API response from Gemini.
part_type: Type of parts to extract ("text" or a MIME type).

Returns:
List of response parts matching the requested type.
"""
parts = []
for part in self.get_parts_from_response(response):
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
parts.append(part)
# Skip parts that don't match the requested type
return parts

def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
"""
Extract and concatenate all text parts from the response.

Args:
response: The API response from Gemini.

Returns:
Combined text from all text parts in the response.
"""
parts = self.get_parts_by_type(response, "text")
return "\n".join([part.text for part in parts])

def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
"""
Convert video input to Gemini API compatible parts.
Expand Down Expand Up @@ -271,43 +342,6 @@ def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
)
return audio_parts

def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.

Args:
image_input: Batch of image tensors from ComfyUI.

Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=image_as_b64,
)
)
)
return image_parts

def create_text_part(self, text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.

Args:
text: The text content to include in the request.

Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)

async def api_call(
self,
prompt: str,
Expand All @@ -323,11 +357,11 @@ async def api_call(
validate_string(prompt, strip_whitespace=False)

# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [self.create_text_part(prompt)]
parts: list[GeminiPart] = [create_text_part(prompt)]

# Add other modal parts
if images is not None:
image_parts = self.create_image_parts(images)
image_parts = create_image_parts(images)
parts.extend(image_parts)
if audio is not None:
parts.extend(self.create_audio_parts(audio))
Expand All @@ -351,7 +385,7 @@ async def api_call(
).execute()

# Get result output
output_text = self.get_text_from_response(response)
output_text = get_text_from_response(response)
if unique_id and output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
Expand Down Expand Up @@ -462,12 +496,112 @@ def prepare_files(
return (files,)


class GeminiImage(ComfyNodeABC):

@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text prompt for generation",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2**31 - 1,
"step": 1,
"display": "number",
"control_after_generate": True,
"tooltip": "not implemented yet in backend",
},
),
# TODO: later we can add this parameter later
# "n": (
# IO.INT,
# {
# "default": 1,
# "min": 1,
# "max": 8,
# "step": 1,
# "display": "number",
# "tooltip": "How many images to generate",
# },
# ),
"image": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional reference images (Max 3).",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}

RETURN_TYPES = (IO.IMAGE, IO.STRING)
FUNCTION = "api_call"
CATEGORY = "api node/image/Gemini"
DESCRIPTION = "Edit images synchronously via Google API."
API_NODE = True

async def api_call(
self,
prompt,
n=1,
image=None,
**kwargs,
):
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [create_text_part(prompt)]

if image is not None:
image_parts = create_image_parts(image)
parts.extend(image_parts)

response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/gemini-2.5-flash-image-preview",
method=HttpMethod.POST,
request_model=GeminiGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
),
request=GeminiGenerateContentRequest(
contents=[
GeminiContent(
role="user",
parts=parts,
)
]
),
auth_kwargs=kwargs,
).execute()

output_image = get_image_from_response(response)
output_text = get_text_from_response(response)
return output_image, output_text


NODE_CLASS_MAPPINGS = {
"GeminiNode": GeminiNode,
"GeminiInputFiles": GeminiInputFiles,
"GeminiImage": GeminiImage,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"GeminiNode": "Google Gemini",
"GeminiInputFiles": "Gemini Input Files",
"GeminiImage": "Gemini Image",
}