wsj1995 commited on
Commit
76b64a1
·
1 Parent(s): 31a4b67

feat: spaces

Browse files
Files changed (1) hide show
  1. GPT_SoVITS/inference_webui.py +29 -4
GPT_SoVITS/inference_webui.py CHANGED
@@ -34,11 +34,39 @@ import re
34
  import sys
35
  import traceback
36
  import warnings
 
37
 
38
  import torch
39
  import torchaudio
40
  from text.LangSegmenter import LangSegmenter
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  logging.getLogger("markdown_it").setLevel(logging.ERROR)
43
  logging.getLogger("urllib3").setLevel(logging.ERROR)
44
  logging.getLogger("httpcore").setLevel(logging.ERROR)
@@ -123,10 +151,7 @@ i18n = I18nAuto(language=language)
123
 
124
  # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
125
 
126
- if torch.cuda.is_available():
127
- device = "cuda"
128
- else:
129
- device = "cpu"
130
 
131
  dict_language_v1 = {
132
  i18n("中文"): "all_zh", # 全部按中文识别
 
34
  import sys
35
  import traceback
36
  import warnings
37
+ import spaces
38
 
39
  import torch
40
  import torchaudio
41
  from text.LangSegmenter import LangSegmenter
42
 
43
+ # 保存原始构造器
44
+ original_storage_new = torch.UntypedStorage.__new__
45
+
46
+
47
+ def _untyped_storage_new_register(cls, *args, **kwargs):
48
+ cuda = False
49
+ device = kwargs.get('device')
50
+
51
+ # 先判断类型是否为 torch.device 再访问 type 属性
52
+ if isinstance(device, torch.device) and device.type == 'cuda':
53
+ cuda = True
54
+ del kwargs['device']
55
+
56
+ # 正确调用 __new__
57
+ storage = torch._C.StorageBase.__new__(cls, *args, **kwargs)
58
+
59
+ # 标记是否是 ZeroGPU 模式
60
+ if cuda:
61
+ storage._zerogpu = True
62
+
63
+ return storage
64
+
65
+
66
+ # 替换 __new__ 方法
67
+ torch.UntypedStorage.__new__ = _untyped_storage_new_register
68
+
69
+
70
  logging.getLogger("markdown_it").setLevel(logging.ERROR)
71
  logging.getLogger("urllib3").setLevel(logging.ERROR)
72
  logging.getLogger("httpcore").setLevel(logging.ERROR)
 
151
 
152
  # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
153
 
154
+ device = "cuda"
 
 
 
155
 
156
  dict_language_v1 = {
157
  i18n("中文"): "all_zh", # 全部按中文识别