MoraxCheng commited on
Commit
b55bd43
·
1 Parent(s): 95230fb

Add patch for transformers URL handling and enhance model loading with manual config download

Browse files
Files changed (1) hide show
  1. app.py +64 -22
app.py CHANGED
@@ -13,8 +13,30 @@ os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
13
  os.environ['HF_ENDPOINT'] = 'https://huggingface.co'
14
  # Disable offline mode to allow downloads
15
  os.environ['TRANSFORMERS_OFFLINE'] = '0'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import torch
17
  import transformers
 
18
  from transformers import PreTrainedTokenizerFast
19
  import numpy as np
20
  import pandas as pd
@@ -107,21 +129,10 @@ def load_model_cached(model_type):
107
  cache_dir = "/tmp/huggingface/transformers"
108
  os.makedirs(cache_dir, exist_ok=True)
109
 
110
- # Clear any potential proxy issues
111
- import requests
112
- session = requests.Session()
113
- session.trust_env = False
114
-
115
- # Try loading with explicit parameters
116
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
117
- pretrained_model_name_or_path=model_path,
118
- cache_dir=cache_dir,
119
- local_files_only=False, # Allow downloading if not cached
120
- resume_download=True, # Resume incomplete downloads
121
- force_download=False, # Don't force re-download if cached
122
- proxies=None, # Explicitly set no proxies
123
- use_auth_token=None, # No auth token needed for public models
124
- revision="main" # Use main branch
125
  )
126
  MODEL_CACHE[model_type] = model
127
  print(f"{model_type} model loaded and cached")
@@ -130,21 +141,52 @@ def load_model_cached(model_type):
130
  print(f"Error loading {model_type} model: {e}")
131
  print(f"Attempting alternative loading method...")
132
 
133
- # Try alternative loading approach
134
  try:
135
- # Manually specify the full model ID
136
- full_model_id = f"PascalNotin/Tranception_{model_type}"
137
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
138
- full_model_id,
139
- cache_dir=cache_dir,
140
- local_files_only=False,
141
- trust_remote_code=True # Allow custom model code
142
  )
143
  MODEL_CACHE[model_type] = model
144
- print(f"{model_type} model loaded successfully with alternative method")
145
  return model
146
  except Exception as e2:
147
  print(f"Alternative loading also failed: {e2}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # Fallback to Medium if requested model fails
149
  if model_type != "Medium":
150
  print("Falling back to Medium model...")
 
13
  os.environ['HF_ENDPOINT'] = 'https://huggingface.co'
14
  # Disable offline mode to allow downloads
15
  os.environ['TRANSFORMERS_OFFLINE'] = '0'
16
+
17
+ # Patch for transformers 4.17.0 URL issue in HF Spaces
18
+ import urllib.parse
19
+
20
+ def patch_transformers_url():
21
+ """Fix URL scheme issue in transformers 4.17.0"""
22
+ try:
23
+ import transformers.file_utils
24
+ original_get_from_cache = transformers.file_utils.get_from_cache
25
+
26
+ def patched_get_from_cache(url, *args, **kwargs):
27
+ # Fix URLs that start with /api/ by prepending https://huggingface.co
28
+ if isinstance(url, str) and url.startswith('/api/'):
29
+ url = 'https://huggingface.co' + url
30
+ return original_get_from_cache(url, *args, **kwargs)
31
+
32
+ transformers.file_utils.get_from_cache = patched_get_from_cache
33
+ print("Applied URL patch for transformers")
34
+ except Exception as e:
35
+ print(f"Warning: Could not patch transformers URL handling: {e}")
36
+
37
  import torch
38
  import transformers
39
+ patch_transformers_url()
40
  from transformers import PreTrainedTokenizerFast
41
  import numpy as np
42
  import pandas as pd
 
129
  cache_dir = "/tmp/huggingface/transformers"
130
  os.makedirs(cache_dir, exist_ok=True)
131
 
132
+ # Try loading with minimal parameters first
 
 
 
 
 
133
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
134
+ model_path,
135
+ cache_dir=cache_dir
 
 
 
 
 
 
136
  )
137
  MODEL_CACHE[model_type] = model
138
  print(f"{model_type} model loaded and cached")
 
141
  print(f"Error loading {model_type} model: {e}")
142
  print(f"Attempting alternative loading method...")
143
 
144
+ # Try alternative loading approach with full URL
145
  try:
146
+ # Use full URL to bypass any path resolution issues
147
+ full_url = f"https://huggingface.co/PascalNotin/Tranception_{model_type}"
148
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
149
+ full_url,
150
+ cache_dir=cache_dir
 
 
151
  )
152
  MODEL_CACHE[model_type] = model
153
+ print(f"{model_type} model loaded successfully with full URL")
154
  return model
155
  except Exception as e2:
156
  print(f"Alternative loading also failed: {e2}")
157
+
158
+ # Final attempt: manually download config first
159
+ try:
160
+ import json
161
+ import requests
162
+
163
+ # Download config.json manually
164
+ config_url = f"https://huggingface.co/PascalNotin/Tranception_{model_type}/raw/main/config.json"
165
+ print(f"Manually downloading config from: {config_url}")
166
+
167
+ response = requests.get(config_url)
168
+ if response.status_code == 200:
169
+ # Save config locally
170
+ local_model_dir = f"/tmp/Tranception_{model_type}"
171
+ os.makedirs(local_model_dir, exist_ok=True)
172
+
173
+ with open(f"{local_model_dir}/config.json", "w") as f:
174
+ json.dump(response.json(), f)
175
+
176
+ # Now try loading from the HF model ID again
177
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
178
+ f"PascalNotin/Tranception_{model_type}",
179
+ cache_dir=cache_dir,
180
+ local_files_only=False
181
+ )
182
+ MODEL_CACHE[model_type] = model
183
+ print(f"{model_type} model loaded successfully after manual config download")
184
+ return model
185
+ else:
186
+ print(f"Failed to download config: {response.status_code}")
187
+ except Exception as e3:
188
+ print(f"Manual download also failed: {e3}")
189
+
190
  # Fallback to Medium if requested model fails
191
  if model_type != "Medium":
192
  print("Falling back to Medium model...")