Shenuki commited on
Commit
0d7fa59
Β·
verified Β·
1 Parent(s): e02d2af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -5,21 +5,31 @@ import requests
5
  import wikipedia
6
  import gradio as gr
7
  import torch
 
8
  from transformers import (
9
  SeamlessM4TProcessor,
10
  SeamlessM4TForTextToText,
 
11
  pipeline as hf_pipeline
12
  )
13
 
14
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
15
- # 1) SeamlessM4T Text2Text
16
  MODEL_NAME = "facebook/hf-seamless-m4t-medium"
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
- processor = SeamlessM4TProcessor.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
19
  m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL_NAME).to(device).eval()
20
 
21
  def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
22
- # src_iso3: e.g. "eng", "fra", etc. If auto_detect=True, pass None
23
  src = None if auto_detect else src_iso3
24
  inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
25
  tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
@@ -41,7 +51,8 @@ def geocode(place: str):
41
  params={"q": place, "format": "json", "limit": 1},
42
  headers={"User-Agent":"iVoiceContext/1.0"}
43
  ).json()
44
- if not resp: return None
 
45
  return float(resp[0]["lat"]), float(resp[0]["lon"])
46
 
47
  def fetch_osm(lat, lon, osm_filter, limit=5):
@@ -63,16 +74,16 @@ def fetch_osm(lat, lon, osm_filter, limit=5):
63
 
64
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
65
  def get_context(text: str,
66
- source_lang: str, # always 3-letter, e.g. "eng"
67
- output_lang: str, # always 3-letter, e.g. "fra"
68
  auto_detect: bool):
69
- # 1) Ensure English for NER
70
  if auto_detect or source_lang != "eng":
71
  en_text = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
72
  else:
73
  en_text = text
74
 
75
- # 2) Extract entities
76
  ner_out = ner(en_text)
77
  ents = { ent["word"]: ent["entity_group"] for ent in ner_out }
78
 
@@ -84,25 +95,22 @@ def get_context(text: str,
84
  results[ent_text] = {"type":"location","error":"could not geocode"}
85
  else:
86
  lat, lon = geo
87
- rest = fetch_osm(lat, lon, '["amenity"="restaurant"]')
88
- attr = fetch_osm(lat, lon, '["tourism"="attraction"]')
89
  results[ent_text] = {
90
  "type": "location",
91
- "restaurants": rest,
92
- "attractions": attr
93
  }
94
  else:
95
- # PERSON, ORG, MISC β†’ Wikipedia
96
  try:
97
- summary = wikipedia.summary(ent_text, sentences=2)
98
  except Exception:
99
- summary = "No summary available."
100
- results[ent_text] = {"type":"wiki","summary": summary}
101
 
102
  if not results:
103
  return {"error":"no entities found"}
104
 
105
- # 3) Translate **all** text fields β†’ output_lang
106
  if output_lang != "eng":
107
  for info in results.values():
108
  if info["type"] == "wiki":
@@ -110,13 +118,11 @@ def get_context(text: str,
110
  info["summary"], "eng", output_lang, auto_detect=False
111
  )
112
  elif info["type"] == "location":
113
- for poi_list in ("restaurants","attractions"):
114
- translated = []
115
- for item in info[poi_list]:
116
- name = item["name"]
117
- tr = translate_m4t(name, "eng", output_lang, auto_detect=False)
118
- translated.append({"name": tr})
119
- info[poi_list] = translated
120
 
121
  return results
122
 
@@ -133,9 +139,9 @@ iface = gr.Interface(
133
  title="iVoice Translate + Context-Aware",
134
  description=(
135
  "1) Translate your text β†’ English (if needed)\n"
136
- "2) Run BERT-NER on English to find LOC/PERSON/ORG\n"
137
  "3) Geocode LOC β†’ fetch nearby restaurants & attractions\n"
138
- "4) Fetch Wikipedia summaries for PERSON/ORG\n"
139
  "5) Translate **all** results β†’ your target language"
140
  )
141
  ).queue()
 
5
  import wikipedia
6
  import gradio as gr
7
  import torch
8
+
9
  from transformers import (
10
  SeamlessM4TProcessor,
11
  SeamlessM4TForTextToText,
12
+ SeamlessM4TTokenizer, # <<< import the tokenizer class
13
  pipeline as hf_pipeline
14
  )
15
 
16
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
17
+ # 1) Load SeamlessM4T tokenizer (slow) and processor
18
  MODEL_NAME = "facebook/hf-seamless-m4t-medium"
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # load the slow tokenizer (no conversion attempted)
22
+ tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
23
+
24
+ # pass it into the processor so it won't try to convert
25
+ processor = SeamlessM4TProcessor.from_pretrained(
26
+ MODEL_NAME,
27
+ tokenizer=tokenizer
28
+ )
29
+
30
  m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL_NAME).to(device).eval()
31
 
32
  def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
 
33
  src = None if auto_detect else src_iso3
34
  inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
35
  tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
 
51
  params={"q": place, "format": "json", "limit": 1},
52
  headers={"User-Agent":"iVoiceContext/1.0"}
53
  ).json()
54
+ if not resp:
55
+ return None
56
  return float(resp[0]["lat"]), float(resp[0]["lon"])
57
 
58
  def fetch_osm(lat, lon, osm_filter, limit=5):
 
74
 
75
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
76
  def get_context(text: str,
77
+ source_lang: str, # always ISO639-3, e.g. "eng"
78
+ output_lang: str, # always ISO639-3, e.g. "fra"
79
  auto_detect: bool):
80
+ # 1) Ensure English text for NER
81
  if auto_detect or source_lang != "eng":
82
  en_text = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
83
  else:
84
  en_text = text
85
 
86
+ # 2) Run NER
87
  ner_out = ner(en_text)
88
  ents = { ent["word"]: ent["entity_group"] for ent in ner_out }
89
 
 
95
  results[ent_text] = {"type":"location","error":"could not geocode"}
96
  else:
97
  lat, lon = geo
 
 
98
  results[ent_text] = {
99
  "type": "location",
100
+ "restaurants": fetch_osm(lat, lon, '["amenity"="restaurant"]'),
101
+ "attractions": fetch_osm(lat, lon, '["tourism"="attraction"]'),
102
  }
103
  else:
 
104
  try:
105
+ summ = wikipedia.summary(ent_text, sentences=2)
106
  except Exception:
107
+ summ = "No summary available."
108
+ results[ent_text] = {"type":"wiki","summary": summ}
109
 
110
  if not results:
111
  return {"error":"no entities found"}
112
 
113
+ # 3) Translate all text fields β†’ output_lang
114
  if output_lang != "eng":
115
  for info in results.values():
116
  if info["type"] == "wiki":
 
118
  info["summary"], "eng", output_lang, auto_detect=False
119
  )
120
  elif info["type"] == "location":
121
+ for key in ("restaurants","attractions"):
122
+ info[key] = [
123
+ {"name": translate_m4t(item["name"], "eng", output_lang)}
124
+ for item in info[key]
125
+ ]
 
 
126
 
127
  return results
128
 
 
139
  title="iVoice Translate + Context-Aware",
140
  description=(
141
  "1) Translate your text β†’ English (if needed)\n"
142
+ "2) Extract LOC/PERSON/ORG via BERT-NER\n"
143
  "3) Geocode LOC β†’ fetch nearby restaurants & attractions\n"
144
+ "4) Fetch Wikipedia summaries\n"
145
  "5) Translate **all** results β†’ your target language"
146
  )
147
  ).queue()