MoraxCheng commited on
Commit
3c96d15
·
1 Parent(s): 634f0ae

Enhance URL handling in transformers with comprehensive validation and retry mechanism; implement cache file validation and cleaning process

Browse files
Files changed (1) hide show
  1. app.py +108 -22
app.py CHANGED
@@ -16,22 +16,59 @@ os.environ['TRANSFORMERS_OFFLINE'] = '0'
16
 
17
  # Patch for transformers 4.17.0 URL issue in HF Spaces
18
  import urllib.parse
 
 
19
 
20
  def patch_transformers_url():
21
- """Fix URL scheme issue in transformers 4.17.0"""
22
  try:
23
  import transformers.file_utils
24
  original_get_from_cache = transformers.file_utils.get_from_cache
25
 
26
  def patched_get_from_cache(url, *args, **kwargs):
27
- # More robust URL fixing
28
- if isinstance(url, str) and url.startswith('/api/'):
29
- # Use urljoin for safer URL construction
30
- url = urllib.parse.urljoin('https://huggingface.co', url)
31
- return original_get_from_cache(url, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  transformers.file_utils.get_from_cache = patched_get_from_cache
34
- print("Applied URL patch for transformers")
35
  except Exception as e:
36
  print(f"Warning: Could not patch transformers URL handling: {e}")
37
 
@@ -106,6 +143,56 @@ from tranception import config, model_pytorch
106
  # Model loading configuration
107
  MODEL_CACHE = {}
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def get_model_path(model_name):
110
  """Get model path - always use HF Hub for Zero GPU spaces"""
111
  # In HF Spaces, models are cached automatically by the transformers library
@@ -187,26 +274,20 @@ def load_model_cached(model_type):
187
  model_path = get_model_path(model_name)
188
 
189
  try:
190
- # Clear any corrupted cache files
191
  import shutil
192
  cache_dir = "/tmp/huggingface/transformers"
193
- if os.path.exists(cache_dir):
194
- # Remove corrupted tranception cache files
195
- for file in os.listdir(cache_dir):
196
- if "tranception" in file.lower():
197
- try:
198
- filepath = os.path.join(cache_dir, file)
199
- if os.path.isfile(filepath) and os.path.getsize(filepath) < 1000:
200
- os.remove(filepath)
201
- print(f"Removed corrupted cache file: {file}")
202
- except:
203
- pass
204
-
205
  os.makedirs(cache_dir, exist_ok=True)
206
 
207
- # Try loading with force_download to avoid corrupted cache
208
- # Use HF_ENDPOINT environment variable to ensure proper URL
 
 
 
209
  os.environ["HF_ENDPOINT"] = "https://huggingface.co"
 
 
 
210
 
211
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
212
  model_path,
@@ -220,6 +301,11 @@ def load_model_cached(model_type):
220
  return model
221
  except Exception as e:
222
  print(f"Error loading {model_type} model: {e}")
 
 
 
 
 
223
  print(f"Attempting alternative loading method...")
224
 
225
  # Try alternative loading approach with full URL
 
16
 
17
  # Patch for transformers 4.17.0 URL issue in HF Spaces
18
  import urllib.parse
19
+ import json
20
+ import time
21
 
22
  def patch_transformers_url():
23
+ """Fix URL scheme issue in transformers 4.17.0 with comprehensive URL handling"""
24
  try:
25
  import transformers.file_utils
26
  original_get_from_cache = transformers.file_utils.get_from_cache
27
 
28
  def patched_get_from_cache(url, *args, **kwargs):
29
+ # Comprehensive URL fixing for various formats
30
+ if isinstance(url, str):
31
+ # Handle different types of malformed URLs
32
+ if url.startswith('/api/'):
33
+ # Fix relative API URLs - ensure proper base URL
34
+ url = 'https://huggingface.co' + url
35
+ elif url.startswith('//'):
36
+ # Fix protocol-relative URLs
37
+ url = 'https:' + url
38
+ elif not url.startswith(('http://', 'https://')):
39
+ # Handle other relative paths
40
+ if url.startswith('/'):
41
+ url = 'https://huggingface.co' + url
42
+ else:
43
+ url = 'https://huggingface.co/' + url
44
+
45
+ # Additional validation and normalization
46
+ try:
47
+ parsed = urllib.parse.urlparse(url)
48
+ if not parsed.netloc:
49
+ # If no netloc found, construct proper URL
50
+ url = 'https://huggingface.co' + ('/' + url if not url.startswith('/') else url)
51
+ except Exception:
52
+ # Fallback for URL parsing errors
53
+ if not url.startswith('https://'):
54
+ url = 'https://huggingface.co' + ('/' + url if not url.startswith('/') else url)
55
+
56
+ # Add retry mechanism for network requests
57
+ max_retries = 3
58
+ for attempt in range(max_retries):
59
+ try:
60
+ return original_get_from_cache(url, *args, **kwargs)
61
+ except Exception as e:
62
+ if attempt < max_retries - 1:
63
+ print(f"Download attempt {attempt + 1} failed for {url}: {e}. Retrying...")
64
+ time.sleep(2 ** attempt) # Exponential backoff
65
+ continue
66
+ else:
67
+ print(f"All download attempts failed for {url}: {e}")
68
+ raise
69
 
70
  transformers.file_utils.get_from_cache = patched_get_from_cache
71
+ print("Applied enhanced URL patch for transformers")
72
  except Exception as e:
73
  print(f"Warning: Could not patch transformers URL handling: {e}")
74
 
 
143
  # Model loading configuration
144
  MODEL_CACHE = {}
145
 
146
+ def validate_cache_file(file_path, min_size=1000):
147
+ """Validate cache file integrity and content"""
148
+ if not os.path.exists(file_path):
149
+ return False, "File does not exist"
150
+
151
+ # Check file size
152
+ try:
153
+ file_size = os.path.getsize(file_path)
154
+ if file_size < min_size:
155
+ return False, f"File too small ({file_size} bytes < {min_size})"
156
+ except Exception as e:
157
+ return False, f"Cannot get file size: {e}"
158
+
159
+ # Check if it's supposed to be a JSON file (config files)
160
+ if file_path.endswith('.json') or 'config' in file_path.lower():
161
+ try:
162
+ with open(file_path, 'r', encoding='utf-8') as f:
163
+ content = f.read().strip()
164
+ if not content:
165
+ return False, "Empty JSON file"
166
+ json.loads(content) # Validate JSON syntax
167
+ return True, "Valid JSON file"
168
+ except json.JSONDecodeError:
169
+ return False, "Invalid JSON content"
170
+ except Exception as e:
171
+ return False, f"Cannot read JSON file: {e}"
172
+
173
+ return True, "File appears valid"
174
+
175
+ def clean_corrupted_cache_files(cache_dir):
176
+ """Clean corrupted or invalid cache files"""
177
+ if not os.path.exists(cache_dir):
178
+ return
179
+
180
+ cleaned_count = 0
181
+ for file in os.listdir(cache_dir):
182
+ filepath = os.path.join(cache_dir, file)
183
+ if os.path.isfile(filepath):
184
+ valid, reason = validate_cache_file(filepath)
185
+ if not valid:
186
+ try:
187
+ os.remove(filepath)
188
+ print(f"Removed corrupted cache file: {file} ({reason})")
189
+ cleaned_count += 1
190
+ except Exception as e:
191
+ print(f"Could not remove {file}: {e}")
192
+
193
+ if cleaned_count > 0:
194
+ print(f"Cleaned {cleaned_count} corrupted cache files")
195
+
196
  def get_model_path(model_name):
197
  """Get model path - always use HF Hub for Zero GPU spaces"""
198
  # In HF Spaces, models are cached automatically by the transformers library
 
274
  model_path = get_model_path(model_name)
275
 
276
  try:
277
+ # Enhanced cache cleaning with validation
278
  import shutil
279
  cache_dir = "/tmp/huggingface/transformers"
 
 
 
 
 
 
 
 
 
 
 
 
280
  os.makedirs(cache_dir, exist_ok=True)
281
 
282
+ # Clean corrupted cache files using the new validation system
283
+ print("Validating and cleaning cache files...")
284
+ clean_corrupted_cache_files(cache_dir)
285
+
286
+ # Enhanced environment setup for robust model loading
287
  os.environ["HF_ENDPOINT"] = "https://huggingface.co"
288
+ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
289
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
290
+ os.environ["HF_HUB_DISABLE_EXPERIMENTAL_WARNING"] = "1"
291
 
292
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
293
  model_path,
 
301
  return model
302
  except Exception as e:
303
  print(f"Error loading {model_type} model: {e}")
304
+ print(f"Error type: {type(e).__name__}")
305
+ if hasattr(e, '__cause__') and e.__cause__:
306
+ print(f"Root cause: {e.__cause__}")
307
+ print(f"Model path used: {model_path}")
308
+ print(f"Cache directory: {cache_dir}")
309
  print(f"Attempting alternative loading method...")
310
 
311
  # Try alternative loading approach with full URL