bwingenroth commited on
Commit
2c69943
·
verified ·
1 Parent(s): 68581ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -95
app.py CHANGED
@@ -5,15 +5,11 @@ from typing import Iterator, Union, Any
5
  import fasttext
6
  import gradio as gr
7
  from dotenv import load_dotenv
8
- from httpx import Client, Timeout
9
  from huggingface_hub import hf_hub_download
10
  from huggingface_hub.utils import logging
11
  from toolz import concat, groupby, valmap
12
- from fastapi import FastAPI
13
- from httpx import AsyncClient
14
  from pathlib import Path
15
 
16
- app = FastAPI()
17
  logger = logging.get_logger(__name__)
18
  load_dotenv()
19
 
@@ -23,7 +19,6 @@ def load_model(repo_id: str) -> fasttext.FastText._FastText:
23
  model_path = hf_hub_download(repo_id, filename="model.bin")
24
  return fasttext.load_model(model_path)
25
 
26
-
27
  def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
28
  for row in rows:
29
  if isinstance(row, str):
@@ -42,10 +37,9 @@ def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterat
42
  except TypeError:
43
  continue
44
 
45
-
46
  FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
47
 
48
- # model = load_model(DEFAULT_FAST_TEXT_MODEL)
49
  Path("code/models").mkdir(parents=True, exist_ok=True)
