Christopher Román Jaimes
commited on
Commit
·
7f254c6
1
Parent(s):
2c6872f
feat: add bert model to predict house levels.
Browse files
app.py
CHANGED
@@ -10,11 +10,17 @@ import pandas as pd
|
|
10 |
import gradio as gr
|
11 |
# GLiNER Model
|
12 |
from gliner import GLiNER
|
|
|
|
|
13 |
|
14 |
-
# Load Model
|
15 |
model = GLiNER.from_pretrained("chris32/gliner_multi_pii_real_state-v2")
|
16 |
model.eval()
|
17 |
|
|
|
|
|
|
|
|
|
18 |
# Global Variables: For Post Cleaning Inferences
|
19 |
YEAR_OF_REMODELING_LIMIT = 100
|
20 |
CURRENT_YEAR = int(datetime.date.today().year)
|
@@ -189,6 +195,40 @@ threshols_dict = {
|
|
189 |
'NOMBRE_DESARROLLO': 0.9,
|
190 |
}
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
def generate_answer(text):
|
193 |
labels = [
|
194 |
'SUPERFICIE_JARDIN',
|
@@ -216,7 +256,16 @@ def generate_answer(text):
|
|
216 |
entity_prediction_cleaned = clean_prediction(entities_formatted, feature_name, threshols_dict, clean_functions_dict)
|
217 |
if isinstance(entity_prediction_cleaned, str) or isinstance(entity_prediction_cleaned, int):
|
218 |
entities_cleaned[feature_name] = entity_prediction_cleaned
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
result_json = json.dumps(entities_cleaned, indent = 4, ensure_ascii = False)
|
221 |
|
222 |
return result_json + "\n \n" + json.dumps(entities_formatted, indent = 4, ensure_ascii = False)
|
|
|
10 |
import gradio as gr
|
11 |
# GLiNER Model
|
12 |
from gliner import GLiNER
|
13 |
+
# Transformers
|
14 |
+
from transformers import pipeline
|
15 |
|
16 |
+
# Load GLiNER Model
|
17 |
model = GLiNER.from_pretrained("chris32/gliner_multi_pii_real_state-v2")
|
18 |
model.eval()
|
19 |
|
20 |
+
# BERT Model
|
21 |
+
model_name = "chris32/distilbert-base-spanish-uncased-finetuned-text-intelligence"
|
22 |
+
pipe = pipeline(model = model_name, device = "cpu")
|
23 |
+
|
24 |
# Global Variables: For Post Cleaning Inferences
|
25 |
YEAR_OF_REMODELING_LIMIT = 100
|
26 |
CURRENT_YEAR = int(datetime.date.today().year)
|
|
|
195 |
'NOMBRE_DESARROLLO': 0.9,
|
196 |
}
|
197 |
|
198 |
+
label_names_dict = {
|
199 |
+
'LABEL_0': None,
|
200 |
+
'LABEL_1': 1,
|
201 |
+
'LABEL_2': 2,
|
202 |
+
'LABEL_3': 3,
|
203 |
+
}
|
204 |
+
BERT_SCORE_LIMIT = 0.98
|
205 |
+
|
206 |
+
def extract_max_label_score(probabilities):
|
207 |
+
# Find the dictionary with the maximum score
|
208 |
+
max_item = max(probabilities, key=lambda x: x['score'])
|
209 |
+
# Extract the label and the score
|
210 |
+
label = max_item['label']
|
211 |
+
score = max_item['score']
|
212 |
+
|
213 |
+
return label, score
|
214 |
+
|
215 |
+
def clean_prediction_bert(label, score):
|
216 |
+
if score > BERT_SCORE_LIMIT:
|
217 |
+
label_formatted = label_names_dict.get(label, None)
|
218 |
+
return label_formatted
|
219 |
+
else:
|
220 |
+
return None
|
221 |
+
|
222 |
+
# BERT Inference Config
|
223 |
+
pipe_config = {
|
224 |
+
"batch_size": 8,
|
225 |
+
"truncation": True,
|
226 |
+
"max_length": 250,
|
227 |
+
"add_special_tokens": True,
|
228 |
+
"return_all_scores": True,
|
229 |
+
"padding": True,
|
230 |
+
}
|
231 |
+
|
232 |
def generate_answer(text):
|
233 |
labels = [
|
234 |
'SUPERFICIE_JARDIN',
|
|
|
256 |
entity_prediction_cleaned = clean_prediction(entities_formatted, feature_name, threshols_dict, clean_functions_dict)
|
257 |
if isinstance(entity_prediction_cleaned, str) or isinstance(entity_prediction_cleaned, int):
|
258 |
entities_cleaned[feature_name] = entity_prediction_cleaned
|
259 |
+
|
260 |
+
# BERT Inference
|
261 |
+
predictions = pipe([text], **pipe_config)
|
262 |
+
|
263 |
+
# Format Prediction
|
264 |
+
label, score = extract_max_label_score(predictions[0])
|
265 |
+
prediction_cleaned = clean_prediction_bert(label, score)
|
266 |
+
if isinstance(prediction_cleaned, int):
|
267 |
+
entities_cleaned["NIVELES_CASA"] = prediction_cleaned
|
268 |
+
|
269 |
result_json = json.dumps(entities_cleaned, indent = 4, ensure_ascii = False)
|
270 |
|
271 |
return result_json + "\n \n" + json.dumps(entities_formatted, indent = 4, ensure_ascii = False)
|