replace the namedtupple for an object to avoid pickling issues
Browse files
sparktts/modules/speaker/perceiver_encoder.py
CHANGED
@@ -15,7 +15,6 @@
|
|
15 |
|
16 |
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
17 |
|
18 |
-
from collections import namedtuple
|
19 |
from functools import wraps
|
20 |
|
21 |
import torch
|
@@ -63,12 +62,21 @@ class Attend(nn.Module):
|
|
63 |
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
64 |
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
65 |
|
66 |
-
#
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
self.cuda_config = None
|
73 |
|
74 |
if not torch.cuda.is_available() or not use_flash:
|
@@ -85,7 +93,7 @@ class Attend(nn.Module):
|
|
85 |
print_once(
|
86 |
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
87 |
)
|
88 |
-
self.cuda_config =
|
89 |
|
90 |
def get_mask(self, n, device):
|
91 |
if exists(self.mask) and self.mask.shape[-1] >= n:
|
|
|
15 |
|
16 |
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
17 |
|
|
|
18 |
from functools import wraps
|
19 |
|
20 |
import torch
|
|
|
62 |
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
63 |
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
64 |
|
65 |
+
# Define config as a regular class instead of namedtuple
|
66 |
+
class EfficientAttentionConfig:
|
67 |
+
def __init__(self, enable_flash, enable_math, enable_mem_efficient):
|
68 |
+
self.enable_flash = enable_flash
|
69 |
+
self.enable_math = enable_math
|
70 |
+
self.enable_mem_efficient = enable_mem_efficient
|
71 |
+
|
72 |
+
def _asdict(self):
|
73 |
+
return {
|
74 |
+
'enable_flash': self.enable_flash,
|
75 |
+
'enable_math': self.enable_math,
|
76 |
+
'enable_mem_efficient': self.enable_mem_efficient
|
77 |
+
}
|
78 |
+
|
79 |
+
self.cpu_config = EfficientAttentionConfig(True, True, True)
|
80 |
self.cuda_config = None
|
81 |
|
82 |
if not torch.cuda.is_available() or not use_flash:
|
|
|
93 |
print_once(
|
94 |
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
95 |
)
|
96 |
+
self.cuda_config = EfficientAttentionConfig(False, True, True)
|
97 |
|
98 |
def get_mask(self, n, device):
|
99 |
if exists(self.mask) and self.mask.shape[-1] >= n:
|