cyberandy commited on
Commit
25d948c
·
verified ·
1 Parent(s): fdaf3eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -44
app.py CHANGED
@@ -75,52 +75,46 @@ def load_model(selected_language, model_name=None, entity_set=None):
75
  # Suppress warnings during model loading
76
  with warnings.catch_warnings():
77
  warnings.simplefilter("ignore")
78
-
79
  try:
80
- if selected_language == "German":
81
- # Download and load the German-specific model
82
- try:
83
- nlp_model_de = spacy.load("de_core_news_lg")
84
- except OSError:
85
- st.info("Downloading German language model... This may take a moment.")
86
- spacy.cli.download("de_core_news_lg")
87
- nlp_model_de = spacy.load("de_core_news_lg")
88
-
89
- # Check if entityfishing component is available
90
- if "entityfishing" not in nlp_model_de.pipe_names:
91
- try:
92
- nlp_model_de.add_pipe("entityfishing")
93
- except Exception as e:
94
- st.warning(f"Entity-fishing not available, using basic NER only: {e}")
95
- # Return model without entityfishing for basic NER
96
- return nlp_model_de
97
-
98
- return nlp_model_de
99
-
100
- elif selected_language == "English - spaCy":
101
- # Download and load English-specific model
102
- try:
103
- nlp_model_en = spacy.load("en_core_web_sm")
104
- except OSError:
105
- st.info("Downloading English language model... This may take a moment.")
106
- spacy.cli.download("en_core_web_sm")
107
- nlp_model_en = spacy.load("en_core_web_sm")
108
-
109
- # Check if entityfishing component is available
110
- if "entityfishing" not in nlp_model_en.pipe_names:
111
- try:
112
- nlp_model_en.add_pipe("entityfishing")
113
- except Exception as e:
114
- st.warning(f"Entity-fishing not available, using basic NER only: {e}")
115
- # Return model without entityfishing for basic NER
116
- return nlp_model_en
117
-
118
- return nlp_model_en
119
  else:
120
- # Load the pretrained model for other languages
121
- refined_model = Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
122
- return refined_model
123
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  except Exception as e:
125
  st.error(f"Error loading model: {e}")
126
  return None
 
75
  # Suppress warnings during model loading
76
  with warnings.catch_warnings():
77
  warnings.simplefilter("ignore")
78
+
79
  try:
80
+ if selected_language == "German" or selected_language == "English - spaCy":
81
+ # ... (your existing spaCy loading logic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  else:
83
+ try:
84
+ # Attempt to load the pretrained model directly
85
+ refined_model = Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
86
+ return refined_model
87
+ except AttributeError as e:
88
+ if "add_special_tokens" in str(e):
89
+ st.warning("Encountered 'add_special_tokens' conflict. Attempting to fix by modifying tokenizer config...")
90
+ # Define a local directory to save the model
91
+ local_model_dir = f"./{model_name}_{entity_set}"
92
+
93
+ # Download and save the tokenizer, then modify its config
94
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
95
+ tokenizer.save_pretrained(local_model_dir)
96
+
97
+ # Load the tokenizer_config.json and remove the conflicting key
98
+ tokenizer_config_path = f"{local_model_dir}/tokenizer_config.json"
99
+ with open(tokenizer_config_path, 'r') as f:
100
+ config = json.load(f)
101
+
102
+ if "add_special_tokens" in config:
103
+ del config["add_special_tokens"]
104
+
105
+ with open(tokenizer_config_path, 'w') as f:
106
+ json.dump(config, f, indent=2)
107
+
108
+ # Download and save the model
109
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
110
+ model.save_pretrained(local_model_dir)
111
+
112
+ # Load the model from the modified local directory
113
+ refined_model = Refined.from_pretrained(model_name=local_model_dir, entity_set=entity_set)
114
+ st.success("Successfully loaded model after applying fix.")
115
+ return refined_model
116
+ else:
117
+ raise e # Re-raise other AttributeError exceptions
118
  except Exception as e:
119
  st.error(f"Error loading model: {e}")
120
  return None