Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update PR to use latest gen ai utils..
  • Loading branch information
DylanRussell committed Sep 5, 2025
commit 62e427415c5d4033b4fd7c6d2755fb792d5fdf41
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ classifiers = [
dependencies = [
"opentelemetry-api ~= 1.28",
"opentelemetry-instrumentation == 0.58b0dev",
"opentelemetry-util-genai == 0.1b0.dev",
"opentelemetry-semantic-conventions == 0.58b0dev",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from opentelemetry.instrumentation._semconv import (
_OpenTelemetrySemanticConventionStability,
_OpenTelemetryStabilitySignalType,
_StabilityMode,
)
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap
Expand Down Expand Up @@ -125,9 +126,23 @@ def _instrument(self, **kwargs: Any):
sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
_OpenTelemetryStabilitySignalType.GEN_AI,
)
if sem_conv_opt_in_mode == _StabilityMode.DEFAULT:
# Type checker now knows sem_conv_opt_in_mode is a Literal[_StabilityMode.DEFAULT]
content_enabled = is_content_enabled(sem_conv_opt_in_mode)
elif sem_conv_opt_in_mode == _StabilityMode.GEN_AI_LATEST_EXPERIMENTAL:
# Type checker now knows it's the other literal
content_enabled = is_content_enabled(sem_conv_opt_in_mode)
else:
# Impossible to reach here, only 2 opt-in modes exist for GEN_AI.
raise ValueError(
f"Sem Conv opt in mode {sem_conv_opt_in_mode} not supported."
)

method_wrappers = MethodWrappers(
tracer, event_logger, is_content_enabled(), sem_conv_opt_in_mode
tracer,
event_logger,
content_enabled,
sem_conv_opt_in_mode,
)
for client_class, method_name, wrapper in _methods_to_wrap(
method_wrappers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
Any,
Awaitable,
Callable,
Literal,
MutableSequence,
Union,
cast,
overload,
)

from opentelemetry._events import EventLogger
Expand All @@ -38,6 +42,7 @@
response_to_events,
)
from opentelemetry.trace import SpanKind, Tracer
from opentelemetry.util.genai.types import ContentCapturingMode

