|
|
|
|
|
import os |
|
import requests |
|
import wikipedia |
|
import gradio as gr |
|
import torch |
|
|
|
from transformers import ( |
|
SeamlessM4TTokenizer, |
|
SeamlessM4TProcessor, |
|
SeamlessM4TForTextToText, |
|
pipeline as hf_pipeline |
|
) |
|
|
|
|
|
MODEL = "facebook/hf-seamless-m4t-medium" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False) |
|
processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer) |
|
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device).eval() |
|
|
|
def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False): |
|
src = None if auto_detect else src_iso3 |
|
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device) |
|
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3) |
|
return processor.decode(tokens[0].tolist(), skip_special_tokens=True) |
|
|
|
|
|
ner = hf_pipeline("ner", model="dslim/bert-base-NER-uncased", grouped_entities=True) |
|
|
|
|
|
def geocode(place): |
|
r = requests.get( |
|
"https://nominatim.openstreetmap.org/search", |
|
params={"q": place, "format": "json", "limit": 1}, |
|
headers={"User-Agent":"iVoiceContext/1.0"} |
|
).json() |
|
if not r: return None |
|
return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])} |
|
|
|
def fetch_osm(lat, lon, osm_filter, limit=5): |
|
payload = f""" |
|
[out:json][timeout:25]; |
|
( node{osm_filter}(around:1000,{lat},{lon}); |
|
way{osm_filter}(around:1000,{lat},{lon}); ); |
|
out center {limit}; |
|
""" |
|
resp = requests.post("https://overpass-api.de/api/interpreter", data={"data": payload}) |
|
elems = resp.json().get("elements", []) |
|
return [{"name": e["tags"]["name"]} for e in elems if e.get("tags",{}).get("name")] |
|
|
|
|
|
def get_context(text: str, |
|
source_lang: str, |
|
output_lang: str, |
|
auto_detect: bool): |
|
|
|
|
|
if auto_detect or source_lang != "eng": |
|
en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect) |
|
else: |
|
en = text |
|
|
|
|
|
ner_out = ner(en) |
|
seen, entities = set(), [] |
|
for ent in ner_out: |
|
w, lbl = ent["word"], ent["entity_group"] |
|
if w in seen: continue |
|
seen.add(w) |
|
|
|
if lbl == "LOC": |
|
geo = geocode(w) |
|
if not geo: |
|
obj = {"text": w, "label": lbl, "type": "location", "error": "could not geocode"} |
|
else: |
|
obj = { |
|
"text": w, |
|
"label": lbl, |
|
"type": "location", |
|
"geo": geo, |
|
"restaurants": fetch_osm(geo["lat"], geo["lon"], '["amenity"="restaurant"]'), |
|
"attractions": fetch_osm(geo["lat"], geo["lon"], '["tourism"="attraction"]') |
|
} |
|
|
|
else: |
|
|
|
try: |
|
summ = wikipedia.summary(w, sentences=2) |
|
except: |
|
summ = "No summary available." |
|
obj = {"text": w, "label": lbl, "type": "wiki", "summary": summ} |
|
|
|
entities.append(obj) |
|
|
|
|
|
if output_lang != "eng": |
|
for e in entities: |
|
if e["type"] == "wiki": |
|
e["summary"] = translate_m4t(e["summary"], "eng", output_lang) |
|
elif e["type"] == "location": |
|
for field in ("restaurants","attractions"): |
|
e[field] = [ |
|
{"name": translate_m4t(item["name"], "eng", output_lang)} |
|
for item in e[field] |
|
] |
|
|
|
|
|
return {"entities": entities} |
|
|
|
|
|
iface = gr.Interface( |
|
fn=get_context, |
|
inputs=[ |
|
gr.Textbox(lines=3, placeholder="Enter text…"), |
|
gr.Textbox(label="Source Language (ISO 639-3)"), |
|
gr.Textbox(label="Target Language (ISO 639-3)"), |
|
gr.Checkbox(label="Auto-detect source language") |
|
], |
|
outputs="json", |
|
title="iVoice Context-Aware", |
|
description="Returns only the detected entities and their related info." |
|
).queue() |
|
|
|
if __name__ == "__main__": |
|
iface.launch(server_name="0.0.0.0", |
|
server_port=int(os.environ.get("PORT", 7860)), |
|
share=True) |
|
|