Skip to content

Commit 8289b3a

Browse files
committed
feat: Add OpenAI language model support
1 parent 75a6f12 commit 8289b3a

File tree

4 files changed

+312
-6
lines changed

4 files changed

+312
-6
lines changed

README.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ docker run --rm -e LANGEXTRACT_API_KEY="your-api-key" langextract python your_sc
190190

191191
## API Key Setup for Cloud Models
192192

193-
When using LangExtract with cloud-hosted models (like Gemini), you'll need to
193+
When using LangExtract with cloud-hosted models (like Gemini or OpenAI), you'll need to
194194
set up an API key. On-device models don't require an API key. For developers
195195
using local LLMs, LangExtract offers built-in support for Ollama and can be
196196
extended to other third-party APIs by updating the inference endpoints.
@@ -201,6 +201,7 @@ Get API keys from:
201201

202202
* [AI Studio](https://aistudio.google.com/app/apikey) for Gemini models
203203
* [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/sdks/overview) for enterprise use
204+
* [OpenAI Platform](https://platform.openai.com/api-keys) for OpenAI models
204205

205206
### Setting up API key in your environment
206207

@@ -250,6 +251,27 @@ result = lx.extract(
250251
)
251252
```
252253

254+
## Using OpenAI Models
255+
256+
LangExtract also supports OpenAI models. Example OpenAI configuration:
257+
258+
```python
259+
from langextract.inference import OpenAILanguageModel
260+
261+
result = lx.extract(
262+
text_or_documents=input_text,
263+
prompt_description=prompt,
264+
examples=examples,
265+
language_model_type=OpenAILanguageModel,
266+
model_id="gpt-4o",
267+
api_key=os.environ.get('OPENAI_API_KEY'),
268+
fence_output=True,
269+
use_schema_constraints=False
270+
)
271+
```
272+
273+
Note: OpenAI models require `fence_output=True` and `use_schema_constraints=False` because LangExtract doesn't implement schema constraints for OpenAI yet.
274+
253275
## More Examples
254276

255277
Additional examples of LangExtract in action:

langextract/inference.py

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Any
2525

2626
from google import genai
27+
import openai
2728
import requests
2829
from typing_extensions import override
2930
import yaml
@@ -383,7 +384,174 @@ def infer(
383384
yield [result]
384385

385386
def parse_output(self, output: str) -> Any:
386-
"""Parses Gemini output as JSON or YAML."""
387+
"""Parses Gemini output as JSON or YAML.
388+
389+
Note: This expects raw JSON/YAML without code fences.
390+
Code fence extraction is handled by resolver.py.
391+
"""
392+
try:
393+
if self.format_type == data.FormatType.JSON:
394+
return json.loads(output)
395+
else:
396+
return yaml.safe_load(output)
397+
except Exception as e:
398+
raise ValueError(
399+
f'Failed to parse output as {self.format_type.name}: {str(e)}'
400+
) from e
401+
402+
403+
@dataclasses.dataclass(init=False)
404+
class OpenAILanguageModel(BaseLanguageModel):
405+
"""Language model inference using OpenAI's API with structured output."""
406+
407+
model_id: str = 'gpt-4o-mini'
408+
api_key: str | None = None
409+
organization: str | None = None
410+
format_type: data.FormatType = data.FormatType.JSON
411+
temperature: float = 0.0
412+
max_workers: int = 10
413+
_client: openai.OpenAI | None = dataclasses.field(
414+
default=None, repr=False, compare=False
415+
)
416+
_extra_kwargs: dict[str, Any] = dataclasses.field(
417+
default_factory=dict, repr=False, compare=False
418+
)
419+
420+
def __init__(
421+
self,
422+
model_id: str = 'gpt-4o-mini',
423+
api_key: str | None = None,
424+
organization: str | None = None,
425+
format_type: data.FormatType = data.FormatType.JSON,
426+
temperature: float = 0.0,
427+
max_workers: int = 10,
428+
**kwargs,
429+
) -> None:
430+
"""Initialize the OpenAI language model.
431+
432+
Args:
433+
model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o').
434+
api_key: API key for OpenAI service.
435+
organization: Optional OpenAI organization ID.
436+
format_type: Output format (JSON or YAML).
437+
temperature: Sampling temperature.
438+
max_workers: Maximum number of parallel API calls.
439+
**kwargs: Ignored extra parameters so callers can pass a superset of
440+
arguments shared across back-ends without raising ``TypeError``.
441+
"""
442+
self.model_id = model_id
443+
self.api_key = api_key
444+
self.organization = organization
445+
self.format_type = format_type
446+
self.temperature = temperature
447+
self.max_workers = max_workers
448+
self._extra_kwargs = kwargs or {}
449+
450+
if not self.api_key:
451+
raise ValueError('API key not provided.')
452+
453+
# Initialize the OpenAI client
454+
self._client = openai.OpenAI(
455+
api_key=self.api_key, organization=self.organization
456+
)
457+
458+
super().__init__(
459+
constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE)
460+
)
461+
462+
def _process_single_prompt(self, prompt: str, config: dict) -> ScoredOutput:
463+
"""Process a single prompt and return a ScoredOutput."""
464+
try:
465+
# Prepare the system message for structured output
466+
system_message = ''
467+
if self.format_type == data.FormatType.JSON:
468+
system_message = (
469+
'You are a helpful assistant that responds in JSON format.'
470+
)
471+
elif self.format_type == data.FormatType.YAML:
472+
system_message = (
473+
'You are a helpful assistant that responds in YAML format.'
474+
)
475+
476+
# Create the chat completion using the v1.x client API
477+
response = self._client.chat.completions.create(
478+
model=self.model_id,
479+
messages=[
480+
{'role': 'system', 'content': system_message},
481+
{'role': 'user', 'content': prompt},
482+
],
483+
temperature=config.get('temperature', self.temperature),
484+
max_tokens=config.get('max_output_tokens'),
485+
top_p=config.get('top_p'),
486+
n=1,
487+
)
488+
489+
# Extract the response text using the v1.x response format
490+
output_text = response.choices[0].message.content
491+
492+
return ScoredOutput(score=1.0, output=output_text)
493+
494+
except Exception as e:
495+
raise InferenceOutputError(f'OpenAI API error: {str(e)}') from e
496+
497+
def infer(
498+
self, batch_prompts: Sequence[str], **kwargs
499+
) -> Iterator[Sequence[ScoredOutput]]:
500+
"""Runs inference on a list of prompts via OpenAI's API.
501+
502+
Args:
503+
batch_prompts: A list of string prompts.
504+
**kwargs: Additional generation params (temperature, top_p, etc.)
505+
506+
Yields:
507+
Lists of ScoredOutputs.
508+
"""
509+
config = {
510+
'temperature': kwargs.get('temperature', self.temperature),
511+
}
512+
if 'max_output_tokens' in kwargs:
513+
config['max_output_tokens'] = kwargs['max_output_tokens']
514+
if 'top_p' in kwargs:
515+
config['top_p'] = kwargs['top_p']
516+
517+
# Use parallel processing for batches larger than 1
518+
if len(batch_prompts) > 1 and self.max_workers > 1:
519+
with concurrent.futures.ThreadPoolExecutor(
520+
max_workers=min(self.max_workers, len(batch_prompts))
521+
) as executor:
522+
future_to_index = {
523+
executor.submit(
524+
self._process_single_prompt, prompt, config.copy()
525+
): i
526+
for i, prompt in enumerate(batch_prompts)
527+
}
528+
529+
results: list[ScoredOutput | None] = [None] * len(batch_prompts)
530+
for future in concurrent.futures.as_completed(future_to_index):
531+
index = future_to_index[future]
532+
try:
533+
results[index] = future.result()
534+
except Exception as e:
535+
raise InferenceOutputError(
536+
f'Parallel inference error: {str(e)}'
537+
) from e
538+
539+
for result in results:
540+
if result is None:
541+
raise InferenceOutputError('Failed to process one or more prompts')
542+
yield [result]
543+
else:
544+
# Sequential processing for single prompt or worker
545+
for prompt in batch_prompts:
546+
result = self._process_single_prompt(prompt, config.copy())
547+
yield [result]
548+
549+
def parse_output(self, output: str) -> Any:
550+
"""Parses OpenAI output as JSON or YAML.
551+
552+
Note: This expects raw JSON/YAML without code fences.
553+
Code fence extraction is handled by resolver.py.
554+
"""
387555
try:
388556
if self.format_type == data.FormatType.JSON:
389557
return json.loads(output)

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ build-backend = "setuptools.build_meta"
1818

1919
[project]
2020
name = "langextract"
21-
version = "1.0.2"
21+
version = "1.0.3"
2222
description = "LangExtract: A library for extracting structured data from language models"
2323
readme = "README.md"
2424
requires-python = ">=3.10"
@@ -35,10 +35,11 @@ dependencies = [
3535
"ml-collections>=0.1.0",
3636
"more-itertools>=8.0.0",
3737
"numpy>=1.20.0",
38-
"openai>=0.27.0",
38+
"openai>=1.50.0",
3939
"pandas>=1.3.0",
4040
"pydantic>=1.8.0",
4141
"python-dotenv>=0.19.0",
42+
"PyYAML>=6.0",
4243
"requests>=2.25.0",
4344
"tqdm>=4.64.0",
4445
"typing-extensions>=4.0.0"
@@ -55,9 +56,8 @@ dev = [
5556
"pyink~=24.3.0",
5657
"isort>=5.13.0",
5758
"pylint>=3.0.0",
58-
"pytest>=7.4.0",
5959
"pytype>=2024.10.11",
60-
"tox>=4.0.0",
60+
"tox>=4.0.0"
6161
]
6262
test = [
6363
"pytest>=7.4.0",

tests/inference_test.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from absl.testing import absltest
1818

19+
from langextract import data
1920
from langextract import inference
2021

2122

@@ -93,5 +94,120 @@ def test_ollama_infer(self, mock_ollama_query):
9394
self.assertEqual(results, expected_results)
9495

9596

97+
class TestOpenAILanguageModel(absltest.TestCase):
98+
99+
@mock.patch("openai.OpenAI")
100+
def test_openai_infer(self, mock_openai_class):
101+
# Mock the OpenAI client and chat completion response
102+
mock_client = mock.Mock()
103+
mock_openai_class.return_value = mock_client
104+
105+
# Mock response structure for v1.x API
106+
mock_response = mock.Mock()
107+
mock_response.choices = [
108+
mock.Mock(message=mock.Mock(content='{"name": "John", "age": 30}'))
109+
]
110+
mock_client.chat.completions.create.return_value = mock_response
111+
112+
# Create model instance
113+
model = inference.OpenAILanguageModel(
114+
model_id="gpt-4o-mini", api_key="test-api-key", temperature=0.5
115+
)
116+
117+
# Test inference
118+
batch_prompts = ["Extract name and age from: John is 30 years old"]
119+
results = list(model.infer(batch_prompts))
120+
121+
# Verify API was called correctly
122+
mock_client.chat.completions.create.assert_called_once_with(
123+
model="gpt-4o-mini",
124+
messages=[
125+
{
126+
"role": "system",
127+
"content": (
128+
"You are a helpful assistant that responds in JSON format."
129+
),
130+
},
131+
{
132+
"role": "user",
133+
"content": "Extract name and age from: John is 30 years old",
134+
},
135+
],
136+
temperature=0.5,
137+
max_tokens=None,
138+
top_p=None,
139+
n=1,
140+
)
141+
142+
# Check results
143+
expected_results = [[
144+
inference.ScoredOutput(score=1.0, output='{"name": "John", "age": 30}')
145+
]]
146+
self.assertEqual(results, expected_results)
147+
148+
def test_openai_parse_output_json(self):
149+
model = inference.OpenAILanguageModel(
150+
api_key="test-key", format_type=data.FormatType.JSON
151+
)
152+
153+
# Test valid JSON parsing
154+
output = '{"key": "value", "number": 42}'
155+
parsed = model.parse_output(output)
156+
self.assertEqual(parsed, {"key": "value", "number": 42})
157+
158+
# Test invalid JSON
159+
with self.assertRaises(ValueError) as context:
160+
model.parse_output("invalid json")
161+
self.assertIn("Failed to parse output as JSON", str(context.exception))
162+
163+
def test_openai_parse_output_yaml(self):
164+
model = inference.OpenAILanguageModel(
165+
api_key="test-key", format_type=data.FormatType.YAML
166+
)
167+
168+
# Test valid YAML parsing
169+
output = "key: value\nnumber: 42"
170+
parsed = model.parse_output(output)
171+
self.assertEqual(parsed, {"key": "value", "number": 42})
172+
173+
# Test invalid YAML
174+
with self.assertRaises(ValueError) as context:
175+
model.parse_output("invalid: yaml: bad")
176+
self.assertIn("Failed to parse output as YAML", str(context.exception))
177+
178+
def test_openai_no_api_key_raises_error(self):
179+
with self.assertRaises(ValueError) as context:
180+
inference.OpenAILanguageModel(api_key=None)
181+
self.assertEqual(str(context.exception), "API key not provided.")
182+
183+
@mock.patch("openai.OpenAI")
184+
def test_openai_temperature_zero(self, mock_openai_class):
185+
# Test that temperature=0.0 is properly passed through
186+
mock_client = mock.Mock()
187+
mock_openai_class.return_value = mock_client
188+
189+
mock_response = mock.Mock()
190+
mock_response.choices = [
191+
mock.Mock(message=mock.Mock(content='{"result": "test"}'))
192+
]
193+
mock_client.chat.completions.create.return_value = mock_response
194+
195+
model = inference.OpenAILanguageModel(
196+
api_key="test-key", temperature=0.0 # Testing zero temperature
197+
)
198+
199+
list(model.infer(["test prompt"]))
200+
201+
# Verify temperature=0.0 was passed to the API
202+
mock_client.chat.completions.create.assert_called_with(
203+
model="gpt-4o-mini",
204+
messages=mock.ANY,
205+
temperature=0.0,
206+
max_tokens=None,
207+
top_p=None,
208+
n=1,
209+
)
210+
211+
96212
if __name__ == "__main__":
97213
absltest.main()

0 commit comments

Comments
 (0)