File size: 2,739 Bytes
5cfeca6
 
 
 
d7f22c4
 
 
4ae2215
d7f22c4
 
 
4ae2215
d7f22c4
 
 
 
 
 
 
c508945
d7f22c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae2215
d7f22c4
 
 
 
 
 
 
 
 
 
 
 
 
4ae2215
 
 
 
 
d7f22c4
7bdf3c3
d7f22c4
 
 
 
 
 
 
 
 
190f22a
 
d7f22c4
 
4ae2215
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Literal, MutableSequence, Optional

import torch

from AR.models.t2s_model_abc import KVCacheABC, Sampler, T2SDecoderABC

Tensor = torch.Tensor


@dataclass
class T2SResult:
    result: List[Tensor] | None = None
    infer_speed: float = 0.0
    status: Literal["Success", "Error"] = "Success"
    exception: Optional[Exception] = None
    traceback: Optional[str] = None


@dataclass
class T2SRequest:
    x: List[torch.Tensor]
    x_lens: Tensor
    prompts: torch.Tensor
    bert_feature: List[Tensor]
    valid_length: int
    top_k: int = 5
    top_p: float = 1
    early_stop_num: int = -1
    temperature: float = 1.0
    repetition_penalty: float = 1.35
    use_cuda_graph: bool = False
    debug: bool = False


class T2SSession:
    def __init__(self, decoder: T2SDecoderABC, request: T2SRequest, device: torch.device, dtype: torch.dtype):
        with device:
            self.decoder = decoder
            self.request = request
            self.device = device
            self.dtype = dtype

            bsz = len(request.x)
            y_len = request.prompts.size(-1)
            self.bsz = bsz
            self.y_len = y_len

            # Cache
            self.kv_cache: MutableSequence[KVCacheABC]
            self.sampler = Sampler(bsz, decoder.vocab_size)

            # Forward args
            self.x = request.x
            self.x_lens = request.x_lens.to(torch.int32)
            self.y = request.prompts
            self.bert_feature = request.bert_feature

            self.prefill_len = self.x_lens + self.y.size(1)

            self.input_pos = torch.zeros_like(self.prefill_len)
            self.input_pos.add_(self.prefill_len)

            # CUDA Graph
            self.graph: Optional[torch.cuda.CUDAGraph] = None
            self.xy_pos_: Tensor
            self.xy_dec_: Tensor

            # EOS
            self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
            self.y_results: List[Tensor] = [None] * len(self.x)  # type: ignore

            self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)

            attn_mask = []
            for bs in range(bsz):
                pos = int(self.x_lens[bs].item())
                mask = torch.zeros(pos + y_len, pos + y_len).bool()
                mask[:, :pos].fill_(True)
                if y_len > 0:
                    mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1)
                attn_mask.append(mask)
            self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)

            self.id: int = -1