bwingenroth commited on
Commit
29cd9c0
·
verified ·
1 Parent(s): 4ddeebe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from statistics import mean
4
+ from typing import Iterator, Union, Any
5
+ import fasttext
6
+ import gradio as gr
7
+ from dotenv import load_dotenv
8
+ from httpx import Client, Timeout
9
+ from huggingface_hub import hf_hub_download
10
+ from huggingface_hub.utils import logging
11
+ from toolz import concat, groupby, valmap
12
+ from fastapi import FastAPI
13
+ from httpx import AsyncClient
14
+ from pathlib import Path
15
+
16
+ app = FastAPI()
17
+ logger = logging.get_logger(__name__)
18
+ load_dotenv()
19
+
20
+ DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID"
21
+
22
+ def load_model(repo_id: str) -> fasttext.FastText._FastText:
23
+ model_path = hf_hub_download(repo_id, filename="model.bin")
24
+ return fasttext.load_model(model_path)
25
+
26
+
27
+ def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
28
+ for row in rows:
29
+ if isinstance(row, str):
30
+ # split on lines and remove empty lines
31
+ line = row.split("\n")
32
+ for line in line:
33
+ if line:
34
+ yield line
35
+ elif isinstance(row, list):
36
+ try:
37
+ line = " ".join(row)
38
+ if len(line) < min_length:
39
+ continue
40
+ else:
41
+ yield line
42
+ except TypeError:
43
+ continue
44
+
45
+
46
+ FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
47
+
48
+ # model = load_model(DEFAULT_FAST_TEXT_MODEL)
49
+ Path("code/models").mkdir(parents=True, exist_ok=True)
50
+ model = fasttext.load_model(
51
+ hf_hub_download(
52
+ "facebook/fasttext-language-identification",
53
+ "model.bin",
54
+ cache_dir="code/models",
55
+ local_dir="code/models",
56
+ local_dir_use_symlinks=False,
57
+ )
58
+ )
59
+
60
+
61
+ def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
62
+ predictions = model.predict(inputs, k=k)
63
+ return [
64
+ {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob}
65
+ for label, prob in zip(predictions[0], predictions[1])
66
+ ]
67
+
68
+
69
+ def get_label(x):
70
+ return x.get("label")
71
+
72
+
73
+ def get_mean_score(preds):
74
+ return mean([pred.get("score") for pred in preds])
75
+
76
+
77
+ def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
78
+ """Filter a dict to include items whose value is above `threshold_percent`"""
79
+ total = sum(counts_dict.values())
80
+ threshold = total * threshold_percent
81
+ return {k for k, v in counts_dict.items() if v >= threshold}
82
+
83
+
84
+ def predict_rows(rows, target_column, language_threshold_percent=0.2):
85
+ rows = (row.get(target_column) for row in rows)
86
+ rows = (row for row in rows if row is not None)
87
+ rows = list(yield_clean_rows(rows))
88
+ predictions = [model_predict(row) for row in rows]
89
+ predictions = [pred for pred in predictions if pred is not None]
90
+ predictions = list(concat(predictions))
91
+ predictions_by_lang = groupby(get_label, predictions)
92
+ langues_counts = valmap(len, predictions_by_lang)
93
+ keys_to_keep = filter_by_frequency(
94
+ langues_counts, threshold_percent=language_threshold_percent
95
+ )
96
+ filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
97
+ return {
98
+ "predictions": dict(valmap(get_mean_score, filtered_dict)),
99
+ "pred": predictions,
100
+ }
101
+
102
+
103
+ @app.get("/items/{hub_id}")
104
+ async def predict_language(
105
+ hub_id: str,
106
+ config: str | None = None,
107
+ split: str | None = None,
108
+ max_request_calls: int = 10,
109
+ number_of_rows: int = 1000,
110
+ ) -> dict[Any, Any]:
111
+ is_valid = datasets_server_valid_rows(hub_id)
112
+ if not is_valid:
113
+ gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
114
+ if not config:
115
+ config, split = await get_first_config_and_split_name(hub_id)
116
+ info = await get_dataset_info(hub_id, config)
117
+ if info is None:
118
+ gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
119
+ if dataset_info := info.get("dataset_info"):
120
+ total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples")
121
+ features = dataset_info.get("features")
122
+ column_names = set(features.keys())
123
+ logger.info(f"Column names: {column_names}")
124
+ if not set(column_names).intersection(TARGET_COLUMN_NAMES):
125
+ raise gr.Error(
126
+ f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}"
127
+ )
128
+ for column in TARGET_COLUMN_NAMES:
129
+ if column in column_names:
130
+ target_column = column
131
+ logger.info(f"Using column {target_column} for language detection")
132
+ break
133
+ random_rows = await get_random_rows(
134
+ hub_id,
135
+ total_rows_for_split,
136
+ number_of_rows,
137
+ max_request_calls,
138
+ config,
139
+ split,
140
+ )
141
+ logger.info(f"Predicting language for {len(random_rows)} rows")
142
+ predictions = predict_rows(random_rows, target_column)
143
+ predictions["hub_id"] = hub_id
144
+ predictions["config"] = config
145
+ predictions["split"] = split
146
+ return predictions
147
+
148
+
149
+ @app.get("/")
150
+
151
+ app_title = "Language Detection"
152
+ inputs = [
153
+ gr.Textbox(
154
+ None,
155
+ label="enter content",
156
+ ),
157
+ gr.Textbox(None, label="split"),
158
+ ]
159
+ interface = gr.Interface(
160
+ predict_language,
161
+ inputs=inputs,
162
+ outputs="json",
163
+ title=app_title,
164
+ # article=app_description,
165
+ )
166
+ interface.queue()
167
+ interface.launch()