diff --git a/notebooks/chronos-2-quickstart.ipynb b/notebooks/chronos-2-quickstart.ipynb
index d2489c17..89baed5a 100644
--- a/notebooks/chronos-2-quickstart.ipynb
+++ b/notebooks/chronos-2-quickstart.ipynb
@@ -1730,12 +1730,268 @@
]
},
{
- "cell_type": "code",
- "execution_count": null,
- "id": "7c899976",
"metadata": {},
- "outputs": [],
- "source": []
+ "cell_type": "markdown",
+ "source": [
+ "## Custom Group IDs: Examples and Use Cases\n",
+ "\n",
+ "Custom group IDs let you control how Chronos-2 shares information between series during prediction:\n",
+ "- Default (no group_ids): each series is predicted independently (no cross-series sharing).\n",
+ "- cross_learning=True: all series in the batch are jointly predicted and share information.\n",
+ "- Custom group_ids: only series within the same group share information; groups remain independent.\n",
+ "\n",
+ "Use custom group IDs when you know meaningful clusters (e.g., geography, sector, etc.). This can boost accuracy, especially for short or noisy series, while avoiding contamination from unrelated series.\n"
+ ],
+ "id": "e504694b8cfab326"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-12-09T13:39:45.846727Z",
+ "start_time": "2025-12-09T13:39:45.827756Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Simulate ~30 weather stations with regional clustering (North, South, Coastal)\n",
+ "import numpy as np, pandas as pd\n",
+ "np.random.seed(123)\n",
+ "n = 200\n",
+ "prediction_length = 24\n",
+ "ts = pd.date_range('2020-01-01', periods=n, freq='D')\n",
+ "\n",
+ "north_ids = [f'station_north_{i+1}' for i in range(10)]\n",
+ "south_ids = [f'station_south_{i+1}' for i in range(10)]\n",
+ "coast_ids = [f'station_coast_{i+1}' for i in range(10)]\n",
+ "all_ids = north_ids + south_ids + coast_ids\n",
+ "regions = (['North']*len(north_ids) + ['South']*len(south_ids) + ['Coastal']*len(coast_ids))\n",
+ "\n",
+ "def synth_series(base, amp, noise, phase=0.0):\n",
+ " t = np.arange(n)\n",
+ " signal = base + amp*np.sin(2*np.pi*t/365.0 + phase)\n",
+ " return (signal + noise*np.random.randn(n)).astype('float32')\n",
+ "\n",
+ "data_frames = []\n",
+ "for sid, region in zip(all_ids, regions):\n",
+ " if region == 'North':\n",
+ " y = synth_series(base=5, amp=6, noise=0.8, phase=0.3)\n",
+ " elif region == 'South':\n",
+ " y = synth_series(base=18, amp=8, noise=0.8, phase=0.9)\n",
+ " else: # Coastal\n",
+ " y = synth_series(base=12, amp=3.5, noise=0.3, phase=0.5)\n",
+ " df_i = pd.DataFrame({'item_id': sid, 'timestamp': ts, 'target': y, 'region': region})\n",
+ " data_frames.append(df_i)\n",
+ "\n",
+ "weather_df = pd.concat(data_frames, ignore_index=True)\n",
+ "weather_df.head()"
+ ],
+ "id": "326d90e1a05586e8",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " item_id timestamp target region\n",
+ "0 station_north_1 2020-01-01 5.904617 North\n",
+ "1 station_north_1 2020-01-02 7.669402 North\n",
+ "2 station_north_1 2020-01-03 7.195759 North\n",
+ "3 station_north_1 2020-01-04 5.861607 North\n",
+ "4 station_north_1 2020-01-05 6.700416 North"
+ ],
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " item_id | \n",
+ " timestamp | \n",
+ " target | \n",
+ " region | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " station_north_1 | \n",
+ " 2020-01-01 | \n",
+ " 5.904617 | \n",
+ " North | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " station_north_1 | \n",
+ " 2020-01-02 | \n",
+ " 7.669402 | \n",
+ " North | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " station_north_1 | \n",
+ " 2020-01-03 | \n",
+ " 7.195759 | \n",
+ " North | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " station_north_1 | \n",
+ " 2020-01-04 | \n",
+ " 5.861607 | \n",
+ " North | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " station_north_1 | \n",
+ " 2020-01-05 | \n",
+ " 6.700416 | \n",
+ " North | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 3
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-12-09T13:39:49.624289Z",
+ "start_time": "2025-12-09T13:39:49.251047Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# 1) Using predict_df with create_group_ids_dict_from_category (group by 'region')\n",
+ "from chronos import (\n",
+ " create_group_ids_dict_from_category,\n",
+ " create_group_ids_from_category\n",
+ ")\n",
+ "\n",
+ "group_ids_cat = create_group_ids_dict_from_category(\n",
+ " df=weather_df,\n",
+ " id_column='item_id',\n",
+ " category_column='region'\n",
+ ")\n",
+ "\n",
+ "# Split into context (train) and future truth (test)\n",
+ "test_df = weather_df.groupby('item_id').tail(prediction_length)\n",
+ "future_df = test_df.drop(columns=['target', 'region']).copy()\n",
+ "train_df = weather_df.drop(test_df.index).drop(columns=[\"region\"])\n",
+ "\n",
+ "pred_df_ids = pipeline.predict_df(\n",
+ " df=train_df,\n",
+ " future_df=future_df,\n",
+ " id_column='item_id',\n",
+ " timestamp_column='timestamp',\n",
+ " target='target',\n",
+ " prediction_length=prediction_length,\n",
+ " group_ids=group_ids_cat,\n",
+ " quantile_levels=[0.5],\n",
+ ")\n",
+ "\n",
+ "# Compute MSE/MAE against ground truth\n",
+ "eval_df = pred_df_ids.merge(\n",
+ " test_df[['item_id', 'timestamp', 'target']],\n",
+ " on=['item_id', 'timestamp'], how='inner',\n",
+ ")\n",
+ "y_true = eval_df['target'].to_numpy()\n",
+ "y_pred = eval_df['predictions'].to_numpy()\n",
+ "mse_ids = float(np.mean((y_pred - y_true) ** 2))\n",
+ "mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n",
+ "print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')"
+ ],
+ "id": "9113a0f64f257772",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MSE=0.523, MAE=0.537\n"
+ ]
+ }
+ ],
+ "execution_count": 4
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-12-09T13:39:53.070369Z",
+ "start_time": "2025-12-09T13:39:52.884440Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# 2) Low-level API: predict_quantiles with group_ids list\n",
+ "\n",
+ "# Create group IDs as a list aligned with series_ids\n",
+ "group_ids = create_group_ids_from_category(\n",
+ "\t train=weather_df,\n",
+ "\t id_column=\"item_id\",\n",
+ "\t category_column=\"region\"\n",
+ "\t)\n",
+ "\n",
+ "# Build inputs in the same order as series_ids\n",
+ "series_ids = train_df['item_id'].unique().tolist()\n",
+ "inputs_list = [\n",
+ " {'target': train_df.loc[train_df['item_id']==sid, 'target'].to_numpy(dtype=np.float32) }\n",
+ " for sid in series_ids\n",
+ "]\n",
+ "\n",
+ "_, mean = pipeline.predict_quantiles(\n",
+ " inputs=inputs_list,\n",
+ " prediction_length=prediction_length,\n",
+ " quantile_levels=[0.5],\n",
+ " group_ids=group_ids,\n",
+ ")\n",
+ "\n",
+ "pred_df_ids = []\n",
+ "for id in series_ids:\n",
+ " test_id_df = test_df[test_df['item_id']==id].copy()\n",
+ " test_id_df['predictions'] = mean[series_ids.index(id)].numpy().flatten()\n",
+ " test_id_df.drop(columns=['target'], inplace=True)\n",
+ " pred_df_ids.append(test_id_df)\n",
+ "pred_df_ids = pd.concat(pred_df_ids, ignore_index=True)\n",
+ "\n",
+ "# Compute MSE/MAE against ground truth\n",
+ "eval_df = pred_df_ids.merge(\n",
+ " test_df[['item_id', 'timestamp', 'target']],\n",
+ " on=['item_id', 'timestamp'], how='inner',\n",
+ ")\n",
+ "y_true = eval_df['target'].to_numpy()\n",
+ "y_pred = eval_df['predictions'].to_numpy()\n",
+ "mse_ids = float(np.mean((y_pred - y_true) ** 2))\n",
+ "mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n",
+ "print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')"
+ ],
+ "id": "629bab4c4ea6fe53",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MSE=0.523, MAE=0.537\n"
+ ]
+ }
+ ],
+ "execution_count": 5
}
],
"metadata": {
diff --git a/src/chronos/__init__.py b/src/chronos/__init__.py
index 75d0bfbd..68019988 100644
--- a/src/chronos/__init__.py
+++ b/src/chronos/__init__.py
@@ -12,6 +12,12 @@
)
from .chronos2 import Chronos2ForecastingConfig, Chronos2Model, Chronos2Pipeline
from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline
+from .utils import (
+ create_group_ids_dict_from_category,
+ create_group_ids_dict_from_mapping,
+ create_manual_group_ids_dict,
+ create_group_ids_from_category
+)
__all__ = [
"__version__",
@@ -27,4 +33,7 @@
"Chronos2ForecastingConfig",
"Chronos2Model",
"Chronos2Pipeline",
+ "create_group_ids_dict_from_category",
+ "create_group_ids_dict_from_mapping",
+ "create_manual_group_ids_dict",
]
diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py
index 3eddcd7f..67b8a0be 100644
--- a/src/chronos/chronos2/pipeline.py
+++ b/src/chronos/chronos2/pipeline.py
@@ -448,6 +448,7 @@ def predict(
batch_size: int = 256,
context_length: int | None = None,
cross_learning: bool = False,
+ group_ids: list[int] | torch.Tensor | None = None,
limit_prediction_length: bool = False,
**kwargs,
) -> list[torch.Tensor]:
@@ -540,6 +541,12 @@ def predict(
- Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining.
For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
- Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
+ group_ids
+ Optional custom group IDs to control information sharing between time series.
+ If provided, must be a list or tensor of integers with length equal to the number of tasks in `inputs`.
+ Tasks with the same group ID will share information during prediction via cross-attention.
+ Cannot be used together with `cross_learning=True`. By default None (each task gets unique group ID).
+ Example: [0, 0, 1] means first two tasks share information, third is separate.
limit_prediction_length
If True, an error is raised when prediction_length is greater than model's default prediction length, by default False
@@ -561,6 +568,40 @@ def predict(
stacklevel=2,
)
cross_learning = kwargs.pop("predict_batches_jointly")
+
+ # Validate group_ids and cross_learning interaction
+ if group_ids is not None and cross_learning:
+ raise ValueError(
+ "Cannot specify both `group_ids` and `cross_learning=True`. "
+ "Use `group_ids` to define custom groups, or `cross_learning=True` to enable full batch-wide learning."
+ )
+
+ # Convert group_ids to tensor if provided
+ custom_group_ids_tensor = None
+ if group_ids is not None:
+ if isinstance(group_ids, list):
+ # Strict type check: only integers are allowed in the list
+ if not all(isinstance(x, (int, np.integer)) for x in group_ids):
+ raise TypeError("`group_ids` list must contain only integers")
+ if any(x < 0 for x in group_ids):
+ raise ValueError("`group_ids` must contain only non-negative integers")
+ custom_group_ids_tensor = torch.tensor(group_ids, dtype=torch.long)
+ elif isinstance(group_ids, torch.Tensor):
+ # Enforce integer dtype for tensor inputs
+ if torch.is_floating_point(group_ids) or group_ids.dtype == torch.bool:
+ raise TypeError("`group_ids` tensor must have an integer dtype")
+ if (group_ids < 0).any():
+ raise ValueError("`group_ids` must contain only non-negative integers")
+ custom_group_ids_tensor = group_ids.to(dtype=torch.long).clone()
+ else:
+ raise TypeError(f"`group_ids` must be a list or torch.Tensor, got {type(group_ids)}")
+
+ # Validate length matches number of tasks
+ if len(custom_group_ids_tensor) != len(inputs):
+ raise ValueError(
+ f"`group_ids` length ({len(custom_group_ids_tensor)}) must match number of tasks in inputs ({len(inputs)})"
+ )
+
# The maximum number of output patches to generate in a single forward pass before the long-horizon heuristic kicks in. Note: A value larger
# than the model's default max_output_patches may lead to degradation in forecast accuracy, defaults to a model-specific value
max_output_patches = kwargs.pop("max_output_patches", self.max_output_patches)
@@ -613,6 +654,9 @@ def predict(
)
all_predictions: list[torch.Tensor] = []
+ # Track the current task index for custom group ID mapping
+ current_task_idx = 0
+
for batch in test_loader:
assert batch["future_target"] is None
batch_context = batch["context"]
@@ -620,7 +664,38 @@ def predict(
batch_future_covariates = batch["future_covariates"]
batch_target_idx_ranges = batch["target_idx_ranges"]
- if cross_learning:
+ # Apply custom group IDs if provided
+ if custom_group_ids_tensor is not None:
+ # Determine how many tasks are in this batch
+ num_tasks_in_batch = len(batch_target_idx_ranges)
+
+ # The key insight: batch_group_ids already maps variates to tasks
+ # We just need to replace the task IDs with our custom ones
+
+ # Create a mapping from old task IDs to new task IDs
+ old_group_ids = batch_group_ids.cpu().numpy()
+ # Preserve first-appearance order of group IDs (robust to non-consecutive IDs)
+ seen = set()
+ unique_old_ids_list: list[int] = []
+ for gid in old_group_ids.tolist():
+ if gid not in seen:
+ seen.add(int(gid))
+ unique_old_ids_list.append(int(gid))
+
+ # Map old group IDs to task indices (0, 1, 2, ..., num_tasks_in_batch-1)
+ # Then map those to custom group IDs
+ new_group_ids = old_group_ids.copy()
+
+ for task_offset, old_group_id in enumerate(unique_old_ids_list):
+ task_idx = current_task_idx + task_offset
+ custom_group_id = custom_group_ids_tensor[task_idx].item()
+ # Replace all occurrences of old_group_id with custom_group_id
+ new_group_ids[old_group_ids == old_group_id] = custom_group_id
+
+ batch_group_ids = torch.tensor(new_group_ids, dtype=torch.long)
+ current_task_idx += num_tasks_in_batch
+
+ elif cross_learning:
batch_group_ids = torch.zeros_like(batch_group_ids)
batch_prediction = self._predict_batch(
@@ -813,6 +888,7 @@ def predict_df(
batch_size: int = 256,
context_length: int | None = None,
cross_learning: bool = False,
+ group_ids: dict[str, int] | None = None,
validate_inputs: bool = True,
**predict_kwargs,
) -> "pd.DataFrame":
@@ -852,6 +928,12 @@ def predict_df(
- Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining.
For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
- Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
+ group_ids
+ Optional dictionary mapping series IDs (from id_column) to group IDs.
+ Series with the same group ID will share information during prediction.
+ Cannot be used together with `cross_learning=True`. By default None.
+ Example: {'series_A': 0, 'series_B': 0, 'series_C': 1} means series_A and series_B share info, series_C is separate.
+ If a series ID is not in the dictionary, it will be assigned a unique group ID.
validate_inputs
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
@@ -888,6 +970,48 @@ def predict_df(
validate_inputs=validate_inputs,
)
+ # Convert dictionary group_ids to list format matching inputs order
+ group_ids_list = None
+ if group_ids is not None:
+ # Validate group_ids format
+ if not isinstance(group_ids, dict):
+ raise TypeError(f"`group_ids` must be a dictionary, got {type(group_ids)}")
+
+ if not all(isinstance(k, str) and isinstance(v, int) for k, v in group_ids.items()):
+ raise TypeError("`group_ids` dictionary must have string keys and integer values")
+
+ if any(v < 0 for v in group_ids.values()):
+ raise ValueError("`group_ids` values must be non-negative integers")
+
+ if cross_learning:
+ raise ValueError(
+ "Cannot specify both `group_ids` and `cross_learning=True`. "
+ "Use `group_ids` to define custom groups, or `cross_learning=True` to enable full batch-wide learning."
+ )
+
+ # Warn if series IDs in group_ids don't exist in dataframe
+ series_in_df = set(df[id_column].unique())
+ unknown_series = set(group_ids.keys()) - series_in_df
+ if unknown_series:
+ warnings.warn(
+ f"The following series IDs in `group_ids` were not found in the dataframe: {unknown_series}. "
+ f"They will be ignored.",
+ category=UserWarning,
+ )
+
+ # Create list of group IDs matching the order of inputs
+ # original_order contains the series IDs in the order they appear in inputs
+ group_ids_list = []
+ next_auto_group_id = max(group_ids.values()) + 1 if group_ids else 0
+
+ for series_id in original_order:
+ if series_id in group_ids:
+ group_ids_list.append(group_ids[series_id])
+ else:
+ # Assign unique group ID to series not in the mapping
+ group_ids_list.append(next_auto_group_id)
+ next_auto_group_id += 1
+
# Generate forecasts
quantiles, mean = self.predict_quantiles(
inputs=inputs,
@@ -897,6 +1021,7 @@ def predict_df(
batch_size=batch_size,
context_length=context_length,
cross_learning=cross_learning,
+ group_ids=group_ids_list,
**predict_kwargs,
)
# since predict_df tasks are homogenous by input design, we can safely stack the list of tensors into a single tensor
diff --git a/src/chronos/utils.py b/src/chronos/utils.py
index 1c318167..4e591ccf 100644
--- a/src/chronos/utils.py
+++ b/src/chronos/utils.py
@@ -7,6 +7,12 @@
import torch
from einops import repeat
+try:
+ import pandas as pd
+ _PANDAS_AVAILABLE = True
+except ImportError:
+ _PANDAS_AVAILABLE = False
+
def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
max_len = max(len(c) for c in tensors)
@@ -210,3 +216,228 @@ def weighted_quantile(
# Reshape to original shape
final_shape = (*orig_samples_shape[:-1], len(query_quantile_levels))
return interpolated_quantiles.reshape(final_shape).to(dtype=orig_dtype)
+
+
+def create_group_ids_dict_from_category(
+ df: "pd.DataFrame",
+ id_column: str,
+ category_column: str
+) -> dict[str, int]:
+ """
+ Create group_ids dictionary (for predict_df) from a categorical column.
+
+ This function is specifically designed for use with Chronos2Pipeline.predict_df(),
+ which accepts a dictionary mapping series IDs to group IDs.
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ DataFrame containing id_column and category_column.
+ id_column : str
+ Name of the column containing time series identifiers.
+ category_column : str
+ Name of the categorical column to group by (e.g., "region", "product_type", "industry").
+
+ Returns
+ -------
+ dict[str, int]
+ Dictionary mapping series IDs to group IDs.
+ Series with the same category will have the same group ID.
+
+ Examples
+ --------
+ >>> import pandas as pd
+ >>> df = pd.DataFrame({
+ ... 'item_id': ['A', 'A', 'B', 'B', 'C', 'C'],
+ ... 'region': ['North', 'North', 'North', 'North', 'South', 'South'],
+ ... 'value': [1, 2, 3, 4, 5, 6]
+ ... })
+ >>> create_group_ids_dict_from_category(df, 'item_id', 'region')
+ {'A': 0, 'B': 0, 'C': 1} # A and B are North (group 0), C is South (group 1)
+
+ Notes
+ -----
+ - Use this for Chronos2Pipeline.predict_df() which expects dict[str, int]
+ - The function automatically handles the order of series IDs
+ - Categories are mapped to consecutive integers starting from 0
+ """
+ if not _PANDAS_AVAILABLE:
+ raise ImportError("pandas is required for this function. Please install it with `pip install pandas`.")
+
+ # Get unique series and their category
+ series_categories = df.groupby(id_column, sort=False)[category_column].first()
+
+ # Map categories to group IDs
+ unique_categories = series_categories.unique()
+ category_to_group = {cat: i for i, cat in enumerate(unique_categories)}
+
+ # Create dictionary mapping series_id -> group_id
+ group_ids_dict = {
+ series_id: category_to_group[category]
+ for series_id, category in series_categories.items()
+ }
+
+ return group_ids_dict
+
+
+def create_group_ids_dict_from_mapping(
+ df: "pd.DataFrame",
+ id_column: str,
+ category_to_group_map: dict[str, int]
+) -> dict[str, int]:
+ """
+ Create group_ids dictionary using a custom category-to-group mapping.
+
+ This allows you to manually specify which categories belong to which groups,
+ useful when you want custom grouping logic beyond simple one-to-one category mapping.
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ DataFrame containing id_column and a category column.
+ id_column : str
+ Name of the column containing time series identifiers.
+ category_to_group_map : dict[str, int]
+ Dictionary mapping category names to group IDs.
+ Example: {'Retail': 0, 'Wholesale': 0, 'Food': 1, 'Services': 2}
+
+ Returns
+ -------
+ dict[str, int]
+ Dictionary mapping series IDs to group IDs.
+
+ Examples
+ --------
+ >>> import pandas as pd
+ >>> df = pd.DataFrame({
+ ... 'item_id': ['A', 'A', 'B', 'B', 'C', 'C', 'D', 'D'],
+ ... 'industry': ['Retail', 'Retail', 'Wholesale', 'Wholesale',
+ ... 'Food', 'Food', 'Services', 'Services'],
+ ... 'value': [1, 2, 3, 4, 5, 6, 7, 8]
+ ... })
+ >>> mapping = {'Retail': 0, 'Wholesale': 0, 'Food': 1, 'Services': 2}
+ >>> create_group_ids_dict_from_mapping(df, 'item_id', mapping)
+ {'A': 0, 'B': 0, 'C': 1, 'D': 2}
+
+ Notes
+ -----
+ - This is more flexible than create_group_ids_dict_from_category()
+ - Useful when multiple categories should map to the same group
+ - All categories in the data must be in the mapping, or will raise KeyError
+ """
+ if not _PANDAS_AVAILABLE:
+ raise ImportError("pandas is required for this function. Please install it with `pip install pandas`.")
+
+ # Infer category column by checking which column contains the mapping keys
+ category_column = None
+ for col in df.columns:
+ if col != id_column and df[col].dtype == 'object':
+ if any(cat in category_to_group_map for cat in df[col].unique()):
+ category_column = col
+ break
+
+ if category_column is None:
+ raise ValueError(
+ f"Could not infer category column. Available columns: {df.columns.tolist()}. "
+ f"Make sure one of them contains categories from the mapping: {list(category_to_group_map.keys())}"
+ )
+
+ # Get unique series and their category
+ series_categories = df.groupby(id_column, sort=False)[category_column].first()
+
+ # Create dictionary mapping series_id -> group_id using the custom mapping
+ group_ids_dict = {}
+ for series_id, category in series_categories.items():
+ if category not in category_to_group_map:
+ raise KeyError(
+ f"Category '{category}' for series '{series_id}' not found in mapping. "
+ f"Available categories in mapping: {list(category_to_group_map.keys())}"
+ )
+ group_ids_dict[series_id] = category_to_group_map[category]
+
+ return group_ids_dict
+
+
+def create_manual_group_ids_dict(
+ series_ids: List[str],
+ group_assignments: List[int]
+) -> dict[str, int]:
+ """
+ Create group_ids dictionary from manual series ID and group assignment lists.
+
+ Parameters
+ ----------
+ series_ids : list of str
+ List of series identifiers.
+ group_assignments : list of int
+ List of group IDs corresponding to each series.
+ Must have the same length as series_ids.
+
+ Returns
+ -------
+ dict[str, int]
+ Dictionary mapping series IDs to group IDs.
+
+ Examples
+ --------
+ >>> series_ids = ['store_1', 'store_2', 'store_3', 'store_4']
+ >>> groups = [0, 0, 1, 1] # First two together, last two together
+ >>> create_manual_group_ids_dict(series_ids, groups)
+ {'store_1': 0, 'store_2': 0, 'store_3': 1, 'store_4': 1}
+
+ Raises
+ ------
+ ValueError
+ If series_ids and group_assignments have different lengths.
+ """
+ if len(series_ids) != len(group_assignments):
+ raise ValueError(
+ f"Length mismatch: series_ids has {len(series_ids)} elements, "
+ f"but group_assignments has {len(group_assignments)} elements."
+ )
+
+ return dict(zip(series_ids, group_assignments))
+
+def create_group_ids_from_category(
+ train: pd.DataFrame,
+ id_column: str,
+ category_column: str
+) -> list[int]:
+ """
+ Create group_ids list from a categorical column.
+
+ Parameters
+ ----------
+ train : pd.DataFrame
+ Training data containing id_column and category_column.
+ id_column : str
+ Name of the column containing time series identifiers.
+ category_column : str
+ Name of the categorical column to group by (e.g., "region", "product_type").
+
+ Returns
+ -------
+ list[int]
+ List of group IDs, one per series, matching the order of series_ids.
+
+ Examples
+ --------
+ >>> train = pd.DataFrame({
+ ... 'series_id': ['A', 'A', 'B', 'B', 'C', 'C'],
+ ... 'region': ['North', 'North', 'North', 'North', 'South', 'South'],
+ ... 'value': [1, 2, 3, 4, 5, 6]
+ ... })
+ >>> create_group_ids_from_category(train, 'series_id', 'region')
+ [0, 0, 1] # A and B are North (group 0), C is South (group 1)
+ """
+ # Get unique series and their category
+ series_categories = train.groupby(id_column, sort=False)[category_column].first()
+
+ # Map categories to group IDs
+ unique_categories = series_categories.unique()
+ category_to_group = {cat: i for i, cat in enumerate(unique_categories)}
+
+ # Create group_ids list
+ group_ids = [category_to_group[cat] for cat in series_categories.values]
+
+ return group_ids
diff --git a/test/test_chronos2.py b/test/test_chronos2.py
index 3fac7261..dca9f0a2 100644
--- a/test/test_chronos2.py
+++ b/test/test_chronos2.py
@@ -1132,3 +1132,283 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
# Should match exactly or very close (numerical precision)
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
+
+
+# ============================================================================
+# Tests for custom group_ids functionality
+# ============================================================================
+
+
+@pytest.mark.parametrize("group_ids", [[0, 0, 1], torch.tensor([0, 0, 1])])
+def test_predict_with_custom_group_ids_list_and_tensor(pipeline, group_ids):
+ """Test basic functionality with custom group_ids as both list and tensor."""
+ inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
+ outputs = pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
+
+ assert isinstance(outputs, list) and len(outputs) == 3
+ for out in outputs:
+ validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 24), dtype=torch.float32)
+
+
+def test_predict_with_group_ids_univariate_batch(pipeline):
+ """Test group_ids with homogeneous univariate batch."""
+ inputs = torch.rand(5, 1, 100)
+ group_ids = [0, 0, 1, 1, 2] # First two together, next two together, last one alone
+
+ outputs = pipeline.predict(inputs, prediction_length=12, group_ids=group_ids)
+
+ assert len(outputs) == 5
+ for out in outputs:
+ validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 12), dtype=torch.float32)
+
+
+def test_predict_with_group_ids_multivariate(pipeline):
+ """Test group_ids with multivariate inputs."""
+ inputs = [torch.rand(2, 100), torch.rand(2, 110), torch.rand(2, 90)]
+ group_ids = [0, 0, 1] # First two share info, third is separate
+
+ outputs = pipeline.predict(inputs, prediction_length=16, group_ids=group_ids)
+
+ assert len(outputs) == 3
+ for out in outputs:
+ validate_tensor(out, (2, DEFAULT_MODEL_NUM_QUANTILES, 16), dtype=torch.float32)
+
+
+def test_predict_with_group_ids_and_covariates(pipeline):
+ """Test group_ids with covariates."""
+ prediction_length = 24
+ inputs = [
+ {
+ "target": torch.rand(100),
+ "past_covariates": {"temperature": torch.rand(100)},
+ "future_covariates": {"temperature": torch.rand(prediction_length)},
+ },
+ {
+ "target": torch.rand(110),
+ "past_covariates": {"temperature": torch.rand(110)},
+ "future_covariates": {"temperature": torch.rand(prediction_length)},
+ },
+ {
+ "target": torch.rand(90),
+ "past_covariates": {"temperature": torch.rand(90)},
+ "future_covariates": {"temperature": torch.rand(prediction_length)},
+ },
+ ]
+ group_ids = [0, 0, 1]
+
+ outputs = pipeline.predict(inputs, prediction_length=prediction_length, group_ids=group_ids)
+
+ assert len(outputs) == 3
+ for out in outputs:
+ validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, prediction_length), dtype=torch.float32)
+
+
+def test_predict_df_with_group_ids_dict(pipeline):
+ """Test predict_df with dictionary group_ids."""
+ df = create_df(series_ids=["A", "B", "C"], n_points=[10, 10, 10])
+ group_ids = {"A": 0, "B": 0, "C": 1} # A and B share info, C is separate
+
+ pred_df = pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
+
+ assert isinstance(pred_df, pd.DataFrame)
+ assert len(pred_df) == 15 # 3 series * 5 predictions
+ assert set(pred_df["item_id"].unique()) == {"A", "B", "C"}
+
+
+def test_predict_df_with_partial_group_ids(pipeline):
+ """Test predict_df when only some series have group_ids assigned."""
+ df = create_df(series_ids=["A", "B", "C", "D"], n_points=[10, 10, 10, 10])
+ group_ids = {"A": 0, "B": 0} # Only A and B specified, C and D should get unique IDs
+
+ pred_df = pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
+
+ assert isinstance(pred_df, pd.DataFrame)
+ assert len(pred_df) == 20 # 4 series * 5 predictions
+ assert set(pred_df["item_id"].unique()) == {"A", "B", "C", "D"}
+
+
+def test_predict_df_with_group_ids_and_covariates(pipeline):
+ """Test predict_df with both group_ids and covariates."""
+ df = create_df(series_ids=["A", "B", "C"], n_points=[10, 10, 10], covariates=["temp"])
+ future_df = create_future_df(
+ get_forecast_start_times(df), series_ids=["A", "B", "C"], n_points=[5, 5, 5], covariates=["temp"]
+ )
+ group_ids = {"A": 0, "B": 0, "C": 1}
+
+ pred_df = pipeline.predict_df(df, future_df=future_df, prediction_length=5, group_ids=group_ids)
+
+ assert isinstance(pred_df, pd.DataFrame)
+ assert len(pred_df) == 15
+
+
+def test_group_ids_cross_learning_mutual_exclusion(pipeline):
+ """Test that error is raised when both group_ids and cross_learning are specified."""
+ inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
+ group_ids = [0, 0, 1]
+
+ with pytest.raises(ValueError, match="Cannot specify both `group_ids` and `cross_learning=True`"):
+ pipeline.predict(inputs, prediction_length=24, group_ids=group_ids, cross_learning=True)
+
+
+def test_predict_df_group_ids_cross_learning_mutual_exclusion(pipeline):
+ """Test that predict_df raises error when both group_ids and cross_learning are specified."""
+ df = create_df(series_ids=["A", "B"], n_points=[10, 10])
+ group_ids = {"A": 0, "B": 0}
+
+ with pytest.raises(ValueError, match="Cannot specify both `group_ids` and `cross_learning=True`"):
+ pipeline.predict_df(df, prediction_length=5, group_ids=group_ids, cross_learning=True)
+
+
+def test_group_ids_length_mismatch_raises_error(pipeline):
+ """Test that error is raised when group_ids length doesn't match inputs."""
+ inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
+ group_ids = [0, 0] # Only 2 IDs for 3 inputs
+
+ with pytest.raises(ValueError, match="length .* must match number of tasks"):
+ pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
+
+
+def test_group_ids_negative_values_raises_error(pipeline):
+ """Test that error is raised when group_ids contain negative values."""
+ inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
+ group_ids = [0, -1, 1] # Negative ID not allowed
+
+ with pytest.raises(ValueError, match="must contain only non-negative integers"):
+ pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
+
+
+def test_group_ids_invalid_type_raises_error(pipeline):
+ """Test that error is raised when group_ids is not list or tensor."""
+ inputs = [torch.rand(100), torch.rand(110)]
+ group_ids = "invalid" # String not allowed
+
+ with pytest.raises(TypeError, match="must be a list or torch.Tensor"):
+ pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
+
+
+def test_predict_df_group_ids_invalid_type_raises_error(pipeline):
+ """Test that predict_df raises error when group_ids is not a dict."""
+ df = create_df(series_ids=["A", "B"], n_points=[10, 10])
+ group_ids = [0, 0] # List not allowed for predict_df (needs dict)
+
+ with pytest.raises(TypeError, match="must be a dictionary"):
+ pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
+
+
+def test_predict_df_group_ids_invalid_dict_values_raises_error(pipeline):
+ """Test that predict_df raises error when group_ids dict has negative values."""
+ df = create_df(series_ids=["A", "B"], n_points=[10, 10])
+ group_ids = {"A": 0, "B": -1} # Negative value not allowed
+
+ with pytest.raises(ValueError, match="must be non-negative integers"):
+ pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
+
+
+def test_predict_df_group_ids_warns_unknown_series(pipeline):
+ """Test that predict_df warns when group_ids contains unknown series IDs."""
+ df = create_df(series_ids=["A", "B"], n_points=[10, 10])
+ group_ids = {"A": 0, "B": 0, "X": 1, "Y": 1} # X and Y don't exist
+
+ with pytest.warns(UserWarning, match="not found in the dataframe"):
+ pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
+
+
+# ============================================================================
+# Tests for group_ids helper functions
+# ============================================================================
+
+
+def test_create_group_ids_dict_from_category():
+ """Test create_group_ids_dict_from_category helper function."""
+ from chronos import create_group_ids_dict_from_category
+
+ df = pd.DataFrame(
+ {
+ "item_id": ["A", "A", "B", "B", "C", "C"],
+ "region": ["North", "North", "North", "North", "South", "South"],
+ "value": [1, 2, 3, 4, 5, 6],
+ }
+ )
+
+ result = create_group_ids_dict_from_category(df, "item_id", "region")
+
+ assert isinstance(result, dict)
+ assert result == {"A": 0, "B": 0, "C": 1}
+
+
+def test_create_group_ids_dict_from_mapping():
+ """Test create_group_ids_dict_from_mapping helper function."""
+ from chronos import create_group_ids_dict_from_mapping
+
+ df = pd.DataFrame(
+ {
+ "item_id": ["A", "A", "B", "B", "C", "C", "D", "D"],
+ "industry": ["Retail", "Retail", "Wholesale", "Wholesale", "Food", "Food", "Services", "Services"],
+ "value": [1, 2, 3, 4, 5, 6, 7, 8],
+ }
+ )
+ mapping = {"Retail": 0, "Wholesale": 0, "Food": 1, "Services": 2}
+
+ result = create_group_ids_dict_from_mapping(df, "item_id", mapping)
+
+ assert isinstance(result, dict)
+ assert result == {"A": 0, "B": 0, "C": 1, "D": 2}
+
+
+def test_create_group_ids_dict_from_mapping_missing_category_raises_error():
+ """Test that create_group_ids_dict_from_mapping raises error for unmapped categories."""
+ from chronos import create_group_ids_dict_from_mapping
+
+ df = pd.DataFrame(
+ {
+ "item_id": ["A", "A", "B", "B"],
+ "industry": ["Retail", "Retail", "Tech", "Tech"],
+ "value": [1, 2, 3, 4],
+ }
+ )
+ mapping = {"Retail": 0} # Missing "Tech"
+
+ with pytest.raises(KeyError, match="not found in mapping"):
+ create_group_ids_dict_from_mapping(df, "item_id", mapping)
+
+
+def test_create_manual_group_ids_dict():
+ """Test create_manual_group_ids_dict helper function."""
+ from chronos import create_manual_group_ids_dict
+
+ series_ids = ["store_1", "store_2", "store_3", "store_4"]
+ groups = [0, 0, 1, 1]
+
+ result = create_manual_group_ids_dict(series_ids, groups)
+
+ assert isinstance(result, dict)
+ assert result == {"store_1": 0, "store_2": 0, "store_3": 1, "store_4": 1}
+
+
+def test_create_manual_group_ids_dict_length_mismatch_raises_error():
+ """Test that create_manual_group_ids_dict raises error on length mismatch."""
+ from chronos import create_manual_group_ids_dict
+
+ series_ids = ["A", "B", "C"]
+ groups = [0, 0] # Mismatched length
+
+ with pytest.raises(ValueError, match="Length mismatch"):
+ create_manual_group_ids_dict(series_ids, groups)
+
+
+def test_create_group_ids_from_category():
+ """Test create_group_ids_from_category helper function for list output."""
+ from chronos import create_group_ids_from_category
+
+ df = pd.DataFrame(
+ {
+ "item_id": ["A", "A", "B", "B", "C", "C"],
+ "region": ["North", "North", "North", "North", "South", "South"],
+ "value": [1, 2, 3, 4, 5, 6],
+ }
+ )
+
+ result = create_group_ids_from_category(df, "item_id", "region")
+
+ assert isinstance(result, list)
+ assert result == [0, 0, 1] # A and B are North (0), C is South (1)