Skip to content

Commit adc20ca

Browse files
authored
Update kwargs handling for provider models (google#138)
* Fix kwargs passthrough for all providers Resolves initialization order bug where provider kwargs were lost. Runtime kwargs now properly override stored defaults. - Add merge_kwargs() to base class - Filter None values to prevent SDK errors - Support provider-specific params (stop, top_p, top_logprobs) - Add test coverage * Update OpenAI JSON mode to not require fence output
1 parent 6375011 commit adc20ca

File tree

6 files changed

+613
-112
lines changed

6 files changed

+613
-112
lines changed

langextract/inference.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import enum
2121
import json
2222
import textwrap
23-
from typing import Any
23+
from typing import Any, Mapping
2424

2525
from absl import logging
2626
from typing_extensions import deprecated
@@ -73,6 +73,7 @@ def __init__(self, constraint: schema.Constraint = schema.Constraint()):
7373
self._constraint = constraint
7474
self._schema: schema.BaseSchema | None = None
7575
self._fence_output_override: bool | None = None
76+
self._extra_kwargs: dict[str, Any] = {}
7677

7778
@classmethod
7879
def get_schema_class(cls) -> type[schema.BaseSchema] | None:
@@ -116,6 +117,23 @@ def requires_fence_output(self) -> bool:
116117
return True
117118
return not self._schema.supports_strict_mode
118119

120+
def merge_kwargs(
121+
self, runtime_kwargs: Mapping[str, Any] | None = None
122+
) -> dict[str, Any]:
123+
"""Merge stored extra kwargs with runtime kwargs.
124+
125+
Runtime kwargs take precedence over stored kwargs.
126+
127+
Args:
128+
runtime_kwargs: Kwargs provided at inference time, or None.
129+
130+
Returns:
131+
Merged kwargs dictionary.
132+
"""
133+
base = getattr(self, '_extra_kwargs', {}) or {}
134+
incoming = dict(runtime_kwargs or {})
135+
return {**base, **incoming}
136+
119137
@abc.abstractmethod
120138
def infer(
121139
self, batch_prompts: Sequence[str], **kwargs

langextract/providers/gemini.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,27 @@
1919

2020
import concurrent.futures
2121
import dataclasses
22-
from typing import Any, Iterator, Sequence
22+
from typing import Any, Final, Iterator, Sequence
2323

2424
from langextract import data
2525
from langextract import exceptions
2626
from langextract import inference
2727
from langextract import schema
2828
from langextract.providers import registry
2929

30+
_API_CONFIG_KEYS: Final[set[str]] = {
31+
'response_mime_type',
32+
'response_schema',
33+
'safety_settings',
34+
'system_instruction',
35+
'tools',
36+
'stop_sequences',
37+
'candidate_count',
38+
}
39+
3040

3141
@registry.register(
32-
r'^gemini', # gemini-2.5-flash, gemini-2.5-pro, etc.
42+
r'^gemini',
3343
priority=10,
3444
)
3545
@dataclasses.dataclass(init=False)
@@ -109,37 +119,31 @@ def __init__(
109119
self.temperature = temperature
110120
self.max_workers = max_workers
111121
self.fence_output = fence_output
112-
api_config_keys = {
113-
'response_schema',
114-
'response_mime_type',
115-
'tools',
116-
'safety_settings',
117-
'stop_sequences',
118-
'candidate_count',
119-
'system_instruction',
120-
}
121-
self._extra_kwargs = {
122-
k: v for k, v in (kwargs or {}).items() if k in api_config_keys
123-
}
124122

125123
if not self.api_key:
126-
raise exceptions.InferenceConfigError('API key not provided for Gemini.')
124+
raise exceptions.InferenceConfigError('API key not provided.')
127125

128126
self._client = genai.Client(api_key=self.api_key)
129127

130128
super().__init__(
131129
constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE)
132130
)
131+
self._extra_kwargs = {
132+
k: v for k, v in (kwargs or {}).items() if k in _API_CONFIG_KEYS
133+
}
133134

134135
def _process_single_prompt(
135136
self, prompt: str, config: dict
136137
) -> inference.ScoredOutput:
137138
"""Process a single prompt and return a ScoredOutput."""
138139
try:
139-
if self._extra_kwargs:
140-
config.update(self._extra_kwargs)
140+
# Apply stored kwargs that weren't already set in config
141+
for key, value in self._extra_kwargs.items():
142+
if key not in config and value is not None:
143+
config[key] = value
144+
141145
if self.gemini_schema:
142-
# Gemini structured output only supports JSON
146+
# Structured output requires JSON format
143147
if self.format_type != data.FormatType.JSON:
144148
raise exceptions.InferenceConfigError(
145149
'Gemini structured output only supports JSON format. '
@@ -149,7 +153,7 @@ def _process_single_prompt(
149153
config.setdefault('response_schema', self.gemini_schema.schema_dict)
150154

151155
response = self._client.models.generate_content(
152-
model=self.model_id, contents=prompt, config=config # type: ignore[arg-type]
156+
model=self.model_id, contents=prompt, config=config
153157
)
154158

155159
return inference.ScoredOutput(score=1.0, output=response.text)
@@ -171,15 +175,26 @@ def infer(
171175
Yields:
172176
Lists of ScoredOutputs.
173177
"""
178+
merged_kwargs = self.merge_kwargs(kwargs)
179+
174180
config = {
175-
'temperature': kwargs.get('temperature', self.temperature),
181+
'temperature': merged_kwargs.get('temperature', self.temperature),
176182
}
177-
if 'max_output_tokens' in kwargs:
178-
config['max_output_tokens'] = kwargs['max_output_tokens']
179-
if 'top_p' in kwargs:
180-
config['top_p'] = kwargs['top_p']
181-
if 'top_k' in kwargs:
182-
config['top_k'] = kwargs['top_k']
183+
if 'max_output_tokens' in merged_kwargs:
184+
config['max_output_tokens'] = merged_kwargs['max_output_tokens']
185+
if 'top_p' in merged_kwargs:
186+
config['top_p'] = merged_kwargs['top_p']
187+
if 'top_k' in merged_kwargs:
188+
config['top_k'] = merged_kwargs['top_k']
189+
190+
handled_keys = {'temperature', 'max_output_tokens', 'top_p', 'top_k'}
191+
for key, value in merged_kwargs.items():
192+
if (
193+
key not in handled_keys
194+
and key in _API_CONFIG_KEYS
195+
and value is not None
196+
):
197+
config[key] = value
183198

184199
# Use parallel processing for batches larger than 1
185200
if len(batch_prompts) > 1 and self.max_workers > 1:

langextract/providers/ollama.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@
8383
from langextract import schema
8484
from langextract.providers import registry
8585

86+
# Ollama defaults
8687
_OLLAMA_DEFAULT_MODEL_URL = 'http://localhost:11434'
88+
_DEFAULT_TEMPERATURE = 0.8
89+
_DEFAULT_TIMEOUT = 30
90+
_DEFAULT_KEEP_ALIVE = 5 * 60 # 5 minutes
91+
_DEFAULT_NUM_CTX = 2048
8792

8893

8994
@registry.register(
@@ -193,8 +198,8 @@ def __init__(
193198
self._model_url = base_url or model_url or _OLLAMA_DEFAULT_MODEL_URL
194199
self.format_type = format_type
195200
self._constraint = constraint
196-
self._extra_kwargs = kwargs or {}
197201
super().__init__(constraint=constraint)
202+
self._extra_kwargs = kwargs or {}
198203

199204
def infer(
200205
self, batch_prompts: Sequence[str], **kwargs
@@ -208,6 +213,8 @@ def infer(
208213
Yields:
209214
Lists of ScoredOutputs.
210215
"""
216+
combined_kwargs = self.merge_kwargs(kwargs)
217+
211218
for prompt in batch_prompts:
212219
try:
213220
response = self._ollama_query(
@@ -217,7 +224,7 @@ def infer(
217224
if self.format_type == data.FormatType.JSON
218225
else 'yaml',
219226
model_url=self._model_url,
220-
**kwargs,
227+
**combined_kwargs,
221228
)
222229
# No score for Ollama. Default to 1.0
223230
yield [inference.ScoredOutput(score=1.0, output=response['response'])]
@@ -230,18 +237,20 @@ def _ollama_query(
230237
self,
231238
prompt: str,
232239
model: str | None = None,
233-
temperature: float = 0.8,
240+
temperature: float | None = None,
234241
seed: int | None = None,
235242
top_k: int | None = None,
243+
top_p: float | None = None,
236244
max_output_tokens: int | None = None,
237245
structured_output_format: str | None = None,
238246
system: str = '',
239247
raw: bool = False,
240248
model_url: str | None = None,
241-
timeout: int = 30,
242-
keep_alive: int = 5 * 60,
249+
timeout: int | None = None,
250+
keep_alive: int | None = None,
243251
num_threads: int | None = None,
244-
num_ctx: int = 2048,
252+
num_ctx: int | None = None,
253+
stop: str | list[str] | None = None,
245254
**kwargs, # pylint: disable=unused-argument
246255
) -> Mapping[str, Any]:
247256
"""Sends a prompt to an Ollama model and returns the generated response.
@@ -257,6 +266,7 @@ def _ollama_query(
257266
output.
258267
seed: Seed for reproducible generation. If None, random seed is used.
259268
top_k: The top-K parameter for sampling.
269+
top_p: The top-P (nucleus) sampling parameter.
260270
max_output_tokens: Maximum tokens to generate. If None, the model's
261271
default is used.
262272
structured_output_format: If set to "json" or a JSON schema dict, requests
@@ -272,6 +282,7 @@ def _ollama_query(
272282
heuristic.
273283
num_ctx: Number of context tokens allowed. If None, uses model's default
274284
or config.
285+
stop: Stop sequences to halt generation. Can be a string or list of strings.
275286
**kwargs: Additional parameters passed through.
276287
277288
Returns:
@@ -291,19 +302,30 @@ def _ollama_query(
291302
'json' if self.format_type == data.FormatType.JSON else 'yaml'
292303
)
293304

294-
options: dict[str, Any] = {'keep_alive': keep_alive}
295-
if seed:
305+
options: dict[str, Any] = {}
306+
if keep_alive is not None:
307+
options['keep_alive'] = keep_alive
308+
else:
309+
options['keep_alive'] = _DEFAULT_KEEP_ALIVE
310+
311+
if seed is not None:
296312
options['seed'] = seed
297-
if temperature:
313+
if temperature is not None:
298314
options['temperature'] = temperature
299-
if top_k:
315+
else:
316+
options['temperature'] = _DEFAULT_TEMPERATURE
317+
if top_k is not None:
300318
options['top_k'] = top_k
301-
if num_threads:
319+
if top_p is not None:
320+
options['top_p'] = top_p
321+
if num_threads is not None:
302322
options['num_thread'] = num_threads
303-
if max_output_tokens:
323+
if max_output_tokens is not None:
304324
options['num_predict'] = max_output_tokens
305-
if num_ctx:
325+
if num_ctx is not None:
306326
options['num_ctx'] = num_ctx
327+
else:
328+
options['num_ctx'] = _DEFAULT_NUM_CTX
307329

308330
api_url = model_url + '/api/generate'
309331

@@ -317,6 +339,12 @@ def _ollama_query(
317339
'options': options,
318340
}
319341

342+
# Add stop sequences if provided (top-level in Ollama API)
343+
if stop is not None:
344+
payload['stop'] = stop
345+
346+
request_timeout = timeout if timeout is not None else _DEFAULT_TIMEOUT
347+
320348
try:
321349
response = self._requests.post(
322350
api_url,
@@ -325,12 +353,12 @@ def _ollama_query(
325353
'Accept': 'application/json',
326354
},
327355
json=payload,
328-
timeout=timeout,
356+
timeout=request_timeout,
329357
)
330358
except self._requests.exceptions.RequestException as e:
331359
if isinstance(e, self._requests.exceptions.ReadTimeout):
332360
msg = (
333-
f'Ollama Model timed out (timeout={timeout},'
361+
f'Ollama Model timed out (timeout={request_timeout},'
334362
f' num_threads={num_threads})'
335363
)
336364
raise exceptions.InferenceRuntimeError(

0 commit comments

Comments
 (0)