Spaces:
Sleeping
Sleeping
""" | |
FastText embedding module for Tibetan text. | |
This module provides functions to train and use FastText models for Tibetan text. | |
""" | |
import os | |
from pathlib import Path | |
import math | |
import logging | |
import numpy as np | |
import fasttext | |
from collections import Counter | |
from typing import List, Optional, Tuple, Any, Set | |
from huggingface_hub import hf_hub_download | |
# Set up logging | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) # Ensure this logger processes DEBUG messages | |
# Default parameters optimized for Tibetan | |
DEFAULT_DIM = 100 | |
DEFAULT_EPOCH = 5 | |
DEFAULT_MIN_COUNT = 5 | |
DEFAULT_WINDOW = 5 | |
DEFAULT_MINN = 3 | |
DEFAULT_MAXN = 6 | |
DEFAULT_NEG = 5 | |
# Model version information | |
MODEL_VERSIONS = { | |
"facebook-fasttext-pretrained": "v1.0", | |
} | |
# Define paths for model storage | |
DEFAULT_MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") | |
DEFAULT_MODEL_PATH = str(Path(__file__).resolve().parent.parent / "fasttext-modelling" / "tibetan_cbow_model.bin") # Updated to custom model | |
# Facebook's official Tibetan FastText model | |
FACEBOOK_TIBETAN_MODEL_ID = "facebook/fasttext-bo-vectors" | |
FACEBOOK_TIBETAN_MODEL_FILE = "model.bin" | |
# Create models directory if it doesn't exist | |
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True) | |
def ensure_dir_exists(directory: str) -> None: | |
""" | |
Ensure that a directory exists, creating it if necessary. | |
Args: | |
directory: Directory path to ensure exists | |
""" | |
if not os.path.exists(directory): | |
os.makedirs(directory, exist_ok=True) | |
def train_fasttext_model( | |
corpus_path: str, | |
model_path: str = DEFAULT_MODEL_PATH, | |
dim: int = DEFAULT_DIM, | |
epoch: int = DEFAULT_EPOCH, | |
min_count: int = DEFAULT_MIN_COUNT, | |
window: int = DEFAULT_WINDOW, | |
minn: int = DEFAULT_MINN, | |
maxn: int = DEFAULT_MAXN, | |
neg: int = DEFAULT_NEG, | |
model_type: str = "skipgram" | |
) -> fasttext.FastText._FastText: | |
""" | |
Train a FastText model on Tibetan corpus using optimized parameters. | |
Args: | |
corpus_path: Path to the corpus file | |
model_path: Path where to save the trained model | |
dim: Embedding dimension (default: 100) | |
epoch: Number of training epochs (default: 5) | |
min_count: Minimum count of words (default: 5) | |
window: Size of context window (default: 5) | |
minn: Minimum length of char n-gram (default: 3) | |
maxn: Maximum length of char n-gram (default: 6) | |
neg: Number of negatives in negative sampling (default: 5) | |
model_type: FastText model type ('skipgram' or 'cbow') | |
Returns: | |
Trained FastText model | |
""" | |
ensure_dir_exists(os.path.dirname(model_path)) | |
logger.info("Training FastText model with %s, dim=%d, epoch=%d, window=%d, minn=%d, maxn=%d...", | |
model_type, dim, epoch, window, minn, maxn) | |
processed_corpus_path = corpus_path + ".processed" | |
corpus_to_train = corpus_path | |
model = None | |
try: | |
# Preprocess the corpus to a temporary file | |
with open(corpus_path, 'r', encoding='utf-8') as f_in, open(processed_corpus_path, 'w', encoding='utf-8') as f_out: | |
content = f_in.read() | |
processed_content = content.replace('་', '་ ') | |
f_out.write(processed_content) | |
logger.info("Corpus preprocessed to temporary file for Tibetan syllable segmentation.") | |
corpus_to_train = processed_corpus_path | |
# Train the model with optimized parameters | |
if model_type == "skipgram": | |
model = fasttext.train_unsupervised( | |
corpus_to_train, | |
model="skipgram", | |
dim=dim, epoch=epoch, minCount=min_count, wordNgrams=1, | |
minn=minn, maxn=maxn, neg=neg, window=window | |
) | |
else: # cbow | |
model = fasttext.train_unsupervised( | |
corpus_to_train, | |
model="cbow", | |
dim=dim, epoch=epoch, minCount=min_count, wordNgrams=1, | |
minn=minn, maxn=maxn, neg=neg, window=window | |
) | |
model.save_model(model_path) | |
logger.info("FastText model trained and saved to %s", model_path) | |
except Exception as e: | |
logger.error(f"An error occurred during model training: {e}", exc_info=True) | |
# Re-raise the exception after logging and cleanup | |
raise | |
finally: | |
# Clean up the temporary processed file | |
if os.path.exists(processed_corpus_path): | |
os.remove(processed_corpus_path) | |
logger.info(f"Cleaned up temporary file: {processed_corpus_path}") | |
return model | |
def _load_facebook_official_tibetan_model() -> Optional[fasttext.FastText._FastText]: | |
""" | |
Downloads (if necessary) and loads the official Facebook FastText Tibetan model. | |
Returns: | |
Loaded FastText model or None if loading fails. | |
""" | |
try: | |
logger.info("Attempting to download and load official Facebook FastText Tibetan model") | |
facebook_model_path = hf_hub_download( | |
repo_id=FACEBOOK_TIBETAN_MODEL_ID, | |
filename=FACEBOOK_TIBETAN_MODEL_FILE, | |
cache_dir=DEFAULT_MODEL_DIR | |
) | |
logger.info(f"Loading official Facebook FastText Tibetan model from {facebook_model_path}") | |
model = fasttext.load_model(facebook_model_path) | |
if model: | |
logger.info(f"FastText model loaded in load_facebook_official_tibetan_model. Type: {type(model)}") | |
try: | |
# Basic check: get model dimensions | |
dims = model.get_dimension() | |
logger.info(f"Model dimensions reported by fasttext_embedding: {dims}") | |
# Check for a specific word to see if get_word_vector is callable with a string | |
# Using a common Tibetan particle that should be in the vocab | |
test_word = "ལ་" | |
try: | |
vec = model.get_word_vector(test_word) | |
logger.info(f"Successfully retrieved vector for test word '{test_word}'. Vector shape: {vec.shape if vec is not None else 'None'}") | |
except Exception as e_gwv: | |
logger.error(f"Error calling get_word_vector for test word '{test_word}' in fasttext_embedding: {e_gwv}", exc_info=True) | |
# Potentially re-raise or handle if this is critical for model validity | |
except Exception as e_diag_load: | |
logger.error(f"Error during diagnostic checks of loaded FastText model in fasttext_embedding: {e_diag_load}", exc_info=True) | |
# If diagnostics fail, the model might be unusable. Consider returning None. | |
# For now, let it return the model and fail later if that's the case. | |
else: | |
logger.error("fasttext.load_model returned None in load_facebook_official_tibetan_model.") | |
return model | |
except Exception as e_fb: | |
logger.error(f"Could not load official Facebook FastText Tibetan model (outer try-except): {str(e_fb)}", exc_info=True) | |
return None | |
def load_fasttext_model(model_path: str = DEFAULT_MODEL_PATH) -> Optional[fasttext.FastText._FastText]: | |
""" | |
Load a custom FastText model from the specified file path. | |
Args: | |
model_path: Path to the custom model file. | |
Returns: | |
Loaded FastText model or None if loading fails. | |
""" | |
try: | |
if os.path.exists(model_path): | |
logger.info(f"Attempting to load custom FastText model from {model_path}") | |
return fasttext.load_model(model_path) | |
else: | |
logger.error(f"Custom FastText model path {model_path} does not exist.") | |
return None | |
except Exception as e: | |
logger.error(f"Could not load custom FastText model from {model_path}: {str(e)}") | |
return None | |
def _remove_stopwords_from_tokens(tokens: List[str], stopwords_set: Set[str]) -> List[str]: | |
""" | |
Removes stopwords from a list of tokens using a list comprehension for efficiency. | |
Handles Tibetan punctuation by checking both the token itself and the token after | |
stripping trailing '།' or '༔'. | |
""" | |
if not stopwords_set: | |
return tokens | |
return [token for token in tokens if token not in stopwords_set and token.rstrip('།༔') not in stopwords_set] | |
def get_text_embedding( | |
text: str, | |
model: fasttext.FastText._FastText, | |
tokenize_fn=None, | |
use_stopwords: bool = True, | |
stopwords_set=None, | |
use_tfidf_weighting: bool = True, | |
corpus_token_freq=None, # Retained for TF, but IDF will use doc_freq_map | |
doc_freq_map=None, # Document frequency map for IDF | |
total_docs_in_corpus=0 # Total documents in corpus for IDF | |
) -> np.ndarray: | |
""" | |
Get embedding for a text using a FastText model with optional TF-IDF weighting. | |
Args: | |
text: Input text | |
model: FastText model | |
tokenize_fn: Optional tokenization function or pre-tokenized list | |
use_stopwords: Whether to filter out stopwords before computing embeddings | |
stopwords_set: Set of stopwords to filter out (if use_stopwords is True) | |
use_tfidf_weighting: Whether to use TF-IDF weighting for averaging word vectors | |
corpus_token_freq: Dictionary of token frequencies across corpus (required for TF-IDF) | |
Returns: | |
Text embedding vector | |
""" | |
if not text.strip(): | |
return np.zeros(model.get_dimension()) | |
# Handle tokenization | |
if callable(tokenize_fn): | |
tokens = tokenize_fn(text) | |
logger.debug(f"Tokens from callable tokenize_fn (first 20): {tokens[:20]}") | |
elif isinstance(tokenize_fn, list): | |
tokens = tokenize_fn # Use the provided list directly | |
logger.debug(f"Tokens provided as list (first 20): {tokens[:20]}") | |
else: | |
if tokenize_fn is not None: | |
logger.warning(f"tokenize_fn is of unexpected type: {type(tokenize_fn)}. Defaulting to space-split.") | |
else: | |
# This case handles tokenize_fn being explicitly None | |
logger.debug("tokenize_fn is None. Defaulting to space-split.") | |
tokens = text.split() | |
logger.debug(f"Tokens from space-split fallback (first 20): {tokens[:20]}") | |
if use_stopwords and stopwords_set: | |
logger.debug(f"Original tokens before stopword check (first 20): {tokens[:20]}") | |
original_token_count = len(tokens) | |
tokens = _remove_stopwords_from_tokens(tokens, stopwords_set) | |
removed_count = original_token_count - len(tokens) | |
logger.debug(f"Tokens after stopword removal (removed {removed_count}): {tokens[:20]}") | |
if not tokens: | |
logger.debug("Text became empty after tokenization/stopwords, returning zero vector.") | |
return np.zeros(model.get_dimension()) | |
if use_tfidf_weighting and doc_freq_map and total_docs_in_corpus is not None and total_docs_in_corpus > 0: | |
logger.debug("Applying TF-IDF weighting.") | |
N_docs = total_docs_in_corpus | |
logger.debug(f"Total documents (N_docs) for IDF: {N_docs}") | |
token_counts = Counter(tokens) | |
logger.debug(f"Local token counts for this segment (top 5): {dict(token_counts.most_common(5))}") | |
tf_idf_weights = [] | |
token_details_log = [] | |
for token in tokens: # Iterate in original token order | |
tf = token_counts.get(token, 0) / len(tokens) if len(tokens) > 0 else 0 | |
df = doc_freq_map.get(token, 0) | |
idf = math.log((N_docs + 1) / (df + 1)) + 1 | |
weight = tf * idf | |
tf_idf_weights.append(weight) | |
token_details_log.append(f"Token: '{token}', TF: {tf:.4f}, DF: {df}, IDF: {idf:.4f}, Raw_TFIDF: {weight:.4f}") | |
logger.debug("Token TF-IDF details (first 10 tokens):") | |
for i, log_entry in enumerate(token_details_log[:10]): | |
logger.debug(f" {i+1}. {log_entry}") | |
total_weight = sum(tf_idf_weights) | |
logger.debug(f"Sum of raw TF-IDF weights: {total_weight}") | |
logger.debug(f"TF-IDF Summary for text snippet (first 100 chars): '{text[:100]}'. Total_TFIDF_Weight: {total_weight:.8e}. Fallback_to_Uniform: {total_weight <= 1e-6}.") | |
normalized_weights = [] | |
if total_weight > 1e-6: | |
normalized_weights = [w / total_weight for w in tf_idf_weights] | |
logger.debug(f"Normalized weights (first 10): {[f'{w:.4f}' for w in normalized_weights[:10]]}") | |
else: | |
logger.debug("Total TF-IDF weight is very small, falling back to uniform weights.") | |
num_tokens = len(tokens) | |
if num_tokens > 0: | |
normalized_weights = [1/num_tokens] * num_tokens | |
logger.debug(f"Uniform weights (first 10): {[f'{w:.4f}' for w in normalized_weights[:10]]}") | |
weighted_embeddings_sum = np.zeros(model.get_dimension()) | |
if len(normalized_weights) == len(tokens): | |
for i, token in enumerate(tokens): | |
word_vector = model.get_word_vector(token) | |
vec_sum_for_log = np.sum(word_vector) | |
logger.debug(f" Token: '{token}', Word_Vec_Sum: {vec_sum_for_log:.4f}, Applied_Weight: {normalized_weights[i]:.4f}") | |
weighted_embeddings_sum += word_vector * normalized_weights[i] | |
final_embedding = weighted_embeddings_sum | |
else: | |
logger.error("Mismatch between token count and normalized_weights count. THIS IS A BUG. Falling back to simple average.") | |
embeddings = [model.get_word_vector(t) for t in tokens] | |
if embeddings: | |
final_embedding = np.mean(embeddings, axis=0) | |
else: | |
final_embedding = np.zeros(model.get_dimension()) | |
else: | |
if use_tfidf_weighting: | |
logger.debug("TF-IDF weighting was requested but doc_freq_map or total_docs_in_corpus is missing/invalid. Falling back to simple averaging.") | |
else: | |
logger.debug("Using simple averaging of word vectors (TF-IDF not requested or N_docs=0).") | |
embeddings = [] | |
for token in tokens: | |
word_vector = model.get_word_vector(token) | |
embeddings.append(word_vector) | |
vec_sum_for_log = np.sum(word_vector) | |
logger.debug(f" Token: '{token}', Word_Vec_Sum: {vec_sum_for_log:.4f} (simple avg context)") | |
if embeddings: | |
final_embedding = np.mean(embeddings, axis=0) | |
else: | |
final_embedding = np.zeros(model.get_dimension()) | |
final_emb_sum_for_log = np.sum(final_embedding) | |
logger.debug(f"Final aggregated embedding sum: {final_emb_sum_for_log:.4f}, shape: {final_embedding.shape}") | |
logger.debug(f"--- get_text_embedding finished for text (first 50 chars): {text[:50]} ---") | |
return final_embedding | |
def get_batch_embeddings( | |
texts: List[str], | |
model: fasttext.FastText._FastText, | |
tokenize_fn=None, | |
use_stopwords: bool = True, | |
stopwords_set=None, | |
use_tfidf_weighting: bool = True, | |
corpus_token_freq=None, # Corpus-wide term frequencies | |
doc_freq_map=None, # Document frequency map for IDF | |
total_docs_in_corpus=0 # Total documents in corpus for IDF | |
) -> np.ndarray: | |
""" | |
Get embeddings for a batch of texts with optional TF-IDF weighting. | |
Args: | |
texts: List of input texts | |
model: FastText model | |
tokenize_fn: Optional tokenization function or pre-tokenized list of tokens | |
use_stopwords: Whether to filter out stopwords before computing embeddings | |
stopwords_set: Set of stopwords to filter out (if use_stopwords is True) | |
use_tfidf_weighting: Whether to use TF-IDF weighting for averaging word vectors | |
corpus_token_freq: Dictionary of token frequencies across corpus (required for TF-IDF) | |
Returns: | |
Array of text embedding vectors | |
""" | |
# Get embeddings for each text | |
embeddings = [] | |
for i, text_content in enumerate(texts): # Changed 'text' to 'text_content' | |
tokens_or_tokenizer_for_current_text = None | |
if callable(tokenize_fn): | |
tokens_or_tokenizer_for_current_text = tokenize_fn # Pass the function itself | |
elif isinstance(tokenize_fn, list): | |
# If tokenize_fn is a list, it's assumed to be a list of pre-tokenized documents | |
if i < len(tokenize_fn): | |
tokens_or_tokenizer_for_current_text = tokenize_fn[i] # This is List[str] for the current text | |
else: | |
logger.warning(f"Pre-tokenized list `tokenize_fn` is shorter than the list of texts. Index {i} is out of bounds for `tokenize_fn` with length {len(tokenize_fn)}. Defaulting to None for this text.") | |
# If tokenize_fn is None or other, tokens_or_tokenizer_for_current_text remains None (get_text_embedding handles default). | |
try: | |
embedding = get_text_embedding( | |
text_content, # Use renamed variable | |
model, | |
tokenize_fn=tokens_or_tokenizer_for_current_text, # Pass the correctly determined function or token list | |
use_stopwords=use_stopwords, | |
stopwords_set=stopwords_set, | |
use_tfidf_weighting=use_tfidf_weighting, | |
corpus_token_freq=corpus_token_freq, | |
doc_freq_map=doc_freq_map, | |
total_docs_in_corpus=total_docs_in_corpus | |
) | |
embeddings.append(embedding) | |
except Exception as e: | |
source_module_name = "fasttext_embedding.py" | |
logger.error(f"Error generating FastText embeddings in {source_module_name}: {e}", exc_info=True) | |
# Append a zero vector or handle as per desired error strategy | |
embeddings.append(np.zeros(model.get_dimension())) | |
return np.array(embeddings) | |
# Cache for loaded FastText models | |
_fasttext_model_cache = {} | |
def get_model(model_id: str) -> Tuple[Optional[Any], Optional[str]]: | |
""" | |
Loads a FastText model with version tracking. | |
Args: | |
model_id (str): The identifier for the model to load. | |
Returns: | |
Tuple[Optional[Any], Optional[str]]: A tuple containing the loaded model and its type ('fasttext'), | |
or (None, None) if loading fails. | |
""" | |
# Include version information in cache key | |
model_version = MODEL_VERSIONS.get(model_id, "unknown") | |
cache_key = f"{model_id}@{model_version}" | |
if cache_key in _fasttext_model_cache: | |
logger.info(f"Returning cached FastText model: {model_id} (version: {model_version})") | |
return _fasttext_model_cache[cache_key], "fasttext" | |
logger.info(f"Attempting to load FastText model: {model_id} (version: {model_version})") | |
if model_id == "facebook-fasttext-pretrained": | |
try: | |
model = _load_facebook_official_tibetan_model() | |
if model: | |
_fasttext_model_cache[cache_key] = model | |
logger.info(f"FastText model '{model_id}' (version: {model_version}) loaded successfully.") | |
return model, "fasttext" | |
else: | |
logger.error(f"Model loading for '{model_id}' returned None.") | |
return None, None | |
except Exception as e: | |
logger.error(f"Failed to load FastText model '{model_id}': {e}", exc_info=True) | |
return None, None | |
# Add logic for other custom models here if needed | |
# elif model_id == "custom-model-name": | |
# ... | |
else: | |
logger.warning(f"Unsupported FastText model ID: {model_id}") | |
return None, None | |
def generate_embeddings( | |
texts: List[str], | |
model: fasttext.FastText._FastText, | |
tokenize_fn=None, | |
use_stopwords: bool = True, | |
use_lite_stopwords: bool = False, | |
corpus_token_freq=None, # Existing: For TF-IDF | |
doc_freq_map=None, # Added: For TF-IDF document frequency | |
total_docs_in_corpus=0 # Added: For TF-IDF total documents in corpus | |
) -> np.ndarray: | |
""" | |
Generate embeddings for a list of texts using a FastText model. | |
Args: | |
texts: List of input texts | |
model: FastText model | |
tokenize_fn: Optional tokenization function or pre-tokenized list of tokens | |
use_stopwords: Whether to filter out stopwords | |
use_lite_stopwords: Whether to use a lighter set of stopwords | |
corpus_token_freq: Precomputed term frequencies for the corpus (for TF-IDF). | |
doc_freq_map: Precomputed document frequencies for tokens (for TF-IDF). | |
total_docs_in_corpus: Total number of documents in the corpus (for TF-IDF). | |
Returns: | |
Array of text embedding vectors | |
""" | |
# Generate embeddings using FastText | |
try: | |
# Load stopwords if needed | |
stopwords_set = None | |
if use_stopwords: | |
from .tibetan_stopwords import get_stopwords | |
stopwords_set = get_stopwords(use_lite=use_lite_stopwords) | |
logger.info("Loaded %d Tibetan stopwords", len(stopwords_set)) | |
# Generate embeddings | |
embeddings = get_batch_embeddings( | |
texts, | |
model, | |
tokenize_fn=tokenize_fn, | |
use_stopwords=use_stopwords, | |
stopwords_set=stopwords_set, | |
use_tfidf_weighting=True, # TF-IDF weighting enabled | |
corpus_token_freq=corpus_token_freq, # Pass down | |
doc_freq_map=doc_freq_map, # Pass down for TF-IDF | |
total_docs_in_corpus=total_docs_in_corpus # Pass down for TF-IDF | |
) | |
logger.info("FastText embeddings generated with shape: %s", str(embeddings.shape)) | |
return embeddings | |
except Exception as e: | |
logger.error("Error generating FastText embeddings: %s", str(e)) | |
# Return empty embeddings as fallback | |
return np.zeros((len(texts), model.get_dimension())) | |