Update app.py
Browse files
app.py
CHANGED
@@ -5,21 +5,31 @@ import requests
|
|
5 |
import wikipedia
|
6 |
import gradio as gr
|
7 |
import torch
|
|
|
8 |
from transformers import (
|
9 |
SeamlessM4TProcessor,
|
10 |
SeamlessM4TForTextToText,
|
|
|
11 |
pipeline as hf_pipeline
|
12 |
)
|
13 |
|
14 |
# ββββββββββββββββββββ
|
15 |
-
# 1) SeamlessM4T
|
16 |
MODEL_NAME = "facebook/hf-seamless-m4t-medium"
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL_NAME).to(device).eval()
|
20 |
|
21 |
def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
|
22 |
-
# src_iso3: e.g. "eng", "fra", etc. If auto_detect=True, pass None
|
23 |
src = None if auto_detect else src_iso3
|
24 |
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
|
25 |
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
|
@@ -41,7 +51,8 @@ def geocode(place: str):
|
|
41 |
params={"q": place, "format": "json", "limit": 1},
|
42 |
headers={"User-Agent":"iVoiceContext/1.0"}
|
43 |
).json()
|
44 |
-
if not resp:
|
|
|
45 |
return float(resp[0]["lat"]), float(resp[0]["lon"])
|
46 |
|
47 |
def fetch_osm(lat, lon, osm_filter, limit=5):
|
@@ -63,16 +74,16 @@ def fetch_osm(lat, lon, osm_filter, limit=5):
|
|
63 |
|
64 |
# ββββββββββββββββββββ
|
65 |
def get_context(text: str,
|
66 |
-
source_lang: str, # always 3
|
67 |
-
output_lang: str, # always 3
|
68 |
auto_detect: bool):
|
69 |
-
# 1) Ensure English for NER
|
70 |
if auto_detect or source_lang != "eng":
|
71 |
en_text = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
|
72 |
else:
|
73 |
en_text = text
|
74 |
|
75 |
-
# 2)
|
76 |
ner_out = ner(en_text)
|
77 |
ents = { ent["word"]: ent["entity_group"] for ent in ner_out }
|
78 |
|
@@ -84,25 +95,22 @@ def get_context(text: str,
|
|
84 |
results[ent_text] = {"type":"location","error":"could not geocode"}
|
85 |
else:
|
86 |
lat, lon = geo
|
87 |
-
rest = fetch_osm(lat, lon, '["amenity"="restaurant"]')
|
88 |
-
attr = fetch_osm(lat, lon, '["tourism"="attraction"]')
|
89 |
results[ent_text] = {
|
90 |
"type": "location",
|
91 |
-
"restaurants":
|
92 |
-
"attractions":
|
93 |
}
|
94 |
else:
|
95 |
-
# PERSON, ORG, MISC β Wikipedia
|
96 |
try:
|
97 |
-
|
98 |
except Exception:
|
99 |
-
|
100 |
-
results[ent_text] = {"type":"wiki","summary":
|
101 |
|
102 |
if not results:
|
103 |
return {"error":"no entities found"}
|
104 |
|
105 |
-
# 3) Translate
|
106 |
if output_lang != "eng":
|
107 |
for info in results.values():
|
108 |
if info["type"] == "wiki":
|
@@ -110,13 +118,11 @@ def get_context(text: str,
|
|
110 |
info["summary"], "eng", output_lang, auto_detect=False
|
111 |
)
|
112 |
elif info["type"] == "location":
|
113 |
-
for
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
translated.append({"name": tr})
|
119 |
-
info[poi_list] = translated
|
120 |
|
121 |
return results
|
122 |
|
@@ -133,9 +139,9 @@ iface = gr.Interface(
|
|
133 |
title="iVoice Translate + Context-Aware",
|
134 |
description=(
|
135 |
"1) Translate your text β English (if needed)\n"
|
136 |
-
"2)
|
137 |
"3) Geocode LOC β fetch nearby restaurants & attractions\n"
|
138 |
-
"4) Fetch Wikipedia summaries
|
139 |
"5) Translate **all** results β your target language"
|
140 |
)
|
141 |
).queue()
|
|
|
5 |
import wikipedia
|
6 |
import gradio as gr
|
7 |
import torch
|
8 |
+
|
9 |
from transformers import (
|
10 |
SeamlessM4TProcessor,
|
11 |
SeamlessM4TForTextToText,
|
12 |
+
SeamlessM4TTokenizer, # <<< import the tokenizer class
|
13 |
pipeline as hf_pipeline
|
14 |
)
|
15 |
|
16 |
# ββββββββββββββββββββ
|
17 |
+
# 1) Load SeamlessM4T tokenizer (slow) and processor
|
18 |
MODEL_NAME = "facebook/hf-seamless-m4t-medium"
|
19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
+
|
21 |
+
# load the slow tokenizer (no conversion attempted)
|
22 |
+
tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
|
23 |
+
|
24 |
+
# pass it into the processor so it won't try to convert
|
25 |
+
processor = SeamlessM4TProcessor.from_pretrained(
|
26 |
+
MODEL_NAME,
|
27 |
+
tokenizer=tokenizer
|
28 |
+
)
|
29 |
+
|
30 |
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL_NAME).to(device).eval()
|
31 |
|
32 |
def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
|
|
|
33 |
src = None if auto_detect else src_iso3
|
34 |
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
|
35 |
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
|
|
|
51 |
params={"q": place, "format": "json", "limit": 1},
|
52 |
headers={"User-Agent":"iVoiceContext/1.0"}
|
53 |
).json()
|
54 |
+
if not resp:
|
55 |
+
return None
|
56 |
return float(resp[0]["lat"]), float(resp[0]["lon"])
|
57 |
|
58 |
def fetch_osm(lat, lon, osm_filter, limit=5):
|
|
|
74 |
|
75 |
# ββββββββββββββββββββ
|
76 |
def get_context(text: str,
|
77 |
+
source_lang: str, # always ISO639-3, e.g. "eng"
|
78 |
+
output_lang: str, # always ISO639-3, e.g. "fra"
|
79 |
auto_detect: bool):
|
80 |
+
# 1) Ensure English text for NER
|
81 |
if auto_detect or source_lang != "eng":
|
82 |
en_text = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
|
83 |
else:
|
84 |
en_text = text
|
85 |
|
86 |
+
# 2) Run NER
|
87 |
ner_out = ner(en_text)
|
88 |
ents = { ent["word"]: ent["entity_group"] for ent in ner_out }
|
89 |
|
|
|
95 |
results[ent_text] = {"type":"location","error":"could not geocode"}
|
96 |
else:
|
97 |
lat, lon = geo
|
|
|
|
|
98 |
results[ent_text] = {
|
99 |
"type": "location",
|
100 |
+
"restaurants": fetch_osm(lat, lon, '["amenity"="restaurant"]'),
|
101 |
+
"attractions": fetch_osm(lat, lon, '["tourism"="attraction"]'),
|
102 |
}
|
103 |
else:
|
|
|
104 |
try:
|
105 |
+
summ = wikipedia.summary(ent_text, sentences=2)
|
106 |
except Exception:
|
107 |
+
summ = "No summary available."
|
108 |
+
results[ent_text] = {"type":"wiki","summary": summ}
|
109 |
|
110 |
if not results:
|
111 |
return {"error":"no entities found"}
|
112 |
|
113 |
+
# 3) Translate all text fields β output_lang
|
114 |
if output_lang != "eng":
|
115 |
for info in results.values():
|
116 |
if info["type"] == "wiki":
|
|
|
118 |
info["summary"], "eng", output_lang, auto_detect=False
|
119 |
)
|
120 |
elif info["type"] == "location":
|
121 |
+
for key in ("restaurants","attractions"):
|
122 |
+
info[key] = [
|
123 |
+
{"name": translate_m4t(item["name"], "eng", output_lang)}
|
124 |
+
for item in info[key]
|
125 |
+
]
|
|
|
|
|
126 |
|
127 |
return results
|
128 |
|
|
|
139 |
title="iVoice Translate + Context-Aware",
|
140 |
description=(
|
141 |
"1) Translate your text β English (if needed)\n"
|
142 |
+
"2) Extract LOC/PERSON/ORG via BERT-NER\n"
|
143 |
"3) Geocode LOC β fetch nearby restaurants & attractions\n"
|
144 |
+
"4) Fetch Wikipedia summaries\n"
|
145 |
"5) Translate **all** results β your target language"
|
146 |
)
|
147 |
).queue()
|