detect-language / app.py
bwingenroth's picture
Update app.py
0f7e916 verified
import os
import random
from statistics import mean
from typing import Iterator, Union, Any
import fasttext
import gradio as gr
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import logging
from toolz import concat, groupby, valmap
from pathlib import Path
logger = logging.get_logger(__name__)
load_dotenv()
DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID"
# Language code mapping - feel free to expand this
LANGUAGE_MAPPING = {
"spa_Latn": {"name": "Spanish", "iso_639_1": "es", "full_code": "es_ES"},
"eng_Latn": {"name": "English", "iso_639_1": "en", "full_code": "en_US"},
"fra_Latn": {"name": "French", "iso_639_1": "fr", "full_code": "fr_FR"},
"deu_Latn": {"name": "German", "iso_639_1": "de", "full_code": "de_DE"},
"ita_Latn": {"name": "Italian", "iso_639_1": "it", "full_code": "it_IT"},
"por_Latn": {"name": "Portuguese", "iso_639_1": "pt", "full_code": "pt_PT"},
"rus_Cyrl": {"name": "Russian", "iso_639_1": "ru", "full_code": "ru_RU"},
"zho_Hans": {"name": "Chinese (Simplified)", "iso_639_1": "zh", "full_code": "zh_CN"},
"zho_Hant": {"name": "Chinese (Traditional)", "iso_639_1": "zh", "full_code": "zh_TW"},
"jpn_Jpan": {"name": "Japanese", "iso_639_1": "ja", "full_code": "ja_JP"},
"kor_Hang": {"name": "Korean", "iso_639_1": "ko", "full_code": "ko_KR"},
"ara_Arab": {"name": "Arabic", "iso_639_1": "ar", "full_code": "ar_SA"},
"hin_Deva": {"name": "Hindi", "iso_639_1": "hi", "full_code": "hi_IN"},
"cat_Latn": {"name": "Catalan", "iso_639_1": "ca", "full_code": "ca_ES"},
"glg_Latn": {"name": "Galician", "iso_639_1": "gl", "full_code": "gl_ES"},
"nld_Latn": {"name": "Dutch", "iso_639_1": "nl", "full_code": "nl_NL"},
"swe_Latn": {"name": "Swedish", "iso_639_1": "sv", "full_code": "sv_SE"},
"nor_Latn": {"name": "Norwegian", "iso_639_1": "no", "full_code": "no_NO"},
"dan_Latn": {"name": "Danish", "iso_639_1": "da", "full_code": "da_DK"},
"fin_Latn": {"name": "Finnish", "iso_639_1": "fi", "full_code": "fi_FI"},
"pol_Latn": {"name": "Polish", "iso_639_1": "pl", "full_code": "pl_PL"},
"ces_Latn": {"name": "Czech", "iso_639_1": "cs", "full_code": "cs_CZ"},
"hun_Latn": {"name": "Hungarian", "iso_639_1": "hu", "full_code": "hu_HU"},
"tur_Latn": {"name": "Turkish", "iso_639_1": "tr", "full_code": "tr_TR"},
"heb_Hebr": {"name": "Hebrew", "iso_639_1": "he", "full_code": "he_IL"},
"tha_Thai": {"name": "Thai", "iso_639_1": "th", "full_code": "th_TH"},
"vie_Latn": {"name": "Vietnamese", "iso_639_1": "vi", "full_code": "vi_VN"},
"ukr_Cyrl": {"name": "Ukrainian", "iso_639_1": "uk", "full_code": "uk_UA"},
"ell_Grek": {"name": "Greek", "iso_639_1": "el", "full_code": "el_GR"},
"bul_Cyrl": {"name": "Bulgarian", "iso_639_1": "bg", "full_code": "bg_BG"},
"ron_Latn": {"name": "Romanian", "iso_639_1": "ro", "full_code": "ro_RO"},
"hrv_Latn": {"name": "Croatian", "iso_639_1": "hr", "full_code": "hr_HR"},
"srp_Cyrl": {"name": "Serbian", "iso_639_1": "sr", "full_code": "sr_RS"},
"slv_Latn": {"name": "Slovenian", "iso_639_1": "sl", "full_code": "sl_SI"},
"slk_Latn": {"name": "Slovak", "iso_639_1": "sk", "full_code": "sk_SK"},
"est_Latn": {"name": "Estonian", "iso_639_1": "et", "full_code": "et_EE"},
"lav_Latn": {"name": "Latvian", "iso_639_1": "lv", "full_code": "lv_LV"},
"lit_Latn": {"name": "Lithuanian", "iso_639_1": "lt", "full_code": "lt_LT"},
"msa_Latn": {"name": "Malay", "iso_639_1": "ms", "full_code": "ms_MY"},
"ind_Latn": {"name": "Indonesian", "iso_639_1": "id", "full_code": "id_ID"},
"tgl_Latn": {"name": "Filipino", "iso_639_1": "tl", "full_code": "tl_PH"},
}
def load_model(repo_id: str) -> fasttext.FastText._FastText:
model_path = hf_hub_download(repo_id, filename="model.bin")
return fasttext.load_model(model_path)
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
for row in rows:
if isinstance(row, str):
# split on lines and remove empty lines
line = row.split("\n")
for line in line:
if line:
yield line
elif isinstance(row, list):
try:
line = " ".join(row)
if len(line) < min_length:
continue
else:
yield line
except TypeError:
continue
FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
def format_language_info(fasttext_code):
"""Convert FastText language code to human readable format"""
if fasttext_code in LANGUAGE_MAPPING:
lang_info = LANGUAGE_MAPPING[fasttext_code]
return {
"name": lang_info["name"],
"iso_code": lang_info["iso_639_1"],
"full_code": lang_info["full_code"],
"fasttext_code": fasttext_code
}
else:
# Graceful fallback for unmapped languages
return {
"name": fasttext_code,
"iso_code": "unknown",
"full_code": "unknown",
"fasttext_code": fasttext_code
}
def detect_language_segments(text, confidence_threshold=0.3):
"""Detect language changes in text segments"""
# Split text into logical segments (sentences, clauses)
import re
# More sophisticated splitting on common separators
segments = re.split(r'[.!?;/|]\s+|\s+/\s+|\s+\|\s+', text.strip())
segments = [seg.strip() for seg in segments if seg.strip() and len(seg.strip()) > 10]
if len(segments) < 2:
return None
segment_results = []
for i, segment in enumerate(segments):
predictions = model_predict(segment, k=1)
if predictions and predictions[0]['score'] > confidence_threshold:
lang_info = format_language_info(predictions[0]['label'])
segment_results.append({
"segment_number": i + 1,
"text": segment,
"language": lang_info,
"confidence": predictions[0]['score']
})
# Check if we found different languages
languages_found = set(result['language']['fasttext_code'] for result in segment_results)
if len(languages_found) > 1:
return {
"is_multilingual": True,
"languages_detected": list(languages_found),
"segments": segment_results
}
return None
# Load the model
Path("code/models").mkdir(parents=True, exist_ok=True)
model = fasttext.load_model(
hf_hub_download(
"facebook/fasttext-language-identification",
"model.bin",
cache_dir="code/models",
local_dir="code/models",
local_dir_use_symlinks=False,
)
)
def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
predictions = model.predict(inputs, k=k)
return [
{"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob}
for label, prob in zip(predictions[0], predictions[1])
]
def get_label(x):
return x.get("label")
def get_mean_score(preds):
return mean([pred.get("score") for pred in preds])
def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
"""Filter a dict to include items whose value is above `threshold_percent`"""
total = sum(counts_dict.values())
threshold = total * threshold_percent
return {k for k, v in counts_dict.items() if v >= threshold}
def simple_predict(text, num_predictions=3):
"""Simple language detection function for Gradio interface"""
if not text or not text.strip():
return {"error": "Please enter some text for language detection."}
try:
# Clean the text
cleaned_lines = list(yield_clean_rows([text]))
if not cleaned_lines:
return {"error": "No valid text found after cleaning."}
# Get predictions for each line
all_predictions = []
for line in cleaned_lines:
predictions = model_predict(line, k=num_predictions)
all_predictions.extend(predictions)
if not all_predictions:
return {"error": "No predictions could be made."}
# Group predictions by language
predictions_by_lang = groupby(get_label, all_predictions)
language_counts = valmap(len, predictions_by_lang)
# Calculate average scores for each language
language_scores = valmap(get_mean_score, predictions_by_lang)
# Format results
# Format with human-readable language info
formatted_languages = {}
for fasttext_code, score in language_scores.items():
lang_info = format_language_info(fasttext_code)
formatted_languages[fasttext_code] = {
"score": score,
"language_info": lang_info
}
# Check for multilingual segments
segment_analysis = detect_language_segments(text)
# Format results
results = {
"detected_languages": formatted_languages,
"language_counts": dict(language_counts),
"total_predictions": len(all_predictions),
"text_lines_analyzed": len(cleaned_lines)
}
# Add segment analysis if multilingual
if segment_analysis:
results["segment_analysis"] = segment_analysis
return results
except Exception as e:
return {"error": f"Error during prediction: {str(e)}"}
def batch_predict(text, threshold_percent=0.2):
"""More advanced prediction with filtering"""
if not text or not text.strip():
return {"error": "Please enter some text for language detection."}
try:
# Clean the text
cleaned_lines = list(yield_clean_rows([text]))
if not cleaned_lines:
return {"error": "No valid text found after cleaning."}
# Get predictions
predictions = [model_predict(line) for line in cleaned_lines]
predictions = [pred for pred in predictions if pred is not None]
predictions = list(concat(predictions))
if not predictions:
return {"error": "No predictions could be made."}
# Group and filter
predictions_by_lang = groupby(get_label, predictions)
language_counts = valmap(len, predictions_by_lang)
keys_to_keep = filter_by_frequency(language_counts, threshold_percent=threshold_percent)
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
# Format with human-readable language info
formatted_predictions = {}
for fasttext_code, score in valmap(get_mean_score, filtered_dict).items():
lang_info = format_language_info(fasttext_code)
formatted_predictions[fasttext_code] = {
"score": score,
"language_info": lang_info
}
# Check for multilingual segments
segment_analysis = detect_language_segments(text)
results = {
"predictions": formatted_predictions,
"all_language_counts": dict(language_counts),
"filtered_languages": list(keys_to_keep),
"threshold_used": threshold_percent
}
# Add segment analysis if multilingual
if segment_analysis:
results["segment_analysis"] = segment_analysis
return results
except Exception as e:
return {"error": f"Error during prediction: {str(e)}"}
def build_demo_interface():
app_title = "Language Detection Tool"
with gr.Blocks(title=app_title) as demo:
gr.Markdown(f"# {app_title}")
gr.Markdown("Enter text below to detect the language(s) it contains.")
with gr.Tab("Simple Detection"):
with gr.Row():
with gr.Column():
text_input1 = gr.Textbox(
label="Enter text for language detection",
placeholder="Type or paste your text here...",
lines=5
)
num_predictions = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of top predictions per line"
)
predict_btn1 = gr.Button("Detect Language")
with gr.Column():
output1 = gr.JSON(label="Detection Results")
predict_btn1.click(
simple_predict,
inputs=[text_input1, num_predictions],
outputs=output1
)
with gr.Tab("Advanced Detection"):
with gr.Row():
with gr.Column():
text_input2 = gr.Textbox(
label="Enter text for advanced language detection",
placeholder="Type or paste your text here...",
lines=5
)
threshold = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.2,
step=0.1,
label="Threshold percentage for filtering"
)
predict_btn2 = gr.Button("Advanced Detect")
with gr.Column():
output2 = gr.JSON(label="Advanced Detection Results")
predict_btn2.click(
batch_predict,
inputs=[text_input2, threshold],
outputs=output2
)
gr.Markdown("### About")
gr.Markdown("This tool uses Facebook's FastText language identification model to detect languages in text.")
return demo
if __name__ == "__main__":
demo = build_demo_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)