MoraxCheng commited on
Commit
95230fb
·
1 Parent(s): 5b1db8f

Enhance model loading process with improved error handling and alternative loading method; configure Hugging Face endpoint and disable offline mode

Browse files
Files changed (1) hide show
  1. app.py +37 -6
app.py CHANGED
@@ -9,6 +9,10 @@ import sys
9
  os.environ['HF_HOME'] = '/tmp/huggingface'
10
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface/transformers'
11
  os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
 
 
 
 
12
  import torch
13
  import transformers
14
  from transformers import PreTrainedTokenizerFast
@@ -103,22 +107,49 @@ def load_model_cached(model_type):
103
  cache_dir = "/tmp/huggingface/transformers"
104
  os.makedirs(cache_dir, exist_ok=True)
105
 
 
 
 
 
 
 
106
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
107
  pretrained_model_name_or_path=model_path,
108
  cache_dir=cache_dir,
109
  local_files_only=False, # Allow downloading if not cached
110
- resume_download=True # Resume incomplete downloads
 
 
 
 
111
  )
112
  MODEL_CACHE[model_type] = model
113
  print(f"{model_type} model loaded and cached")
114
  return model
115
  except Exception as e:
116
  print(f"Error loading {model_type} model: {e}")
117
- # Fallback to Medium if requested model fails
118
- if model_type != "Medium":
119
- print("Falling back to Medium model...")
120
- return load_model_cached("Medium")
121
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
124
  tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
 
9
  os.environ['HF_HOME'] = '/tmp/huggingface'
10
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface/transformers'
11
  os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
12
+ # Ensure proper Hugging Face endpoint
13
+ os.environ['HF_ENDPOINT'] = 'https://huggingface.co'
14
+ # Disable offline mode to allow downloads
15
+ os.environ['TRANSFORMERS_OFFLINE'] = '0'
16
  import torch
17
  import transformers
18
  from transformers import PreTrainedTokenizerFast
 
107
  cache_dir = "/tmp/huggingface/transformers"
108
  os.makedirs(cache_dir, exist_ok=True)
109
 
110
+ # Clear any potential proxy issues
111
+ import requests
112
+ session = requests.Session()
113
+ session.trust_env = False
114
+
115
+ # Try loading with explicit parameters
116
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
117
  pretrained_model_name_or_path=model_path,
118
  cache_dir=cache_dir,
119
  local_files_only=False, # Allow downloading if not cached
120
+ resume_download=True, # Resume incomplete downloads
121
+ force_download=False, # Don't force re-download if cached
122
+ proxies=None, # Explicitly set no proxies
123
+ use_auth_token=None, # No auth token needed for public models
124
+ revision="main" # Use main branch
125
  )
126
  MODEL_CACHE[model_type] = model
127
  print(f"{model_type} model loaded and cached")
128
  return model
129
  except Exception as e:
130
  print(f"Error loading {model_type} model: {e}")
131
+ print(f"Attempting alternative loading method...")
132
+
133
+ # Try alternative loading approach
134
+ try:
135
+ # Manually specify the full model ID
136
+ full_model_id = f"PascalNotin/Tranception_{model_type}"
137
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(
138
+ full_model_id,
139
+ cache_dir=cache_dir,
140
+ local_files_only=False,
141
+ trust_remote_code=True # Allow custom model code
142
+ )
143
+ MODEL_CACHE[model_type] = model
144
+ print(f"{model_type} model loaded successfully with alternative method")
145
+ return model
146
+ except Exception as e2:
147
+ print(f"Alternative loading also failed: {e2}")
148
+ # Fallback to Medium if requested model fails
149
+ if model_type != "Medium":
150
+ print("Falling back to Medium model...")
151
+ return load_model_cached("Medium")
152
+ raise
153
 
154
  AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
155
  tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",