We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 77b65bf commit ba1b5d0Copy full SHA for ba1b5d0
flash_rl/vllm_patch.py
@@ -145,8 +145,7 @@ def hacked_process_weights_after_loading(
145
146
for name, p in all_updated_params.items():
147
if name in updated_params:
148
- strided_data = torch.as_strided(
149
- p.data, hacked_data_dict[name].shape, hacked_data_dict[name].stride())
+ strided_data = p.data.t()
150
hacked_data_dict[name].copy_(strided_data)
151
else:
152
skipped_params.append(name)
0 commit comments