Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
b55bd43
1
Parent(s):
95230fb
Add patch for transformers URL handling and enhance model loading with manual config download
Browse files
app.py
CHANGED
@@ -13,8 +13,30 @@ os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
|
|
13 |
os.environ['HF_ENDPOINT'] = 'https://huggingface.co'
|
14 |
# Disable offline mode to allow downloads
|
15 |
os.environ['TRANSFORMERS_OFFLINE'] = '0'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
import torch
|
17 |
import transformers
|
|
|
18 |
from transformers import PreTrainedTokenizerFast
|
19 |
import numpy as np
|
20 |
import pandas as pd
|
@@ -107,21 +129,10 @@ def load_model_cached(model_type):
|
|
107 |
cache_dir = "/tmp/huggingface/transformers"
|
108 |
os.makedirs(cache_dir, exist_ok=True)
|
109 |
|
110 |
-
#
|
111 |
-
import requests
|
112 |
-
session = requests.Session()
|
113 |
-
session.trust_env = False
|
114 |
-
|
115 |
-
# Try loading with explicit parameters
|
116 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
|
117 |
-
|
118 |
-
cache_dir=cache_dir
|
119 |
-
local_files_only=False, # Allow downloading if not cached
|
120 |
-
resume_download=True, # Resume incomplete downloads
|
121 |
-
force_download=False, # Don't force re-download if cached
|
122 |
-
proxies=None, # Explicitly set no proxies
|
123 |
-
use_auth_token=None, # No auth token needed for public models
|
124 |
-
revision="main" # Use main branch
|
125 |
)
|
126 |
MODEL_CACHE[model_type] = model
|
127 |
print(f"{model_type} model loaded and cached")
|
@@ -130,21 +141,52 @@ def load_model_cached(model_type):
|
|
130 |
print(f"Error loading {model_type} model: {e}")
|
131 |
print(f"Attempting alternative loading method...")
|
132 |
|
133 |
-
# Try alternative loading approach
|
134 |
try:
|
135 |
-
#
|
136 |
-
|
137 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
|
138 |
-
|
139 |
-
cache_dir=cache_dir
|
140 |
-
local_files_only=False,
|
141 |
-
trust_remote_code=True # Allow custom model code
|
142 |
)
|
143 |
MODEL_CACHE[model_type] = model
|
144 |
-
print(f"{model_type} model loaded successfully with
|
145 |
return model
|
146 |
except Exception as e2:
|
147 |
print(f"Alternative loading also failed: {e2}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
# Fallback to Medium if requested model fails
|
149 |
if model_type != "Medium":
|
150 |
print("Falling back to Medium model...")
|
|
|
13 |
os.environ['HF_ENDPOINT'] = 'https://huggingface.co'
|
14 |
# Disable offline mode to allow downloads
|
15 |
os.environ['TRANSFORMERS_OFFLINE'] = '0'
|
16 |
+
|
17 |
+
# Patch for transformers 4.17.0 URL issue in HF Spaces
|
18 |
+
import urllib.parse
|
19 |
+
|
20 |
+
def patch_transformers_url():
|
21 |
+
"""Fix URL scheme issue in transformers 4.17.0"""
|
22 |
+
try:
|
23 |
+
import transformers.file_utils
|
24 |
+
original_get_from_cache = transformers.file_utils.get_from_cache
|
25 |
+
|
26 |
+
def patched_get_from_cache(url, *args, **kwargs):
|
27 |
+
# Fix URLs that start with /api/ by prepending https://huggingface.co
|
28 |
+
if isinstance(url, str) and url.startswith('/api/'):
|
29 |
+
url = 'https://huggingface.co' + url
|
30 |
+
return original_get_from_cache(url, *args, **kwargs)
|
31 |
+
|
32 |
+
transformers.file_utils.get_from_cache = patched_get_from_cache
|
33 |
+
print("Applied URL patch for transformers")
|
34 |
+
except Exception as e:
|
35 |
+
print(f"Warning: Could not patch transformers URL handling: {e}")
|
36 |
+
|
37 |
import torch
|
38 |
import transformers
|
39 |
+
patch_transformers_url()
|
40 |
from transformers import PreTrainedTokenizerFast
|
41 |
import numpy as np
|
42 |
import pandas as pd
|
|
|
129 |
cache_dir = "/tmp/huggingface/transformers"
|
130 |
os.makedirs(cache_dir, exist_ok=True)
|
131 |
|
132 |
+
# Try loading with minimal parameters first
|
|
|
|
|
|
|
|
|
|
|
133 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
|
134 |
+
model_path,
|
135 |
+
cache_dir=cache_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
)
|
137 |
MODEL_CACHE[model_type] = model
|
138 |
print(f"{model_type} model loaded and cached")
|
|
|
141 |
print(f"Error loading {model_type} model: {e}")
|
142 |
print(f"Attempting alternative loading method...")
|
143 |
|
144 |
+
# Try alternative loading approach with full URL
|
145 |
try:
|
146 |
+
# Use full URL to bypass any path resolution issues
|
147 |
+
full_url = f"https://huggingface.co/PascalNotin/Tranception_{model_type}"
|
148 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
|
149 |
+
full_url,
|
150 |
+
cache_dir=cache_dir
|
|
|
|
|
151 |
)
|
152 |
MODEL_CACHE[model_type] = model
|
153 |
+
print(f"{model_type} model loaded successfully with full URL")
|
154 |
return model
|
155 |
except Exception as e2:
|
156 |
print(f"Alternative loading also failed: {e2}")
|
157 |
+
|
158 |
+
# Final attempt: manually download config first
|
159 |
+
try:
|
160 |
+
import json
|
161 |
+
import requests
|
162 |
+
|
163 |
+
# Download config.json manually
|
164 |
+
config_url = f"https://huggingface.co/PascalNotin/Tranception_{model_type}/raw/main/config.json"
|
165 |
+
print(f"Manually downloading config from: {config_url}")
|
166 |
+
|
167 |
+
response = requests.get(config_url)
|
168 |
+
if response.status_code == 200:
|
169 |
+
# Save config locally
|
170 |
+
local_model_dir = f"/tmp/Tranception_{model_type}"
|
171 |
+
os.makedirs(local_model_dir, exist_ok=True)
|
172 |
+
|
173 |
+
with open(f"{local_model_dir}/config.json", "w") as f:
|
174 |
+
json.dump(response.json(), f)
|
175 |
+
|
176 |
+
# Now try loading from the HF model ID again
|
177 |
+
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
|
178 |
+
f"PascalNotin/Tranception_{model_type}",
|
179 |
+
cache_dir=cache_dir,
|
180 |
+
local_files_only=False
|
181 |
+
)
|
182 |
+
MODEL_CACHE[model_type] = model
|
183 |
+
print(f"{model_type} model loaded successfully after manual config download")
|
184 |
+
return model
|
185 |
+
else:
|
186 |
+
print(f"Failed to download config: {response.status_code}")
|
187 |
+
except Exception as e3:
|
188 |
+
print(f"Manual download also failed: {e3}")
|
189 |
+
|
190 |
# Fallback to Medium if requested model fails
|
191 |
if model_type != "Medium":
|
192 |
print("Falling back to Medium model...")
|