Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 261 additions & 5 deletions notebooks/chronos-2-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1730,12 +1730,268 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c899976",
"metadata": {},
"outputs": [],
"source": []
"cell_type": "markdown",
"source": [
"## Custom Group IDs: Examples and Use Cases\n",
"\n",
"Custom group IDs let you control how Chronos-2 shares information between series during prediction:\n",
"- Default (no group_ids): each series is predicted independently (no cross-series sharing).\n",
"- cross_learning=True: all series in the batch are jointly predicted and share information.\n",
"- Custom group_ids: only series within the same group share information; groups remain independent.\n",
"\n",
"Use custom group IDs when you know meaningful clusters (e.g., geography, sector, etc.). This can boost accuracy, especially for short or noisy series, while avoiding contamination from unrelated series.\n"
],
"id": "e504694b8cfab326"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-12-09T13:39:45.846727Z",
"start_time": "2025-12-09T13:39:45.827756Z"
}
},
"cell_type": "code",
"source": [
"# Simulate ~30 weather stations with regional clustering (North, South, Coastal)\n",
"import numpy as np, pandas as pd\n",
"np.random.seed(123)\n",
"n = 200\n",
"prediction_length = 24\n",
"ts = pd.date_range('2020-01-01', periods=n, freq='D')\n",
"\n",
"north_ids = [f'station_north_{i+1}' for i in range(10)]\n",
"south_ids = [f'station_south_{i+1}' for i in range(10)]\n",
"coast_ids = [f'station_coast_{i+1}' for i in range(10)]\n",
"all_ids = north_ids + south_ids + coast_ids\n",
"regions = (['North']*len(north_ids) + ['South']*len(south_ids) + ['Coastal']*len(coast_ids))\n",
"\n",
"def synth_series(base, amp, noise, phase=0.0):\n",
" t = np.arange(n)\n",
" signal = base + amp*np.sin(2*np.pi*t/365.0 + phase)\n",
" return (signal + noise*np.random.randn(n)).astype('float32')\n",
"\n",
"data_frames = []\n",
"for sid, region in zip(all_ids, regions):\n",
" if region == 'North':\n",
" y = synth_series(base=5, amp=6, noise=0.8, phase=0.3)\n",
" elif region == 'South':\n",
" y = synth_series(base=18, amp=8, noise=0.8, phase=0.9)\n",
" else: # Coastal\n",
" y = synth_series(base=12, amp=3.5, noise=0.3, phase=0.5)\n",
" df_i = pd.DataFrame({'item_id': sid, 'timestamp': ts, 'target': y, 'region': region})\n",
" data_frames.append(df_i)\n",
"\n",
"weather_df = pd.concat(data_frames, ignore_index=True)\n",
"weather_df.head()"
],
"id": "326d90e1a05586e8",
"outputs": [
{
"data": {
"text/plain": [
" item_id timestamp target region\n",
"0 station_north_1 2020-01-01 5.904617 North\n",
"1 station_north_1 2020-01-02 7.669402 North\n",
"2 station_north_1 2020-01-03 7.195759 North\n",
"3 station_north_1 2020-01-04 5.861607 North\n",
"4 station_north_1 2020-01-05 6.700416 North"
],
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>item_id</th>\n",
" <th>timestamp</th>\n",
" <th>target</th>\n",
" <th>region</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-01</td>\n",
" <td>5.904617</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-02</td>\n",
" <td>7.669402</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-03</td>\n",
" <td>7.195759</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-04</td>\n",
" <td>5.861607</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-05</td>\n",
" <td>6.700416</td>\n",
" <td>North</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-12-09T13:39:49.624289Z",
"start_time": "2025-12-09T13:39:49.251047Z"
}
},
"cell_type": "code",
"source": [
"# 1) Using predict_df with create_group_ids_dict_from_category (group by 'region')\n",
"from chronos import (\n",
" create_group_ids_dict_from_category,\n",
" create_group_ids_from_category\n",
")\n",
"\n",
"group_ids_cat = create_group_ids_dict_from_category(\n",
" df=weather_df,\n",
" id_column='item_id',\n",
" category_column='region'\n",
")\n",
"\n",
"# Split into context (train) and future truth (test)\n",
"test_df = weather_df.groupby('item_id').tail(prediction_length)\n",
"future_df = test_df.drop(columns=['target', 'region']).copy()\n",
"train_df = weather_df.drop(test_df.index).drop(columns=[\"region\"])\n",
"\n",
"pred_df_ids = pipeline.predict_df(\n",
" df=train_df,\n",
" future_df=future_df,\n",
" id_column='item_id',\n",
" timestamp_column='timestamp',\n",
" target='target',\n",
" prediction_length=prediction_length,\n",
" group_ids=group_ids_cat,\n",
" quantile_levels=[0.5],\n",
")\n",
"\n",
"# Compute MSE/MAE against ground truth\n",
"eval_df = pred_df_ids.merge(\n",
" test_df[['item_id', 'timestamp', 'target']],\n",
" on=['item_id', 'timestamp'], how='inner',\n",
")\n",
"y_true = eval_df['target'].to_numpy()\n",
"y_pred = eval_df['predictions'].to_numpy()\n",
"mse_ids = float(np.mean((y_pred - y_true) ** 2))\n",
"mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n",
"print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')"
],
"id": "9113a0f64f257772",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE=0.523, MAE=0.537\n"
]
}
],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-12-09T13:39:53.070369Z",
"start_time": "2025-12-09T13:39:52.884440Z"
}
},
"cell_type": "code",
"source": [
"# 2) Low-level API: predict_quantiles with group_ids list\n",
"\n",
"# Create group IDs as a list aligned with series_ids\n",
"group_ids = create_group_ids_from_category(\n",
"\t train=weather_df,\n",
"\t id_column=\"item_id\",\n",
"\t category_column=\"region\"\n",
"\t)\n",
"\n",
"# Build inputs in the same order as series_ids\n",
"series_ids = train_df['item_id'].unique().tolist()\n",
"inputs_list = [\n",
" {'target': train_df.loc[train_df['item_id']==sid, 'target'].to_numpy(dtype=np.float32) }\n",
" for sid in series_ids\n",
"]\n",
"\n",
"_, mean = pipeline.predict_quantiles(\n",
" inputs=inputs_list,\n",
" prediction_length=prediction_length,\n",
" quantile_levels=[0.5],\n",
" group_ids=group_ids,\n",
")\n",
"\n",
"pred_df_ids = []\n",
"for id in series_ids:\n",
" test_id_df = test_df[test_df['item_id']==id].copy()\n",
" test_id_df['predictions'] = mean[series_ids.index(id)].numpy().flatten()\n",
" test_id_df.drop(columns=['target'], inplace=True)\n",
" pred_df_ids.append(test_id_df)\n",
"pred_df_ids = pd.concat(pred_df_ids, ignore_index=True)\n",
"\n",
"# Compute MSE/MAE against ground truth\n",
"eval_df = pred_df_ids.merge(\n",
" test_df[['item_id', 'timestamp', 'target']],\n",
" on=['item_id', 'timestamp'], how='inner',\n",
")\n",
"y_true = eval_df['target'].to_numpy()\n",
"y_pred = eval_df['predictions'].to_numpy()\n",
"mse_ids = float(np.mean((y_pred - y_true) ** 2))\n",
"mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n",
"print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')"
],
"id": "629bab4c4ea6fe53",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE=0.523, MAE=0.537\n"
]
}
],
"execution_count": 5
}
],
"metadata": {
Expand Down
9 changes: 9 additions & 0 deletions src/chronos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
)
from .chronos2 import Chronos2ForecastingConfig, Chronos2Model, Chronos2Pipeline
from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline
from .utils import (
create_group_ids_dict_from_category,
create_group_ids_dict_from_mapping,
create_manual_group_ids_dict,
create_group_ids_from_category
)

__all__ = [
"__version__",
Expand All @@ -27,4 +33,7 @@
"Chronos2ForecastingConfig",
"Chronos2Model",
"Chronos2Pipeline",
"create_group_ids_dict_from_category",
"create_group_ids_dict_from_mapping",
"create_manual_group_ids_dict",
]
Loading