File size: 4,462 Bytes
2a3aa81
 
5bf38c0
a70a295
 
e08081f
5bf38c0
0d7fa59
5bf38c0
948cffb
5bf38c0
 
 
 
 
948cffb
 
 
 
 
 
5bf38c0
 
 
 
 
 
2a3aa81
948cffb
 
e08081f
948cffb
 
 
2a3aa81
 
5bf38c0
 
948cffb
 
a70a295
 
948cffb
 
 
 
 
2a3aa81
948cffb
 
 
 
 
5bf38c0
948cffb
 
5bf38c0
948cffb
 
5bf38c0
948cffb
5bf38c0
948cffb
 
 
 
 
 
 
 
 
 
 
 
a70a295
948cffb
a70a295
948cffb
 
 
 
 
 
 
a70a295
948cffb
2a3aa81
948cffb
a70a295
948cffb
 
0d7fa59
948cffb
a70a295
948cffb
5bf38c0
948cffb
5bf38c0
948cffb
 
 
 
 
 
 
 
0d7fa59
5bf38c0
948cffb
 
5bf38c0
948cffb
a70a295
 
5bf38c0
948cffb
 
 
 
5bf38c0
a70a295
948cffb
 
5bf38c0
e08081f
a70a295
948cffb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# app.py

import os
import requests
import wikipedia
import gradio as gr
import torch

from transformers import (
    SeamlessM4TTokenizer,
    SeamlessM4TProcessor,
    SeamlessM4TForTextToText,
    pipeline as hf_pipeline
)

# 1) Load SeamlessM4T (slow tokenizer)
MODEL = "facebook/hf-seamless-m4t-medium"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False)
processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer)
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device).eval()

def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
    src = None if auto_detect else src_iso3
    inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
    tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
    return processor.decode(tokens[0].tolist(), skip_special_tokens=True)

# 2) NER pipeline
ner = hf_pipeline("ner", model="dslim/bert-base-NER-uncased", grouped_entities=True)

# 3) Geocode & POIs
def geocode(place):
    r = requests.get(
        "https://nominatim.openstreetmap.org/search",
        params={"q": place, "format": "json", "limit": 1},
        headers={"User-Agent":"iVoiceContext/1.0"}
    ).json()
    if not r: return None
    return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])}

def fetch_osm(lat, lon, osm_filter, limit=5):
    payload = f"""
      [out:json][timeout:25];
      ( node{osm_filter}(around:1000,{lat},{lon});
        way{osm_filter}(around:1000,{lat},{lon}); );
      out center {limit};
    """
    resp = requests.post("https://overpass-api.de/api/interpreter", data={"data": payload})
    elems = resp.json().get("elements", [])
    return [{"name": e["tags"]["name"]} for e in elems if e.get("tags",{}).get("name")]

# 4) Main function
def get_context(text: str,
                source_lang: str,  # ISO-639-3 e.g. "eng"
                output_lang: str,  # ISO-639-3 e.g. "fra"
                auto_detect: bool):

    # a) Ensure English for NER
    if auto_detect or source_lang != "eng":
        en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
    else:
        en = text

    # b) Extract unique entities
    ner_out = ner(en)
    seen, entities = set(), []
    for ent in ner_out:
        w, lbl = ent["word"], ent["entity_group"]
        if w in seen: continue
        seen.add(w)

        if lbl == "LOC":
            geo = geocode(w)
            if not geo:
                obj = {"text": w, "label": lbl, "type": "location", "error": "could not geocode"}
            else:
                obj = {
                    "text": w,
                    "label": lbl,
                    "type": "location",
                    "geo": geo,
                    "restaurants": fetch_osm(geo["lat"], geo["lon"], '["amenity"="restaurant"]'),
                    "attractions": fetch_osm(geo["lat"], geo["lon"], '["tourism"="attraction"]')
                }

        else:
            # PERSON/ORG/MISC → Wikipedia
            try:
                summ = wikipedia.summary(w, sentences=2)
            except:
                summ = "No summary available."
            obj = {"text": w, "label": lbl, "type": "wiki", "summary": summ}

        entities.append(obj)

    # c) Translate all fields → output_lang
    if output_lang != "eng":
        for e in entities:
            if e["type"] == "wiki":
                e["summary"] = translate_m4t(e["summary"], "eng", output_lang)
            elif e["type"] == "location":
                for field in ("restaurants","attractions"):
                    e[field] = [
                        {"name": translate_m4t(item["name"], "eng", output_lang)}
                        for item in e[field]
                    ]

    # d) Return only entities
    return {"entities": entities}

# 5) Gradio interface
iface = gr.Interface(
    fn=get_context,
    inputs=[
        gr.Textbox(lines=3, placeholder="Enter text…"),
        gr.Textbox(label="Source Language (ISO 639-3)"),
        gr.Textbox(label="Target Language (ISO 639-3)"),
        gr.Checkbox(label="Auto-detect source language")
    ],
    outputs="json",
    title="iVoice Context-Aware",
    description="Returns only the detected entities and their related info."
).queue()

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0",
                 server_port=int(os.environ.get("PORT", 7860)),
                 share=True)