if TYPE_CHECKING:
from google.cloud.aiplatform_v1.services.prediction_service import client
Expand Down Expand Up @@ -94,12 +99,35 @@ def _extract_params(


class MethodWrappers:
@overload
def __init__(
self,
tracer: Tracer,
event_logger: EventLogger,
capture_content: ContentCapturingMode,
sem_conv_opt_in_mode: Literal[
_StabilityMode.GEN_AI_LATEST_EXPERIMENTAL
],
) -> None: ...

@overload
def __init__(
self,
tracer: Tracer,
event_logger: EventLogger,
capture_content: bool,
sem_conv_opt_in_mode: _StabilityMode,
sem_conv_opt_in_mode: Literal[_StabilityMode.DEFAULT],
) -> None: ...

def __init__(
self,
tracer: Tracer,
event_logger: EventLogger,
capture_content: Union[bool, ContentCapturingMode],
sem_conv_opt_in_mode: Union[
Literal[_StabilityMode.DEFAULT],
Literal[_StabilityMode.GEN_AI_LATEST_EXPERIMENTAL],
],
) -> None:
self.tracer = tracer
self.event_logger = event_logger
Expand All @@ -116,6 +144,7 @@ def __init__(
@contextmanager
def _with_new_instrumentation(
self,
capture_content: ContentCapturingMode,
instance: client.PredictionServiceClient
| client_v1beta1.PredictionServiceClient,
args: Any,
Expand Down Expand Up @@ -152,7 +181,7 @@ def handle_response(
create_operation_details_event(
api_endpoint=api_endpoint,
params=params,
capture_content=self.capture_content,
capture_content=capture_content,
response=response,
)
)
Expand All @@ -162,6 +191,7 @@ def handle_response(
@contextmanager
def _with_default_instrumentation(
self,
capture_content: bool,
instance: client.PredictionServiceClient
| client_v1beta1.PredictionServiceClient,
args: Any,
Expand All @@ -182,7 +212,7 @@ def _with_default_instrumentation(
attributes=span_attributes,
) as span:
for event in request_to_events(
params=params, capture_content=self.capture_content
params=params, capture_content=capture_content
):
self.event_logger.emit(event)

Expand All @@ -203,7 +233,7 @@ def handle_response(
)

for event in response_to_events(
response=response, capture_content=self.capture_content
response=response, capture_content=capture_content
):
self.event_logger.emit(event)

Expand All @@ -225,15 +255,17 @@ def generate_content(
| prediction_service_v1beta1.GenerateContentResponse
):
if self.sem_conv_opt_in_mode == _StabilityMode.DEFAULT:
capture_content_bool = cast(bool, self.capture_content)
with self._with_default_instrumentation(
instance, args, kwargs
capture_content_bool, instance, args, kwargs
) as handle_response:
response = wrapped(*args, **kwargs)
handle_response(response)
return response
else:
capture_content = cast(ContentCapturingMode, self.capture_content)
with self._with_new_instrumentation(
instance, args, kwargs
capture_content, instance, args, kwargs
) as handle_response:
try:
response = wrapped(*args, **kwargs)
Expand All @@ -243,7 +275,7 @@ def generate_content(
create_operation_details_event(
params=_extract_params(*args, **kwargs),
response=None,
capture_content=self.capture_content,
capture_content=capture_content,
api_endpoint=api_endpoint,
)
)
Expand All @@ -269,15 +301,17 @@ async def agenerate_content(
| prediction_service_v1beta1.GenerateContentResponse
):
if self.sem_conv_opt_in_mode == _StabilityMode.DEFAULT:
capture_content_bool = cast(bool, self.capture_content)
with self._with_default_instrumentation(
instance, args, kwargs
capture_content_bool, instance, args, kwargs
) as handle_response:
response = await wrapped(*args, **kwargs)
handle_response(response)
return response
else:
capture_content = cast(ContentCapturingMode, self.capture_content)
with self._with_new_instrumentation(
instance, args, kwargs
capture_content, instance, args, kwargs
) as handle_response:
try:
response = await wrapped(*args, **kwargs)
Expand All @@ -287,7 +321,7 @@ async def agenerate_content(
create_operation_details_event(
params=_extract_params(*args, **kwargs),
response=None,
capture_content=self.capture_content,
capture_content=capture_content,
api_endpoint=api_endpoint,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,25 @@
from os import environ
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Literal,
Mapping,
Optional,
Sequence,
Union,
cast,
overload,
)
from urllib.parse import urlparse

from google.protobuf import json_format

from opentelemetry._events import Event
from opentelemetry.instrumentation._semconv import (
_StabilityMode,
)
from opentelemetry.instrumentation.vertexai.events import (
ChoiceMessage,
ChoiceToolCall,
FinishReason,
assistant_event,
choice_event,
system_event,
Expand All @@ -49,6 +50,17 @@
gen_ai_attributes as GenAIAttributes,
)
from opentelemetry.semconv.attributes import server_attributes
from opentelemetry.util.genai.types import (
ContentCapturingMode,
FinishReason,
InputMessage,
MessagePart,
OutputMessage,
Text,
ToolCall,
ToolCallResponse,
)
from opentelemetry.util.genai.utils import get_content_capturing_mode
from opentelemetry.util.types import AnyValue, AttributeValue

if TYPE_CHECKING:
Expand All @@ -71,43 +83,6 @@
_MODEL = "model"


@dataclass(frozen=True)
class ToolCall:
type: Literal["tool_call"]
arguments: Any
name: str
id: Optional[str]


@dataclass(frozen=True)
class ToolCallResponse:
type: Literal["tool_call_response"]
response: Any
id: Optional[str]


@dataclass(frozen=True)
class TextPart:
type: Literal["text"]
content: str


MessagePart = Union[TextPart, ToolCall, ToolCallResponse, Any]


@dataclass()
class InputMessage:
role: str
parts: list[MessagePart]


@dataclass()
class OutputMessage:
role: str
parts: list[MessagePart]
finish_reason: Union[str, FinishReason]


@dataclass(frozen=True)
class GenerateContentParams:
model: str
Expand Down Expand Up @@ -245,12 +220,29 @@ def _get_model_name(model: str) -> str:
)


def is_content_enabled() -> bool:
capture_content = environ.get(
OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, "false"
)
@overload
def is_content_enabled(
mode: Literal[_StabilityMode.GEN_AI_LATEST_EXPERIMENTAL],
) -> ContentCapturingMode: ...


return capture_content.lower() == "true"
@overload
def is_content_enabled(mode: Literal[_StabilityMode.DEFAULT]) -> bool: ...


def is_content_enabled(
mode: Union[
Literal[_StabilityMode.DEFAULT],
Literal[_StabilityMode.GEN_AI_LATEST_EXPERIMENTAL],
],
) -> Union[bool, ContentCapturingMode]:
if mode == _StabilityMode.DEFAULT:
capture_content = environ.get(
OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, "false"
)

return capture_content.lower() == "true"
return get_content_capturing_mode()


def get_span_name(span_attributes: Mapping[str, AttributeValue]) -> str:
Expand Down Expand Up @@ -322,7 +314,7 @@ def create_operation_details_event(
| prediction_service_v1beta1.GenerateContentResponse
| None,
params: GenerateContentParams,
capture_content: bool,
capture_content: ContentCapturingMode,
) -> Event:
event = Event(name="gen_ai.client.inference.operation.details")
attributes: dict[str, AnyValue] = {
Expand All @@ -331,7 +323,10 @@ def create_operation_details_event(
**(get_genai_response_attributes(response) if response else {}),
}
event.attributes = attributes
if not capture_content:
if capture_content in {
ContentCapturingMode.NO_CONTENT,
ContentCapturingMode.SPAN_ONLY,
}:
return event
if params.system_instruction:
attributes["gen_ai.system_instructions"] = [
Expand Down Expand Up @@ -381,7 +376,6 @@ def _convert_content_to_message(
part = part.function_response
parts.append(
ToolCallResponse(
type="tool_call_response",
id=f"{part.name}_{idx}",
response=json_format.MessageToDict(part._pb.response), # type: ignore[reportUnknownMemberType]
)
Expand All @@ -390,7 +384,6 @@ def _convert_content_to_message(
part = part.function_call
parts.append(
ToolCall(
type="tool_call",
id=f"{part.name}_{idx}",
name=part.name,
arguments=json_format.MessageToDict(
Expand All @@ -399,7 +392,7 @@ def _convert_content_to_message(
)
)
elif "text" in part:
parts.append(TextPart(type="text", content=part.text))
parts.append(Text(content=part.text))
else:
dict_part = type(part).to_dict( # type: ignore[reportUnknownMemberType]
part, always_print_fields_with_no_presence=False
Expand Down
Loading
Loading