Spaces:
Configuration error
Configuration error
fix flow matching training for zero shot inference
Browse files- cosyvoice/flow/flow.py +6 -0
cosyvoice/flow/flow.py
CHANGED
|
@@ -12,6 +12,7 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
import logging
|
|
|
|
| 15 |
from typing import Dict, Optional
|
| 16 |
import torch
|
| 17 |
import torch.nn as nn
|
|
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
| 77 |
|
| 78 |
# get conditions
|
| 79 |
conds = torch.zeros(feat.shape, device=token.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
conds = conds.transpose(1, 2)
|
| 81 |
|
| 82 |
mask = (~make_pad_mask(feat_len)).to(h)
|
|
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
import logging
|
| 15 |
+
import random
|
| 16 |
from typing import Dict, Optional
|
| 17 |
import torch
|
| 18 |
import torch.nn as nn
|
|
|
|
| 78 |
|
| 79 |
# get conditions
|
| 80 |
conds = torch.zeros(feat.shape, device=token.device)
|
| 81 |
+
for i, j in enumerate(feat_len):
|
| 82 |
+
if random.random() < 0.5:
|
| 83 |
+
continue
|
| 84 |
+
index = random.randint(0, int(0.3 * j))
|
| 85 |
+
conds[i, :index] = feat[i, :index]
|
| 86 |
conds = conds.transpose(1, 2)
|
| 87 |
|
| 88 |
mask = (~make_pad_mask(feat_len)).to(h)
|