Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 100 additions & 2 deletions toolkit/alpha_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def __init__(self, config: Dict[str, Any], rank: int):
logger.info("Alpha scheduling disabled")
return

# Merge user config with defaults to fill in missing alpha values
config = self._merge_with_defaults(config, rank)
self.config = config

# Parse phase definitions
self.phases = self._parse_phases(config.get('conv_alpha_phases', {}))
self.linear_alpha = config.get('linear_alpha', 16)
Expand Down Expand Up @@ -174,6 +178,77 @@ def __init__(self, config: Dict[str, Any], rank: int):
# Validate alpha/rank ratios and warn if high
self._validate_alpha_ratios()

def _merge_with_defaults(self, config: Dict[str, Any], rank: int) -> Dict[str, Any]:
"""
Merge user configuration with defaults to ensure all required fields are present.

If alpha values are missing from phase configurations, fills them in with reasonable defaults.

Args:
config: User-provided configuration
rank: LoRA rank

Returns:
Merged configuration with all required fields
"""
import copy
merged_config = copy.deepcopy(config)

# Get or create default config as template
default_config = create_default_config(rank)

# Ensure linear_alpha is set
if 'linear_alpha' not in merged_config:
merged_config['linear_alpha'] = default_config.get('linear_alpha', 16)

# Merge conv_alpha_phases
if 'conv_alpha_phases' in merged_config:
default_phases = default_config.get('conv_alpha_phases', {})
user_phases = merged_config['conv_alpha_phases']

for phase_name, phase_config in user_phases.items():
# If alpha is missing, try to get it from defaults
if 'alpha' not in phase_config or phase_config['alpha'] is None:
if phase_name in default_phases and 'alpha' in default_phases[phase_name]:
phase_config['alpha'] = default_phases[phase_name]['alpha']
logger.info(f"Filled in missing alpha for '{phase_name}' phase with default: {phase_config['alpha']}")
else:
# Last resort: estimate based on rank
estimated_alpha = max(4, int(rank * 0.22)) # ~0.22 scale
phase_config['alpha'] = estimated_alpha
logger.warning(
f"Alpha not found for '{phase_name}' phase, using estimated value: {estimated_alpha} "
f"(estimated scale: {estimated_alpha/rank:.3f})"
)

# Merge per_expert phases if present
if 'per_expert' in merged_config:
default_per_expert = default_config.get('per_expert', {})
user_per_expert = merged_config['per_expert']

for expert_name, expert_config in user_per_expert.items():
if 'phases' in expert_config:
default_expert_phases = default_per_expert.get(expert_name, {}).get('phases', {})
user_expert_phases = expert_config['phases']

for phase_name, phase_config in user_expert_phases.items():
if 'alpha' not in phase_config or phase_config['alpha'] is None:
if phase_name in default_expert_phases and 'alpha' in default_expert_phases[phase_name]:
phase_config['alpha'] = default_expert_phases[phase_name]['alpha']
logger.info(
f"Filled in missing alpha for '{expert_name}' expert '{phase_name}' phase "
f"with default: {phase_config['alpha']}"
)
else:
estimated_alpha = max(4, int(rank * 0.22))
phase_config['alpha'] = estimated_alpha
logger.warning(
f"Alpha not found for '{expert_name}' expert '{phase_name}' phase, "
f"using estimated value: {estimated_alpha}"
)

return merged_config

def _validate_alpha_ratios(self):
"""Validate alpha/rank ratios and warn if unusually high."""
# Check linear alpha
Expand All @@ -187,6 +262,12 @@ def _validate_alpha_ratios(self):

# Check conv alpha in all phases
for phase in self.phases:
if phase.alpha is None:
logger.warning(
f"⚠️ Conv alpha is not set for '{phase.name}' phase. "
f"Alpha value must be specified in the phase configuration."
)
continue
conv_scale = phase.alpha / self.rank
if conv_scale > 0.5:
logger.warning(
Expand All @@ -199,6 +280,12 @@ def _validate_alpha_ratios(self):
if self.per_expert_phases:
for expert_name, expert_phases in self.per_expert_phases.items():
for phase in expert_phases:
if phase.alpha is None:
logger.warning(
f"⚠️ Conv alpha is not set for '{expert_name}' expert in '{phase.name}' phase. "
f"Alpha value must be specified in the phase configuration."
)
continue
conv_scale = phase.alpha / self.rank
if conv_scale > 0.5:
logger.warning(
Expand Down Expand Up @@ -262,10 +349,18 @@ def get_current_alpha(self, module_name: str, is_conv: bool) -> float:

# Get current phase alpha
if self.current_phase_idx < len(phases):
return phases[self.current_phase_idx].alpha
alpha = phases[self.current_phase_idx].alpha
if alpha is None:
logger.warning(f"Alpha value not set for phase '{phases[self.current_phase_idx].name}', using default")
return self.config.get('conv_alpha', 14)
return alpha
else:
# Staying in final phase
return phases[-1].alpha
alpha = phases[-1].alpha
if alpha is None:
logger.warning(f"Alpha value not set for final phase '{phases[-1].name}', using default")
return self.config.get('conv_alpha', 14)
return alpha

def get_current_scale(self, module_name: str, is_conv: bool) -> float:
"""
Expand All @@ -274,6 +369,9 @@ def get_current_scale(self, module_name: str, is_conv: bool) -> float:
This is the actual effective scaling factor applied in forward pass.
"""
alpha = self.get_current_alpha(module_name, is_conv)
if alpha is None:
# Fallback to default if alpha is somehow still None
alpha = self.config.get('conv_alpha', 14) if is_conv else self.linear_alpha
return alpha / self.rank

def update(self, step: int, loss: Optional[float] = None,
Expand Down
2 changes: 1 addition & 1 deletion ui/cron/actions/startJob.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ const startAndWatchJob = (job: Job) => {
try {
let subprocess;

const devNull = fs.openSync('/dev/null', 'a');
const devNull = fs.openSync(isWindows ? 'nul' : '/dev/null', 'a');

if (isWindows) {
// Spawn Python directly on Windows so the process can survive parent exit
Expand Down