|
|
|
|
|
import os |
|
import requests |
|
import wikipedia |
|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
SeamlessM4TProcessor, |
|
SeamlessM4TForTextToText, |
|
pipeline as hf_pipeline |
|
) |
|
|
|
|
|
|
|
MODEL_NAME = "facebook/hf-seamless-m4t-medium" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
processor = SeamlessM4TProcessor.from_pretrained(MODEL_NAME) |
|
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL_NAME).to(device).eval() |
|
|
|
def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False): |
|
|
|
src = None if auto_detect else src_iso3 |
|
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device) |
|
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3) |
|
return processor.decode(tokens[0].tolist(), skip_special_tokens=True) |
|
|
|
|
|
|
|
ner = hf_pipeline( |
|
"ner", |
|
model="dslim/bert-base-NER-uncased", |
|
grouped_entities=True |
|
) |
|
|
|
|
|
|
|
def geocode(place: str): |
|
resp = requests.get( |
|
"https://nominatim.openstreetmap.org/search", |
|
params={"q": place, "format": "json", "limit": 1}, |
|
headers={"User-Agent":"iVoiceContext/1.0"} |
|
).json() |
|
if not resp: return None |
|
return float(resp[0]["lat"]), float(resp[0]["lon"]) |
|
|
|
def fetch_osm(lat, lon, osm_filter, limit=5): |
|
query = f""" |
|
[out:json][timeout:25]; |
|
( |
|
node{osm_filter}(around:1000,{lat},{lon}); |
|
way{osm_filter}(around:1000,{lat},{lon}); |
|
); |
|
out center {limit}; |
|
""" |
|
r = requests.post("https://overpass-api.de/api/interpreter", data={"data": query}) |
|
elems = r.json().get("elements", []) |
|
return [ |
|
{"name": e["tags"].get("name", "")} |
|
for e in elems |
|
if e.get("tags", {}).get("name") |
|
] |
|
|
|
|
|
def get_context(text: str, |
|
source_lang: str, |
|
output_lang: str, |
|
auto_detect: bool): |
|
|
|
if auto_detect or source_lang != "eng": |
|
en_text = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect) |
|
else: |
|
en_text = text |
|
|
|
|
|
ner_out = ner(en_text) |
|
ents = { ent["word"]: ent["entity_group"] for ent in ner_out } |
|
|
|
results = {} |
|
for ent_text, label in ents.items(): |
|
if label == "LOC": |
|
geo = geocode(ent_text) |
|
if not geo: |
|
results[ent_text] = {"type":"location","error":"could not geocode"} |
|
else: |
|
lat, lon = geo |
|
rest = fetch_osm(lat, lon, '["amenity"="restaurant"]') |
|
attr = fetch_osm(lat, lon, '["tourism"="attraction"]') |
|
results[ent_text] = { |
|
"type": "location", |
|
"restaurants": rest, |
|
"attractions": attr |
|
} |
|
else: |
|
|
|
try: |
|
summary = wikipedia.summary(ent_text, sentences=2) |
|
except Exception: |
|
summary = "No summary available." |
|
results[ent_text] = {"type":"wiki","summary": summary} |
|
|
|
if not results: |
|
return {"error":"no entities found"} |
|
|
|
|
|
if output_lang != "eng": |
|
for info in results.values(): |
|
if info["type"] == "wiki": |
|
info["summary"] = translate_m4t( |
|
info["summary"], "eng", output_lang, auto_detect=False |
|
) |
|
elif info["type"] == "location": |
|
for poi_list in ("restaurants","attractions"): |
|
translated = [] |
|
for item in info[poi_list]: |
|
name = item["name"] |
|
tr = translate_m4t(name, "eng", output_lang, auto_detect=False) |
|
translated.append({"name": tr}) |
|
info[poi_list] = translated |
|
|
|
return results |
|
|
|
|
|
iface = gr.Interface( |
|
fn=get_context, |
|
inputs=[ |
|
gr.Textbox(lines=3, placeholder="Enter textβ¦"), |
|
gr.Textbox(label="Source Language (ISO 639-3)"), |
|
gr.Textbox(label="Target Language (ISO 639-3)"), |
|
gr.Checkbox(label="Auto-detect source language") |
|
], |
|
outputs="json", |
|
title="iVoice Translate + Context-Aware", |
|
description=( |
|
"1) Translate your text β English (if needed)\n" |
|
"2) Run BERT-NER on English to find LOC/PERSON/ORG\n" |
|
"3) Geocode LOC β fetch nearby restaurants & attractions\n" |
|
"4) Fetch Wikipedia summaries for PERSON/ORG\n" |
|
"5) Translate **all** results β your target language" |
|
) |
|
).queue() |
|
|
|
if __name__ == "__main__": |
|
iface.launch( |
|
server_name="0.0.0.0", |
|
server_port=int(os.environ.get("PORT", 7860)), |
|
share=True |
|
) |
|
|