import os import random from statistics import mean from typing import Iterator, Union, Any import fasttext import gradio as gr from dotenv import load_dotenv from huggingface_hub import hf_hub_download from huggingface_hub.utils import logging from toolz import concat, groupby, valmap from pathlib import Path logger = logging.get_logger(__name__) load_dotenv() DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID" # Language code mapping - feel free to expand this LANGUAGE_MAPPING = { "spa_Latn": {"name": "Spanish", "iso_639_1": "es", "full_code": "es_ES"}, "eng_Latn": {"name": "English", "iso_639_1": "en", "full_code": "en_US"}, "fra_Latn": {"name": "French", "iso_639_1": "fr", "full_code": "fr_FR"}, "deu_Latn": {"name": "German", "iso_639_1": "de", "full_code": "de_DE"}, "ita_Latn": {"name": "Italian", "iso_639_1": "it", "full_code": "it_IT"}, "por_Latn": {"name": "Portuguese", "iso_639_1": "pt", "full_code": "pt_PT"}, "rus_Cyrl": {"name": "Russian", "iso_639_1": "ru", "full_code": "ru_RU"}, "zho_Hans": {"name": "Chinese (Simplified)", "iso_639_1": "zh", "full_code": "zh_CN"}, "zho_Hant": {"name": "Chinese (Traditional)", "iso_639_1": "zh", "full_code": "zh_TW"}, "jpn_Jpan": {"name": "Japanese", "iso_639_1": "ja", "full_code": "ja_JP"}, "kor_Hang": {"name": "Korean", "iso_639_1": "ko", "full_code": "ko_KR"}, "ara_Arab": {"name": "Arabic", "iso_639_1": "ar", "full_code": "ar_SA"}, "hin_Deva": {"name": "Hindi", "iso_639_1": "hi", "full_code": "hi_IN"}, "cat_Latn": {"name": "Catalan", "iso_639_1": "ca", "full_code": "ca_ES"}, "glg_Latn": {"name": "Galician", "iso_639_1": "gl", "full_code": "gl_ES"}, "nld_Latn": {"name": "Dutch", "iso_639_1": "nl", "full_code": "nl_NL"}, "swe_Latn": {"name": "Swedish", "iso_639_1": "sv", "full_code": "sv_SE"}, "nor_Latn": {"name": "Norwegian", "iso_639_1": "no", "full_code": "no_NO"}, "dan_Latn": {"name": "Danish", "iso_639_1": "da", "full_code": "da_DK"}, "fin_Latn": {"name": "Finnish", "iso_639_1": "fi", "full_code": "fi_FI"}, "pol_Latn": {"name": "Polish", "iso_639_1": "pl", "full_code": "pl_PL"}, "ces_Latn": {"name": "Czech", "iso_639_1": "cs", "full_code": "cs_CZ"}, "hun_Latn": {"name": "Hungarian", "iso_639_1": "hu", "full_code": "hu_HU"}, "tur_Latn": {"name": "Turkish", "iso_639_1": "tr", "full_code": "tr_TR"}, "heb_Hebr": {"name": "Hebrew", "iso_639_1": "he", "full_code": "he_IL"}, "tha_Thai": {"name": "Thai", "iso_639_1": "th", "full_code": "th_TH"}, "vie_Latn": {"name": "Vietnamese", "iso_639_1": "vi", "full_code": "vi_VN"}, "ukr_Cyrl": {"name": "Ukrainian", "iso_639_1": "uk", "full_code": "uk_UA"}, "ell_Grek": {"name": "Greek", "iso_639_1": "el", "full_code": "el_GR"}, "bul_Cyrl": {"name": "Bulgarian", "iso_639_1": "bg", "full_code": "bg_BG"}, "ron_Latn": {"name": "Romanian", "iso_639_1": "ro", "full_code": "ro_RO"}, "hrv_Latn": {"name": "Croatian", "iso_639_1": "hr", "full_code": "hr_HR"}, "srp_Cyrl": {"name": "Serbian", "iso_639_1": "sr", "full_code": "sr_RS"}, "slv_Latn": {"name": "Slovenian", "iso_639_1": "sl", "full_code": "sl_SI"}, "slk_Latn": {"name": "Slovak", "iso_639_1": "sk", "full_code": "sk_SK"}, "est_Latn": {"name": "Estonian", "iso_639_1": "et", "full_code": "et_EE"}, "lav_Latn": {"name": "Latvian", "iso_639_1": "lv", "full_code": "lv_LV"}, "lit_Latn": {"name": "Lithuanian", "iso_639_1": "lt", "full_code": "lt_LT"}, "msa_Latn": {"name": "Malay", "iso_639_1": "ms", "full_code": "ms_MY"}, "ind_Latn": {"name": "Indonesian", "iso_639_1": "id", "full_code": "id_ID"}, "tgl_Latn": {"name": "Filipino", "iso_639_1": "tl", "full_code": "tl_PH"}, } def load_model(repo_id: str) -> fasttext.FastText._FastText: model_path = hf_hub_download(repo_id, filename="model.bin") return fasttext.load_model(model_path) def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: for row in rows: if isinstance(row, str): # split on lines and remove empty lines line = row.split("\n") for line in line: if line: yield line elif isinstance(row, list): try: line = " ".join(row) if len(line) < min_length: continue else: yield line except TypeError: continue FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn" def format_language_info(fasttext_code): """Convert FastText language code to human readable format""" if fasttext_code in LANGUAGE_MAPPING: lang_info = LANGUAGE_MAPPING[fasttext_code] return { "name": lang_info["name"], "iso_code": lang_info["iso_639_1"], "full_code": lang_info["full_code"], "fasttext_code": fasttext_code } else: # Graceful fallback for unmapped languages return { "name": fasttext_code, "iso_code": "unknown", "full_code": "unknown", "fasttext_code": fasttext_code } def detect_language_segments(text, confidence_threshold=0.3): """Detect language changes in text segments""" # Split text into logical segments (sentences, clauses) import re # More sophisticated splitting on common separators segments = re.split(r'[.!?;/|]\s+|\s+/\s+|\s+\|\s+', text.strip()) segments = [seg.strip() for seg in segments if seg.strip() and len(seg.strip()) > 10] if len(segments) < 2: return None segment_results = [] for i, segment in enumerate(segments): predictions = model_predict(segment, k=1) if predictions and predictions[0]['score'] > confidence_threshold: lang_info = format_language_info(predictions[0]['label']) segment_results.append({ "segment_number": i + 1, "text": segment, "language": lang_info, "confidence": predictions[0]['score'] }) # Check if we found different languages languages_found = set(result['language']['fasttext_code'] for result in segment_results) if len(languages_found) > 1: return { "is_multilingual": True, "languages_detected": list(languages_found), "segments": segment_results } return None # Load the model Path("code/models").mkdir(parents=True, exist_ok=True) model = fasttext.load_model( hf_hub_download( "facebook/fasttext-language-identification", "model.bin", cache_dir="code/models", local_dir="code/models", local_dir_use_symlinks=False, ) ) def model_predict(inputs: str, k=1) -> list[dict[str, float]]: predictions = model.predict(inputs, k=k) return [ {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} for label, prob in zip(predictions[0], predictions[1]) ] def get_label(x): return x.get("label") def get_mean_score(preds): return mean([pred.get("score") for pred in preds]) def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): """Filter a dict to include items whose value is above `threshold_percent`""" total = sum(counts_dict.values()) threshold = total * threshold_percent return {k for k, v in counts_dict.items() if v >= threshold} def simple_predict(text, num_predictions=3): """Simple language detection function for Gradio interface""" if not text or not text.strip(): return {"error": "Please enter some text for language detection."} try: # Clean the text cleaned_lines = list(yield_clean_rows([text])) if not cleaned_lines: return {"error": "No valid text found after cleaning."} # Get predictions for each line all_predictions = [] for line in cleaned_lines: predictions = model_predict(line, k=num_predictions) all_predictions.extend(predictions) if not all_predictions: return {"error": "No predictions could be made."} # Group predictions by language predictions_by_lang = groupby(get_label, all_predictions) language_counts = valmap(len, predictions_by_lang) # Calculate average scores for each language language_scores = valmap(get_mean_score, predictions_by_lang) # Format results # Format with human-readable language info formatted_languages = {} for fasttext_code, score in language_scores.items(): lang_info = format_language_info(fasttext_code) formatted_languages[fasttext_code] = { "score": score, "language_info": lang_info } # Check for multilingual segments segment_analysis = detect_language_segments(text) # Format results results = { "detected_languages": formatted_languages, "language_counts": dict(language_counts), "total_predictions": len(all_predictions), "text_lines_analyzed": len(cleaned_lines) } # Add segment analysis if multilingual if segment_analysis: results["segment_analysis"] = segment_analysis return results except Exception as e: return {"error": f"Error during prediction: {str(e)}"} def batch_predict(text, threshold_percent=0.2): """More advanced prediction with filtering""" if not text or not text.strip(): return {"error": "Please enter some text for language detection."} try: # Clean the text cleaned_lines = list(yield_clean_rows([text])) if not cleaned_lines: return {"error": "No valid text found after cleaning."} # Get predictions predictions = [model_predict(line) for line in cleaned_lines] predictions = [pred for pred in predictions if pred is not None] predictions = list(concat(predictions)) if not predictions: return {"error": "No predictions could be made."} # Group and filter predictions_by_lang = groupby(get_label, predictions) language_counts = valmap(len, predictions_by_lang) keys_to_keep = filter_by_frequency(language_counts, threshold_percent=threshold_percent) filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} # Format with human-readable language info formatted_predictions = {} for fasttext_code, score in valmap(get_mean_score, filtered_dict).items(): lang_info = format_language_info(fasttext_code) formatted_predictions[fasttext_code] = { "score": score, "language_info": lang_info } # Check for multilingual segments segment_analysis = detect_language_segments(text) results = { "predictions": formatted_predictions, "all_language_counts": dict(language_counts), "filtered_languages": list(keys_to_keep), "threshold_used": threshold_percent } # Add segment analysis if multilingual if segment_analysis: results["segment_analysis"] = segment_analysis return results except Exception as e: return {"error": f"Error during prediction: {str(e)}"} def build_demo_interface(): app_title = "Language Detection Tool" with gr.Blocks(title=app_title) as demo: gr.Markdown(f"# {app_title}") gr.Markdown("Enter text below to detect the language(s) it contains.") with gr.Tab("Simple Detection"): with gr.Row(): with gr.Column(): text_input1 = gr.Textbox( label="Enter text for language detection", placeholder="Type or paste your text here...", lines=5 ) num_predictions = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of top predictions per line" ) predict_btn1 = gr.Button("Detect Language") with gr.Column(): output1 = gr.JSON(label="Detection Results") predict_btn1.click( simple_predict, inputs=[text_input1, num_predictions], outputs=output1 ) with gr.Tab("Advanced Detection"): with gr.Row(): with gr.Column(): text_input2 = gr.Textbox( label="Enter text for advanced language detection", placeholder="Type or paste your text here...", lines=5 ) threshold = gr.Slider( minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Threshold percentage for filtering" ) predict_btn2 = gr.Button("Advanced Detect") with gr.Column(): output2 = gr.JSON(label="Advanced Detection Results") predict_btn2.click( batch_predict, inputs=[text_input2, threshold], outputs=output2 ) gr.Markdown("### About") gr.Markdown("This tool uses Facebook's FastText language identification model to detect languages in text.") return demo if __name__ == "__main__": demo = build_demo_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False )