Spaces:
Running
Running
import os | |
import random | |
from statistics import mean | |
from typing import Iterator, Union, Any | |
import fasttext | |
import gradio as gr | |
from dotenv import load_dotenv | |
from httpx import Client, Timeout | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub.utils import logging | |
from toolz import concat, groupby, valmap | |
from fastapi import FastAPI | |
from httpx import AsyncClient | |
from pathlib import Path | |
app = FastAPI() | |
logger = logging.get_logger(__name__) | |
load_dotenv() | |
DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID" | |
def load_model(repo_id: str) -> fasttext.FastText._FastText: | |
model_path = hf_hub_download(repo_id, filename="model.bin") | |
return fasttext.load_model(model_path) | |
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: | |
for row in rows: | |
if isinstance(row, str): | |
# split on lines and remove empty lines | |
line = row.split("\n") | |
for line in line: | |
if line: | |
yield line | |
elif isinstance(row, list): | |
try: | |
line = " ".join(row) | |
if len(line) < min_length: | |
continue | |
else: | |
yield line | |
except TypeError: | |
continue | |
FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn" | |
# model = load_model(DEFAULT_FAST_TEXT_MODEL) | |
Path("code/models").mkdir(parents=True, exist_ok=True) | |
model = fasttext.load_model( | |
hf_hub_download( | |
"facebook/fasttext-language-identification", | |
"model.bin", | |
cache_dir="code/models", | |
local_dir="code/models", | |
local_dir_use_symlinks=False, | |
) | |
) | |
def model_predict(inputs: str, k=1) -> list[dict[str, float]]: | |
predictions = model.predict(inputs, k=k) | |
return [ | |
{"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} | |
for label, prob in zip(predictions[0], predictions[1]) | |
] | |
def get_label(x): | |
return x.get("label") | |
def get_mean_score(preds): | |
return mean([pred.get("score") for pred in preds]) | |
def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): | |
"""Filter a dict to include items whose value is above `threshold_percent`""" | |
total = sum(counts_dict.values()) | |
threshold = total * threshold_percent | |
return {k for k, v in counts_dict.items() if v >= threshold} | |
def predict_rows(rows, target_column, language_threshold_percent=0.2): | |
rows = (row.get(target_column) for row in rows) | |
rows = (row for row in rows if row is not None) | |
rows = list(yield_clean_rows(rows)) | |
predictions = [model_predict(row) for row in rows] | |
predictions = [pred for pred in predictions if pred is not None] | |
predictions = list(concat(predictions)) | |
predictions_by_lang = groupby(get_label, predictions) | |
langues_counts = valmap(len, predictions_by_lang) | |
keys_to_keep = filter_by_frequency( | |
langues_counts, threshold_percent=language_threshold_percent | |
) | |
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} | |
return { | |
"predictions": dict(valmap(get_mean_score, filtered_dict)), | |
"pred": predictions, | |
} | |
async def predict_language( | |
hub_id: str, | |
config: str | None = None, | |
split: str | None = None, | |
max_request_calls: int = 10, | |
number_of_rows: int = 1000, | |
) -> dict[Any, Any]: | |
is_valid = datasets_server_valid_rows(hub_id) | |
if not is_valid: | |
gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") | |
if not config: | |
config, split = await get_first_config_and_split_name(hub_id) | |
info = await get_dataset_info(hub_id, config) | |
if info is None: | |
gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") | |
if dataset_info := info.get("dataset_info"): | |
total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples") | |
features = dataset_info.get("features") | |
column_names = set(features.keys()) | |
logger.info(f"Column names: {column_names}") | |
if not set(column_names).intersection(TARGET_COLUMN_NAMES): | |
raise gr.Error( | |
f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}" | |
) | |
for column in TARGET_COLUMN_NAMES: | |
if column in column_names: | |
target_column = column | |
logger.info(f"Using column {target_column} for language detection") | |
break | |
random_rows = await get_random_rows( | |
hub_id, | |
total_rows_for_split, | |
number_of_rows, | |
max_request_calls, | |
config, | |
split, | |
) | |
logger.info(f"Predicting language for {len(random_rows)} rows") | |
predictions = predict_rows(random_rows, target_column) | |
predictions["hub_id"] = hub_id | |
predictions["config"] = config | |
predictions["split"] = split | |
return predictions | |
def main(): | |
app_title = "Language Detection" | |
inputs = [ | |
gr.Textbox( | |
None, | |
label="enter content", | |
), | |
gr.Textbox(None, label="split"), | |
] | |
interface = gr.Interface( | |
predict_language, | |
inputs=inputs, | |
outputs="json", | |
title=app_title, | |
# article=app_description, | |
) | |
interface.queue() | |
interface.launch() |