Skip to content

chronos-2: Address Batch Size Scalability Issue in Group Attention for Faster Training/Inference#442

Open
li-jinpeng wants to merge 2 commits intoamazon-science:mainfrom
li-jinpeng:main
Open

chronos-2: Address Batch Size Scalability Issue in Group Attention for Faster Training/Inference#442
li-jinpeng wants to merge 2 commits intoamazon-science:mainfrom
li-jinpeng:main

Conversation

@li-jinpeng
Copy link

Background

Chronos-2 introduces group attention, enabling the foundational time series model to handle multivariate modeling. However, as indicated in the code at here, the computational complexity of group attention is proportional to the square of the batch size. This means that as the batch size increases, the training and inference efficiency of the model can significantly decrease.

Optimization

To address this, this PR optimizes the computation of group attention. Consider an input with group_ids = [0, 0, 1, 1, 1, 1, 2, 2, 2, 3]. The original group attention matrix would be structured as shown in the following diagram, where a substantial amount of computation is useless.
Clipboard_Screenshot_1766415155
The optimization strategy involves decomposing this large matrix into four smaller group attention matrices, as illustrated in the second diagram. This approach aims to minimize useless computations.
Clipboard_Screenshot_1766415248

Usage

This optimization can be enabled by setting the environment variable CHRONOS2_USE_FAST_GROUP_ATTENTION=1.

Experimentation & Results

Furthermore, experiments were conducted on a single NVIDIA H20 GPU using the following test code:

import os
import time
import torch
import chronos.chronos2
from transformers import AutoConfig
from chronos.chronos2 import Chronos2Model


def load_model(model_path, device):
    """
    Load Chronos2 model from pretrained weights.
    
    Args:
        model_path (str): Path to the pretrained model
        device (str): Device to load model on ('cuda' or 'cpu')
    
    Returns:
        Chronos2Model: Loaded model instance
    """
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    assert hasattr(config, "chronos_config"), "Not a Chronos config file"
    
    # Dynamically get model class based on architecture
    architecture = config.architectures[0]
    model_class = getattr(chronos.chronos2, architecture, Chronos2Model)
    
    if model_class is None:
        print(f"Unknown architecture: {architecture}, defaulting to Chronos2Model")
        model_class = Chronos2Model
    
    model = model_class.from_pretrained(model_path, trust_remote_code=True)
    model.to(device)
    model.eval()
    
    return model, config


def calculate_num_output_patches(forecast_horizon, config):
    """
    Calculate the number of output patches based on forecast horizon and patch size.
    
    Args:
        forecast_horizon (int): Number of time steps to forecast
        config: Model configuration object
    
    Returns:
        int: Number of output patches
    """
    patch_size = config.chronos_config["output_patch_size"]
    
    if forecast_horizon % patch_size == 0:
        return forecast_horizon // patch_size
    else:
        return forecast_horizon // patch_size + 1


def run_warmup_and_validation(model, config, device, seq_len, num_features, forecast_horizon):
    """
    Run warmup inference and validate consistency between baseline and fast implementations.
    
    Args:
        model: The Chronos2 model
        config: Model configuration
        device: Device to run on
        seq_len (int): Input sequence length
        num_features (int): Number of input features
        forecast_horizon (int): Forecast horizon
    """
    num_output_patches = calculate_num_output_patches(forecast_horizon, config)
    batch_size = 50
    
    # Create test data
    input_data = torch.randn(batch_size, seq_len, num_features).to(device)
    group_ids = torch.arange(batch_size).repeat_interleave(num_features).to(device)
    
    # Reshape input for Chronos2 (batch_size * num_features, seq_len)
    reshaped_input = input_data.permute(0, 2, 1).reshape(batch_size * num_features, seq_len)
    
    with torch.no_grad():
        # Test baseline implementation
        os.environ["CHRONOS2_USE_FAST_GROUP_ATTENTION"] = "0"
        baseline_output = model(
            context=reshaped_input,
            group_ids=group_ids,
            num_output_patches=num_output_patches,
            future_target=None,
        )
        
        # Test fast implementation
        os.environ["CHRONOS2_USE_FAST_GROUP_ATTENTION"] = "1"
        fast_output = model(
            context=reshaped_input,
            group_ids=group_ids,
            num_output_patches=num_output_patches,
            future_target=None,
        )
        
        # Validate numerical equivalence
        torch.testing.assert_close(baseline_output.quantile_preds, fast_output.quantile_preds)
        print("✓ Baseline and fast implementations produce identical results")


