Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -56,14 +56,22 @@ class NERLabelEncoder:
|
|
| 56 |
NER_CHECKPOINT = "microsoft/deberta-base"
|
| 57 |
NER_N_TOKENS = 50
|
| 58 |
NER_N_LABELS = 18
|
| 59 |
-
ner_model = TFAutoModelForTokenClassification.from_pretrained(NER_CHECKPOINT, num_labels=NER_N_LABELS, attention_probs_dropout_prob=0.4, hidden_dropout_prob=0.4)
|
| 60 |
-
ner_model.load_weights(os.path.join("models", "general_ner_deberta_weights.h5"), by_name=True)
|
| 61 |
-
ner_label_encoder = NERLabelEncoder()
|
| 62 |
-
ner_label_encoder.fit()
|
| 63 |
-
ner_tokenizer = DebertaTokenizerFast.from_pretrained(NER_CHECKPOINT, add_prefix_space=True)
|
| 64 |
-
nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0'))
|
| 65 |
NER_COLOR_MAP = {'GEO': '#DFFF00', 'GPE': '#FFBF00', 'PER': '#9FE2BF',
|
| 66 |
'ORG': '#40E0D0', 'TIM': '#CCCCFF', 'ART': '#FFC0CB', 'NAT': '#FFE4B5', 'EVE': '#DCDCDC'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
############ NER MODEL & VARS INITIALIZATION END ####################
|
| 68 |
|
| 69 |
############ NER LOGIC START ####################
|
|
@@ -170,9 +178,16 @@ def get_ner_text(article_txt, ner_result):
|
|
| 170 |
SUMM_CHECKPOINT = "facebook/bart-base"
|
| 171 |
SUMM_INPUT_N_TOKENS = 400
|
| 172 |
SUMM_TARGET_N_TOKENS = 100
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
def summ_preprocess(txt):
|
| 178 |
txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard .
|
|
@@ -190,7 +205,6 @@ def summ_preprocess(txt):
|
|
| 190 |
return txt
|
| 191 |
|
| 192 |
def summ_inference_tokenize(input_: list, n_tokens: int):
|
| 193 |
-
# tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
|
| 194 |
tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
|
| 195 |
return summ_tokenizer, tokenized_data
|
| 196 |
|
|
@@ -207,7 +221,7 @@ def summ_inference(txt: str):
|
|
| 207 |
############## ENTRY POINT START #######################
|
| 208 |
def main():
|
| 209 |
st.title("News Summarizer & NER")
|
| 210 |
-
article_txt = st.text_area("Paste
|
| 211 |
if st.button("Submit"):
|
| 212 |
ner_result = [[ent, label.upper(), np.round(prob, 3)]
|
| 213 |
for ent, label, prob in ner_inference_long_text(article_txt)]
|
|
|
|
| 56 |
NER_CHECKPOINT = "microsoft/deberta-base"
|
| 57 |
NER_N_TOKENS = 50
|
| 58 |
NER_N_LABELS = 18
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
NER_COLOR_MAP = {'GEO': '#DFFF00', 'GPE': '#FFBF00', 'PER': '#9FE2BF',
|
| 60 |
'ORG': '#40E0D0', 'TIM': '#CCCCFF', 'ART': '#FFC0CB', 'NAT': '#FFE4B5', 'EVE': '#DCDCDC'}
|
| 61 |
+
|
| 62 |
+
@st.cache_resource
|
| 63 |
+
def load_ner_models():
|
| 64 |
+
ner_model = TFAutoModelForTokenClassification.from_pretrained(NER_CHECKPOINT, num_labels=NER_N_LABELS, attention_probs_dropout_prob=0.4, hidden_dropout_prob=0.4)
|
| 65 |
+
ner_model.load_weights(os.path.join("models", "general_ner_deberta_weights.h5"), by_name=True)
|
| 66 |
+
ner_label_encoder = NERLabelEncoder()
|
| 67 |
+
ner_label_encoder.fit()
|
| 68 |
+
ner_tokenizer = DebertaTokenizerFast.from_pretrained(NER_CHECKPOINT, add_prefix_space=True)
|
| 69 |
+
nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0'))
|
| 70 |
+
print('Loaded NER models')
|
| 71 |
+
return ner_model, ner_label_encoder, ner_tokenizer, nlp
|
| 72 |
+
|
| 73 |
+
ner_model, ner_label_encoder, ner_tokenizer, nlp = load_ner_models()
|
| 74 |
+
|
| 75 |
############ NER MODEL & VARS INITIALIZATION END ####################
|
| 76 |
|
| 77 |
############ NER LOGIC START ####################
|
|
|
|
| 178 |
SUMM_CHECKPOINT = "facebook/bart-base"
|
| 179 |
SUMM_INPUT_N_TOKENS = 400
|
| 180 |
SUMM_TARGET_N_TOKENS = 100
|
| 181 |
+
|
| 182 |
+
@st.cache_resource
|
| 183 |
+
def load_summarizer_models():
|
| 184 |
+
summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
|
| 185 |
+
summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
|
| 186 |
+
summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
|
| 187 |
+
print('Loaded summarizer models')
|
| 188 |
+
return summ_tokenizer, summ_model
|
| 189 |
+
|
| 190 |
+
summ_tokenizer, summ_model = load_summarizer_models()
|
| 191 |
|
| 192 |
def summ_preprocess(txt):
|
| 193 |
txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard .
|
|
|
|
| 205 |
return txt
|
| 206 |
|
| 207 |
def summ_inference_tokenize(input_: list, n_tokens: int):
|
|
|
|
| 208 |
tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
|
| 209 |
return summ_tokenizer, tokenized_data
|
| 210 |
|
|
|
|
| 221 |
############## ENTRY POINT START #######################
|
| 222 |
def main():
|
| 223 |
st.title("News Summarizer & NER")
|
| 224 |
+
article_txt = st.text_area("Paste few sentences of a news article:", "", height=200)
|
| 225 |
if st.button("Submit"):
|
| 226 |
ner_result = [[ent, label.upper(), np.round(prob, 3)]
|
| 227 |
for ent, label, prob in ner_inference_long_text(article_txt)]
|