Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix lint
  • Loading branch information
wwwjn committed Mar 10, 2026
commit 6224df7e039533c83b2c5e13080871debe2cf391
2 changes: 1 addition & 1 deletion tests/integration_tests/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run_single_test(test_flavor: OverrideDefinitions, output_dir: str):
dump_folder_arg = f"--dump_folder {output_dir}/{test_name}"

# Random init encoder for offline testing
random_init_arg = "--tokenizer.test_mode --encoder.random_init"
random_init_arg = "--tokenizer.test_mode --encoder._random_init"
clip_encoder_version_arg = (
"--encoder.clip_encoder tests/assets/flux_test_encoders/clip-vit-large-patch14/"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_dataset_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_load_dataset(self):
"tests/assets/tokenizer",
"--tokenizer.clip_tokenizer_path",
"tests/assets/tokenizer",
"--encoder.random_init",
"--encoder._random_init",
"--encoder.t5_encoder",
"tests/assets/flux_test_encoders/t5-v1_1-xxl",
"--encoder.clip_encoder",
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/flux/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def encode(self, text: str | list[str]) -> dict[str, torch.Tensor]:
return {
"clip_tokens": self.clip_tokenizer.encode(text),
"t5_tokens": self.t5_tokenizer.encode(text),
Comment on lines +62 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can just call them

Suggested change
"clip_tokens": self.clip_tokenizer.encode(text),
"t5_tokens": self.t5_tokenizer.encode(text),
"clip": self.clip_tokenizer.encode(text),
"t5": self.t5_tokenizer.encode(text),

}
} # pyrefly: ignore [bad-return]

# pyrefly: ignore [bad-override]
def decode(self, tokens: dict[str, list[int]]) -> dict[str, str]:
Expand Down
6 changes: 5 additions & 1 deletion torchtitan/models/flux/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def validate(
assert isinstance(p, str), f"prompt must be a string, got {type(p)}"
if save_img_count != -1 and save_img_count <= 0:
break
img_size = self.config.dataloader.img_size
img_size = (
self.config.dataloader.img_size
) # pyrefly: ignore [missing-attribute]
image = generate_image(
device=self.device,
dtype=self._dtype,
Expand All @@ -179,9 +181,11 @@ def validate(
enable_classifier_free_guidance=self.config.sampling.enable_classifier_free_guidance,
denoising_steps=self.config.sampling.denoising_steps,
classifier_free_guidance_scale=self.config.sampling.classifier_free_guidance_scale,
# pyrefly: ignore [bad-argument-type]
model=model,
prompt=p,
autoencoder=self.autoencoder,
# pyrefly: ignore [bad-argument-type]
tokenizer=self.tokenizer,
t5_encoder=self.t5_encoder,
clip_encoder=self.clip_encoder,
Expand Down
Loading