thunnai commited on
Commit
dfcd575
·
1 Parent(s): 3305254

replace the namedtupple for an object to avoid pickling issues

Browse files
sparktts/modules/speaker/perceiver_encoder.py CHANGED
@@ -15,7 +15,6 @@
15
 
16
  # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
17
 
18
- from collections import namedtuple
19
  from functools import wraps
20
 
21
  import torch
@@ -63,12 +62,21 @@ class Attend(nn.Module):
63
  use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
64
  ), "in order to use flash attention, you must be using pytorch 2.0 or above"
65
 
66
- # determine efficient attention configs for cuda and cpu
67
- self.config = namedtuple(
68
- "EfficientAttentionConfig",
69
- ["enable_flash", "enable_math", "enable_mem_efficient"],
70
- )
71
- self.cpu_config = self.config(True, True, True)
 
 
 
 
 
 
 
 
 
72
  self.cuda_config = None
73
 
74
  if not torch.cuda.is_available() or not use_flash:
@@ -85,7 +93,7 @@ class Attend(nn.Module):
85
  print_once(
86
  "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
87
  )
88
- self.cuda_config = self.config(False, True, True)
89
 
90
  def get_mask(self, n, device):
91
  if exists(self.mask) and self.mask.shape[-1] >= n:
 
15
 
16
  # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
17
 
 
18
  from functools import wraps
19
 
20
  import torch
 
62
  use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
63
  ), "in order to use flash attention, you must be using pytorch 2.0 or above"
64
 
65
+ # Define config as a regular class instead of namedtuple
66
+ class EfficientAttentionConfig:
67
+ def __init__(self, enable_flash, enable_math, enable_mem_efficient):
68
+ self.enable_flash = enable_flash
69
+ self.enable_math = enable_math
70
+ self.enable_mem_efficient = enable_mem_efficient
71
+
72
+ def _asdict(self):
73
+ return {
74
+ 'enable_flash': self.enable_flash,
75
+ 'enable_math': self.enable_math,
76
+ 'enable_mem_efficient': self.enable_mem_efficient
77
+ }
78
+
79
+ self.cpu_config = EfficientAttentionConfig(True, True, True)
80
  self.cuda_config = None
81
 
82
  if not torch.cuda.is_available() or not use_flash:
 
93
  print_once(
94
  "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
95
  )
96
+ self.cuda_config = EfficientAttentionConfig(False, True, True)
97
 
98
  def get_mask(self, n, device):
99
  if exists(self.mask) and self.mask.shape[-1] >= n: