-
Notifications
You must be signed in to change notification settings - Fork 536
Add custom group_ids support to Chronos2Pipeline #429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@StatMixedML Thanks for the PR. We are currently working towards AutoGluon 1.5, so I will take a careful look after the release. Before that, one small feedback I have after a quick skim is that a 900 line PR sounds a bit too much to enable this capability. Could you please check if the size of the PR can be reduced? |
| batch_future_covariates = batch["future_covariates"] | ||
| batch_target_idx_ranges = batch["target_idx_ranges"] | ||
|
|
||
| if cross_learning: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the user instead just run something like the following code?
prediction_per_group = []
for _, group_df in df.groupby("group_id"):
prediction_per_group.append(pipeline.predict(group_df.drop(columns=["group_id"], cross_learning=True, ...))
predictions = pd.concat(prediction_per_group)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shchur I suppose your suggestion is an elegant way of doing it :-) It works the same way the PR suggests
Most of the additions come from the notebook examples and the unit tests, so the actual changes in |
|
@StatMixedML Thanks! In that case, maybe users can just go with the idea that @shchur suggested. That said, it would be cool to cover non-trivial grouping in the "advanced" section of the tutorials. Do you happen to have good (preferably not synthetic) examples of such grouping helping accuracy? |
Shall I then close the PR?
I can keep the PR open and add some examples to the notebook using publicly available data? Sth. like M5 or some monthly seasonal data with geographic grouping? |
Summary
This PR adds support for custom group IDs in Chronos2Pipeline, enabling fine-grained control over which time series share information during prediction through cross-attention. Users can now specify meaningful groupings (e.g., by geography, sector, ...) to improve forecast accuracy while preventing information leakage between unrelated series.
Motivation
Currently, users can either:
This PR adds a middle ground: selective information sharing where only series within the same group exchange information, while different groups remain independent.
Changes
Core API Changes
Added group_ids parameter to predict_df() and predict_quantiles()
Added helper functions in src/chronos/utils.py
Backward Compatibility
✅ Fully backward compatible
Documentation