DeepLearning101 commited on
Commit
19940d0
·
verified ·
1 Parent(s): e781acd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -13
app.py CHANGED
@@ -14,26 +14,19 @@ import torch
14
  from loguru import logger
15
  import nltk
16
 
17
- # ✅ 修正後的 torch.load monkey patch,允許 utils.HParams 被載入
18
- import types
19
  import torch.serialization
 
20
 
21
- _original_torch_load = torch.load
22
-
23
- # 模擬 parrots 模型裡用到的 utils.HParams 類別
24
  class DummyHParams:
25
  pass
26
 
27
- # 建立假的 utils 模組,讓 torch 找得到 utils.HParams
28
- dummy_utils = types.ModuleType("utils")
29
- dummy_utils.HParams = DummyHParams
30
-
31
- import sys
32
- sys.modules["utils"] = dummy_utils # <-- 🔑 核心!告訴 Python:有一個叫 utils 的模組
33
 
 
 
34
  def patched_torch_load(*args, **kwargs):
35
- with torch.serialization.safe_globals({"utils.HParams": DummyHParams}):
36
- return _original_torch_load(*args, **kwargs)
37
 
38
  torch.load = patched_torch_load
39
 
 
14
  from loguru import logger
15
  import nltk
16
 
 
 
17
  import torch.serialization
18
+ from torch.serialization import add_safe_globals
19
 
 
 
 
20
  class DummyHParams:
21
  pass
22
 
23
+ # 加入允許的類別進入 PyTorch 的安全清單中
24
+ add_safe_globals([DummyHParams])
 
 
 
 
25
 
26
+ # ✅ 繼續 monkey patch torch.load
27
+ _original_torch_load = torch.load
28
  def patched_torch_load(*args, **kwargs):
29
+ return _original_torch_load(*args, **kwargs)
 
30
 
31
  torch.load = patched_torch_load
32