50
  model = fasttext.load_model(
51
  hf_hub_download(
@@ -57,7 +51,6 @@ model = fasttext.load_model(
57
  )
58
  )
59
 
60
-
61
  def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
62
  predictions = model.predict(inputs, k=k)
63
  return [
@@ -65,103 +58,163 @@ def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
65
  for label, prob in zip(predictions[0], predictions[1])
66
  ]
67
 
68
-
69
  def get_label(x):
70
  return x.get("label")
71
 
72
-
73
  def get_mean_score(preds):
74
  return mean([pred.get("score") for pred in preds])
75
 
76
-
77
  def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
78
  """Filter a dict to include items whose value is above `threshold_percent`"""
79
  total = sum(counts_dict.values())
80
  threshold = total * threshold_percent
81
  return {k for k, v in counts_dict.items() if v >= threshold}
82
 
83
-
84
- def predict_rows(rows, target_column, language_threshold_percent=0.2):
85
- rows = (row.get(target_column) for row in rows)
86
- rows = (row for row in rows if row is not None)
87
- rows = list(yield_clean_rows(rows))
88
- predictions = [model_predict(row) for row in rows]
89
- predictions = [pred for pred in predictions if pred is not None]
90
- predictions = list(concat(predictions))
91
- predictions_by_lang = groupby(get_label, predictions)
92
- langues_counts = valmap(len, predictions_by_lang)
93
- keys_to_keep = filter_by_frequency(
94
- langues_counts, threshold_percent=language_threshold_percent
95
- )
96
- filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
97
- return {
98
- "predictions": dict(valmap(get_mean_score, filtered_dict)),
99
- "pred": predictions,
100
- }
101
-
102
-
103
- @app.get("/items/{hub_id}")
104
- async def predict_language(
105
- hub_id: str,
106
- config: str | None = None,
107
- split: str | None = None,
108
- max_request_calls: int = 10,
109
- number_of_rows: int = 1000,
110
- ) -> dict[Any, Any]:
111
- is_valid = datasets_server_valid_rows(hub_id)
112
- if not is_valid:
113
- gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
114
- if not config:
115
- config, split = await get_first_config_and_split_name(hub_id)
116
- info = await get_dataset_info(hub_id, config)
117
- if info is None:
118
- gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
119
- if dataset_info := info.get("dataset_info"):
120
- total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples")
121
- features = dataset_info.get("features")
122
- column_names = set(features.keys())
123
- logger.info(f"Column names: {column_names}")
124
- if not set(column_names).intersection(TARGET_COLUMN_NAMES):
125
- raise gr.Error(
126
- f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  )
128
- for column in TARGET_COLUMN_NAMES:
129
- if column in column_names:
130
- target_column = column
131
- logger.info(f"Using column {target_column} for language detection")
132
- break
133
- random_rows = await get_random_rows(
134
- hub_id,
135
- total_rows_for_split,
136
- number_of_rows,
137
- max_request_calls,
138
- config,
139
- split,
140
- )
141
- logger.info(f"Predicting language for {len(random_rows)} rows")
142
- predictions = predict_rows(random_rows, target_column)
143
- predictions["hub_id"] = hub_id
144
- predictions["config"] = config
145
- predictions["split"] = split
146
- return predictions
147
-
148
-
149
- @app.get("/")
150
- def main():
151
- app_title = "Language Detection"
152
- inputs = [
153
- gr.Textbox(
154
- None,
155
- label="enter content",
156
- ),
157
- gr.Textbox(None, label="split"),
158
- ]
159
- interface = gr.Interface(
160
- predict_language,
161
- inputs=inputs,
162
- outputs="json",
163
- title=app_title,
164
- # article=app_description,
165
- )
166
- interface.queue()
167
- interface.launch()
 
5
  import fasttext
6
  import gradio as gr
7
  from dotenv import load_dotenv
 
8
  from huggingface_hub import hf_hub_download
9
  from huggingface_hub.utils import logging
10
  from toolz import concat, groupby, valmap
 
 
11
  from pathlib import Path
12
 
 
13
  logger = logging.get_logger(__name__)
14
  load_dotenv()
15
 
 
19
  model_path = hf_hub_download(repo_id, filename="model.bin")
20
  return fasttext.load_model(model_path)
21
 
 
22
  def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
23
  for row in rows:
24
  if isinstance(row, str):
 
37
  except TypeError:
38
  continue
39
 
 
40
  FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
41
 
42
+ # Load the model
43
  Path("code/models").mkdir(parents=True, exist_ok=True)
44
  model = fasttext.load_model(
45
  hf_hub_download(
 
51
  )
52
  )
53
 
 
54
  def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
55
  predictions = model.predict(inputs, k=k)
56
  return [
 
58
  for label, prob in zip(predictions[0], predictions[1])
59
  ]
60
 
 
61
  def get_label(x):
62
  return x.get("label")
63
 
 
64
  def get_mean_score(preds):
65
  return mean([pred.get("score") for pred in preds])
66
 
 
67
  def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
68
  """Filter a dict to include items whose value is above `threshold_percent`"""
69
  total = sum(counts_dict.values())
70
  threshold = total * threshold_percent
71
  return {k for k, v in counts_dict.items() if v >= threshold}
72
 
73
+ def simple_predict(text, num_predictions=3):
74
+ """Simple language detection function for Gradio interface"""
75
+ if not text or not text.strip():
76
+ return "Please enter some text for language detection."
77
+
78
+ try:
79
+ # Clean the text
80
+ cleaned_lines = list(yield_clean_rows([text]))
81
+ if not cleaned_lines:
82
+ return "No valid text found after cleaning."
83
+
84
+ # Get predictions for each line
85
+ all_predictions = []
86
+ for line in cleaned_lines:
87
+ predictions = model_predict(line, k=num_predictions)
88
+ all_predictions.extend(predictions)
89
+
90
+ if not all_predictions:
91
+ return "No predictions could be made."
92
+
93
+ # Group predictions by language
94
+ predictions_by_lang = groupby(get_label, all_predictions)
95
+ language_counts = valmap(len, predictions_by_lang)
96
+
97
+ # Calculate average scores for each language
98
+ language_scores = valmap(get_mean_score, predictions_by_lang)
99
+
100
+ # Format results
101
+ results = {
102
+ "detected_languages": dict(language_scores),
103
+ "language_counts": dict(language_counts),
104
+ "total_predictions": len(all_predictions),
105
+ "text_lines_analyzed": len(cleaned_lines)
106
+ }
107
+
108
+ return results
109
+
110
+ except Exception as e:
111
+ return f"Error during prediction: {str(e)}"
112
+
113
+ def batch_predict(text, threshold_percent=0.2):
114
+ """More advanced prediction with filtering"""
115
+ if not text or not text.strip():
116
+ return "Please enter some text for language detection."
117
+
118
+ try:
119
+ # Clean the text
120
+ cleaned_lines = list(yield_clean_rows([text]))
121
+ if not cleaned_lines:
122
+ return "No valid text found after cleaning."
123
+
124
+ # Get predictions
125
+ predictions = [model_predict(line) for line in cleaned_lines]
126
+ predictions = [pred for pred in predictions if pred is not None]
127
+ predictions = list(concat(predictions))
128
+
129
+ if not predictions:
130
+ return "No predictions could be made."
131
+
132
+ # Group and filter
133
+ predictions_by_lang = groupby(get_label, predictions)
134
+ language_counts = valmap(len, predictions_by_lang)
135
+ keys_to_keep = filter_by_frequency(language_counts, threshold_percent=threshold_percent)
136
+ filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
137
+
138
+ results = {
139
+ "predictions": dict(valmap(get_mean_score, filtered_dict)),
140
+ "all_language_counts": dict(language_counts),
141
+ "filtered_languages": list(keys_to_keep),
142
+ "threshold_used": threshold_percent
143
+ }
144
+
145
+ return results
146
+
147
+ except Exception as e:
148
+ return f"Error during prediction: {str(e)}"
149
+
150
+ def build_demo_interface():
151
+ app_title = "Language Detection Tool"
152
+ with gr.Blocks(title=app_title) as demo:
153
+ gr.Markdown(f"# {app_title}")
154
+ gr.Markdown("Enter text below to detect the language(s) it contains.")
155
+
156
+ with gr.Tab("Simple Detection"):
157
+ with gr.Row():
158
+ with gr.Column():
159
+ text_input1 = gr.Textbox(
160
+ label="Enter text for language detection",
161
+ placeholder="Type or paste your text here...",
162
+ lines=5
163
+ )
164
+ num_predictions = gr.Slider(
165
+ minimum=1,
166
+ maximum=10,
167
+ value=3,
168
+ step=1,
169
+ label="Number of top predictions per line"
170
+ )
171
+ predict_btn1 = gr.Button("Detect Language")
172
+
173
+ with gr.Column():
174
+ output1 = gr.JSON(label="Detection Results")
175
+
176
+ predict_btn1.click(
177
+ simple_predict,
178
+ inputs=[text_input1, num_predictions],
179
+ outputs=output1
180
  )
181
+
182
+ with gr.Tab("Advanced Detection"):
183
+ with gr.Row():
184
+ with gr.Column():
185
+ text_input2 = gr.Textbox(
186
+ label="Enter text for advanced language detection",
187
+ placeholder="Type or paste your text here...",
188
+ lines=5
189
+ )
190
+ threshold = gr.Slider(
191
+ minimum=0.1,
192
+ maximum=1.0,
193
+ value=0.2,
194
+ step=0.1,
195
+ label="Threshold percentage for filtering"
196
+ )
197
+ predict_btn2 = gr.Button("Advanced Detect")
198
+
199
+ with gr.Column():
200
+ output2 = gr.JSON(label="Advanced Detection Results")
201
+
202
+ predict_btn2.click(
203
+ batch_predict,
204
+ inputs=[text_input2, threshold],
205
+ outputs=output2
206
+ )
207
+
208
+ gr.Markdown("### About")
209
+ gr.Markdown("This tool uses Facebook's FastText language identification model to detect languages in text.")
210
+
211
+ return demo
212
+
213
+
214
+ if __name__ == "__main__":
215
+ demo = build_demo_interface()
216
+ demo.launch(
217
+ server_name="0.0.0.0",
218
+ server_port=7860,
219
+ share=False
220
+ )