Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import requests | |
import wikipedia | |
import gradio as gr | |
import torch | |
from functools import lru_cache | |
from concurrent.futures import ThreadPoolExecutor | |
from typing import List | |
from transformers import ( | |
SeamlessM4TTokenizer, | |
SeamlessM4TProcessor, | |
SeamlessM4TForTextToText, | |
pipeline as hf_pipeline | |
) | |
# ββ 1) Model setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
MODEL = "facebook/hf-seamless-m4t-medium" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False) | |
processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer) | |
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device) | |
if device == "cuda": | |
m4t_model = m4t_model.half() # FP16 for faster inference on GPU | |
m4t_model.eval() | |
def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str: | |
src = None if auto_detect else src_iso3 | |
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device) | |
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3) | |
return processor.decode(tokens[0].tolist(), skip_special_tokens=True) | |
def translate_m4t_batch( | |
texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False | |
) -> List[str]: | |
src = None if auto_detect else src_iso3 | |
inputs = processor( | |
text=texts, src_lang=src, return_tensors="pt", padding=True | |
).to(device) | |
tokens = m4t_model.generate( | |
**inputs, | |
tgt_lang=tgt_iso3, | |
max_new_tokens=60, | |
num_beams=1 | |
) | |
return processor.batch_decode(tokens, skip_special_tokens=True) | |
# ββ 2) NER pipeline (updated for deprecation) ββββββββββββββββββββββββββββββββ | |
ner = hf_pipeline( | |
"ner", | |
model="dslim/bert-base-NER-uncased", | |
aggregation_strategy="simple" | |
) | |
# ββ 3) CACHING helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def geocode_cache(place: str): | |
r = requests.get( | |
"https://nominatim.openstreetmap.org/search", | |
params={"q": place, "format": "json", "limit": 1}, | |
headers={"User-Agent": "iVoiceContext/1.0"} | |
).json() | |
if not r: | |
return None | |
return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])} | |
def fetch_osm_cache(lat: float, lon: float, osm_filter: str, limit: int = 5): | |
payload = f""" | |
[out:json][timeout:25]; | |
( | |
node{osm_filter}(around:1000,{lat},{lon}); | |
way{osm_filter}(around:1000,{lat},{lon}); | |
); | |
out center {limit}; | |
""" | |
resp = requests.post( | |
"https://overpass-api.de/api/interpreter", | |
data={"data": payload} | |
) | |
elems = resp.json().get("elements", []) | |
return [ | |
{"name": e["tags"]["name"]} | |
for e in elems | |
if e.get("tags", {}).get("name") | |
] | |
def wiki_summary_cache(name: str) -> str: | |
try: | |
return wikipedia.summary(name, sentences=2) | |
except: | |
return "No summary available." | |
# ββ 4) Per-entity worker ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def process_entity(ent) -> dict: | |
w = ent["word"] | |
lbl = ent["entity_group"] | |
if lbl == "LOC": | |
geo = geocode_cache(w) | |
if not geo: | |
return { | |
"text": w, | |
"label": lbl, | |
"type": "location", | |
"error": "could not geocode" | |
} | |
restaurants = fetch_osm_cache(geo["lat"], geo["lon"], '["amenity"="restaurant"]') | |
attractions = fetch_osm_cache(geo["lat"], geo["lon"], '["tourism"="attraction"]') | |
return { | |
"text": w, | |
"label": lbl, | |
"type": "location", | |
"geo": geo, | |
"restaurants": restaurants, | |
"attractions": attractions | |
} | |
# PERSON / ORG / MISC β Wikipedia | |
summary = wiki_summary_cache(w) | |
return {"text": w, "label": lbl, "type": "wiki", "summary": summary} | |
# ββ 5) Main function ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def get_context( | |
text: str, | |
source_lang: str, | |
output_lang: str, | |
auto_detect: bool | |
): | |
# a) Ensure English for NER | |
if auto_detect or source_lang != "eng": | |
en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect) | |
else: | |
en = text | |
# b) Run NER + dedupe | |
ner_out = ner(en) | |
seen = set() | |
unique_ents = [] | |
for ent in ner_out: | |
w = ent["word"] | |
if w in seen: | |
continue | |
seen.add(w) | |
unique_ents.append(ent) | |
# c) Parallel I/O | |
entities = [] | |
with ThreadPoolExecutor(max_workers=8) as exe: | |
futures = [exe.submit(process_entity, ent) for ent in unique_ents] | |
for fut in futures: | |
entities.append(fut.result()) | |
# d) Batch-translate non-English fields | |
if output_lang != "eng": | |
to_translate = [] | |
translations_info = [] | |
for i, e in enumerate(entities): | |
if e["type"] == "wiki": | |
translations_info.append(("summary", i)) | |
to_translate.append(e["summary"]) | |
elif e["type"] == "location": | |
for j, r in enumerate(e["restaurants"]): | |
translations_info.append(("restaurant", i, j)) | |
to_translate.append(r["name"]) | |
for j, a in enumerate(e["attractions"]): | |
translations_info.append(("attraction", i, j)) | |
to_translate.append(a["name"]) | |
translated = translate_m4t_batch(to_translate, "eng", output_lang) | |
for txt, info in zip(translated, translations_info): | |
kind = info[0] | |
if kind == "summary": | |
_, ei = info | |
entities[ei]["summary"] = txt | |
elif kind == "restaurant": | |
_, ei, ri = info | |
entities[ei]["restaurants"][ri]["name"] = txt | |
elif kind == "attraction": | |
_, ei, ai = info | |
entities[ei]["attractions"][ai]["name"] = txt | |
return {"entities": entities} | |
# ββ 6) Gradio interface βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
iface = gr.Interface( | |
fn=get_context, | |
inputs=[ | |
gr.Textbox(lines=3, placeholder="Enter textβ¦"), | |
gr.Textbox(label="Source Language (ISO 639-3)"), | |
gr.Textbox(label="Target Language (ISO 639-3)"), | |
gr.Checkbox(label="Auto-detect source language") | |
], | |
outputs="json", | |
title="iVoice Context-Aware", | |
description="Returns only the detected entities and their related info." | |
).queue() # β removed unsupported kwargs | |
if __name__ == "__main__": | |
iface.launch( | |
server_name="0.0.0.0", | |
server_port=int(os.environ.get("PORT", 7860)), | |
share=True | |
) |