Pracheethaa commited on
Commit
a468bc7
Β·
verified Β·
1 Parent(s): aab1e16
Files changed (1) hide show
  1. Data +225 -0
Data ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import os
4
+ import requests
5
+ import wikipedia
6
+ import gradio as gr
7
+ import torch
8
+
9
+ from functools import lru_cache
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from typing import List
12
+
13
+ from transformers import (
14
+ SeamlessM4TTokenizer,
15
+ SeamlessM4TProcessor,
16
+ SeamlessM4TForTextToText,
17
+ pipeline as hf_pipeline
18
+ )
19
+
20
+ # ── 1) Model setup ────────────────────────────────────────────────────────────
21
+
22
+ MODEL = "facebook/hf-seamless-m4t-medium"
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False)
26
+ processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer)
27
+
28
+ m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device)
29
+ if device == "cuda":
30
+ m4t_model = m4t_model.half() # FP16 for faster inference on GPU
31
+ m4t_model.eval()
32
+
33
+ def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str:
34
+ src = None if auto_detect else src_iso3
35
+ inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
36
+ tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
37
+ return processor.decode(tokens[0].tolist(), skip_special_tokens=True)
38
+
39
+ def translate_m4t_batch(
40
+ texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False
41
+ ) -> List[str]:
42
+ src = None if auto_detect else src_iso3
43
+ inputs = processor(
44
+ text=texts, src_lang=src, return_tensors="pt", padding=True
45
+ ).to(device)
46
+ tokens = m4t_model.generate(
47
+ **inputs,
48
+ tgt_lang=tgt_iso3,
49
+ max_new_tokens=60,
50
+ num_beams=1
51
+ )
52
+ return processor.batch_decode(tokens, skip_special_tokens=True)
53
+
54
+
55
+ # ── 2) NER pipeline (updated for deprecation) ────────────────────────────────
56
+
57
+ ner = hf_pipeline(
58
+ "ner",
59
+ model="dslim/bert-base-NER-uncased",
60
+ aggregation_strategy="simple"
61
+ )
62
+
63
+
64
+ # ── 3) CACHING helpers ──────────────────────────────────────────────────────
65
+
66
+ @lru_cache(maxsize=256)
67
+ def geocode_cache(place: str):
68
+ r = requests.get(
69
+ "https://nominatim.openstreetmap.org/search",
70
+ params={"q": place, "format": "json", "limit": 1},
71
+ headers={"User-Agent": "iVoiceContext/1.0"}
72
+ ).json()
73
+ if not r:
74
+ return None
75
+ return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])}
76
+
77
+ @lru_cache(maxsize=256)
78
+ def fetch_osm_cache(lat: float, lon: float, osm_filter: str, limit: int = 5):
79
+ payload = f"""
80
+ [out:json][timeout:25];
81
+ (
82
+ node{osm_filter}(around:1000,{lat},{lon});
83
+ way{osm_filter}(around:1000,{lat},{lon});
84
+ );
85
+ out center {limit};
86
+ """
87
+ resp = requests.post(
88
+ "https://overpass-api.de/api/interpreter",
89
+ data={"data": payload}
90
+ )
91
+ elems = resp.json().get("elements", [])
92
+ return [
93
+ {"name": e["tags"]["name"]}
94
+ for e in elems
95
+ if e.get("tags", {}).get("name")
96
+ ]
97
+
98
+ @lru_cache(maxsize=256)
99
+ def wiki_summary_cache(name: str) -> str:
100
+ try:
101
+ return wikipedia.summary(name, sentences=2)
102
+ except:
103
+ return "No summary available."
104
+
105
+
106
+ # ── 4) Per-entity worker ────────────────────────────────────────────────────
107
+
108
+ def process_entity(ent) -> dict:
109
+ w = ent["word"]
110
+ lbl = ent["entity_group"]
111
+
112
+ if lbl == "LOC":
113
+ geo = geocode_cache(w)
114
+ if not geo:
115
+ return {
116
+ "text": w,
117
+ "label": lbl,
118
+ "type": "location",
119
+ "error": "could not geocode"
120
+ }
121
+
122
+ restaurants = fetch_osm_cache(geo["lat"], geo["lon"], '["amenity"="restaurant"]')
123
+ attractions = fetch_osm_cache(geo["lat"], geo["lon"], '["tourism"="attraction"]')
124
+
125
+ return {
126
+ "text": w,
127
+ "label": lbl,
128
+ "type": "location",
129
+ "geo": geo,
130
+ "restaurants": restaurants,
131
+ "attractions": attractions
132
+ }
133
+
134
+ # PERSON / ORG / MISC β†’ Wikipedia
135
+ summary = wiki_summary_cache(w)
136
+ return {"text": w, "label": lbl, "type": "wiki", "summary": summary}
137
+
138
+
139
+ # ── 5) Main function ────────────────────────────────────────────────────────
140
+
141
+ def get_context(
142
+ text: str,
143
+ source_lang: str,
144
+ output_lang: str,
145
+ auto_detect: bool
146
+ ):
147
+ # a) Ensure English for NER
148
+ if auto_detect or source_lang != "eng":
149
+ en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
150
+ else:
151
+ en = text
152
+
153
+ # b) Run NER + dedupe
154
+ ner_out = ner(en)
155
+ seen = set()
156
+ unique_ents = []
157
+ for ent in ner_out:
158
+ w = ent["word"]
159
+ if w in seen:
160
+ continue
161
+ seen.add(w)
162
+ unique_ents.append(ent)
163
+
164
+ # c) Parallel I/O
165
+ entities = []
166
+ with ThreadPoolExecutor(max_workers=8) as exe:
167
+ futures = [exe.submit(process_entity, ent) for ent in unique_ents]
168
+ for fut in futures:
169
+ entities.append(fut.result())
170
+
171
+ # d) Batch-translate non-English fields
172
+ if output_lang != "eng":
173
+ to_translate = []
174
+ translations_info = []
175
+
176
+ for i, e in enumerate(entities):
177
+ if e["type"] == "wiki":
178
+ translations_info.append(("summary", i))
179
+ to_translate.append(e["summary"])
180
+ elif e["type"] == "location":
181
+ for j, r in enumerate(e["restaurants"]):
182
+ translations_info.append(("restaurant", i, j))
183
+ to_translate.append(r["name"])
184
+ for j, a in enumerate(e["attractions"]):
185
+ translations_info.append(("attraction", i, j))
186
+ to_translate.append(a["name"])
187
+
188
+ translated = translate_m4t_batch(to_translate, "eng", output_lang)
189
+
190
+ for txt, info in zip(translated, translations_info):
191
+ kind = info[0]
192
+ if kind == "summary":
193
+ _, ei = info
194
+ entities[ei]["summary"] = txt
195
+ elif kind == "restaurant":
196
+ _, ei, ri = info
197
+ entities[ei]["restaurants"][ri]["name"] = txt
198
+ elif kind == "attraction":
199
+ _, ei, ai = info
200
+ entities[ei]["attractions"][ai]["name"] = txt
201
+
202
+ return {"entities": entities}
203
+
204
+
205
+ # ── 6) Gradio interface ─────────────────────────────────────────────────────
206
+
207
+ iface = gr.Interface(
208
+ fn=get_context,
209
+ inputs=[
210
+ gr.Textbox(lines=3, placeholder="Enter text…"),
211
+ gr.Textbox(label="Source Language (ISO 639-3)"),
212
+ gr.Textbox(label="Target Language (ISO 639-3)"),
213
+ gr.Checkbox(label="Auto-detect source language")
214
+ ],
215
+ outputs="json",
216
+ title="iVoice Context-Aware",
217
+ description="Returns only the detected entities and their related info."
218
+ ).queue() # ← removed unsupported kwargs
219
+
220
+ if __name__ == "__main__":
221
+ iface.launch(
222
+ server_name="0.0.0.0",
223
+ server_port=int(os.environ.get("PORT", 7860)),
224
+ share=True
225
+ )