1414from codeclm .trainer .codec_song_pl import CodecLM_PL
1515from codeclm .models import CodecLM
1616from third_party .demucs .models .pretrained import get_model_from_yaml
17-
17+ import re
1818
1919auto_prompt_type = ['Pop' , 'R&B' , 'Dance' , 'Jazz' , 'Folk' , 'Rock' , 'Chinese Style' , 'Chinese Tradition' , 'Metal' , 'Reggae' , 'Chinese Opera' , 'Auto' ]
2020
@@ -81,6 +81,7 @@ def parse_args():
8181 return parser .parse_args ()
8282
8383def generate (args ):
84+ torch .set_num_threads (1 )
8485 ckpt_path = args .ckpt_path
8586 input_jsonl = args .input_jsonl
8687 save_dir = args .save_dir
@@ -95,10 +96,9 @@ def generate(args):
9596
9697
9798 separator = Separator ()
98- auto_prompt = torch .load ('ckpt/prompt .pt' )
99+ auto_prompt = torch .load ('tools/new_prompt .pt' )
99100 audio_tokenizer = builders .get_audio_tokenizer_model (cfg .audio_tokenizer_checkpoint , cfg )
100101 audio_tokenizer = audio_tokenizer .eval ().cuda ()
101- merge_prompt = [item for sublist in auto_prompt .values () for item in sublist ]
102102 with open (input_jsonl , "r" ) as fp :
103103 lines = fp .readlines ()
104104
@@ -145,10 +145,7 @@ def generate(args):
145145 melody_is_wav = False
146146 elif "auto_prompt_audio_type" in item :
147147 assert item ["auto_prompt_audio_type" ] in auto_prompt_type , f"auto_prompt_audio_type { item ['auto_prompt_audio_type' ]} not found"
148- if item ["auto_prompt_audio_type" ] == "Auto" :
149- prompt_token = merge_prompt [np .random .randint (0 , len (merge_prompt ))]
150- else :
151- prompt_token = auto_prompt [item ["auto_prompt_audio_type" ]][np .random .randint (0 , len (auto_prompt [item ["auto_prompt_audio_type" ]]))]
148+ prompt_token = auto_prompt [item ["auto_prompt_audio_type" ]][np .random .randint (0 , len (auto_prompt [item ["auto_prompt_audio_type" ]]))]
152149 pmt_wav = prompt_token [:,[0 ],:]
153150 vocal_wav = prompt_token [:,[1 ],:]
154151 bgm_wav = prompt_token [:,[2 ],:]
@@ -280,6 +277,7 @@ def generate(args):
280277 fw .writelines (json .dumps (item , ensure_ascii = False )+ "\n " )
281278
282279def generate_lowmem (args ):
280+ torch .set_num_threads (1 )
283281 ckpt_path = args .ckpt_path
284282 input_jsonl = args .input_jsonl
285283 save_dir = args .save_dir
@@ -304,8 +302,7 @@ def generate_lowmem(args):
304302 separator = Separator ()
305303 audio_tokenizer = builders .get_audio_tokenizer_model (cfg .audio_tokenizer_checkpoint , cfg )
306304 audio_tokenizer = audio_tokenizer .eval ().cuda ()
307- auto_prompt = torch .load ('ckpt/prompt.pt' )
308- merge_prompt = [item for sublist in auto_prompt .values () for item in sublist ]
305+ auto_prompt = torch .load ('tools/new_prompt.pt' )
309306 new_items = []
310307 for line in lines :
311308 item = json .loads (line )
@@ -345,10 +342,7 @@ def generate_lowmem(args):
345342 melody_is_wav = False
346343 elif "auto_prompt_audio_type" in item :
347344 assert item ["auto_prompt_audio_type" ] in auto_prompt_type , f"auto_prompt_audio_type { item ['auto_prompt_audio_type' ]} not found"
348- if item ["auto_prompt_audio_type" ] == "Auto" :
349- prompt_token = merge_prompt [np .random .randint (0 , len (merge_prompt ))]
350- else :
351- prompt_token = auto_prompt [item ["auto_prompt_audio_type" ]][np .random .randint (0 , len (auto_prompt [item ["auto_prompt_audio_type" ]]))]
345+ prompt_token = auto_prompt [item ["auto_prompt_audio_type" ]][np .random .randint (0 , len (auto_prompt [item ["auto_prompt_audio_type" ]]))]
352346 pmt_wav = prompt_token [:,[0 ],:]
353347 vocal_wav = prompt_token [:,[1 ],:]
354348 bgm_wav = prompt_token [:,[2 ],:]
@@ -471,7 +465,8 @@ def generate_lowmem(args):
471465 seperate_tokenizer .model .model .device = torch .device (device )
472466 seperate_tokenizer = seperate_tokenizer .eval ()
473467
474- offload_wav_tokenizer_diffusion = True if 'offload' in cfg .keys () and 'wav_tokenizer_diffusion' in cfg .offload else False
468+ # offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False
469+ offload_wav_tokenizer_diffusion = False
475470 if offload_wav_tokenizer_diffusion :
476471 sep_offload_param = OffloadParamParse .parse_config (seperate_tokenizer , cfg .offload .wav_tokenizer_diffusion )
477472 sep_offload_param .show ()
@@ -533,6 +528,8 @@ def generate_lowmem(args):
533528
534529
535530if __name__ == "__main__" :
531+ # 限制模型使用的显存为0.6
532+ torch .cuda .set_per_process_memory_fraction (0.55 )
536533 torch .backends .cudnn .enabled = False
537534 OmegaConf .register_new_resolver ("eval" , lambda x : eval (x ))
538535 OmegaConf .register_new_resolver ("concat" , lambda * x : [xxx for xx in x for xxx in xx ])
@@ -548,18 +545,29 @@ def generate_lowmem(args):
548545 res_mem = (total - reserved ) / 1024 / 1024 / 1024
549546 print (f"reserved memory: { res_mem } GB" )
550547
551- model_name = args .ckpt_path .split ("/" )[- 1 ]
552- assert model_name in ['songgeneration_base' ], f'{ model_name } is not supported, currently only songgeneration_base is supported'
553- if model_name == 'songgeneration_base' :
548+ model_name = args .ckpt_path .split ("/" )[- 1 ]. lower (). replace ( '-' , '_' )
549+ assert model_name in ['songgeneration_base' , 'songgeneration_base_new' , 'songgeneration_base_full' , 'songgeneration_large' ], f'{ model_name } is not supported, currently only songgeneration_base, songgeneration_base_new, songgeneration_base_full, songgeneration_large are supported. Please download correct files and rename the folder to the corresponding version name. '
550+ if model_name == 'songgeneration_base' or model_name == 'songgeneration_base_new' or model_name == 'songgeneration_base_full' :
554551 if res_mem > 24 and not args .low_mem :
555552 print ("use generate" )
556553 generate (args )
557554 else :
558555 from codeclm .utils .offload_profiler import OffloadProfiler , OffloadParamParse
559556 print ("use generate_lowmem" )
560557 generate_lowmem (args )
558+ elif model_name == 'songgeneration_large' :
559+ if res_mem > 36 and not args .low_mem :
560+ print ("use generate" )
561+ generate (args )
562+ else :
563+ print ("use generate_lowmem" )
564+ from codeclm .utils .offload_profiler import OffloadProfiler , OffloadParamParse
565+ generate_lowmem (args )
566+
567+
568+ # elif model_name == 'songgeneration_base_full':
561569
562570 else :
563571 print ("CUDA is not available" )
564572 exit ()
565-
573+
0 commit comments