Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix typecheck
  • Loading branch information
DylanRussell committed Aug 22, 2025
commit c54e74ea9f499f7cff20ad300b902e35a853f64d
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pylint==3.0.2
httpretty==1.1.4
pyright==v1.1.396
pyright==v1.1.404
sphinx==7.1.2
sphinx-rtd-theme==2.0.0rc4
sphinx-autodoc-typehints==1.25.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,13 @@ def generate_content(
try:
response = wrapped(*args, **kwargs)
except Exception as e:
api_endpoint: str = instance.api_endpoint # type: ignore[reportUnknownMemberType]
self.event_logger.emit(
create_operation_details_event(
params=_extract_params(*args, **kwargs),
response=None,
capture_content=self.capture_content,
api_endpoint=instance.api_endpoint,
api_endpoint=api_endpoint,
)
)
raise e
Expand Down Expand Up @@ -281,12 +282,13 @@ async def agenerate_content(
try:
response = await wrapped(*args, **kwargs)
except Exception as e:
api_endpoint: str = instance.api_endpoint # type: ignore[reportUnknownMemberType]
self.event_logger.emit(
create_operation_details_event(
params=_extract_params(*args, **kwargs),
response=None,
capture_content=self.capture_content,
api_endpoint=instance.api_endpoint,
api_endpoint=api_endpoint,
)
)
raise e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
from os import environ
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Literal,
Mapping,
Optional,
Sequence,
Union,
cast,
)
from urllib.parse import urlparse
Expand Down Expand Up @@ -65,6 +69,43 @@
_MODEL = "model"


@dataclass(frozen=True)
class ToolCall:
type: Literal["tool_call"]
arguments: Any
name: str
id: Optional[str]


@dataclass(frozen=True)
class ToolCallResponse:
type: Literal["tool_call_response"]
response: Any
id: Optional[str]


@dataclass(frozen=True)
class TextPart:
type: Literal["text"]
content: str


MessagePart = Union[TextPart, ToolCall, ToolCallResponse, Any]


@dataclass()
class InputMessage(Any):
role: str
parts: list[MessagePart]


@dataclass()
class OutputMessage(Any):
role: str
parts: list[MessagePart]
finish_reason: Union[str, FinishReason]


@dataclass(frozen=True)
class GenerateContentParams:
model: str
Expand Down Expand Up @@ -256,7 +297,7 @@ def request_to_events(
id_=f"{function_response.name}_{idx}",
role=content.role,
content=json_format.MessageToDict(
function_response._pb.response
function_response._pb.response # type: ignore[reportUnknownMemberType]
)
if capture_content
else None,
Expand Down Expand Up @@ -290,15 +331,15 @@ def create_operation_details_event(
event.attributes = attributes
if not capture_content:
return event

attributes["gen_ai.system_instructions"] = [
{
"type": "text",
"content": "\n".join(
part.text for part in params.system_instruction.parts
),
}
]
if params.system_instruction:
attributes["gen_ai.system_instructions"] = [
{
"type": "text",
"content": "\n".join(
part.text for part in params.system_instruction.parts
),
}
]
if params.contents:
attributes["gen_ai.input.messages"] = [
_convert_content_to_message(content) for content in params.contents
Expand All @@ -313,47 +354,50 @@ def create_operation_details_event(
def _convert_response_to_output_messages(
response: prediction_service.GenerateContentResponse
| prediction_service_v1beta1.GenerateContentResponse,
) -> list:
output_messages = []
) -> list[OutputMessage]:
output_messages: list[OutputMessage] = []
for candidate in response.candidates:
message = _convert_content_to_message(candidate.content)
message["finish_reason"] = _map_finish_reason(candidate.finish_reason)
message.finish_reason = _map_finish_reason(candidate.finish_reason)
output_messages.append(message)
return output_messages


def _convert_content_to_message(content: content.Content) -> dict:
message = {"role": content.role, "parts": []}
def _convert_content_to_message(
content: content.Content | content_v1beta1.Content,
) -> InputMessage:
parts: MessagePart = []
message = InputMessage(role=content.role, parts=parts)
for idx, part in enumerate(content.parts):
if "function_response" in part:
part = part.function_response
message["parts"].append(
{
"type": "tool_call_response",
"id": f"{part.name}_{idx}",
"response": json_format.MessageToDict(part._pb.response),
}
parts.append(
ToolCallResponse(
type="tool_call_response",
id=f"{part.name}_{idx}",
response=json_format.MessageToDict(part._pb.response), # type: ignore[reportUnknownMemberType]
)
)
elif "function_call" in part:
part = part.function_call
message["parts"].append(
{
"type": "tool_call",
"id": f"{part.name}_{idx}",
"name": part.name,
"response": json_format.MessageToDict(
part._pb.args,
parts.append(
ToolCall(
type="tool_call",
id=f"{part.name}_{idx}",
name=part.name,
arguments=json_format.MessageToDict(
part._pb.args, # type: ignore[reportUnknownMemberType]
),
}
)
)
elif "text" in part:
message["parts"].append({"type": "text", "content": part.text})
part = part.text
parts.append(TextPart(type="text", content=part.text))
else:
message["parts"].append(
type(part).to_dict(part, always_print_fields_with_no_presence=False)
dict_part = type(part).to_dict( # type: ignore[reportUnknownMemberType]
part, always_print_fields_with_no_presence=False
)
message["parts"][-1]["type"] = type(part)
dict_part["type"] = type(part)
parts.append(dict_part)
return message


Expand Down Expand Up @@ -401,7 +445,7 @@ def _extract_tool_calls(
function=ChoiceToolCall.Function(
name=part.function_call.name,
arguments=json_format.MessageToDict(
part.function_call._pb.args
part.function_call._pb.args # type: ignore[reportUnknownMemberType]
)
if capture_content
else None,
Expand All @@ -420,7 +464,9 @@ def _parts_to_any_value(
return [
cast(
"dict[str, AnyValue]",
type(part).to_dict(part, always_print_fields_with_no_presence=False), # type: ignore[reportUnknownMemberType]
type(part).to_dict( # type: ignore[reportUnknownMemberType]
part, always_print_fields_with_no_presence=False
),
)
for part in parts
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vertexai.preview.generative_models import (
GenerativeModel as PreviewGenerativeModel,
)

from opentelemetry.instrumentation.vertexai import VertexAIInstrumentor
from opentelemetry.sdk._logs._internal.export.in_memory_log_exporter import (
InMemoryLogExporter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
OTEL_SEMCONV_STABILITY_OPT_IN = "OTEL_SEMCONV_STABILITY_OPT_IN"


class _OpenTelemetryStabilitySignalType:
class _OpenTelemetryStabilitySignalType(Enum):
HTTP = "http"
DATABASE = "database"
GEN_AI = "gen_ai"
Expand Down
Loading