diff --git a/notebooks/chronos-2-quickstart.ipynb b/notebooks/chronos-2-quickstart.ipynb index d2489c17..89baed5a 100644 --- a/notebooks/chronos-2-quickstart.ipynb +++ b/notebooks/chronos-2-quickstart.ipynb @@ -1730,12 +1730,268 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "7c899976", "metadata": {}, - "outputs": [], - "source": [] + "cell_type": "markdown", + "source": [ + "## Custom Group IDs: Examples and Use Cases\n", + "\n", + "Custom group IDs let you control how Chronos-2 shares information between series during prediction:\n", + "- Default (no group_ids): each series is predicted independently (no cross-series sharing).\n", + "- cross_learning=True: all series in the batch are jointly predicted and share information.\n", + "- Custom group_ids: only series within the same group share information; groups remain independent.\n", + "\n", + "Use custom group IDs when you know meaningful clusters (e.g., geography, sector, etc.). This can boost accuracy, especially for short or noisy series, while avoiding contamination from unrelated series.\n" + ], + "id": "e504694b8cfab326" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-12-09T13:39:45.846727Z", + "start_time": "2025-12-09T13:39:45.827756Z" + } + }, + "cell_type": "code", + "source": [ + "# Simulate ~30 weather stations with regional clustering (North, South, Coastal)\n", + "import numpy as np, pandas as pd\n", + "np.random.seed(123)\n", + "n = 200\n", + "prediction_length = 24\n", + "ts = pd.date_range('2020-01-01', periods=n, freq='D')\n", + "\n", + "north_ids = [f'station_north_{i+1}' for i in range(10)]\n", + "south_ids = [f'station_south_{i+1}' for i in range(10)]\n", + "coast_ids = [f'station_coast_{i+1}' for i in range(10)]\n", + "all_ids = north_ids + south_ids + coast_ids\n", + "regions = (['North']*len(north_ids) + ['South']*len(south_ids) + ['Coastal']*len(coast_ids))\n", + "\n", + "def synth_series(base, amp, noise, phase=0.0):\n", + " t = np.arange(n)\n", + " signal = base + amp*np.sin(2*np.pi*t/365.0 + phase)\n", + " return (signal + noise*np.random.randn(n)).astype('float32')\n", + "\n", + "data_frames = []\n", + "for sid, region in zip(all_ids, regions):\n", + " if region == 'North':\n", + " y = synth_series(base=5, amp=6, noise=0.8, phase=0.3)\n", + " elif region == 'South':\n", + " y = synth_series(base=18, amp=8, noise=0.8, phase=0.9)\n", + " else: # Coastal\n", + " y = synth_series(base=12, amp=3.5, noise=0.3, phase=0.5)\n", + " df_i = pd.DataFrame({'item_id': sid, 'timestamp': ts, 'target': y, 'region': region})\n", + " data_frames.append(df_i)\n", + "\n", + "weather_df = pd.concat(data_frames, ignore_index=True)\n", + "weather_df.head()" + ], + "id": "326d90e1a05586e8", + "outputs": [ + { + "data": { + "text/plain": [ + " item_id timestamp target region\n", + "0 station_north_1 2020-01-01 5.904617 North\n", + "1 station_north_1 2020-01-02 7.669402 North\n", + "2 station_north_1 2020-01-03 7.195759 North\n", + "3 station_north_1 2020-01-04 5.861607 North\n", + "4 station_north_1 2020-01-05 6.700416 North" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_idtimestamptargetregion
0station_north_12020-01-015.904617North
1station_north_12020-01-027.669402North
2station_north_12020-01-037.195759North
3station_north_12020-01-045.861607North
4station_north_12020-01-056.700416North
\n", + "
" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-12-09T13:39:49.624289Z", + "start_time": "2025-12-09T13:39:49.251047Z" + } + }, + "cell_type": "code", + "source": [ + "# 1) Using predict_df with create_group_ids_dict_from_category (group by 'region')\n", + "from chronos import (\n", + " create_group_ids_dict_from_category,\n", + " create_group_ids_from_category\n", + ")\n", + "\n", + "group_ids_cat = create_group_ids_dict_from_category(\n", + " df=weather_df,\n", + " id_column='item_id',\n", + " category_column='region'\n", + ")\n", + "\n", + "# Split into context (train) and future truth (test)\n", + "test_df = weather_df.groupby('item_id').tail(prediction_length)\n", + "future_df = test_df.drop(columns=['target', 'region']).copy()\n", + "train_df = weather_df.drop(test_df.index).drop(columns=[\"region\"])\n", + "\n", + "pred_df_ids = pipeline.predict_df(\n", + " df=train_df,\n", + " future_df=future_df,\n", + " id_column='item_id',\n", + " timestamp_column='timestamp',\n", + " target='target',\n", + " prediction_length=prediction_length,\n", + " group_ids=group_ids_cat,\n", + " quantile_levels=[0.5],\n", + ")\n", + "\n", + "# Compute MSE/MAE against ground truth\n", + "eval_df = pred_df_ids.merge(\n", + " test_df[['item_id', 'timestamp', 'target']],\n", + " on=['item_id', 'timestamp'], how='inner',\n", + ")\n", + "y_true = eval_df['target'].to_numpy()\n", + "y_pred = eval_df['predictions'].to_numpy()\n", + "mse_ids = float(np.mean((y_pred - y_true) ** 2))\n", + "mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n", + "print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')" + ], + "id": "9113a0f64f257772", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSE=0.523, MAE=0.537\n" + ] + } + ], + "execution_count": 4 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-12-09T13:39:53.070369Z", + "start_time": "2025-12-09T13:39:52.884440Z" + } + }, + "cell_type": "code", + "source": [ + "# 2) Low-level API: predict_quantiles with group_ids list\n", + "\n", + "# Create group IDs as a list aligned with series_ids\n", + "group_ids = create_group_ids_from_category(\n", + "\t train=weather_df,\n", + "\t id_column=\"item_id\",\n", + "\t category_column=\"region\"\n", + "\t)\n", + "\n", + "# Build inputs in the same order as series_ids\n", + "series_ids = train_df['item_id'].unique().tolist()\n", + "inputs_list = [\n", + " {'target': train_df.loc[train_df['item_id']==sid, 'target'].to_numpy(dtype=np.float32) }\n", + " for sid in series_ids\n", + "]\n", + "\n", + "_, mean = pipeline.predict_quantiles(\n", + " inputs=inputs_list,\n", + " prediction_length=prediction_length,\n", + " quantile_levels=[0.5],\n", + " group_ids=group_ids,\n", + ")\n", + "\n", + "pred_df_ids = []\n", + "for id in series_ids:\n", + " test_id_df = test_df[test_df['item_id']==id].copy()\n", + " test_id_df['predictions'] = mean[series_ids.index(id)].numpy().flatten()\n", + " test_id_df.drop(columns=['target'], inplace=True)\n", + " pred_df_ids.append(test_id_df)\n", + "pred_df_ids = pd.concat(pred_df_ids, ignore_index=True)\n", + "\n", + "# Compute MSE/MAE against ground truth\n", + "eval_df = pred_df_ids.merge(\n", + " test_df[['item_id', 'timestamp', 'target']],\n", + " on=['item_id', 'timestamp'], how='inner',\n", + ")\n", + "y_true = eval_df['target'].to_numpy()\n", + "y_pred = eval_df['predictions'].to_numpy()\n", + "mse_ids = float(np.mean((y_pred - y_true) ** 2))\n", + "mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n", + "print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')" + ], + "id": "629bab4c4ea6fe53", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSE=0.523, MAE=0.537\n" + ] + } + ], + "execution_count": 5 } ], "metadata": { diff --git a/src/chronos/__init__.py b/src/chronos/__init__.py index 75d0bfbd..68019988 100644 --- a/src/chronos/__init__.py +++ b/src/chronos/__init__.py @@ -12,6 +12,12 @@ ) from .chronos2 import Chronos2ForecastingConfig, Chronos2Model, Chronos2Pipeline from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline +from .utils import ( + create_group_ids_dict_from_category, + create_group_ids_dict_from_mapping, + create_manual_group_ids_dict, + create_group_ids_from_category +) __all__ = [ "__version__", @@ -27,4 +33,7 @@ "Chronos2ForecastingConfig", "Chronos2Model", "Chronos2Pipeline", + "create_group_ids_dict_from_category", + "create_group_ids_dict_from_mapping", + "create_manual_group_ids_dict", ] diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 3eddcd7f..67b8a0be 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -448,6 +448,7 @@ def predict( batch_size: int = 256, context_length: int | None = None, cross_learning: bool = False, + group_ids: list[int] | torch.Tensor | None = None, limit_prediction_length: bool = False, **kwargs, ) -> list[torch.Tensor]: @@ -540,6 +541,12 @@ def predict( - Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining. For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report). - Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch. + group_ids + Optional custom group IDs to control information sharing between time series. + If provided, must be a list or tensor of integers with length equal to the number of tasks in `inputs`. + Tasks with the same group ID will share information during prediction via cross-attention. + Cannot be used together with `cross_learning=True`. By default None (each task gets unique group ID). + Example: [0, 0, 1] means first two tasks share information, third is separate. limit_prediction_length If True, an error is raised when prediction_length is greater than model's default prediction length, by default False @@ -561,6 +568,40 @@ def predict( stacklevel=2, ) cross_learning = kwargs.pop("predict_batches_jointly") + + # Validate group_ids and cross_learning interaction + if group_ids is not None and cross_learning: + raise ValueError( + "Cannot specify both `group_ids` and `cross_learning=True`. " + "Use `group_ids` to define custom groups, or `cross_learning=True` to enable full batch-wide learning." + ) + + # Convert group_ids to tensor if provided + custom_group_ids_tensor = None + if group_ids is not None: + if isinstance(group_ids, list): + # Strict type check: only integers are allowed in the list + if not all(isinstance(x, (int, np.integer)) for x in group_ids): + raise TypeError("`group_ids` list must contain only integers") + if any(x < 0 for x in group_ids): + raise ValueError("`group_ids` must contain only non-negative integers") + custom_group_ids_tensor = torch.tensor(group_ids, dtype=torch.long) + elif isinstance(group_ids, torch.Tensor): + # Enforce integer dtype for tensor inputs + if torch.is_floating_point(group_ids) or group_ids.dtype == torch.bool: + raise TypeError("`group_ids` tensor must have an integer dtype") + if (group_ids < 0).any(): + raise ValueError("`group_ids` must contain only non-negative integers") + custom_group_ids_tensor = group_ids.to(dtype=torch.long).clone() + else: + raise TypeError(f"`group_ids` must be a list or torch.Tensor, got {type(group_ids)}") + + # Validate length matches number of tasks + if len(custom_group_ids_tensor) != len(inputs): + raise ValueError( + f"`group_ids` length ({len(custom_group_ids_tensor)}) must match number of tasks in inputs ({len(inputs)})" + ) + # The maximum number of output patches to generate in a single forward pass before the long-horizon heuristic kicks in. Note: A value larger # than the model's default max_output_patches may lead to degradation in forecast accuracy, defaults to a model-specific value max_output_patches = kwargs.pop("max_output_patches", self.max_output_patches) @@ -613,6 +654,9 @@ def predict( ) all_predictions: list[torch.Tensor] = [] + # Track the current task index for custom group ID mapping + current_task_idx = 0 + for batch in test_loader: assert batch["future_target"] is None batch_context = batch["context"] @@ -620,7 +664,38 @@ def predict( batch_future_covariates = batch["future_covariates"] batch_target_idx_ranges = batch["target_idx_ranges"] - if cross_learning: + # Apply custom group IDs if provided + if custom_group_ids_tensor is not None: + # Determine how many tasks are in this batch + num_tasks_in_batch = len(batch_target_idx_ranges) + + # The key insight: batch_group_ids already maps variates to tasks + # We just need to replace the task IDs with our custom ones + + # Create a mapping from old task IDs to new task IDs + old_group_ids = batch_group_ids.cpu().numpy() + # Preserve first-appearance order of group IDs (robust to non-consecutive IDs) + seen = set() + unique_old_ids_list: list[int] = [] + for gid in old_group_ids.tolist(): + if gid not in seen: + seen.add(int(gid)) + unique_old_ids_list.append(int(gid)) + + # Map old group IDs to task indices (0, 1, 2, ..., num_tasks_in_batch-1) + # Then map those to custom group IDs + new_group_ids = old_group_ids.copy() + + for task_offset, old_group_id in enumerate(unique_old_ids_list): + task_idx = current_task_idx + task_offset + custom_group_id = custom_group_ids_tensor[task_idx].item() + # Replace all occurrences of old_group_id with custom_group_id + new_group_ids[old_group_ids == old_group_id] = custom_group_id + + batch_group_ids = torch.tensor(new_group_ids, dtype=torch.long) + current_task_idx += num_tasks_in_batch + + elif cross_learning: batch_group_ids = torch.zeros_like(batch_group_ids) batch_prediction = self._predict_batch( @@ -813,6 +888,7 @@ def predict_df( batch_size: int = 256, context_length: int | None = None, cross_learning: bool = False, + group_ids: dict[str, int] | None = None, validate_inputs: bool = True, **predict_kwargs, ) -> "pd.DataFrame": @@ -852,6 +928,12 @@ def predict_df( - Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining. For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report). - Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch. + group_ids + Optional dictionary mapping series IDs (from id_column) to group IDs. + Series with the same group ID will share information during prediction. + Cannot be used together with `cross_learning=True`. By default None. + Example: {'series_A': 0, 'series_B': 0, 'series_C': 1} means series_A and series_B share info, series_C is separate. + If a series ID is not in the dictionary, it will be assigned a unique group ID. validate_inputs When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a regular frequency, and item IDs match between past and future data. Setting to False disables these checks. @@ -888,6 +970,48 @@ def predict_df( validate_inputs=validate_inputs, ) + # Convert dictionary group_ids to list format matching inputs order + group_ids_list = None + if group_ids is not None: + # Validate group_ids format + if not isinstance(group_ids, dict): + raise TypeError(f"`group_ids` must be a dictionary, got {type(group_ids)}") + + if not all(isinstance(k, str) and isinstance(v, int) for k, v in group_ids.items()): + raise TypeError("`group_ids` dictionary must have string keys and integer values") + + if any(v < 0 for v in group_ids.values()): + raise ValueError("`group_ids` values must be non-negative integers") + + if cross_learning: + raise ValueError( + "Cannot specify both `group_ids` and `cross_learning=True`. " + "Use `group_ids` to define custom groups, or `cross_learning=True` to enable full batch-wide learning." + ) + + # Warn if series IDs in group_ids don't exist in dataframe + series_in_df = set(df[id_column].unique()) + unknown_series = set(group_ids.keys()) - series_in_df + if unknown_series: + warnings.warn( + f"The following series IDs in `group_ids` were not found in the dataframe: {unknown_series}. " + f"They will be ignored.", + category=UserWarning, + ) + + # Create list of group IDs matching the order of inputs + # original_order contains the series IDs in the order they appear in inputs + group_ids_list = [] + next_auto_group_id = max(group_ids.values()) + 1 if group_ids else 0 + + for series_id in original_order: + if series_id in group_ids: + group_ids_list.append(group_ids[series_id]) + else: + # Assign unique group ID to series not in the mapping + group_ids_list.append(next_auto_group_id) + next_auto_group_id += 1 + # Generate forecasts quantiles, mean = self.predict_quantiles( inputs=inputs, @@ -897,6 +1021,7 @@ def predict_df( batch_size=batch_size, context_length=context_length, cross_learning=cross_learning, + group_ids=group_ids_list, **predict_kwargs, ) # since predict_df tasks are homogenous by input design, we can safely stack the list of tensors into a single tensor diff --git a/src/chronos/utils.py b/src/chronos/utils.py index 1c318167..4e591ccf 100644 --- a/src/chronos/utils.py +++ b/src/chronos/utils.py @@ -7,6 +7,12 @@ import torch from einops import repeat +try: + import pandas as pd + _PANDAS_AVAILABLE = True +except ImportError: + _PANDAS_AVAILABLE = False + def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: max_len = max(len(c) for c in tensors) @@ -210,3 +216,228 @@ def weighted_quantile( # Reshape to original shape final_shape = (*orig_samples_shape[:-1], len(query_quantile_levels)) return interpolated_quantiles.reshape(final_shape).to(dtype=orig_dtype) + + +def create_group_ids_dict_from_category( + df: "pd.DataFrame", + id_column: str, + category_column: str +) -> dict[str, int]: + """ + Create group_ids dictionary (for predict_df) from a categorical column. + + This function is specifically designed for use with Chronos2Pipeline.predict_df(), + which accepts a dictionary mapping series IDs to group IDs. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing id_column and category_column. + id_column : str + Name of the column containing time series identifiers. + category_column : str + Name of the categorical column to group by (e.g., "region", "product_type", "industry"). + + Returns + ------- + dict[str, int] + Dictionary mapping series IDs to group IDs. + Series with the same category will have the same group ID. + + Examples + -------- + >>> import pandas as pd + >>> df = pd.DataFrame({ + ... 'item_id': ['A', 'A', 'B', 'B', 'C', 'C'], + ... 'region': ['North', 'North', 'North', 'North', 'South', 'South'], + ... 'value': [1, 2, 3, 4, 5, 6] + ... }) + >>> create_group_ids_dict_from_category(df, 'item_id', 'region') + {'A': 0, 'B': 0, 'C': 1} # A and B are North (group 0), C is South (group 1) + + Notes + ----- + - Use this for Chronos2Pipeline.predict_df() which expects dict[str, int] + - The function automatically handles the order of series IDs + - Categories are mapped to consecutive integers starting from 0 + """ + if not _PANDAS_AVAILABLE: + raise ImportError("pandas is required for this function. Please install it with `pip install pandas`.") + + # Get unique series and their category + series_categories = df.groupby(id_column, sort=False)[category_column].first() + + # Map categories to group IDs + unique_categories = series_categories.unique() + category_to_group = {cat: i for i, cat in enumerate(unique_categories)} + + # Create dictionary mapping series_id -> group_id + group_ids_dict = { + series_id: category_to_group[category] + for series_id, category in series_categories.items() + } + + return group_ids_dict + + +def create_group_ids_dict_from_mapping( + df: "pd.DataFrame", + id_column: str, + category_to_group_map: dict[str, int] +) -> dict[str, int]: + """ + Create group_ids dictionary using a custom category-to-group mapping. + + This allows you to manually specify which categories belong to which groups, + useful when you want custom grouping logic beyond simple one-to-one category mapping. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing id_column and a category column. + id_column : str + Name of the column containing time series identifiers. + category_to_group_map : dict[str, int] + Dictionary mapping category names to group IDs. + Example: {'Retail': 0, 'Wholesale': 0, 'Food': 1, 'Services': 2} + + Returns + ------- + dict[str, int] + Dictionary mapping series IDs to group IDs. + + Examples + -------- + >>> import pandas as pd + >>> df = pd.DataFrame({ + ... 'item_id': ['A', 'A', 'B', 'B', 'C', 'C', 'D', 'D'], + ... 'industry': ['Retail', 'Retail', 'Wholesale', 'Wholesale', + ... 'Food', 'Food', 'Services', 'Services'], + ... 'value': [1, 2, 3, 4, 5, 6, 7, 8] + ... }) + >>> mapping = {'Retail': 0, 'Wholesale': 0, 'Food': 1, 'Services': 2} + >>> create_group_ids_dict_from_mapping(df, 'item_id', mapping) + {'A': 0, 'B': 0, 'C': 1, 'D': 2} + + Notes + ----- + - This is more flexible than create_group_ids_dict_from_category() + - Useful when multiple categories should map to the same group + - All categories in the data must be in the mapping, or will raise KeyError + """ + if not _PANDAS_AVAILABLE: + raise ImportError("pandas is required for this function. Please install it with `pip install pandas`.") + + # Infer category column by checking which column contains the mapping keys + category_column = None + for col in df.columns: + if col != id_column and df[col].dtype == 'object': + if any(cat in category_to_group_map for cat in df[col].unique()): + category_column = col + break + + if category_column is None: + raise ValueError( + f"Could not infer category column. Available columns: {df.columns.tolist()}. " + f"Make sure one of them contains categories from the mapping: {list(category_to_group_map.keys())}" + ) + + # Get unique series and their category + series_categories = df.groupby(id_column, sort=False)[category_column].first() + + # Create dictionary mapping series_id -> group_id using the custom mapping + group_ids_dict = {} + for series_id, category in series_categories.items(): + if category not in category_to_group_map: + raise KeyError( + f"Category '{category}' for series '{series_id}' not found in mapping. " + f"Available categories in mapping: {list(category_to_group_map.keys())}" + ) + group_ids_dict[series_id] = category_to_group_map[category] + + return group_ids_dict + + +def create_manual_group_ids_dict( + series_ids: List[str], + group_assignments: List[int] +) -> dict[str, int]: + """ + Create group_ids dictionary from manual series ID and group assignment lists. + + Parameters + ---------- + series_ids : list of str + List of series identifiers. + group_assignments : list of int + List of group IDs corresponding to each series. + Must have the same length as series_ids. + + Returns + ------- + dict[str, int] + Dictionary mapping series IDs to group IDs. + + Examples + -------- + >>> series_ids = ['store_1', 'store_2', 'store_3', 'store_4'] + >>> groups = [0, 0, 1, 1] # First two together, last two together + >>> create_manual_group_ids_dict(series_ids, groups) + {'store_1': 0, 'store_2': 0, 'store_3': 1, 'store_4': 1} + + Raises + ------ + ValueError + If series_ids and group_assignments have different lengths. + """ + if len(series_ids) != len(group_assignments): + raise ValueError( + f"Length mismatch: series_ids has {len(series_ids)} elements, " + f"but group_assignments has {len(group_assignments)} elements." + ) + + return dict(zip(series_ids, group_assignments)) + +def create_group_ids_from_category( + train: pd.DataFrame, + id_column: str, + category_column: str +) -> list[int]: + """ + Create group_ids list from a categorical column. + + Parameters + ---------- + train : pd.DataFrame + Training data containing id_column and category_column. + id_column : str + Name of the column containing time series identifiers. + category_column : str + Name of the categorical column to group by (e.g., "region", "product_type"). + + Returns + ------- + list[int] + List of group IDs, one per series, matching the order of series_ids. + + Examples + -------- + >>> train = pd.DataFrame({ + ... 'series_id': ['A', 'A', 'B', 'B', 'C', 'C'], + ... 'region': ['North', 'North', 'North', 'North', 'South', 'South'], + ... 'value': [1, 2, 3, 4, 5, 6] + ... }) + >>> create_group_ids_from_category(train, 'series_id', 'region') + [0, 0, 1] # A and B are North (group 0), C is South (group 1) + """ + # Get unique series and their category + series_categories = train.groupby(id_column, sort=False)[category_column].first() + + # Map categories to group IDs + unique_categories = series_categories.unique() + category_to_group = {cat: i for i, cat in enumerate(unique_categories)} + + # Create group_ids list + group_ids = [category_to_group[cat] for cat in series_categories.values] + + return group_ids diff --git a/test/test_chronos2.py b/test/test_chronos2.py index 3fac7261..dca9f0a2 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -1132,3 +1132,283 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline): for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped): # Should match exactly or very close (numerical precision) assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) + + +# ============================================================================ +# Tests for custom group_ids functionality +# ============================================================================ + + +@pytest.mark.parametrize("group_ids", [[0, 0, 1], torch.tensor([0, 0, 1])]) +def test_predict_with_custom_group_ids_list_and_tensor(pipeline, group_ids): + """Test basic functionality with custom group_ids as both list and tensor.""" + inputs = [torch.rand(100), torch.rand(110), torch.rand(120)] + outputs = pipeline.predict(inputs, prediction_length=24, group_ids=group_ids) + + assert isinstance(outputs, list) and len(outputs) == 3 + for out in outputs: + validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 24), dtype=torch.float32) + + +def test_predict_with_group_ids_univariate_batch(pipeline): + """Test group_ids with homogeneous univariate batch.""" + inputs = torch.rand(5, 1, 100) + group_ids = [0, 0, 1, 1, 2] # First two together, next two together, last one alone + + outputs = pipeline.predict(inputs, prediction_length=12, group_ids=group_ids) + + assert len(outputs) == 5 + for out in outputs: + validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 12), dtype=torch.float32) + + +def test_predict_with_group_ids_multivariate(pipeline): + """Test group_ids with multivariate inputs.""" + inputs = [torch.rand(2, 100), torch.rand(2, 110), torch.rand(2, 90)] + group_ids = [0, 0, 1] # First two share info, third is separate + + outputs = pipeline.predict(inputs, prediction_length=16, group_ids=group_ids) + + assert len(outputs) == 3 + for out in outputs: + validate_tensor(out, (2, DEFAULT_MODEL_NUM_QUANTILES, 16), dtype=torch.float32) + + +def test_predict_with_group_ids_and_covariates(pipeline): + """Test group_ids with covariates.""" + prediction_length = 24 + inputs = [ + { + "target": torch.rand(100), + "past_covariates": {"temperature": torch.rand(100)}, + "future_covariates": {"temperature": torch.rand(prediction_length)}, + }, + { + "target": torch.rand(110), + "past_covariates": {"temperature": torch.rand(110)}, + "future_covariates": {"temperature": torch.rand(prediction_length)}, + }, + { + "target": torch.rand(90), + "past_covariates": {"temperature": torch.rand(90)}, + "future_covariates": {"temperature": torch.rand(prediction_length)}, + }, + ] + group_ids = [0, 0, 1] + + outputs = pipeline.predict(inputs, prediction_length=prediction_length, group_ids=group_ids) + + assert len(outputs) == 3 + for out in outputs: + validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, prediction_length), dtype=torch.float32) + + +def test_predict_df_with_group_ids_dict(pipeline): + """Test predict_df with dictionary group_ids.""" + df = create_df(series_ids=["A", "B", "C"], n_points=[10, 10, 10]) + group_ids = {"A": 0, "B": 0, "C": 1} # A and B share info, C is separate + + pred_df = pipeline.predict_df(df, prediction_length=5, group_ids=group_ids) + + assert isinstance(pred_df, pd.DataFrame) + assert len(pred_df) == 15 # 3 series * 5 predictions + assert set(pred_df["item_id"].unique()) == {"A", "B", "C"} + + +def test_predict_df_with_partial_group_ids(pipeline): + """Test predict_df when only some series have group_ids assigned.""" + df = create_df(series_ids=["A", "B", "C", "D"], n_points=[10, 10, 10, 10]) + group_ids = {"A": 0, "B": 0} # Only A and B specified, C and D should get unique IDs + + pred_df = pipeline.predict_df(df, prediction_length=5, group_ids=group_ids) + + assert isinstance(pred_df, pd.DataFrame) + assert len(pred_df) == 20 # 4 series * 5 predictions + assert set(pred_df["item_id"].unique()) == {"A", "B", "C", "D"} + + +def test_predict_df_with_group_ids_and_covariates(pipeline): + """Test predict_df with both group_ids and covariates.""" + df = create_df(series_ids=["A", "B", "C"], n_points=[10, 10, 10], covariates=["temp"]) + future_df = create_future_df( + get_forecast_start_times(df), series_ids=["A", "B", "C"], n_points=[5, 5, 5], covariates=["temp"] + ) + group_ids = {"A": 0, "B": 0, "C": 1} + + pred_df = pipeline.predict_df(df, future_df=future_df, prediction_length=5, group_ids=group_ids) + + assert isinstance(pred_df, pd.DataFrame) + assert len(pred_df) == 15 + + +def test_group_ids_cross_learning_mutual_exclusion(pipeline): + """Test that error is raised when both group_ids and cross_learning are specified.""" + inputs = [torch.rand(100), torch.rand(110), torch.rand(120)] + group_ids = [0, 0, 1] + + with pytest.raises(ValueError, match="Cannot specify both `group_ids` and `cross_learning=True`"): + pipeline.predict(inputs, prediction_length=24, group_ids=group_ids, cross_learning=True) + + +def test_predict_df_group_ids_cross_learning_mutual_exclusion(pipeline): + """Test that predict_df raises error when both group_ids and cross_learning are specified.""" + df = create_df(series_ids=["A", "B"], n_points=[10, 10]) + group_ids = {"A": 0, "B": 0} + + with pytest.raises(ValueError, match="Cannot specify both `group_ids` and `cross_learning=True`"): + pipeline.predict_df(df, prediction_length=5, group_ids=group_ids, cross_learning=True) + + +def test_group_ids_length_mismatch_raises_error(pipeline): + """Test that error is raised when group_ids length doesn't match inputs.""" + inputs = [torch.rand(100), torch.rand(110), torch.rand(120)] + group_ids = [0, 0] # Only 2 IDs for 3 inputs + + with pytest.raises(ValueError, match="length .* must match number of tasks"): + pipeline.predict(inputs, prediction_length=24, group_ids=group_ids) + + +def test_group_ids_negative_values_raises_error(pipeline): + """Test that error is raised when group_ids contain negative values.""" + inputs = [torch.rand(100), torch.rand(110), torch.rand(120)] + group_ids = [0, -1, 1] # Negative ID not allowed + + with pytest.raises(ValueError, match="must contain only non-negative integers"): + pipeline.predict(inputs, prediction_length=24, group_ids=group_ids) + + +def test_group_ids_invalid_type_raises_error(pipeline): + """Test that error is raised when group_ids is not list or tensor.""" + inputs = [torch.rand(100), torch.rand(110)] + group_ids = "invalid" # String not allowed + + with pytest.raises(TypeError, match="must be a list or torch.Tensor"): + pipeline.predict(inputs, prediction_length=24, group_ids=group_ids) + + +def test_predict_df_group_ids_invalid_type_raises_error(pipeline): + """Test that predict_df raises error when group_ids is not a dict.""" + df = create_df(series_ids=["A", "B"], n_points=[10, 10]) + group_ids = [0, 0] # List not allowed for predict_df (needs dict) + + with pytest.raises(TypeError, match="must be a dictionary"): + pipeline.predict_df(df, prediction_length=5, group_ids=group_ids) + + +def test_predict_df_group_ids_invalid_dict_values_raises_error(pipeline): + """Test that predict_df raises error when group_ids dict has negative values.""" + df = create_df(series_ids=["A", "B"], n_points=[10, 10]) + group_ids = {"A": 0, "B": -1} # Negative value not allowed + + with pytest.raises(ValueError, match="must be non-negative integers"): + pipeline.predict_df(df, prediction_length=5, group_ids=group_ids) + + +def test_predict_df_group_ids_warns_unknown_series(pipeline): + """Test that predict_df warns when group_ids contains unknown series IDs.""" + df = create_df(series_ids=["A", "B"], n_points=[10, 10]) + group_ids = {"A": 0, "B": 0, "X": 1, "Y": 1} # X and Y don't exist + + with pytest.warns(UserWarning, match="not found in the dataframe"): + pipeline.predict_df(df, prediction_length=5, group_ids=group_ids) + + +# ============================================================================ +# Tests for group_ids helper functions +# ============================================================================ + + +def test_create_group_ids_dict_from_category(): + """Test create_group_ids_dict_from_category helper function.""" + from chronos import create_group_ids_dict_from_category + + df = pd.DataFrame( + { + "item_id": ["A", "A", "B", "B", "C", "C"], + "region": ["North", "North", "North", "North", "South", "South"], + "value": [1, 2, 3, 4, 5, 6], + } + ) + + result = create_group_ids_dict_from_category(df, "item_id", "region") + + assert isinstance(result, dict) + assert result == {"A": 0, "B": 0, "C": 1} + + +def test_create_group_ids_dict_from_mapping(): + """Test create_group_ids_dict_from_mapping helper function.""" + from chronos import create_group_ids_dict_from_mapping + + df = pd.DataFrame( + { + "item_id": ["A", "A", "B", "B", "C", "C", "D", "D"], + "industry": ["Retail", "Retail", "Wholesale", "Wholesale", "Food", "Food", "Services", "Services"], + "value": [1, 2, 3, 4, 5, 6, 7, 8], + } + ) + mapping = {"Retail": 0, "Wholesale": 0, "Food": 1, "Services": 2} + + result = create_group_ids_dict_from_mapping(df, "item_id", mapping) + + assert isinstance(result, dict) + assert result == {"A": 0, "B": 0, "C": 1, "D": 2} + + +def test_create_group_ids_dict_from_mapping_missing_category_raises_error(): + """Test that create_group_ids_dict_from_mapping raises error for unmapped categories.""" + from chronos import create_group_ids_dict_from_mapping + + df = pd.DataFrame( + { + "item_id": ["A", "A", "B", "B"], + "industry": ["Retail", "Retail", "Tech", "Tech"], + "value": [1, 2, 3, 4], + } + ) + mapping = {"Retail": 0} # Missing "Tech" + + with pytest.raises(KeyError, match="not found in mapping"): + create_group_ids_dict_from_mapping(df, "item_id", mapping) + + +def test_create_manual_group_ids_dict(): + """Test create_manual_group_ids_dict helper function.""" + from chronos import create_manual_group_ids_dict + + series_ids = ["store_1", "store_2", "store_3", "store_4"] + groups = [0, 0, 1, 1] + + result = create_manual_group_ids_dict(series_ids, groups) + + assert isinstance(result, dict) + assert result == {"store_1": 0, "store_2": 0, "store_3": 1, "store_4": 1} + + +def test_create_manual_group_ids_dict_length_mismatch_raises_error(): + """Test that create_manual_group_ids_dict raises error on length mismatch.""" + from chronos import create_manual_group_ids_dict + + series_ids = ["A", "B", "C"] + groups = [0, 0] # Mismatched length + + with pytest.raises(ValueError, match="Length mismatch"): + create_manual_group_ids_dict(series_ids, groups) + + +def test_create_group_ids_from_category(): + """Test create_group_ids_from_category helper function for list output.""" + from chronos import create_group_ids_from_category + + df = pd.DataFrame( + { + "item_id": ["A", "A", "B", "B", "C", "C"], + "region": ["North", "North", "North", "North", "South", "South"], + "value": [1, 2, 3, 4, 5, 6], + } + ) + + result = create_group_ids_from_category(df, "item_id", "region") + + assert isinstance(result, list) + assert result == [0, 0, 1] # A and B are North (0), C is South (1)