|
24 | 24 | from typing import Any |
25 | 25 |
|
26 | 26 | from google import genai |
| 27 | +import openai |
27 | 28 | import requests |
28 | 29 | from typing_extensions import override |
29 | 30 | import yaml |
@@ -383,7 +384,174 @@ def infer( |
383 | 384 | yield [result] |
384 | 385 |
|
385 | 386 | def parse_output(self, output: str) -> Any: |
386 | | - """Parses Gemini output as JSON or YAML.""" |
| 387 | + """Parses Gemini output as JSON or YAML. |
| 388 | +
|
| 389 | + Note: This expects raw JSON/YAML without code fences. |
| 390 | + Code fence extraction is handled by resolver.py. |
| 391 | + """ |
| 392 | + try: |
| 393 | + if self.format_type == data.FormatType.JSON: |
| 394 | + return json.loads(output) |
| 395 | + else: |
| 396 | + return yaml.safe_load(output) |
| 397 | + except Exception as e: |
| 398 | + raise ValueError( |
| 399 | + f'Failed to parse output as {self.format_type.name}: {str(e)}' |
| 400 | + ) from e |
| 401 | + |
| 402 | + |
| 403 | +@dataclasses.dataclass(init=False) |
| 404 | +class OpenAILanguageModel(BaseLanguageModel): |
| 405 | + """Language model inference using OpenAI's API with structured output.""" |
| 406 | + |
| 407 | + model_id: str = 'gpt-4o-mini' |
| 408 | + api_key: str | None = None |
| 409 | + organization: str | None = None |
| 410 | + format_type: data.FormatType = data.FormatType.JSON |
| 411 | + temperature: float = 0.0 |
| 412 | + max_workers: int = 10 |
| 413 | + _client: openai.OpenAI | None = dataclasses.field( |
| 414 | + default=None, repr=False, compare=False |
| 415 | + ) |
| 416 | + _extra_kwargs: dict[str, Any] = dataclasses.field( |
| 417 | + default_factory=dict, repr=False, compare=False |
| 418 | + ) |
| 419 | + |
| 420 | + def __init__( |
| 421 | + self, |
| 422 | + model_id: str = 'gpt-4o-mini', |
| 423 | + api_key: str | None = None, |
| 424 | + organization: str | None = None, |
| 425 | + format_type: data.FormatType = data.FormatType.JSON, |
| 426 | + temperature: float = 0.0, |
| 427 | + max_workers: int = 10, |
| 428 | + **kwargs, |
| 429 | + ) -> None: |
| 430 | + """Initialize the OpenAI language model. |
| 431 | +
|
| 432 | + Args: |
| 433 | + model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o'). |
| 434 | + api_key: API key for OpenAI service. |
| 435 | + organization: Optional OpenAI organization ID. |
| 436 | + format_type: Output format (JSON or YAML). |
| 437 | + temperature: Sampling temperature. |
| 438 | + max_workers: Maximum number of parallel API calls. |
| 439 | + **kwargs: Ignored extra parameters so callers can pass a superset of |
| 440 | + arguments shared across back-ends without raising ``TypeError``. |
| 441 | + """ |
| 442 | + self.model_id = model_id |
| 443 | + self.api_key = api_key |
| 444 | + self.organization = organization |
| 445 | + self.format_type = format_type |
| 446 | + self.temperature = temperature |
| 447 | + self.max_workers = max_workers |
| 448 | + self._extra_kwargs = kwargs or {} |
| 449 | + |
| 450 | + if not self.api_key: |
| 451 | + raise ValueError('API key not provided.') |
| 452 | + |
| 453 | + # Initialize the OpenAI client |
| 454 | + self._client = openai.OpenAI( |
| 455 | + api_key=self.api_key, organization=self.organization |
| 456 | + ) |
| 457 | + |
| 458 | + super().__init__( |
| 459 | + constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE) |
| 460 | + ) |
| 461 | + |
| 462 | + def _process_single_prompt(self, prompt: str, config: dict) -> ScoredOutput: |
| 463 | + """Process a single prompt and return a ScoredOutput.""" |
| 464 | + try: |
| 465 | + # Prepare the system message for structured output |
| 466 | + system_message = '' |
| 467 | + if self.format_type == data.FormatType.JSON: |
| 468 | + system_message = ( |
| 469 | + 'You are a helpful assistant that responds in JSON format.' |
| 470 | + ) |
| 471 | + elif self.format_type == data.FormatType.YAML: |
| 472 | + system_message = ( |
| 473 | + 'You are a helpful assistant that responds in YAML format.' |
| 474 | + ) |
| 475 | + |
| 476 | + # Create the chat completion using the v1.x client API |
| 477 | + response = self._client.chat.completions.create( |
| 478 | + model=self.model_id, |
| 479 | + messages=[ |
| 480 | + {'role': 'system', 'content': system_message}, |
| 481 | + {'role': 'user', 'content': prompt}, |
| 482 | + ], |
| 483 | + temperature=config.get('temperature', self.temperature), |
| 484 | + max_tokens=config.get('max_output_tokens'), |
| 485 | + top_p=config.get('top_p'), |
| 486 | + n=1, |
| 487 | + ) |
| 488 | + |
| 489 | + # Extract the response text using the v1.x response format |
| 490 | + output_text = response.choices[0].message.content |
| 491 | + |
| 492 | + return ScoredOutput(score=1.0, output=output_text) |
| 493 | + |
| 494 | + except Exception as e: |
| 495 | + raise InferenceOutputError(f'OpenAI API error: {str(e)}') from e |
| 496 | + |
| 497 | + def infer( |
| 498 | + self, batch_prompts: Sequence[str], **kwargs |
| 499 | + ) -> Iterator[Sequence[ScoredOutput]]: |
| 500 | + """Runs inference on a list of prompts via OpenAI's API. |
| 501 | +
|
| 502 | + Args: |
| 503 | + batch_prompts: A list of string prompts. |
| 504 | + **kwargs: Additional generation params (temperature, top_p, etc.) |
| 505 | +
|
| 506 | + Yields: |
| 507 | + Lists of ScoredOutputs. |
| 508 | + """ |
| 509 | + config = { |
| 510 | + 'temperature': kwargs.get('temperature', self.temperature), |
| 511 | + } |
| 512 | + if 'max_output_tokens' in kwargs: |
| 513 | + config['max_output_tokens'] = kwargs['max_output_tokens'] |
| 514 | + if 'top_p' in kwargs: |
| 515 | + config['top_p'] = kwargs['top_p'] |
| 516 | + |
| 517 | + # Use parallel processing for batches larger than 1 |
| 518 | + if len(batch_prompts) > 1 and self.max_workers > 1: |
| 519 | + with concurrent.futures.ThreadPoolExecutor( |
| 520 | + max_workers=min(self.max_workers, len(batch_prompts)) |
| 521 | + ) as executor: |
| 522 | + future_to_index = { |
| 523 | + executor.submit( |
| 524 | + self._process_single_prompt, prompt, config.copy() |
| 525 | + ): i |
| 526 | + for i, prompt in enumerate(batch_prompts) |
| 527 | + } |
| 528 | + |
| 529 | + results: list[ScoredOutput | None] = [None] * len(batch_prompts) |
| 530 | + for future in concurrent.futures.as_completed(future_to_index): |
| 531 | + index = future_to_index[future] |
| 532 | + try: |
| 533 | + results[index] = future.result() |
| 534 | + except Exception as e: |
| 535 | + raise InferenceOutputError( |
| 536 | + f'Parallel inference error: {str(e)}' |
| 537 | + ) from e |
| 538 | + |
| 539 | + for result in results: |
| 540 | + if result is None: |
| 541 | + raise InferenceOutputError('Failed to process one or more prompts') |
| 542 | + yield [result] |
| 543 | + else: |
| 544 | + # Sequential processing for single prompt or worker |
| 545 | + for prompt in batch_prompts: |
| 546 | + result = self._process_single_prompt(prompt, config.copy()) |
| 547 | + yield [result] |
| 548 | + |
| 549 | + def parse_output(self, output: str) -> Any: |
| 550 | + """Parses OpenAI output as JSON or YAML. |
| 551 | +
|
| 552 | + Note: This expects raw JSON/YAML without code fences. |
| 553 | + Code fence extraction is handled by resolver.py. |
| 554 | + """ |
387 | 555 | try: |
388 | 556 | if self.format_type == data.FormatType.JSON: |
389 | 557 | return json.loads(output) |
|
0 commit comments