File size: 4,053 Bytes
e730386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File    : layers.py
@Time    : 2024/4/22 下午2:40
@Author  : waytan
@Contact : [email protected]
@License : (C)Copyright 2024, Tencent
"""
import os
import json
import time
import logging
import argparse
from datetime import datetime


import torch

from models.apply import BagOfModels
from models.pretrained import get_model_from_yaml


class Separator:
    def __init__(self, dm_model_path, dm_config_path, gpu_id=0) -> None:
        if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
            self.device = torch.device(f"cuda:{gpu_id}")
        else:
            self.device = torch.device("cpu")
        self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)

    def init_demucs_model(self, model_path, config_path) -> BagOfModels:
        model = get_model_from_yaml(config_path, model_path)
        model.to(self.device)
        model.eval()
        return model
    
    def run(self, audio_path, output_dir, ext=".flac"):
        name, _ = os.path.splitext(os.path.split(audio_path)[-1])
        output_paths = []
        for stem in self.demucs_model.sources:
            output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
            if os.path.exists(output_path):
                output_paths.append(output_path)
        if len(output_paths) == 4:
            drums_path, bass_path, other_path, vocal_path = output_paths
        else:
            drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
        data_dict = {
            "vocal_path": vocal_path,
            "bgm_path": [drums_path, bass_path, other_path]
        }
        return data_dict


def json_io(input_json, output_json, model_dir, dst_dir, gpu_id=0):
    current_datetime = datetime.now()
    current_datetime_str = current_datetime.strftime('%Y-%m-%d-%H:%M')
    logging.basicConfig(filename=os.path.join(dst_dir, f'logger-separate-{os.path.split(input_json)[1]}-{current_datetime_str}.log'), level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

    sp = Separator(os.path.join(model_dir, "htdemucs.pth"), os.path.join(model_dir, "htdemucs.yaml"), gpu_id=gpu_id)
    with open(input_json, "r") as fp:
        lines = fp.readlines()
    t1 = time.time()
    success_num =  0
    fail_num = 0
    total_num = len(lines)
    sep_items = []
    for line in lines: 
        item = json.loads(line)
        flac_file = item["path"]
        try:
            fix_data = sp.run(flac_file, dst_dir)
        except Exception as e:
            fail_num += 1
            logging.error(f"process-{success_num + fail_num}/{total_num}|success-{success_num}|fail-{fail_num}|{item['idx']} process fail for {str(e)}")
            continue
        
        item["vocal_path"] = fix_data["vocal_path"]
        item["bgm_path"] = fix_data["bgm_path"]
        sep_items.append(item)
        success_num += 1
        logging.debug(f"process-{success_num + fail_num}/{total_num}|success-{success_num}|fail-{fail_num}|{item['idx']} process success")

    with open(output_json, "w", encoding='utf-8') as fw:
        for item in sep_items:
            fw.write(json.dumps(item, ensure_ascii=False) + "\n")

    t2 = time.time()
    logging.debug(f"total cost {round(t2-t1, 3)}s")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='')
    parser.add_argument("-m", dest="model_dir")
    parser.add_argument("-d", dest="dst_dir")
    parser.add_argument("-j", dest="input_json")
    parser.add_argument("-o", dest="output_json")
    parser.add_argument("-gid", dest="gpu_id", default=0, type=int)
    args = parser.parse_args()

    if not args.dst_dir:
        dst_dir = os.path.join(os.getcwd(), "separate_result")
        os.makedirs(dst_dir, exist_ok=True)
    else:
        dst_dir = os.path.join(args.dst_dir, "separate_result")
        os.makedirs(dst_dir, exist_ok=True)

    json_io(args.input_json, args.output_json, args.model_dir, dst_dir, gpu_id=args.gpu_id)