Spaces:
Sleeping
Sleeping
import os | |
import random | |
from statistics import mean | |
from typing import Iterator, Union, Any | |
import fasttext | |
import gradio as gr | |
from dotenv import load_dotenv | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub.utils import logging | |
from toolz import concat, groupby, valmap | |
from pathlib import Path | |
logger = logging.get_logger(__name__) | |
load_dotenv() | |
DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID" | |
# Language code mapping - feel free to expand this | |
LANGUAGE_MAPPING = { | |
"spa_Latn": {"name": "Spanish", "iso_639_1": "es", "full_code": "es_ES"}, | |
"eng_Latn": {"name": "English", "iso_639_1": "en", "full_code": "en_US"}, | |
"fra_Latn": {"name": "French", "iso_639_1": "fr", "full_code": "fr_FR"}, | |
"deu_Latn": {"name": "German", "iso_639_1": "de", "full_code": "de_DE"}, | |
"ita_Latn": {"name": "Italian", "iso_639_1": "it", "full_code": "it_IT"}, | |
"por_Latn": {"name": "Portuguese", "iso_639_1": "pt", "full_code": "pt_PT"}, | |
"rus_Cyrl": {"name": "Russian", "iso_639_1": "ru", "full_code": "ru_RU"}, | |
"zho_Hans": {"name": "Chinese (Simplified)", "iso_639_1": "zh", "full_code": "zh_CN"}, | |
"zho_Hant": {"name": "Chinese (Traditional)", "iso_639_1": "zh", "full_code": "zh_TW"}, | |
"jpn_Jpan": {"name": "Japanese", "iso_639_1": "ja", "full_code": "ja_JP"}, | |
"kor_Hang": {"name": "Korean", "iso_639_1": "ko", "full_code": "ko_KR"}, | |
"ara_Arab": {"name": "Arabic", "iso_639_1": "ar", "full_code": "ar_SA"}, | |
"hin_Deva": {"name": "Hindi", "iso_639_1": "hi", "full_code": "hi_IN"}, | |
"cat_Latn": {"name": "Catalan", "iso_639_1": "ca", "full_code": "ca_ES"}, | |
"glg_Latn": {"name": "Galician", "iso_639_1": "gl", "full_code": "gl_ES"}, | |
"nld_Latn": {"name": "Dutch", "iso_639_1": "nl", "full_code": "nl_NL"}, | |
"swe_Latn": {"name": "Swedish", "iso_639_1": "sv", "full_code": "sv_SE"}, | |
"nor_Latn": {"name": "Norwegian", "iso_639_1": "no", "full_code": "no_NO"}, | |
"dan_Latn": {"name": "Danish", "iso_639_1": "da", "full_code": "da_DK"}, | |
"fin_Latn": {"name": "Finnish", "iso_639_1": "fi", "full_code": "fi_FI"}, | |
"pol_Latn": {"name": "Polish", "iso_639_1": "pl", "full_code": "pl_PL"}, | |
"ces_Latn": {"name": "Czech", "iso_639_1": "cs", "full_code": "cs_CZ"}, | |
"hun_Latn": {"name": "Hungarian", "iso_639_1": "hu", "full_code": "hu_HU"}, | |
"tur_Latn": {"name": "Turkish", "iso_639_1": "tr", "full_code": "tr_TR"}, | |
"heb_Hebr": {"name": "Hebrew", "iso_639_1": "he", "full_code": "he_IL"}, | |
"tha_Thai": {"name": "Thai", "iso_639_1": "th", "full_code": "th_TH"}, | |
"vie_Latn": {"name": "Vietnamese", "iso_639_1": "vi", "full_code": "vi_VN"}, | |
"ukr_Cyrl": {"name": "Ukrainian", "iso_639_1": "uk", "full_code": "uk_UA"}, | |
"ell_Grek": {"name": "Greek", "iso_639_1": "el", "full_code": "el_GR"}, | |
"bul_Cyrl": {"name": "Bulgarian", "iso_639_1": "bg", "full_code": "bg_BG"}, | |
"ron_Latn": {"name": "Romanian", "iso_639_1": "ro", "full_code": "ro_RO"}, | |
"hrv_Latn": {"name": "Croatian", "iso_639_1": "hr", "full_code": "hr_HR"}, | |
"srp_Cyrl": {"name": "Serbian", "iso_639_1": "sr", "full_code": "sr_RS"}, | |
"slv_Latn": {"name": "Slovenian", "iso_639_1": "sl", "full_code": "sl_SI"}, | |
"slk_Latn": {"name": "Slovak", "iso_639_1": "sk", "full_code": "sk_SK"}, | |
"est_Latn": {"name": "Estonian", "iso_639_1": "et", "full_code": "et_EE"}, | |
"lav_Latn": {"name": "Latvian", "iso_639_1": "lv", "full_code": "lv_LV"}, | |
"lit_Latn": {"name": "Lithuanian", "iso_639_1": "lt", "full_code": "lt_LT"}, | |
"msa_Latn": {"name": "Malay", "iso_639_1": "ms", "full_code": "ms_MY"}, | |
"ind_Latn": {"name": "Indonesian", "iso_639_1": "id", "full_code": "id_ID"}, | |
"tgl_Latn": {"name": "Filipino", "iso_639_1": "tl", "full_code": "tl_PH"}, | |
} | |
def load_model(repo_id: str) -> fasttext.FastText._FastText: | |
model_path = hf_hub_download(repo_id, filename="model.bin") | |
return fasttext.load_model(model_path) | |
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: | |
for row in rows: | |
if isinstance(row, str): | |
# split on lines and remove empty lines | |
line = row.split("\n") | |
for line in line: | |
if line: | |
yield line | |
elif isinstance(row, list): | |
try: | |
line = " ".join(row) | |
if len(line) < min_length: | |
continue | |
else: | |
yield line | |
except TypeError: | |
continue | |
FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn" | |
def format_language_info(fasttext_code): | |
"""Convert FastText language code to human readable format""" | |
if fasttext_code in LANGUAGE_MAPPING: | |
lang_info = LANGUAGE_MAPPING[fasttext_code] | |
return { | |
"name": lang_info["name"], | |
"iso_code": lang_info["iso_639_1"], | |
"full_code": lang_info["full_code"], | |
"fasttext_code": fasttext_code | |
} | |
else: | |
# Graceful fallback for unmapped languages | |
return { | |
"name": fasttext_code, | |
"iso_code": "unknown", | |
"full_code": "unknown", | |
"fasttext_code": fasttext_code | |
} | |
def detect_language_segments(text, confidence_threshold=0.3): | |
"""Detect language changes in text segments""" | |
# Split text into logical segments (sentences, clauses) | |
import re | |
# More sophisticated splitting on common separators | |
segments = re.split(r'[.!?;/|]\s+|\s+/\s+|\s+\|\s+', text.strip()) | |
segments = [seg.strip() for seg in segments if seg.strip() and len(seg.strip()) > 10] | |
if len(segments) < 2: | |
return None | |
segment_results = [] | |
for i, segment in enumerate(segments): | |
predictions = model_predict(segment, k=1) | |
if predictions and predictions[0]['score'] > confidence_threshold: | |
lang_info = format_language_info(predictions[0]['label']) | |
segment_results.append({ | |
"segment_number": i + 1, | |
"text": segment, | |
"language": lang_info, | |
"confidence": predictions[0]['score'] | |
}) | |
# Check if we found different languages | |
languages_found = set(result['language']['fasttext_code'] for result in segment_results) | |
if len(languages_found) > 1: | |
return { | |
"is_multilingual": True, | |
"languages_detected": list(languages_found), | |
"segments": segment_results | |
} | |
return None | |
# Load the model | |
Path("code/models").mkdir(parents=True, exist_ok=True) | |
model = fasttext.load_model( | |
hf_hub_download( | |
"facebook/fasttext-language-identification", | |
"model.bin", | |
cache_dir="code/models", | |
local_dir="code/models", | |
local_dir_use_symlinks=False, | |
) | |
) | |
def model_predict(inputs: str, k=1) -> list[dict[str, float]]: | |
predictions = model.predict(inputs, k=k) | |
return [ | |
{"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} | |
for label, prob in zip(predictions[0], predictions[1]) | |
] | |
def get_label(x): | |
return x.get("label") | |
def get_mean_score(preds): | |
return mean([pred.get("score") for pred in preds]) | |
def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): | |
"""Filter a dict to include items whose value is above `threshold_percent`""" | |
total = sum(counts_dict.values()) | |
threshold = total * threshold_percent | |
return {k for k, v in counts_dict.items() if v >= threshold} | |
def simple_predict(text, num_predictions=3): | |
"""Simple language detection function for Gradio interface""" | |
if not text or not text.strip(): | |
return {"error": "Please enter some text for language detection."} | |
try: | |
# Clean the text | |
cleaned_lines = list(yield_clean_rows([text])) | |
if not cleaned_lines: | |
return {"error": "No valid text found after cleaning."} | |
# Get predictions for each line | |
all_predictions = [] | |
for line in cleaned_lines: | |
predictions = model_predict(line, k=num_predictions) | |
all_predictions.extend(predictions) | |
if not all_predictions: | |
return {"error": "No predictions could be made."} | |
# Group predictions by language | |
predictions_by_lang = groupby(get_label, all_predictions) | |
language_counts = valmap(len, predictions_by_lang) | |
# Calculate average scores for each language | |
language_scores = valmap(get_mean_score, predictions_by_lang) | |
# Format results | |
# Format with human-readable language info | |
formatted_languages = {} | |
for fasttext_code, score in language_scores.items(): | |
lang_info = format_language_info(fasttext_code) | |
formatted_languages[fasttext_code] = { | |
"score": score, | |
"language_info": lang_info | |
} | |
# Check for multilingual segments | |
segment_analysis = detect_language_segments(text) | |
# Format results | |
results = { | |
"detected_languages": formatted_languages, | |
"language_counts": dict(language_counts), | |
"total_predictions": len(all_predictions), | |
"text_lines_analyzed": len(cleaned_lines) | |
} | |
# Add segment analysis if multilingual | |
if segment_analysis: | |
results["segment_analysis"] = segment_analysis | |
return results | |
except Exception as e: | |
return {"error": f"Error during prediction: {str(e)}"} | |
def batch_predict(text, threshold_percent=0.2): | |
"""More advanced prediction with filtering""" | |
if not text or not text.strip(): | |
return {"error": "Please enter some text for language detection."} | |
try: | |
# Clean the text | |
cleaned_lines = list(yield_clean_rows([text])) | |
if not cleaned_lines: | |
return {"error": "No valid text found after cleaning."} | |
# Get predictions | |
predictions = [model_predict(line) for line in cleaned_lines] | |
predictions = [pred for pred in predictions if pred is not None] | |
predictions = list(concat(predictions)) | |
if not predictions: | |
return {"error": "No predictions could be made."} | |
# Group and filter | |
predictions_by_lang = groupby(get_label, predictions) | |
language_counts = valmap(len, predictions_by_lang) | |
keys_to_keep = filter_by_frequency(language_counts, threshold_percent=threshold_percent) | |
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} | |
# Format with human-readable language info | |
formatted_predictions = {} | |
for fasttext_code, score in valmap(get_mean_score, filtered_dict).items(): | |
lang_info = format_language_info(fasttext_code) | |
formatted_predictions[fasttext_code] = { | |
"score": score, | |
"language_info": lang_info | |
} | |
# Check for multilingual segments | |
segment_analysis = detect_language_segments(text) | |
results = { | |
"predictions": formatted_predictions, | |
"all_language_counts": dict(language_counts), | |
"filtered_languages": list(keys_to_keep), | |
"threshold_used": threshold_percent | |
} | |
# Add segment analysis if multilingual | |
if segment_analysis: | |
results["segment_analysis"] = segment_analysis | |
return results | |
except Exception as e: | |
return {"error": f"Error during prediction: {str(e)}"} | |
def build_demo_interface(): | |
app_title = "Language Detection Tool" | |
with gr.Blocks(title=app_title) as demo: | |
gr.Markdown(f"# {app_title}") | |
gr.Markdown("Enter text below to detect the language(s) it contains.") | |
with gr.Tab("Simple Detection"): | |
with gr.Row(): | |
with gr.Column(): | |
text_input1 = gr.Textbox( | |
label="Enter text for language detection", | |
placeholder="Type or paste your text here...", | |
lines=5 | |
) | |
num_predictions = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
label="Number of top predictions per line" | |
) | |
predict_btn1 = gr.Button("Detect Language") | |
with gr.Column(): | |
output1 = gr.JSON(label="Detection Results") | |
predict_btn1.click( | |
simple_predict, | |
inputs=[text_input1, num_predictions], | |
outputs=output1 | |
) | |
with gr.Tab("Advanced Detection"): | |
with gr.Row(): | |
with gr.Column(): | |
text_input2 = gr.Textbox( | |
label="Enter text for advanced language detection", | |
placeholder="Type or paste your text here...", | |
lines=5 | |
) | |
threshold = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.2, | |
step=0.1, | |
label="Threshold percentage for filtering" | |
) | |
predict_btn2 = gr.Button("Advanced Detect") | |
with gr.Column(): | |
output2 = gr.JSON(label="Advanced Detection Results") | |
predict_btn2.click( | |
batch_predict, | |
inputs=[text_input2, threshold], | |
outputs=output2 | |
) | |
gr.Markdown("### About") | |
gr.Markdown("This tool uses Facebook's FastText language identification model to detect languages in text.") | |
return demo | |
if __name__ == "__main__": | |
demo = build_demo_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) |