NER / app.py
Shenuki's picture
Update app.py
5bf38c0 verified
raw
history blame
5.21 kB
# app.py
import os
import requests
import wikipedia
import gradio as gr
import torch
from transformers import (
SeamlessM4TProcessor,
SeamlessM4TForTextToText,
pipeline as hf_pipeline
)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 1) SeamlessM4T Text2Text
MODEL_NAME = "facebook/hf-seamless-m4t-medium"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = SeamlessM4TProcessor.from_pretrained(MODEL_NAME)
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL_NAME).to(device).eval()
def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
# src_iso3: e.g. "eng", "fra", etc. If auto_detect=True, pass None
src = None if auto_detect else src_iso3
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
return processor.decode(tokens[0].tolist(), skip_special_tokens=True)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 2) BERT‐based NER
ner = hf_pipeline(
"ner",
model="dslim/bert-base-NER-uncased",
grouped_entities=True
)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 3) Geocoding & POIs via OpenStreetMap
def geocode(place: str):
resp = requests.get(
"https://nominatim.openstreetmap.org/search",
params={"q": place, "format": "json", "limit": 1},
headers={"User-Agent":"iVoiceContext/1.0"}
).json()
if not resp: return None
return float(resp[0]["lat"]), float(resp[0]["lon"])
def fetch_osm(lat, lon, osm_filter, limit=5):
query = f"""
[out:json][timeout:25];
(
node{osm_filter}(around:1000,{lat},{lon});
way{osm_filter}(around:1000,{lat},{lon});
);
out center {limit};
"""
r = requests.post("https://overpass-api.de/api/interpreter", data={"data": query})
elems = r.json().get("elements", [])
return [
{"name": e["tags"].get("name", "")}
for e in elems
if e.get("tags", {}).get("name")
]
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def get_context(text: str,
source_lang: str, # always 3-letter, e.g. "eng"
output_lang: str, # always 3-letter, e.g. "fra"
auto_detect: bool):
# 1) Ensure English for NER
if auto_detect or source_lang != "eng":
en_text = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
else:
en_text = text
# 2) Extract entities
ner_out = ner(en_text)
ents = { ent["word"]: ent["entity_group"] for ent in ner_out }
results = {}
for ent_text, label in ents.items():
if label == "LOC":
geo = geocode(ent_text)
if not geo:
results[ent_text] = {"type":"location","error":"could not geocode"}
else:
lat, lon = geo
rest = fetch_osm(lat, lon, '["amenity"="restaurant"]')
attr = fetch_osm(lat, lon, '["tourism"="attraction"]')
results[ent_text] = {
"type": "location",
"restaurants": rest,
"attractions": attr
}
else:
# PERSON, ORG, MISC β†’ Wikipedia
try:
summary = wikipedia.summary(ent_text, sentences=2)
except Exception:
summary = "No summary available."
results[ent_text] = {"type":"wiki","summary": summary}
if not results:
return {"error":"no entities found"}
# 3) Translate **all** text fields β†’ output_lang
if output_lang != "eng":
for info in results.values():
if info["type"] == "wiki":
info["summary"] = translate_m4t(
info["summary"], "eng", output_lang, auto_detect=False
)
elif info["type"] == "location":
for poi_list in ("restaurants","attractions"):
translated = []
for item in info[poi_list]:
name = item["name"]
tr = translate_m4t(name, "eng", output_lang, auto_detect=False)
translated.append({"name": tr})
info[poi_list] = translated
return results
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
iface = gr.Interface(
fn=get_context,
inputs=[
gr.Textbox(lines=3, placeholder="Enter text…"),
gr.Textbox(label="Source Language (ISO 639-3)"),
gr.Textbox(label="Target Language (ISO 639-3)"),
gr.Checkbox(label="Auto-detect source language")
],
outputs="json",
title="iVoice Translate + Context-Aware",
description=(
"1) Translate your text β†’ English (if needed)\n"
"2) Run BERT-NER on English to find LOC/PERSON/ORG\n"
"3) Geocode LOC β†’ fetch nearby restaurants & attractions\n"
"4) Fetch Wikipedia summaries for PERSON/ORG\n"
"5) Translate **all** results β†’ your target language"
)
).queue()
if __name__ == "__main__":
iface.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
share=True
)