cyberandy commited on
Commit
ad1e1d2
·
verified ·
1 Parent(s): 25d948c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -34
app.py CHANGED
@@ -75,46 +75,75 @@ def load_model(selected_language, model_name=None, entity_set=None):
75
  # Suppress warnings during model loading
76
  with warnings.catch_warnings():
77
  warnings.simplefilter("ignore")
78
-
79
  try:
80
- if selected_language == "German" or selected_language == "English - spaCy":
81
- # ... (your existing spaCy loading logic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  else:
83
  try:
84
- # Attempt to load the pretrained model directly
85
- refined_model = Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
86
- return refined_model
87
- except AttributeError as e:
 
88
  if "add_special_tokens" in str(e):
89
- st.warning("Encountered 'add_special_tokens' conflict. Attempting to fix by modifying tokenizer config...")
90
- # Define a local directory to save the model
91
- local_model_dir = f"./{model_name}_{entity_set}"
92
-
93
- # Download and save the tokenizer, then modify its config
 
94
  tokenizer = AutoTokenizer.from_pretrained(model_name)
95
- tokenizer.save_pretrained(local_model_dir)
96
-
97
- # Load the tokenizer_config.json and remove the conflicting key
98
- tokenizer_config_path = f"{local_model_dir}/tokenizer_config.json"
99
- with open(tokenizer_config_path, 'r') as f:
100
- config = json.load(f)
101
-
102
- if "add_special_tokens" in config:
103
- del config["add_special_tokens"]
104
-
105
- with open(tokenizer_config_path, 'w') as f:
106
- json.dump(config, f, indent=2)
107
-
108
- # Download and save the model
109
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
110
- model.save_pretrained(local_model_dir)
111
-
112
- # Load the model from the modified local directory
113
- refined_model = Refined.from_pretrained(model_name=local_model_dir, entity_set=entity_set)
114
- st.success("Successfully loaded model after applying fix.")
115
- return refined_model
116
  else:
117
- raise e # Re-raise other AttributeError exceptions
 
 
118
  except Exception as e:
119
  st.error(f"Error loading model: {e}")
120
  return None
 
75
  # Suppress warnings during model loading
76
  with warnings.catch_warnings():
77
  warnings.simplefilter("ignore")
78
+
79
  try:
80
+ # This block handles the spaCy models for German and English
81
+ if selected_language == "German":
82
+ try:
83
+ nlp_model_de = spacy.load("de_core_news_lg")
84
+ except OSError:
85
+ st.info("Downloading German language model... This may take a moment.")
86
+ spacy.cli.download("de_core_news_lg")
87
+ nlp_model_de = spacy.load("de_core_news_lg")
88
+
89
+ if "entityfishing" not in nlp_model_de.pipe_names:
90
+ try:
91
+ nlp_model_de.add_pipe("entityfishing")
92
+ except Exception as e:
93
+ st.warning(f"Entity-fishing not available, using basic NER only: {e}")
94
+ return nlp_model_de
95
+
96
+ elif selected_language == "English - spaCy":
97
+ try:
98
+ nlp_model_en = spacy.load("en_core_web_sm")
99
+ except OSError:
100
+ st.info("Downloading English language model... This may take a moment.")
101
+ spacy.cli.download("en_core_web_sm")
102
+ nlp_model_en = spacy.load("en_core_web_sm")
103
+
104
+ if "entityfishing" not in nlp_model_en.pipe_names:
105
+ try:
106
+ nlp_model_en.add_pipe("entityfishing")
107
+ except Exception as e:
108
+ st.warning(f"Entity-fishing not available, using basic NER only: {e}")
109
+ return nlp_model_en
110
+
111
+ # This block handles the ReFinED model and the "add_special_tokens" error
112
  else:
113
  try:
114
+ # First, attempt to load the model as usual
115
+ return Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
116
+
117
+ except Exception as e:
118
+ # If the specific "add_special_tokens" error occurs, apply the fix
119
  if "add_special_tokens" in str(e):
120
+ st.warning("Conflict detected. Applying fix by modifying tokenizer config...")
121
+
122
+ # Define a local path to save/load the fixed model
123
+ local_model_path = f"./{model_name}-{entity_set}-fixed"
124
+
125
+ # Download tokenizer, modify config, and save locally
126
  tokenizer = AutoTokenizer.from_pretrained(model_name)
127
+ tokenizer.save_pretrained(local_model_path)
128
+
129
+ config_path = os.path.join(local_model_path, "tokenizer_config.json")
130
+ with open(config_path, "r") as f:
131
+ config_data = json.load(f)
132
+
133
+ # Remove the conflicting parameter
134
+ config_data.pop("add_special_tokens", None)
135
+
136
+ with open(config_path, "w") as f:
137
+ json.dump(config_data, f, indent=2)
138
+
139
+ # Now, load the model from the local, fixed path
140
+ st.success("Fix applied. Loading model from local cache.")
141
+ return Refined.from_pretrained(model_name=local_model_path, entity_set=entity_set)
142
+
 
 
 
 
 
143
  else:
144
+ # If it's a different error, raise it
145
+ raise e
146
+
147
  except Exception as e:
148
  st.error(f"Error loading model: {e}")
149
  return None