File size: 4,462 Bytes
2a3aa81 5bf38c0 a70a295 e08081f 5bf38c0 0d7fa59 5bf38c0 948cffb 5bf38c0 948cffb 5bf38c0 2a3aa81 948cffb e08081f 948cffb 2a3aa81 5bf38c0 948cffb a70a295 948cffb 2a3aa81 948cffb 5bf38c0 948cffb 5bf38c0 948cffb 5bf38c0 948cffb 5bf38c0 948cffb a70a295 948cffb a70a295 948cffb a70a295 948cffb 2a3aa81 948cffb a70a295 948cffb 0d7fa59 948cffb a70a295 948cffb 5bf38c0 948cffb 5bf38c0 948cffb 0d7fa59 5bf38c0 948cffb 5bf38c0 948cffb a70a295 5bf38c0 948cffb 5bf38c0 a70a295 948cffb 5bf38c0 e08081f a70a295 948cffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# app.py
import os
import requests
import wikipedia
import gradio as gr
import torch
from transformers import (
SeamlessM4TTokenizer,
SeamlessM4TProcessor,
SeamlessM4TForTextToText,
pipeline as hf_pipeline
)
# 1) Load SeamlessM4T (slow tokenizer)
MODEL = "facebook/hf-seamless-m4t-medium"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False)
processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer)
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device).eval()
def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
src = None if auto_detect else src_iso3
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
return processor.decode(tokens[0].tolist(), skip_special_tokens=True)
# 2) NER pipeline
ner = hf_pipeline("ner", model="dslim/bert-base-NER-uncased", grouped_entities=True)
# 3) Geocode & POIs
def geocode(place):
r = requests.get(
"https://nominatim.openstreetmap.org/search",
params={"q": place, "format": "json", "limit": 1},
headers={"User-Agent":"iVoiceContext/1.0"}
).json()
if not r: return None
return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])}
def fetch_osm(lat, lon, osm_filter, limit=5):
payload = f"""
[out:json][timeout:25];
( node{osm_filter}(around:1000,{lat},{lon});
way{osm_filter}(around:1000,{lat},{lon}); );
out center {limit};
"""
resp = requests.post("https://overpass-api.de/api/interpreter", data={"data": payload})
elems = resp.json().get("elements", [])
return [{"name": e["tags"]["name"]} for e in elems if e.get("tags",{}).get("name")]
# 4) Main function
def get_context(text: str,
source_lang: str, # ISO-639-3 e.g. "eng"
output_lang: str, # ISO-639-3 e.g. "fra"
auto_detect: bool):
# a) Ensure English for NER
if auto_detect or source_lang != "eng":
en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
else:
en = text
# b) Extract unique entities
ner_out = ner(en)
seen, entities = set(), []
for ent in ner_out:
w, lbl = ent["word"], ent["entity_group"]
if w in seen: continue
seen.add(w)
if lbl == "LOC":
geo = geocode(w)
if not geo:
obj = {"text": w, "label": lbl, "type": "location", "error": "could not geocode"}
else:
obj = {
"text": w,
"label": lbl,
"type": "location",
"geo": geo,
"restaurants": fetch_osm(geo["lat"], geo["lon"], '["amenity"="restaurant"]'),
"attractions": fetch_osm(geo["lat"], geo["lon"], '["tourism"="attraction"]')
}
else:
# PERSON/ORG/MISC → Wikipedia
try:
summ = wikipedia.summary(w, sentences=2)
except:
summ = "No summary available."
obj = {"text": w, "label": lbl, "type": "wiki", "summary": summ}
entities.append(obj)
# c) Translate all fields → output_lang
if output_lang != "eng":
for e in entities:
if e["type"] == "wiki":
e["summary"] = translate_m4t(e["summary"], "eng", output_lang)
elif e["type"] == "location":
for field in ("restaurants","attractions"):
e[field] = [
{"name": translate_m4t(item["name"], "eng", output_lang)}
for item in e[field]
]
# d) Return only entities
return {"entities": entities}
# 5) Gradio interface
iface = gr.Interface(
fn=get_context,
inputs=[
gr.Textbox(lines=3, placeholder="Enter text…"),
gr.Textbox(label="Source Language (ISO 639-3)"),
gr.Textbox(label="Target Language (ISO 639-3)"),
gr.Checkbox(label="Auto-detect source language")
],
outputs="json",
title="iVoice Context-Aware",
description="Returns only the detected entities and their related info."
).queue()
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
share=True)
|