File size: 4,299 Bytes
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbfbc08
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a23964
91394e0
 
 
7a23964
 
91394e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import random
from typing import Iterator

from data_handlers import get_melody_handler

from .utils.g2p import preprocess_text


class MelodyController:
    def __init__(self, melody_source_id: str, cache_dir: str):
        self.melody_source_id = melody_source_id
        self.song_id = None

        # load song database if needed
        parts = self.melody_source_id.split("-")
        self.mode = parts[0]
        self.align_type = parts[1]
        dataset_name = parts[-1]
        if dataset_name == "none":
            self.database = None
        else:
            handler_cls = get_melody_handler(dataset_name)
            self.database = handler_cls(self.align_type, cache_dir)

    def get_melody_constraints(self, max_num_phrases: int = 5) -> str:
        """Return a lyric-format prompt based on melody structure."""
        if self.mode == "gen":
            return ""

        elif self.mode == "sample":
            assert self.database is not None, "Song database is not loaded."
            self.song_id = random.choice(self.database.get_song_ids())
            self.reference_song = self.database.iter_song_phrases(self.song_id)
            phrase_length = self.database.get_phrase_length(self.song_id)

            if not phrase_length:
                return ""

            prompt = (
                "\n请按照歌词格式回复,每句需遵循以下字数规则:"
                + "".join(
                    [
                        f"\n第{i}句:{c}个字"
                        for i, c in enumerate(phrase_length[:max_num_phrases], 1)
                    ]
                )
                + "\n如果没有足够的信息回答,请使用最少的句子,不要重复、不要扩展、不要加入无关内容。\n"
            )
            return prompt

        else:
            raise ValueError(f"Unsupported melody mode: {self.mode}")

    def generate_score(
        self, lyrics: str, language: str
    ) -> list[tuple[float, float, str, int]]:
        """
        lyrics: [lyric, ...]
        returns: [(start, end, lyric, pitch), ...]
        """
        text_list = preprocess_text(lyrics, language)
        if self.mode == "gen" and self.align_type == "random":
            return self._generate_random_score(text_list)

        elif self.mode == "sample":
            if not self.reference_song:
                raise RuntimeError(
                    "Must call get_melody_constraints() before generate_score() in sample mode."
                )
            return self._align_text_to_score(
                text_list, self.reference_song, self.align_type
            )

        else:
            raise ValueError(f"Unsupported melody_source_id: {self.melody_source_id}")

    def _generate_random_score(self, text_list: list[str]):
        st = 0
        score = []
        for lyric in text_list:
            pitch = random.randint(57, 69)
            duration = round(random.uniform(0.1, 0.5), 4)
            ed = st + duration
            score.append((st, ed, lyric, pitch))
            st = ed
        return score

    def _align_text_to_score(
        self,
        text_list: list[str],
        song_phrase_iterator: Iterator[dict],
        align_type: str,
    ):
        score = []
        text_idx = 0

        while text_idx < len(text_list):
            reference = next(song_phrase_iterator)
            for st, ed, ref_lyric, pitch in zip(
                reference["note_start_times"],
                reference["note_end_times"],
                reference["note_lyrics"],
                reference["note_midi"],
            ):
                assert ref_lyric not in [
                    "<AP>",
                    "<SP>",
                ], f"Proccessed {self.melody_source_id} score segments should not contain <AP> or <SP>."  # TODO: remove in PR, only for debug
                if pitch == 0:
                    score.append((st, ed, ref_lyric, pitch))
                elif ref_lyric in ["-", "——"] and align_type == "lyric":
                    score.append((st, ed, "-", pitch))
                else:
                    score.append((st, ed, text_list[text_idx], pitch))
                    text_idx += 1
                    if text_idx >= len(text_list):
                        break
        return score