Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 109 additions & 66 deletions scripts/loss_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,6 @@ def validate_arguments(
import_result: str | None,
) -> None:
"""Validate command line arguments."""
# Validate commit arguments - if one is ".", both must be "."
if (baseline_commit == "." and test_commit != ".") or (
baseline_commit != "." and test_commit == "."
):
log_print("Error: If one commit is '.', both commits must be '.'")
log_print(f" Got baseline: '{baseline_commit}', test: '{test_commit}'")
log_print(
" Use '.' for both commits to compare different "
"configurations on current working directory"
)
sys.exit(1)

# Validate that we are comparing different settings
commits_differ = baseline_commit != test_commit
configs_differ = baseline_config != test_config
Expand Down Expand Up @@ -335,7 +323,8 @@ def print_configuration(
def check_git_clean_state() -> None:
"""Check if git working directory is clean before switching commits.

Raises SystemExit if there are uncommitted changes or untracked files.
Raises SystemExit if there are uncommitted changes to tracked files.
Untracked files are ignored.
"""
result = subprocess.run(
["git", "status", "--porcelain"],
Expand All @@ -344,12 +333,20 @@ def check_git_clean_state() -> None:
check=True,
)

if result.stdout.strip():
log_print("Error: Git working directory is not clean")
# Filter out untracked files (lines starting with "??")
modified_tracked_files = []
for line in result.stdout.strip().split("\n"):
if line and not line.startswith("??"):
modified_tracked_files.append(line)

if modified_tracked_files:
log_print(
"Error: Git working directory has uncommitted changes to tracked files"
)
log_print(" Cannot switch commits with uncommitted changes")
log_print("")
log_print("Modified/untracked files:")
for line in result.stdout.strip().split("\n"):
log_print("Modified tracked files:")
for line in modified_tracked_files:
log_print(f" {line}")
log_print("")
log_print(
Expand All @@ -370,6 +367,39 @@ def checkout_commit(commit: str, commit_name: str) -> None:
log_print(f"Using current working directory for {commit_name} (commit: '.')")


def get_current_commit() -> str:
"""Get the current git commit hash or branch name.

Returns the current branch name if on a branch, otherwise returns the commit hash.
"""
# Try to get current branch name
result = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
capture_output=True,
text=True,
check=True,
)
ref = result.stdout.strip()

# If in detached HEAD state, ref will be "HEAD", so get the commit hash instead
if ref == "HEAD":
result = subprocess.run(
["git", "rev-parse", "HEAD"],
capture_output=True,
text=True,
check=True,
)
ref = result.stdout.strip()

return ref


def restore_original_commit(original_commit: str) -> None:
"""Restore the original git commit/branch."""
log_print(f"Restoring original commit/branch: {original_commit}")
subprocess.run(["git", "checkout", original_commit], check=True)


# =============================================================================
# TRAINING OPERATIONS
# =============================================================================
Expand Down Expand Up @@ -1002,61 +1032,74 @@ def main() -> None:
)

# Check if git working directory is clean before switching commits
# Skip check if both commits are "." (comparing configs on same commit)
# Skip check only if both commits are "." (comparing configs on same commit)
needs_git_checkout = args.baseline_commit != "." or args.test_commit != "."
if needs_git_checkout:
check_git_clean_state()

create_seed_checkpoint(
enable_seed_checkpoint,
args.baseline_config,
args.baseline_train_file,
args.output_folder,
args.job_dump_folder,
)
# Run baseline and test training
baseline_log = run_scenario(
"baseline",
args.baseline_commit,
args.baseline_config,
args.baseline_train_file,
args.baseline_options,
args.steps,
enable_seed_checkpoint,
args.output_folder,
args.job_dump_folder,
args.baseline_ngpus,
)
# Save original commit if we're going to do checkouts
original_commit = None
if needs_git_checkout:
original_commit = get_current_commit()
log_print(f"Saving original commit/branch: {original_commit}")
log_print()

test_log = run_scenario(
"test",
args.test_commit,
args.test_config,
args.test_train_file,
args.test_options,
args.steps,
enable_seed_checkpoint,
args.output_folder,
args.job_dump_folder,
args.test_ngpus,
)
log_print()
try:
create_seed_checkpoint(
enable_seed_checkpoint,
args.baseline_config,
args.baseline_train_file,
args.output_folder,
args.job_dump_folder,
)
# Run baseline and test training
baseline_log = run_scenario(
"baseline",
args.baseline_commit,
args.baseline_config,
args.baseline_train_file,
args.baseline_options,
args.steps,
enable_seed_checkpoint,
args.output_folder,
args.job_dump_folder,
args.baseline_ngpus,
)

test_log = run_scenario(
"test",
args.test_commit,
args.test_config,
args.test_train_file,
args.test_options,
args.steps,
enable_seed_checkpoint,
args.output_folder,
args.job_dump_folder,
args.test_ngpus,
)
log_print()

# Assert losses are equal if requested
if args.assert_equal:
# Pass import_result if provided for 3-way comparison
assert_losses_equal(baseline_log, test_log, args.import_result)

# Export losses if requested (only after assertion passes)
if args.export_result:
# Extract baseline losses (they equal test losses since assertion passed)
baseline_losses = extract_losses_from_log(baseline_log)
export_losses_to_file(baseline_losses, args.export_result)

# Analysis and reporting
perform_loss_analysis(baseline_log, test_log, stats_file)
cleanup_temp_files(args.output_folder)
print_completion_summary(args.output_folder, enable_seed_checkpoint)
# Assert losses are equal if requested
if args.assert_equal:
# Pass import_result if provided for 3-way comparison
assert_losses_equal(baseline_log, test_log, args.import_result)

# Export losses if requested (only after assertion passes)
if args.export_result:
# Extract baseline losses (they equal test losses since assertion passed)
baseline_losses = extract_losses_from_log(baseline_log)
export_losses_to_file(baseline_losses, args.export_result)

# Analysis and reporting
perform_loss_analysis(baseline_log, test_log, stats_file)
cleanup_temp_files(args.output_folder)
print_completion_summary(args.output_folder, enable_seed_checkpoint)
finally:
# Restore original commit if we did checkouts
if original_commit is not None:
log_print()
restore_original_commit(original_commit)


if __name__ == "__main__":
Expand Down
Loading