Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2ce2399
docs(pypi): Improve README display and badge reliability
aksg87 Jul 22, 2025
4fe7580
feat: add trusted publishing workflow and prepare v1.0.0 release
aksg87 Jul 22, 2025
e696a48
Fix: Resolve libmagic ImportError (#6)
aksg87 Aug 1, 2025
5447637
docs: clarify output_dir behavior in medication_examples.md
kleeena Aug 1, 2025
9c47b34
Merge pull request #11 from google/fix/libmagic-dependency-issue
aksg87 Aug 1, 2025
175e075
Removed inline comment in medication example
kleeena Aug 2, 2025
9472099
Merge pull request #15 from kleeena/docs/update-medication_examples.md
aksg87 Aug 2, 2025
e6c3dcd
docs: add output_dir="." to all save_annotated_documents examples
aksg87 Aug 2, 2025
1fb1f1d
Merge pull request #17 from google/fix/output-dir-consistency
aksg87 Aug 2, 2025
13fbd2c
build: add formatting & linting pipeline with pre-commit integration
aksg87 Aug 3, 2025
c8d2027
style: apply pyink, isort, and pre-commit formatting
aksg87 Aug 3, 2025
146a095
ci: enable format and lint checks in tox
aksg87 Aug 3, 2025
aa6da18
Merge pull request #24 from google/feat/code-formatting-pipeline
aksg87 Aug 3, 2025
ed65bca
Add LangExtractError base exception for centralized error handling
aksg87 Aug 3, 2025
6c4508b
Merge pull request #26 from google/feat/exception-hierarchy
aksg87 Aug 3, 2025
8b85225
fix: Remove LangFun and pylibmagic dependencies (v1.0.2)
aksg87 Aug 3, 2025
88520cc
Merge pull request #28 from google/fix/remove-breaking-dep-langfun
aksg87 Aug 3, 2025
75a6f12
Fix save_annotated_documents to handle string paths
aksg87 Aug 3, 2025
a415b94
Merge pull request #29 from google/fix-save-annotated-documents-mkdir
aksg87 Aug 3, 2025
8289b3a
feat: Add OpenAI language model support
aksg87 Aug 3, 2025
c8ef723
Merge pull request #31 from google/feature/add-oai-inference
aksg87 Aug 3, 2025
dfe8188
fix(ui): prevent current highlight border from being obscured. Chan…
tonebeta Aug 4, 2025
87c511e
feat: Add live API integration tests (#39)
aksg87 Aug 4, 2025
dc61372
Add PR template validation workflow (#45)
aksg87 Aug 4, 2025
ee86ae0
add model_url to openai model
aniketchopade Aug 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: Add OpenAI language model support
  • Loading branch information
aksg87 committed Aug 3, 2025
commit 8289b3a5425c13fc3af894b7ea8bb5601681e9e0
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ docker run --rm -e LANGEXTRACT_API_KEY="your-api-key" langextract python your_sc

## API Key Setup for Cloud Models

When using LangExtract with cloud-hosted models (like Gemini), you'll need to
When using LangExtract with cloud-hosted models (like Gemini or OpenAI), you'll need to
set up an API key. On-device models don't require an API key. For developers
using local LLMs, LangExtract offers built-in support for Ollama and can be
extended to other third-party APIs by updating the inference endpoints.
Expand All @@ -201,6 +201,7 @@ Get API keys from:

* [AI Studio](https://aistudio.google.com/app/apikey) for Gemini models
* [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/sdks/overview) for enterprise use
* [OpenAI Platform](https://platform.openai.com/api-keys) for OpenAI models

### Setting up API key in your environment

Expand Down Expand Up @@ -250,6 +251,27 @@ result = lx.extract(
)
```

## Using OpenAI Models

LangExtract also supports OpenAI models. Example OpenAI configuration:

```python
from langextract.inference import OpenAILanguageModel

result = lx.extract(
text_or_documents=input_text,
prompt_description=prompt,
examples=examples,
language_model_type=OpenAILanguageModel,
model_id="gpt-4o",
api_key=os.environ.get('OPENAI_API_KEY'),
fence_output=True,
use_schema_constraints=False
)
```

Note: OpenAI models require `fence_output=True` and `use_schema_constraints=False` because LangExtract doesn't implement schema constraints for OpenAI yet.

## More Examples

Additional examples of LangExtract in action:
Expand Down
170 changes: 169 additions & 1 deletion langextract/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Any

from google import genai
import openai
import requests
from typing_extensions import override
import yaml
Expand Down Expand Up @@ -383,7 +384,174 @@ def infer(
yield [result]

def parse_output(self, output: str) -> Any:
"""Parses Gemini output as JSON or YAML."""
"""Parses Gemini output as JSON or YAML.

Note: This expects raw JSON/YAML without code fences.
Code fence extraction is handled by resolver.py.
"""
try:
if self.format_type == data.FormatType.JSON:
return json.loads(output)
else:
return yaml.safe_load(output)
except Exception as e:
raise ValueError(
f'Failed to parse output as {self.format_type.name}: {str(e)}'
) from e


@dataclasses.dataclass(init=False)
class OpenAILanguageModel(BaseLanguageModel):
"""Language model inference using OpenAI's API with structured output."""

model_id: str = 'gpt-4o-mini'
api_key: str | None = None
organization: str | None = None
format_type: data.FormatType = data.FormatType.JSON
temperature: float = 0.0
max_workers: int = 10
_client: openai.OpenAI | None = dataclasses.field(
default=None, repr=False, compare=False
)
_extra_kwargs: dict[str, Any] = dataclasses.field(
default_factory=dict, repr=False, compare=False
)

def __init__(
self,
model_id: str = 'gpt-4o-mini',
api_key: str | None = None,
organization: str | None = None,
format_type: data.FormatType = data.FormatType.JSON,
temperature: float = 0.0,
max_workers: int = 10,
**kwargs,
) -> None:
"""Initialize the OpenAI language model.

Args:
model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o').
api_key: API key for OpenAI service.
organization: Optional OpenAI organization ID.
format_type: Output format (JSON or YAML).
temperature: Sampling temperature.
max_workers: Maximum number of parallel API calls.
**kwargs: Ignored extra parameters so callers can pass a superset of
arguments shared across back-ends without raising ``TypeError``.
"""
self.model_id = model_id
self.api_key = api_key
self.organization = organization
self.format_type = format_type
self.temperature = temperature
self.max_workers = max_workers
self._extra_kwargs = kwargs or {}

if not self.api_key:
raise ValueError('API key not provided.')

# Initialize the OpenAI client
self._client = openai.OpenAI(
api_key=self.api_key, organization=self.organization
)

super().__init__(
constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE)
)

def _process_single_prompt(self, prompt: str, config: dict) -> ScoredOutput:
"""Process a single prompt and return a ScoredOutput."""
try:
# Prepare the system message for structured output
system_message = ''
if self.format_type == data.FormatType.JSON:
system_message = (
'You are a helpful assistant that responds in JSON format.'
)
elif self.format_type == data.FormatType.YAML:
system_message = (
'You are a helpful assistant that responds in YAML format.'
)

# Create the chat completion using the v1.x client API
response = self._client.chat.completions.create(
model=self.model_id,
messages=[
{'role': 'system', 'content': system_message},
{'role': 'user', 'content': prompt},
],
temperature=config.get('temperature', self.temperature),
max_tokens=config.get('max_output_tokens'),
top_p=config.get('top_p'),
n=1,
)

# Extract the response text using the v1.x response format
output_text = response.choices[0].message.content

return ScoredOutput(score=1.0, output=output_text)

except Exception as e:
raise InferenceOutputError(f'OpenAI API error: {str(e)}') from e

def infer(
self, batch_prompts: Sequence[str], **kwargs
) -> Iterator[Sequence[ScoredOutput]]:
"""Runs inference on a list of prompts via OpenAI's API.

Args:
batch_prompts: A list of string prompts.
**kwargs: Additional generation params (temperature, top_p, etc.)

Yields:
Lists of ScoredOutputs.
"""
config = {
'temperature': kwargs.get('temperature', self.temperature),
}
if 'max_output_tokens' in kwargs:
config['max_output_tokens'] = kwargs['max_output_tokens']
if 'top_p' in kwargs:
config['top_p'] = kwargs['top_p']

# Use parallel processing for batches larger than 1
if len(batch_prompts) > 1 and self.max_workers > 1:
with concurrent.futures.ThreadPoolExecutor(
max_workers=min(self.max_workers, len(batch_prompts))
) as executor:
future_to_index = {
executor.submit(
self._process_single_prompt, prompt, config.copy()
): i
for i, prompt in enumerate(batch_prompts)
}

results: list[ScoredOutput | None] = [None] * len(batch_prompts)
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
try:
results[index] = future.result()
except Exception as e:
raise InferenceOutputError(
f'Parallel inference error: {str(e)}'
) from e

for result in results:
if result is None:
raise InferenceOutputError('Failed to process one or more prompts')
yield [result]
else:
# Sequential processing for single prompt or worker
for prompt in batch_prompts:
result = self._process_single_prompt(prompt, config.copy())
yield [result]

def parse_output(self, output: str) -> Any:
"""Parses OpenAI output as JSON or YAML.

Note: This expects raw JSON/YAML without code fences.
Code fence extraction is handled by resolver.py.
"""
try:
if self.format_type == data.FormatType.JSON:
return json.loads(output)
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "langextract"
version = "1.0.2"
version = "1.0.3"
description = "LangExtract: A library for extracting structured data from language models"
readme = "README.md"
requires-python = ">=3.10"
Expand All @@ -35,10 +35,11 @@ dependencies = [
"ml-collections>=0.1.0",
"more-itertools>=8.0.0",
"numpy>=1.20.0",
"openai>=0.27.0",
"openai>=1.50.0",
"pandas>=1.3.0",
"pydantic>=1.8.0",
"python-dotenv>=0.19.0",
"PyYAML>=6.0",
"requests>=2.25.0",
"tqdm>=4.64.0",
"typing-extensions>=4.0.0"
Expand All @@ -55,9 +56,8 @@ dev = [
"pyink~=24.3.0",
"isort>=5.13.0",
"pylint>=3.0.0",
"pytest>=7.4.0",
"pytype>=2024.10.11",
"tox>=4.0.0",
"tox>=4.0.0"
]
test = [
"pytest>=7.4.0",
Expand Down
116 changes: 116 additions & 0 deletions tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from absl.testing import absltest

from langextract import data
from langextract import inference


Expand Down Expand Up @@ -93,5 +94,120 @@ def test_ollama_infer(self, mock_ollama_query):
self.assertEqual(results, expected_results)


class TestOpenAILanguageModel(absltest.TestCase):

@mock.patch("openai.OpenAI")
def test_openai_infer(self, mock_openai_class):
# Mock the OpenAI client and chat completion response
mock_client = mock.Mock()
mock_openai_class.return_value = mock_client

# Mock response structure for v1.x API
mock_response = mock.Mock()
mock_response.choices = [
mock.Mock(message=mock.Mock(content='{"name": "John", "age": 30}'))
]
mock_client.chat.completions.create.return_value = mock_response

# Create model instance
model = inference.OpenAILanguageModel(
model_id="gpt-4o-mini", api_key="test-api-key", temperature=0.5
)

# Test inference
batch_prompts = ["Extract name and age from: John is 30 years old"]
results = list(model.infer(batch_prompts))

# Verify API was called correctly
mock_client.chat.completions.create.assert_called_once_with(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": (
"You are a helpful assistant that responds in JSON format."
),
},
{
"role": "user",
"content": "Extract name and age from: John is 30 years old",
},
],
temperature=0.5,
max_tokens=None,
top_p=None,
n=1,
)

# Check results
expected_results = [[
inference.ScoredOutput(score=1.0, output='{"name": "John", "age": 30}')
]]
self.assertEqual(results, expected_results)

def test_openai_parse_output_json(self):
model = inference.OpenAILanguageModel(
api_key="test-key", format_type=data.FormatType.JSON
)

# Test valid JSON parsing
output = '{"key": "value", "number": 42}'
parsed = model.parse_output(output)
self.assertEqual(parsed, {"key": "value", "number": 42})

# Test invalid JSON
with self.assertRaises(ValueError) as context:
model.parse_output("invalid json")
self.assertIn("Failed to parse output as JSON", str(context.exception))

def test_openai_parse_output_yaml(self):
model = inference.OpenAILanguageModel(
api_key="test-key", format_type=data.FormatType.YAML
)

# Test valid YAML parsing
output = "key: value\nnumber: 42"
parsed = model.parse_output(output)
self.assertEqual(parsed, {"key": "value", "number": 42})

# Test invalid YAML
with self.assertRaises(ValueError) as context:
model.parse_output("invalid: yaml: bad")
self.assertIn("Failed to parse output as YAML", str(context.exception))

def test_openai_no_api_key_raises_error(self):
with self.assertRaises(ValueError) as context:
inference.OpenAILanguageModel(api_key=None)
self.assertEqual(str(context.exception), "API key not provided.")

@mock.patch("openai.OpenAI")
def test_openai_temperature_zero(self, mock_openai_class):
# Test that temperature=0.0 is properly passed through
mock_client = mock.Mock()
mock_openai_class.return_value = mock_client

mock_response = mock.Mock()
mock_response.choices = [
mock.Mock(message=mock.Mock(content='{"result": "test"}'))
]
mock_client.chat.completions.create.return_value = mock_response

model = inference.OpenAILanguageModel(
api_key="test-key", temperature=0.0 # Testing zero temperature
)

list(model.infer(["test prompt"]))

# Verify temperature=0.0 was passed to the API
mock_client.chat.completions.create.assert_called_with(
model="gpt-4o-mini",
messages=mock.ANY,
temperature=0.0,
max_tokens=None,
top_p=None,
n=1,
)


if __name__ == "__main__":
absltest.main()
Loading