Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update InferenceJobOutput schema type and tweak run_inference metrics…
… calc code
  • Loading branch information
peteski22 committed Mar 27, 2025
commit acd8bbbf46d2d1346be1f6ed5aaa3b608b963701
16 changes: 11 additions & 5 deletions lumigator/jobs/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,18 @@ def run_inference(config: InferenceJobConfig, api_key: str | None = None) -> Pat
output["inference_time"] = inference_time

artifacts = InferenceJobOutput.model_validate(output)

# Only attempt to metric calculate averages if we have a metric for EVERY prediction result.
if all(p.metrics is not None for p in prediction_results):
avg_prompt_tokens = sum([p.metrics.prompt_tokens for p in prediction_results]) / len(prediction_results)
avg_total_tokens = sum([p.metrics.total_tokens for p in prediction_results]) / len(prediction_results)
avg_completion_tokens = sum([p.metrics.completion_tokens for p in prediction_results]) / len(prediction_results)
avg_reasoning_tokens = sum([p.metrics.reasoning_tokens for p in prediction_results]) / len(prediction_results)
avg_answer_tokens = sum([p.metrics.answer_tokens for p in prediction_results]) / len(prediction_results)
total_results = len(prediction_results)

avg_prompt_tokens = sum(p.metrics.prompt_tokens for p in prediction_results) / total_results
avg_total_tokens = sum(p.metrics.total_tokens for p in prediction_results) / total_results
avg_completion_tokens = sum(p.metrics.completion_tokens for p in prediction_results) / total_results
# Provide a default for optional fields so we don't affect the average.
avg_reasoning_tokens = sum((p.metrics.reasoning_tokens or 0) for p in prediction_results) / total_results
avg_answer_tokens = sum((p.metrics.answer_tokens or 0) for p in prediction_results) / total_results

metrics = AverageInferenceMetrics(
avg_prompt_tokens=avg_prompt_tokens,
avg_total_tokens=avg_total_tokens,
Expand Down
2 changes: 1 addition & 1 deletion lumigator/jobs/inference/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class InferenceJobOutput(BaseModel):
ground_truth: list | None = None
model: str
inference_time: float
inference_metrics: list[InferenceMetrics] | list[None] = None
inference_metrics: list[InferenceMetrics | None] = []


class PredictionResult(BaseModel):
Expand Down