Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add custom group_ids support to Chronos2Pipeline
  • Loading branch information
StatMixedML committed Dec 9, 2025
commit d5b5282f2a570e3b96a297dba39584edfac50be1
266 changes: 261 additions & 5 deletions notebooks/chronos-2-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1730,12 +1730,268 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c899976",
"metadata": {},
"outputs": [],
"source": []
"cell_type": "markdown",
"source": [
"## Custom Group IDs: Examples and Use Cases\n",
"\n",
"Custom group IDs let you control how Chronos-2 shares information between series during prediction:\n",
"- Default (no group_ids): each series is predicted independently (no cross-series sharing).\n",
"- cross_learning=True: all series in the batch are jointly predicted and share information.\n",
"- Custom group_ids: only series within the same group share information; groups remain independent.\n",
"\n",
"Use custom group IDs when you know meaningful clusters (e.g., geography, sector, etc.). This can boost accuracy, especially for short or noisy series, while avoiding contamination from unrelated series.\n"
],
"id": "e504694b8cfab326"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-12-09T13:39:45.846727Z",
"start_time": "2025-12-09T13:39:45.827756Z"
}
},
"cell_type": "code",
"source": [
"# Simulate ~30 weather stations with regional clustering (North, South, Coastal)\n",
"import numpy as np, pandas as pd\n",
"np.random.seed(123)\n",
"n = 200\n",
"prediction_length = 24\n",
"ts = pd.date_range('2020-01-01', periods=n, freq='D')\n",
"\n",
"north_ids = [f'station_north_{i+1}' for i in range(10)]\n",
"south_ids = [f'station_south_{i+1}' for i in range(10)]\n",
"coast_ids = [f'station_coast_{i+1}' for i in range(10)]\n",
"all_ids = north_ids + south_ids + coast_ids\n",
"regions = (['North']*len(north_ids) + ['South']*len(south_ids) + ['Coastal']*len(coast_ids))\n",
"\n",
"def synth_series(base, amp, noise, phase=0.0):\n",
" t = np.arange(n)\n",
" signal = base + amp*np.sin(2*np.pi*t/365.0 + phase)\n",
" return (signal + noise*np.random.randn(n)).astype('float32')\n",
"\n",
"data_frames = []\n",
"for sid, region in zip(all_ids, regions):\n",
" if region == 'North':\n",
" y = synth_series(base=5, amp=6, noise=0.8, phase=0.3)\n",
" elif region == 'South':\n",
" y = synth_series(base=18, amp=8, noise=0.8, phase=0.9)\n",
" else: # Coastal\n",
" y = synth_series(base=12, amp=3.5, noise=0.3, phase=0.5)\n",
" df_i = pd.DataFrame({'item_id': sid, 'timestamp': ts, 'target': y, 'region': region})\n",
" data_frames.append(df_i)\n",
"\n",
"weather_df = pd.concat(data_frames, ignore_index=True)\n",
"weather_df.head()"
],
"id": "326d90e1a05586e8",
"outputs": [
{
"data": {
"text/plain": [
" item_id timestamp target region\n",
"0 station_north_1 2020-01-01 5.904617 North\n",
"1 station_north_1 2020-01-02 7.669402 North\n",
"2 station_north_1 2020-01-03 7.195759 North\n",
"3 station_north_1 2020-01-04 5.861607 North\n",
"4 station_north_1 2020-01-05 6.700416 North"
],
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>item_id</th>\n",
" <th>timestamp</th>\n",
" <th>target</th>\n",
" <th>region</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-01</td>\n",
" <td>5.904617</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-02</td>\n",
" <td>7.669402</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-03</td>\n",
" <td>7.195759</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-04</td>\n",
" <td>5.861607</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-05</td>\n",
" <td>6.700416</td>\n",
" <td>North</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-12-09T13:39:49.624289Z",
"start_time": "2025-12-09T13:39:49.251047Z"
}
},
"cell_type": "code",
"source": [
"# 1) Using predict_df with create_group_ids_dict_from_category (group by 'region')\n",
"from chronos import (\n",
" create_group_ids_dict_from_category,\n",
" create_group_ids_from_category\n",
")\n",
"\n",
"group_ids_cat = create_group_ids_dict_from_category(\n",
" df=weather_df,\n",
" id_column='item_id',\n",
" category_column='region'\n",
")\n",
"\n",
"# Split into context (train) and future truth (test)\n",
"test_df = weather_df.groupby('item_id').tail(prediction_length)\n",
"future_df = test_df.drop(columns=['target', 'region']).copy()\n",
"train_df = weather_df.drop(test_df.index).drop(columns=[\"region\"])\n",
"\n",
"pred_df_ids = pipeline.predict_df(\n",
" df=train_df,\n",
" future_df=future_df,\n",
" id_column='item_id',\n",
" timestamp_column='timestamp',\n",
" target='target',\n",
" prediction_length=prediction_length,\n",
" group_ids=group_ids_cat,\n",
" quantile_levels=[0.5],\n",
")\n",
"\n",
"# Compute MSE/MAE against ground truth\n",
"eval_df = pred_df_ids.merge(\n",
" test_df[['item_id', 'timestamp', 'target']],\n",
" on=['item_id', 'timestamp'], how='inner',\n",
")\n",
"y_true = eval_df['target'].to_numpy()\n",
"y_pred = eval_df['predictions'].to_numpy()\n",
"mse_ids = float(np.mean((y_pred - y_true) ** 2))\n",
"mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n",
"print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')"
],
"id": "9113a0f64f257772",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE=0.523, MAE=0.537\n"
]
}
],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-12-09T13:39:53.070369Z",
"start_time": "2025-12-09T13:39:52.884440Z"
}
},
"cell_type": "code",
"source": [
"# 2) Low-level API: predict_quantiles with group_ids list\n",
"\n",
"# Create group IDs as a list aligned with series_ids\n",
"group_ids = create_group_ids_from_category(\n",
"\t train=weather_df,\n",
"\t id_column=\"item_id\",\n",
"\t category_column=\"region\"\n",
"\t)\n",
"\n",
"# Build inputs in the same order as series_ids\n",
"series_ids = train_df['item_id'].unique().tolist()\n",
"inputs_list = [\n",
" {'target': train_df.loc[train_df['item_id']==sid, 'target'].to_numpy(dtype=np.float32) }\n",
" for sid in series_ids\n",
"]\n",
"\n",
"_, mean = pipeline.predict_quantiles(\n",
" inputs=inputs_list,\n",
" prediction_length=prediction_length,\n",
" quantile_levels=[0.5],\n",
" group_ids=group_ids,\n",
")\n",
"\n",
"pred_df_ids = []\n",
"for id in series_ids:\n",
" test_id_df = test_df[test_df['item_id']==id].copy()\n",
" test_id_df['predictions'] = mean[series_ids.index(id)].numpy().flatten()\n",
" test_id_df.drop(columns=['target'], inplace=True)\n",
" pred_df_ids.append(test_id_df)\n",
"pred_df_ids = pd.concat(pred_df_ids, ignore_index=True)\n",
"\n",
"# Compute MSE/MAE against ground truth\n",
"eval_df = pred_df_ids.merge(\n",
" test_df[['item_id', 'timestamp', 'target']],\n",
" on=['item_id', 'timestamp'], how='inner',\n",
")\n",
"y_true = eval_df['target'].to_numpy()\n",
"y_pred = eval_df['predictions'].to_numpy()\n",
"mse_ids = float(np.mean((y_pred - y_true) ** 2))\n",
"mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n",
"print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')"
],
"id": "629bab4c4ea6fe53",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE=0.523, MAE=0.537\n"
]
}
],
"execution_count": 5
}
],
"metadata": {
Expand Down
9 changes: 9 additions & 0 deletions src/chronos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
)
from .chronos2 import Chronos2ForecastingConfig, Chronos2Model, Chronos2Pipeline
from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline
from .utils import (
create_group_ids_dict_from_category,
create_group_ids_dict_from_mapping,
create_manual_group_ids_dict,
create_group_ids_from_category
)

__all__ = [
"__version__",
Expand All @@ -27,4 +33,7 @@
"Chronos2ForecastingConfig",
"Chronos2Model",
"Chronos2Pipeline",
"create_group_ids_dict_from_category",
"create_group_ids_dict_from_mapping",
"create_manual_group_ids_dict",
]
Loading