Skip to content

Commit b96f8b6

Browse files
author
stephenlei
committed
update new version
1 parent 360d737 commit b96f8b6

File tree

9 files changed

+346
-43
lines changed

9 files changed

+346
-43
lines changed

README.md

Lines changed: 234 additions & 25 deletions
Large diffs are not rendered by default.

generate.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from codeclm.trainer.codec_song_pl import CodecLM_PL
1515
from codeclm.models import CodecLM
1616
from third_party.demucs.models.pretrained import get_model_from_yaml
17-
17+
import re
1818

1919
auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
2020

@@ -81,6 +81,7 @@ def parse_args():
8181
return parser.parse_args()
8282

8383
def generate(args):
84+
torch.set_num_threads(1)
8485
ckpt_path = args.ckpt_path
8586
input_jsonl = args.input_jsonl
8687
save_dir = args.save_dir
@@ -95,10 +96,9 @@ def generate(args):
9596

9697

9798
separator = Separator()
98-
auto_prompt = torch.load('ckpt/prompt.pt')
99+
auto_prompt = torch.load('tools/new_prompt.pt')
99100
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
100101
audio_tokenizer = audio_tokenizer.eval().cuda()
101-
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
102102
with open(input_jsonl, "r") as fp:
103103
lines = fp.readlines()
104104

@@ -145,10 +145,7 @@ def generate(args):
145145
melody_is_wav = False
146146
elif "auto_prompt_audio_type" in item:
147147
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
148-
if item["auto_prompt_audio_type"] == "Auto":
149-
prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
150-
else:
151-
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
148+
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
152149
pmt_wav = prompt_token[:,[0],:]
153150
vocal_wav = prompt_token[:,[1],:]
154151
bgm_wav = prompt_token[:,[2],:]
@@ -280,6 +277,7 @@ def generate(args):
280277
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
281278

282279
def generate_lowmem(args):
280+
torch.set_num_threads(1)
283281
ckpt_path = args.ckpt_path
284282
input_jsonl = args.input_jsonl
285283
save_dir = args.save_dir
@@ -304,8 +302,7 @@ def generate_lowmem(args):
304302
separator = Separator()
305303
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
306304
audio_tokenizer = audio_tokenizer.eval().cuda()
307-
auto_prompt = torch.load('ckpt/prompt.pt')
308-
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
305+
auto_prompt = torch.load('tools/new_prompt.pt')
309306
new_items = []
310307
for line in lines:
311308
item = json.loads(line)
@@ -345,10 +342,7 @@ def generate_lowmem(args):
345342
melody_is_wav = False
346343
elif "auto_prompt_audio_type" in item:
347344
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
348-
if item["auto_prompt_audio_type"] == "Auto":
349-
prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
350-
else:
351-
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
345+
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
352346
pmt_wav = prompt_token[:,[0],:]
353347
vocal_wav = prompt_token[:,[1],:]
354348
bgm_wav = prompt_token[:,[2],:]
@@ -471,7 +465,8 @@ def generate_lowmem(args):
471465
seperate_tokenizer.model.model.device = torch.device(device)
472466
seperate_tokenizer = seperate_tokenizer.eval()
473467

474-
offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False
468+
# offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False
469+
offload_wav_tokenizer_diffusion = False
475470
if offload_wav_tokenizer_diffusion:
476471
sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, cfg.offload.wav_tokenizer_diffusion)
477472
sep_offload_param.show()
@@ -533,6 +528,8 @@ def generate_lowmem(args):
533528

534529

535530
if __name__ == "__main__":
531+
# 限制模型使用的显存为0.6
532+
torch.cuda.set_per_process_memory_fraction(0.55)
536533
torch.backends.cudnn.enabled = False
537534
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
538535
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
@@ -548,18 +545,29 @@ def generate_lowmem(args):
548545
res_mem = (total - reserved) / 1024 / 1024 / 1024
549546
print(f"reserved memory: {res_mem}GB")
550547

551-
model_name = args.ckpt_path.split("/")[-1]
552-
assert model_name in ['songgeneration_base'], f'{model_name} is not supported, currently only songgeneration_base is supported'
553-
if model_name == 'songgeneration_base':
548+
model_name = args.ckpt_path.split("/")[-1].lower().replace('-', '_')
549+
assert model_name in ['songgeneration_base', 'songgeneration_base_new', 'songgeneration_base_full', 'songgeneration_large'], f'{model_name} is not supported, currently only songgeneration_base, songgeneration_base_new, songgeneration_base_full, songgeneration_large are supported. Please download correct files and rename the folder to the corresponding version name.'
550+
if model_name == 'songgeneration_base' or model_name == 'songgeneration_base_new' or model_name == 'songgeneration_base_full':
554551
if res_mem > 24 and not args.low_mem:
555552
print("use generate")
556553
generate(args)
557554
else:
558555
from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
559556
print("use generate_lowmem")
560557
generate_lowmem(args)
558+
elif model_name == 'songgeneration_large':
559+
if res_mem > 36 and not args.low_mem:
560+
print("use generate")
561+
generate(args)
562+
else:
563+
print("use generate_lowmem")
564+
from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
565+
generate_lowmem(args)
566+
567+
568+
# elif model_name == 'songgeneration_base_full':
561569

562570
else:
563571
print("CUDA is not available")
564572
exit()
565-
573+

generate.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ export PYTHONDONTWRITEBYTECODE=1
33
export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
44
export NCCL_HOME=/usr/local/tccl
55
export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
6+
export OMP_NUM_THREADS=1
7+
export MKL_NUM_THREADS=1
8+
export CUDA_LAUNCH_BLOCKING=0
69

710
CKPT_PATH=$1
811
JSONL=$2

img/contact.jpg

-119 KB
Binary file not shown.

img/contact.png

118 KB
Loading

0 commit comments

Comments
 (0)