DeepLearning101 commited on
Commit
e781acd
·
verified ·
1 Parent(s): 4fc2b57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -14,22 +14,24 @@ import torch
14
  from loguru import logger
15
  import nltk
16
 
17
- # ✅ 加入 monkey patch 讓 torch.load 自動允許 utils.HParams
18
- import builtins
19
  import torch.serialization
 
20
  _original_torch_load = torch.load
21
 
22
- def patched_torch_load(*args, **kwargs):
23
- import types
 
24
 
25
- # 建立假的 utils 模組和 HParams 類別(只需要匹配名稱)
26
- class DummyHParams:
27
- pass
28
 
29
- dummy_utils = types.ModuleType("utils")
30
- dummy_utils.HParams = DummyHParams
31
 
32
- # safe_globals 臨時允許這個類別被 pickle 載入
33
  with torch.serialization.safe_globals({"utils.HParams": DummyHParams}):
34
  return _original_torch_load(*args, **kwargs)
35
 
 
14
  from loguru import logger
15
  import nltk
16
 
17
+ # ✅ 修正後的 torch.load monkey patch,允許 utils.HParams 被載入
18
+ import types
19
  import torch.serialization
20
+
21
  _original_torch_load = torch.load
22
 
23
+ # 模擬 parrots 模型裡用到的 utils.HParams 類別
24
+ class DummyHParams:
25
+ pass
26
 
27
+ # 建立假的 utils 模組,讓 torch 找得到 utils.HParams
28
+ dummy_utils = types.ModuleType("utils")
29
+ dummy_utils.HParams = DummyHParams
30
 
31
+ import sys
32
+ sys.modules["utils"] = dummy_utils # <-- 🔑 核心!告訴 Python:有一個叫 utils 的模組
33
 
34
+ def patched_torch_load(*args, **kwargs):
35
  with torch.serialization.safe_globals({"utils.HParams": DummyHParams}):
36
  return _original_torch_load(*args, **kwargs)
37