def benchmark_inference(model, config, device, seq_len, num_features, forecast_horizon, 
                       batch_sizes, num_inferences=5, use_fast_attention=False):
    """
    Benchmark inference speed for different batch sizes.
    
    Args:
        model: The Chronos2 model
        config: Model configuration
        device: Device to run on
        seq_len (int): Input sequence length
        num_features (int): Number of input features
        forecast_horizon (int): Forecast horizon
        batch_sizes (list): List of batch sizes to test
        num_inferences (int): Number of inference runs per batch size
        use_fast_attention (bool): Whether to use fast group attention
    
    Returns:
        dict: Inference times per sample for each batch size
    """
    # Set environment variable for fast attention
    os.environ["CHRONOS2_USE_FAST_GROUP_ATTENTION"] = "1" if use_fast_attention else "0"
    mode = "Fast" if use_fast_attention else "Baseline"
    
    print(f"\n=== {mode} Implementation Benchmark ===")
    
    num_output_patches = calculate_num_output_patches(forecast_horizon, config)
    results = {}
    
    for batch_size in batch_sizes:
        total_time = 0
        
        for i in range(num_inferences):
            # Prepare input data
            input_data = torch.randn(batch_size, seq_len, num_features).to(device)
            group_ids = torch.arange(batch_size).repeat_interleave(num_features).to(device)
            reshaped_input = input_data.permute(0, 2, 1).reshape(batch_size * num_features, seq_len)
            
            # Time inference
            torch.cuda.synchronize() if device == "cuda" else None
            start_time = time.time()
            
            with torch.no_grad():
                _ = model(
                    context=reshaped_input,
                    group_ids=group_ids,
                    num_output_patches=num_output_patches,
                    future_target=None,
                )
            
            torch.cuda.synchronize() if device == "cuda" else None
            end_time = time.time()
            total_time += end_time - start_time
        
        # Calculate average time per sample
        avg_time_per_sample = total_time / (num_inferences * batch_size)
        results[batch_size] = avg_time_per_sample
        print(f"Batch size {batch_size:3d}: {avg_time_per_sample:.4f} s/sample")
    
    return results


def main():
    """Main function to run the inference speed test."""
    # Configuration
    PRETRAINED_MODEL_PATH = (
        "/path/to/chronos-2"
    )
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Inference parameters
    SEQ_LENGTH = 1440
    NUM_FEATURES = 20
    FORECAST_HORIZON = 96
    BATCH_SIZES = [20, 40, 60, 80, 100]
    NUM_INFERENCES = 3
    
    print("Chronos2 Inference Speed Test")
    print("=" * 40)
    print(f"Device: {DEVICE}")
    print(f"Sequence length: {SEQ_LENGTH}")
    print(f"Number of features: {NUM_FEATURES}")
    print(f"Forecast horizon: {FORECAST_HORIZON}")
    
    try:
        # Load model
        model, config = load_model(PRETRAINED_MODEL_PATH, DEVICE)
        print(f"✓ Model loaded successfully from {PRETRAINED_MODEL_PATH}")
        
        # Warmup and validation
        run_warmup_and_validation(model, config, DEVICE, SEQ_LENGTH, NUM_FEATURES, FORECAST_HORIZON)
        
        # Benchmark both implementations
        baseline_results = benchmark_inference(
            model, config, DEVICE, SEQ_LENGTH, NUM_FEATURES, FORECAST_HORIZON,
            BATCH_SIZES, NUM_INFERENCES, use_fast_attention=False
        )
        
        fast_results = benchmark_inference(
            model, config, DEVICE, SEQ_LENGTH, NUM_FEATURES, FORECAST_HORIZON,
            BATCH_SIZES, NUM_INFERENCES, use_fast_attention=True
        )
        
        # Print summary
        print("\n=== Benchmark Summary ===")
        print("Batch Size | Baseline (s/sample) | Fast (s/sample) | Speedup")
        print("-" * 60)
        for batch_size in BATCH_SIZES:
            baseline_time = baseline_results[batch_size]
            fast_time = fast_results[batch_size]
            speedup = baseline_time / fast_time if fast_time > 0 else 0
            print(f"{batch_size:10d} | {baseline_time:18.4f} | {fast_time:16.4f} | {speedup:7.2f}x")
            
    except Exception as e:
        print(f"Error during execution: {e}")
        raise


if __name__ == "__main__":
    main()

The final experimental results are summarized in the table below:

Chronos2 Inference Speed Test

Device: cuda
Sequence length: 1440
Number of features: 20
Forecast horizon: 96
✓ Model loaded successfully from /apdcephfs_fsgm/share_304079515/hunyuan_test/DualWeaver-1214/hf_ltm/chronos-2
✓ Baseline and fast implementations produce identical results

=== Baseline Implementation Benchmark ===
Batch size 20: 0.0228 s/sample
Batch size 40: 0.0303 s/sample
Batch size 60: 0.0358 s/sample
Batch size 80: 0.0415 s/sample
Batch size 100: 0.0471 s/sample

=== Fast Implementation Benchmark ===
Batch size 20: 0.0203 s/sample
Batch size 40: 0.0202 s/sample
Batch size 60: 0.0198 s/sample
Batch size 80: 0.0199 s/sample
Batch size 100: 0.0200 s/sample

=== Benchmark Summary ===

Batch Size Baseline (s/sample) Fast (s/sample) Speedup
20 0.0228 0.0203 1.12x
40 0.0303 0.0202 1.50x
60 0.0358 0.0198 1.80x
80 0.0415 0.0199 2.09x
100 0.0471 0.0200 2.35x

The results demonstrate a significant improvement in end-to-end inference speed, particularly for larger batch sizes.

Summary

This PR optimizes the computation of group attention to accelerate the end-to-end inference speed of Chronos-2 (training is similarly affected). The optimization effect is particularly significant when the batch size is large.

For the computation of group attention, certain flash_attn operators, such as flash_attn_varlen_func, are quite suitable and can further enhance its computational efficiency. However, it is noteworthy that flash_attn_varlen_func currently only supports bf16 and fp16 data types. I look forward to integrating more efficient operators into Chronos-2 in the future, enabling it to be applied more broadly.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

xymli added 2 commits December 22, 2025 21:03
Signed-off-by: xymli <xymli@tencent.com>
Signed-off-by: xymli <xymli@tencent.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant