Gengze Zhou1*, Chongjian Ge2, Hao Tan2, Feng Liu2, Yicong Hong2
1Australian Institute for Machine Learning, Adelaide University 2Adobe Research
Recent advances in autoregressive (AR) generative models have produced increasingly powerful systems for media synthesis. Among them, next-scale prediction has emerged as a popular paradigm, where models generate images in a coarse-to-fine manner. However, scale-wise AR models suffer from exposure bias, which undermines generation quality. We identify two primary causes of this issue:
- Train–test mismatch: The model relies on imperfect predictions during inference but ground truth during training.
- Imbalanced learning difficulty: Coarse scales must generate global structure from scratch, while fine scales only perform easier reconstruction.
To address this, we propose Self-Autoregressive Refinement (SAR). SAR introduces a Stagger-Scale Rollout (SSR) mechanism to expose the model to its own intermediate predictions and a Contrastive Student-Forcing Loss (CSFL) to ensure stable training. Experimental results show that applying SAR to pretrained AR models consistently improves generation quality with minimal computational overhead (e.g., 5.2% FID reduction on FlexVAR-d16 within 10 epochs).
SAR is a lightweight post-training algorithm that bridges the train-test gap. It consists of two key components:
SSR is a two-step rollout strategy that is computationally efficient (requiring only one extra forward pass):
- Step 1 (Teacher Forcing): The model performs teacher forcing and predicts at all scales using ground-truth conditioning.
- Step 2 (Student Forcing): These predictions are upsampled to form scale-shifted inputs, enabling a second forward pass that produces student-forced predictions.
Teacher-forcing loss provides ground-truth supervision, while the contrastive student-forcing loss aligns student-forced outputs with their teacher-forced counterparts.
Naive student forcing often causes the model to drift away from the ground truth. To fix this, CSFL:
- Aligns the student prediction with the stable teacher prediction instead of forcing the student prediction to match the ground truth (which causes conflicts).
- Teaches the model to remain consistent with the "expert" trajectory even when conditioned on imperfect inputs.
# Install uv if not already installed
curl -LsSf https://astral.sh/uv/install.sh | sh
# Create virtual environment and install dependencies
uv venv --python 3.10
source .venv/bin/activate
uv pip install -r requirements.txt
# (Optional) Install flash-attn and xformers for faster attention
uv pip install flash-attn xformersconda create -n sar python=3.10 -y
conda activate sar
pip install -r requirements.txt
# (Optional) Install flash-attn and xformers for faster attention
pip install flash-attn xformersNote: Always activate your environment (
source .venv/bin/activatefor uv orconda activate sarfor conda) before running training/evaluation scripts.
Download the ImageNet dataset.
Assume the ImageNet is in /path/to/imagenet. It should look like this:
/path/to/imagenet/:
train/:
n01440764:
many_images.JPEG ...
n01443537:
many_images.JPEG ...
val/:
n01440764:
ILSVRC2012_val_00000293.JPEG ...
n01443537:
ILSVRC2012_val_00000236.JPEG ...
Note: The arg
--data_path=/path/to/imagenetshould be passed to the training script.
Download FlexVAE.pth first and place it at pretrained/FlexVAE.pth.
| Model | Params | FID ↓ | IS ↑ | Weights |
|---|---|---|---|---|
| FlexVAR-d16 | 310M | 3.05 | 291.3 | FlexVARd16-epo179.pth |
| FlexVAR-d16 + SAR | 310M | 2.89 | 266.6 | SARd16-epo179.pth |
| FlexVAR-d20 | 600M | 2.41 | 299.3 | FlexVARd20-epo249.pth |
| FlexVAR-d20 + SAR | 600M | 2.35 | 293.3 | SARd20-epo249.pth |
| FlexVAR-d24 | 1.0B | 2.21 | 299.1 | FlexVARd24-epo349.pth |
| FlexVAR-d24 + SAR | 1.0B | 2.14 | 315.5 | SARd24-epo349.pth |
Train SAR with Contrastive Student-Forcing Loss (CSFL):
# SAR d16 (depth 16)
bash scripts/train_SAR_d16.sh
# SAR d20 (depth 20)
bash scripts/train_SAR_d20.sh
# SAR d24 (depth 24)
bash scripts/train_SAR_d24.shNote: (1) For simplicity, instead of training a FlexVAR model for 170 epochs and then applying SAR for 10 epochs, in the following examples we directly download the pretrained FlexVAR checkpoint (e.g., 180 epochs for d16) and train SAR for 10 additional epochs. (2) The
scalear_trainer.py,scalear.pyimplements the hybrid modeling model described in Section 3.3 of our paper, but is not used in the main experiment. The code is extensively commented for clarity; readers are encouraged to review it to gain a deeper understanding of the design choices and preliminary experiments underlying this work.
To enable Weights & Biases logging during training:
# Option 1: Interactive login (one-time setup)
wandb login
# Option 2: Using API token (useful for servers/automation)
export WANDB_API_KEY=your-api-key-here
# Or pass it directly: --wandb_api_key=your-api-key-here
# Then add these flags to your training command
torchrun ... train.py \
--logger_type=wandb \
--wandb_project=sar \
--wandb_entity=your-team-name \ # optional
--wandb_run_name=my-experiment \ # optional, auto-generated if not provided
--wandb_tags=d16,csfl \ # optional, comma-separated
...Available wandb arguments:
| Argument | Default | Description |
|---|---|---|
--logger_type |
tensorboard |
Set to wandb to enable wandb logging |
--wandb_project |
sar |
Wandb project name |
--wandb_entity |
None |
Wandb team/entity name (optional) |
--wandb_run_name |
None |
Custom run name (auto-generated if not set) |
--wandb_tags |
None |
Comma-separated tags for the run |
--wandb_notes |
None |
Notes for the run |
First, setup the evaluation environment:
bash scripts/setup_eval.shFor FID evaluation, images are sampled and saved as PNG files. Use the OpenAI's FID evaluation toolkit with reference ground truth npz file of 256×256 to evaluate FID, IS, precision, and recall. See Evaluation for details.
Run evaluation:
# Evaluate SAR d16
bash scripts/eval_SAR_d16.sh
# Evaluate SAR d20
bash scripts/eval_SAR_d20.sh
# Evaluate SAR d24
bash scripts/eval_SAR_d24.shGenerate sample images with a trained model:
import torch
from models import build_vae_var
# Load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vae, model = build_vae_var(
V=8912, Cvae=32, device=device,
num_classes=1000, depth=16,
vae_ckpt='pretrained/FlexVAE.pth'
)
# Load checkpoint
ckpt = torch.load('pretrained/SARd16-epo179.pth', map_location='cpu')
if 'trainer' in ckpt:
ckpt = ckpt['trainer']['var_wo_ddp']
model.load_state_dict(ckpt, strict=False)
model.eval()
# Generate images
with torch.no_grad():
# Class labels (e.g., 207=golden retriever, 88=parrot)
labels = torch.tensor([207, 88, 360, 387], device=device)
images = model.autoregressive_infer_cfg(
vqvae=vae,
B=4,
label_B=labels,
cfg=2.5, # classifier-free guidance scale
top_k=900, # top-k sampling
top_p=0.95, # nucleus sampling
)
# Save images
from torchvision.utils import save_image
save_image(images, 'samples.png', normalize=True, value_range=(-1, 1), nrow=4)A complete inference demo is provided at demo_inference.ipynb.
This codebase is built upon VAR, FlexVAR. Our thanks go out to the creators of these outstanding projects.
@article{zhou2025rethinking,
title={Rethinking Training Dynamics in Scale-wise Autoregressive Generation},
author={Zhou, Gengze and Ge, Chongjian and Tan, Hao and Liu, Feng and Hong, Yicong},
journal={arXiv preprint arXiv:2512.06421},
year={2025}
}
