DeepLearning101 commited on
Commit
4fc2b57
·
verified ·
1 Parent(s): 6114804

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -14,13 +14,26 @@ import torch
14
  from loguru import logger
15
  import nltk
16
 
 
 
17
  import torch.serialization
18
- torch.serialization.register_package("utils")
19
  _original_torch_load = torch.load
20
- def safe_torch_load(*args, **kwargs):
21
- kwargs.setdefault("weights_only", False)
22
- return _original_torch_load(*args, **kwargs)
23
- torch.load = safe_torch_load
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # 設定 HTTPS context 避免證書錯誤
26
  ssl._create_default_https_context = ssl._create_unverified_context
 
14
  from loguru import logger
15
  import nltk
16
 
17
+ # ✅ 加入 monkey patch 讓 torch.load 自動允許 utils.HParams
18
+ import builtins
19
  import torch.serialization
 
20
  _original_torch_load = torch.load
21
+
22
+ def patched_torch_load(*args, **kwargs):
23
+ import types
24
+
25
+ # 建立假的 utils 模組和 HParams 類別(只需要匹配名稱)
26
+ class DummyHParams:
27
+ pass
28
+
29
+ dummy_utils = types.ModuleType("utils")
30
+ dummy_utils.HParams = DummyHParams
31
+
32
+ # 用 safe_globals 臨時允許這個類別被 pickle 載入
33
+ with torch.serialization.safe_globals({"utils.HParams": DummyHParams}):
34
+ return _original_torch_load(*args, **kwargs)
35
+
36
+ torch.load = patched_torch_load
37
 
38
  # 設定 HTTPS context 避免證書錯誤
39
  ssl._create_default_https_context = ssl._create_unverified_context