|
5 | 5 | import logging
|
6 | 6 | import re
|
7 | 7 | from dataclasses import dataclass
|
8 |
| -from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints |
| 8 | +from typing import Annotated, Any, Callable, Literal, get_args, get_origin, get_type_hints |
9 | 9 |
|
10 | 10 | from griffe import Docstring, DocstringSectionKind
|
11 | 11 | from pydantic import BaseModel, Field, create_model
|
@@ -185,6 +185,31 @@ def generate_func_documentation(
|
185 | 185 | )
|
186 | 186 |
|
187 | 187 |
|
| 188 | +def _strip_annotated(annotation: Any) -> tuple[Any, tuple[Any, ...]]: |
| 189 | + """Returns the underlying annotation and any metadata from typing.Annotated.""" |
| 190 | + |
| 191 | + metadata: tuple[Any, ...] = () |
| 192 | + ann = annotation |
| 193 | + |
| 194 | + while get_origin(ann) is Annotated: |
| 195 | + args = get_args(ann) |
| 196 | + if not args: |
| 197 | + break |
| 198 | + ann = args[0] |
| 199 | + metadata = (*metadata, *args[1:]) |
| 200 | + |
| 201 | + return ann, metadata |
| 202 | + |
| 203 | + |
| 204 | +def _extract_description_from_metadata(metadata: tuple[Any, ...]) -> str | None: |
| 205 | + """Extracts a human readable description from Annotated metadata if present.""" |
| 206 | + |
| 207 | + for item in metadata: |
| 208 | + if isinstance(item, str): |
| 209 | + return item |
| 210 | + return None |
| 211 | + |
| 212 | + |
188 | 213 | def function_schema(
|
189 | 214 | func: Callable[..., Any],
|
190 | 215 | docstring_style: DocstringStyle | None = None,
|
@@ -219,17 +244,34 @@ def function_schema(
|
219 | 244 | # 1. Grab docstring info
|
220 | 245 | if use_docstring_info:
|
221 | 246 | doc_info = generate_func_documentation(func, docstring_style)
|
222 |
| - param_descs = doc_info.param_descriptions or {} |
| 247 | + param_descs = dict(doc_info.param_descriptions or {}) |
223 | 248 | else:
|
224 | 249 | doc_info = None
|
225 | 250 | param_descs = {}
|
226 | 251 |
|
| 252 | + type_hints_with_extras = get_type_hints(func, include_extras=True) |
| 253 | + type_hints: dict[str, Any] = {} |
| 254 | + annotated_param_descs: dict[str, str] = {} |
| 255 | + |
| 256 | + for name, annotation in type_hints_with_extras.items(): |
| 257 | + if name == "return": |
| 258 | + continue |
| 259 | + |
| 260 | + stripped_ann, metadata = _strip_annotated(annotation) |
| 261 | + type_hints[name] = stripped_ann |
| 262 | + |
| 263 | + description = _extract_description_from_metadata(metadata) |
| 264 | + if description is not None: |
| 265 | + annotated_param_descs[name] = description |
| 266 | + |
| 267 | + for name, description in annotated_param_descs.items(): |
| 268 | + param_descs.setdefault(name, description) |
| 269 | + |
227 | 270 | # Ensure name_override takes precedence even if docstring info is disabled.
|
228 | 271 | func_name = name_override or (doc_info.name if doc_info else func.__name__)
|
229 | 272 |
|
230 | 273 | # 2. Inspect function signature and get type hints
|
231 | 274 | sig = inspect.signature(func)
|
232 |
| - type_hints = get_type_hints(func) |
233 | 275 | params = list(sig.parameters.items())
|
234 | 276 | takes_context = False
|
235 | 277 | filtered_params = []
|
|
0 commit comments