ttm-webapp-hf / pipeline /fasttext_embedding.py
daniel-wojahn's picture
maintenance and alignment prototype
bda2b5b
"""
FastText embedding module for Tibetan text.
This module provides functions to train and use FastText models for Tibetan text.
"""
import os
from pathlib import Path
import math
import logging
import numpy as np
import fasttext
from collections import Counter
from typing import List, Optional, Tuple, Any, Set
from huggingface_hub import hf_hub_download
# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # Ensure this logger processes DEBUG messages
# Default parameters optimized for Tibetan
DEFAULT_DIM = 100
DEFAULT_EPOCH = 5
DEFAULT_MIN_COUNT = 5
DEFAULT_WINDOW = 5
DEFAULT_MINN = 3
DEFAULT_MAXN = 6
DEFAULT_NEG = 5
# Model version information
MODEL_VERSIONS = {
"facebook-fasttext-pretrained": "v1.0",
}
# Define paths for model storage
DEFAULT_MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
DEFAULT_MODEL_PATH = str(Path(__file__).resolve().parent.parent / "fasttext-modelling" / "tibetan_cbow_model.bin") # Updated to custom model
# Facebook's official Tibetan FastText model
FACEBOOK_TIBETAN_MODEL_ID = "facebook/fasttext-bo-vectors"
FACEBOOK_TIBETAN_MODEL_FILE = "model.bin"
# Create models directory if it doesn't exist
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
def ensure_dir_exists(directory: str) -> None:
"""
Ensure that a directory exists, creating it if necessary.
Args:
directory: Directory path to ensure exists
"""
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
def train_fasttext_model(
corpus_path: str,
model_path: str = DEFAULT_MODEL_PATH,
dim: int = DEFAULT_DIM,
epoch: int = DEFAULT_EPOCH,
min_count: int = DEFAULT_MIN_COUNT,
window: int = DEFAULT_WINDOW,
minn: int = DEFAULT_MINN,
maxn: int = DEFAULT_MAXN,
neg: int = DEFAULT_NEG,
model_type: str = "skipgram"
) -> fasttext.FastText._FastText:
"""
Train a FastText model on Tibetan corpus using optimized parameters.
Args:
corpus_path: Path to the corpus file
model_path: Path where to save the trained model
dim: Embedding dimension (default: 100)
epoch: Number of training epochs (default: 5)
min_count: Minimum count of words (default: 5)
window: Size of context window (default: 5)
minn: Minimum length of char n-gram (default: 3)
maxn: Maximum length of char n-gram (default: 6)
neg: Number of negatives in negative sampling (default: 5)
model_type: FastText model type ('skipgram' or 'cbow')
Returns:
Trained FastText model
"""
ensure_dir_exists(os.path.dirname(model_path))
logger.info("Training FastText model with %s, dim=%d, epoch=%d, window=%d, minn=%d, maxn=%d...",
model_type, dim, epoch, window, minn, maxn)
processed_corpus_path = corpus_path + ".processed"
corpus_to_train = corpus_path
model = None
try:
# Preprocess the corpus to a temporary file
with open(corpus_path, 'r', encoding='utf-8') as f_in, open(processed_corpus_path, 'w', encoding='utf-8') as f_out:
content = f_in.read()
processed_content = content.replace('་', '་ ')
f_out.write(processed_content)
logger.info("Corpus preprocessed to temporary file for Tibetan syllable segmentation.")
corpus_to_train = processed_corpus_path
# Train the model with optimized parameters
if model_type == "skipgram":
model = fasttext.train_unsupervised(
corpus_to_train,
model="skipgram",
dim=dim, epoch=epoch, minCount=min_count, wordNgrams=1,
minn=minn, maxn=maxn, neg=neg, window=window
)
else: # cbow
model = fasttext.train_unsupervised(
corpus_to_train,
model="cbow",
dim=dim, epoch=epoch, minCount=min_count, wordNgrams=1,
minn=minn, maxn=maxn, neg=neg, window=window
)
model.save_model(model_path)
logger.info("FastText model trained and saved to %s", model_path)
except Exception as e:
logger.error(f"An error occurred during model training: {e}", exc_info=True)
# Re-raise the exception after logging and cleanup
raise
finally:
# Clean up the temporary processed file
if os.path.exists(processed_corpus_path):
os.remove(processed_corpus_path)
logger.info(f"Cleaned up temporary file: {processed_corpus_path}")
return model
def _load_facebook_official_tibetan_model() -> Optional[fasttext.FastText._FastText]:
"""
Downloads (if necessary) and loads the official Facebook FastText Tibetan model.
Returns:
Loaded FastText model or None if loading fails.
"""
try:
logger.info("Attempting to download and load official Facebook FastText Tibetan model")
facebook_model_path = hf_hub_download(
repo_id=FACEBOOK_TIBETAN_MODEL_ID,
filename=FACEBOOK_TIBETAN_MODEL_FILE,
cache_dir=DEFAULT_MODEL_DIR
)
logger.info(f"Loading official Facebook FastText Tibetan model from {facebook_model_path}")
model = fasttext.load_model(facebook_model_path)
if model:
logger.info(f"FastText model loaded in load_facebook_official_tibetan_model. Type: {type(model)}")
try:
# Basic check: get model dimensions
dims = model.get_dimension()
logger.info(f"Model dimensions reported by fasttext_embedding: {dims}")
# Check for a specific word to see if get_word_vector is callable with a string
# Using a common Tibetan particle that should be in the vocab
test_word = "ལ་"
try:
vec = model.get_word_vector(test_word)
logger.info(f"Successfully retrieved vector for test word '{test_word}'. Vector shape: {vec.shape if vec is not None else 'None'}")
except Exception as e_gwv:
logger.error(f"Error calling get_word_vector for test word '{test_word}' in fasttext_embedding: {e_gwv}", exc_info=True)
# Potentially re-raise or handle if this is critical for model validity
except Exception as e_diag_load:
logger.error(f"Error during diagnostic checks of loaded FastText model in fasttext_embedding: {e_diag_load}", exc_info=True)
# If diagnostics fail, the model might be unusable. Consider returning None.
# For now, let it return the model and fail later if that's the case.
else:
logger.error("fasttext.load_model returned None in load_facebook_official_tibetan_model.")
return model
except Exception as e_fb:
logger.error(f"Could not load official Facebook FastText Tibetan model (outer try-except): {str(e_fb)}", exc_info=True)
return None
def load_fasttext_model(model_path: str = DEFAULT_MODEL_PATH) -> Optional[fasttext.FastText._FastText]:
"""
Load a custom FastText model from the specified file path.
Args:
model_path: Path to the custom model file.
Returns:
Loaded FastText model or None if loading fails.
"""
try:
if os.path.exists(model_path):
logger.info(f"Attempting to load custom FastText model from {model_path}")
return fasttext.load_model(model_path)
else:
logger.error(f"Custom FastText model path {model_path} does not exist.")
return None
except Exception as e:
logger.error(f"Could not load custom FastText model from {model_path}: {str(e)}")
return None
def _remove_stopwords_from_tokens(tokens: List[str], stopwords_set: Set[str]) -> List[str]:
"""
Removes stopwords from a list of tokens using a list comprehension for efficiency.
Handles Tibetan punctuation by checking both the token itself and the token after
stripping trailing '།' or '༔'.
"""
if not stopwords_set:
return tokens
return [token for token in tokens if token not in stopwords_set and token.rstrip('།༔') not in stopwords_set]
def get_text_embedding(
text: str,
model: fasttext.FastText._FastText,
tokenize_fn=None,
use_stopwords: bool = True,
stopwords_set=None,
use_tfidf_weighting: bool = True,
corpus_token_freq=None, # Retained for TF, but IDF will use doc_freq_map
doc_freq_map=None, # Document frequency map for IDF
total_docs_in_corpus=0 # Total documents in corpus for IDF
) -> np.ndarray:
"""
Get embedding for a text using a FastText model with optional TF-IDF weighting.
Args:
text: Input text
model: FastText model
tokenize_fn: Optional tokenization function or pre-tokenized list
use_stopwords: Whether to filter out stopwords before computing embeddings
stopwords_set: Set of stopwords to filter out (if use_stopwords is True)
use_tfidf_weighting: Whether to use TF-IDF weighting for averaging word vectors
corpus_token_freq: Dictionary of token frequencies across corpus (required for TF-IDF)
Returns:
Text embedding vector
"""
if not text.strip():
return np.zeros(model.get_dimension())
# Handle tokenization
if callable(tokenize_fn):
tokens = tokenize_fn(text)
logger.debug(f"Tokens from callable tokenize_fn (first 20): {tokens[:20]}")
elif isinstance(tokenize_fn, list):
tokens = tokenize_fn # Use the provided list directly
logger.debug(f"Tokens provided as list (first 20): {tokens[:20]}")
else:
if tokenize_fn is not None:
logger.warning(f"tokenize_fn is of unexpected type: {type(tokenize_fn)}. Defaulting to space-split.")
else:
# This case handles tokenize_fn being explicitly None
logger.debug("tokenize_fn is None. Defaulting to space-split.")
tokens = text.split()
logger.debug(f"Tokens from space-split fallback (first 20): {tokens[:20]}")
if use_stopwords and stopwords_set:
logger.debug(f"Original tokens before stopword check (first 20): {tokens[:20]}")
original_token_count = len(tokens)
tokens = _remove_stopwords_from_tokens(tokens, stopwords_set)
removed_count = original_token_count - len(tokens)
logger.debug(f"Tokens after stopword removal (removed {removed_count}): {tokens[:20]}")
if not tokens:
logger.debug("Text became empty after tokenization/stopwords, returning zero vector.")
return np.zeros(model.get_dimension())
if use_tfidf_weighting and doc_freq_map and total_docs_in_corpus is not None and total_docs_in_corpus > 0:
logger.debug("Applying TF-IDF weighting.")
N_docs = total_docs_in_corpus
logger.debug(f"Total documents (N_docs) for IDF: {N_docs}")
token_counts = Counter(tokens)
logger.debug(f"Local token counts for this segment (top 5): {dict(token_counts.most_common(5))}")
tf_idf_weights = []
token_details_log = []
for token in tokens: # Iterate in original token order
tf = token_counts.get(token, 0) / len(tokens) if len(tokens) > 0 else 0
df = doc_freq_map.get(token, 0)
idf = math.log((N_docs + 1) / (df + 1)) + 1
weight = tf * idf
tf_idf_weights.append(weight)
token_details_log.append(f"Token: '{token}', TF: {tf:.4f}, DF: {df}, IDF: {idf:.4f}, Raw_TFIDF: {weight:.4f}")
logger.debug("Token TF-IDF details (first 10 tokens):")
for i, log_entry in enumerate(token_details_log[:10]):
logger.debug(f" {i+1}. {log_entry}")
total_weight = sum(tf_idf_weights)
logger.debug(f"Sum of raw TF-IDF weights: {total_weight}")
logger.debug(f"TF-IDF Summary for text snippet (first 100 chars): '{text[:100]}'. Total_TFIDF_Weight: {total_weight:.8e}. Fallback_to_Uniform: {total_weight <= 1e-6}.")
normalized_weights = []
if total_weight > 1e-6:
normalized_weights = [w / total_weight for w in tf_idf_weights]
logger.debug(f"Normalized weights (first 10): {[f'{w:.4f}' for w in normalized_weights[:10]]}")
else:
logger.debug("Total TF-IDF weight is very small, falling back to uniform weights.")
num_tokens = len(tokens)
if num_tokens > 0:
normalized_weights = [1/num_tokens] * num_tokens
logger.debug(f"Uniform weights (first 10): {[f'{w:.4f}' for w in normalized_weights[:10]]}")
weighted_embeddings_sum = np.zeros(model.get_dimension())
if len(normalized_weights) == len(tokens):
for i, token in enumerate(tokens):
word_vector = model.get_word_vector(token)
vec_sum_for_log = np.sum(word_vector)
logger.debug(f" Token: '{token}', Word_Vec_Sum: {vec_sum_for_log:.4f}, Applied_Weight: {normalized_weights[i]:.4f}")
weighted_embeddings_sum += word_vector * normalized_weights[i]
final_embedding = weighted_embeddings_sum
else:
logger.error("Mismatch between token count and normalized_weights count. THIS IS A BUG. Falling back to simple average.")
embeddings = [model.get_word_vector(t) for t in tokens]
if embeddings:
final_embedding = np.mean(embeddings, axis=0)
else:
final_embedding = np.zeros(model.get_dimension())
else:
if use_tfidf_weighting:
logger.debug("TF-IDF weighting was requested but doc_freq_map or total_docs_in_corpus is missing/invalid. Falling back to simple averaging.")
else:
logger.debug("Using simple averaging of word vectors (TF-IDF not requested or N_docs=0).")
embeddings = []
for token in tokens:
word_vector = model.get_word_vector(token)
embeddings.append(word_vector)
vec_sum_for_log = np.sum(word_vector)
logger.debug(f" Token: '{token}', Word_Vec_Sum: {vec_sum_for_log:.4f} (simple avg context)")
if embeddings:
final_embedding = np.mean(embeddings, axis=0)
else:
final_embedding = np.zeros(model.get_dimension())
final_emb_sum_for_log = np.sum(final_embedding)
logger.debug(f"Final aggregated embedding sum: {final_emb_sum_for_log:.4f}, shape: {final_embedding.shape}")
logger.debug(f"--- get_text_embedding finished for text (first 50 chars): {text[:50]} ---")
return final_embedding
def get_batch_embeddings(
texts: List[str],
model: fasttext.FastText._FastText,
tokenize_fn=None,
use_stopwords: bool = True,
stopwords_set=None,
use_tfidf_weighting: bool = True,
corpus_token_freq=None, # Corpus-wide term frequencies
doc_freq_map=None, # Document frequency map for IDF
total_docs_in_corpus=0 # Total documents in corpus for IDF
) -> np.ndarray:
"""
Get embeddings for a batch of texts with optional TF-IDF weighting.
Args:
texts: List of input texts
model: FastText model
tokenize_fn: Optional tokenization function or pre-tokenized list of tokens
use_stopwords: Whether to filter out stopwords before computing embeddings
stopwords_set: Set of stopwords to filter out (if use_stopwords is True)
use_tfidf_weighting: Whether to use TF-IDF weighting for averaging word vectors
corpus_token_freq: Dictionary of token frequencies across corpus (required for TF-IDF)
Returns:
Array of text embedding vectors
"""
# Get embeddings for each text
embeddings = []
for i, text_content in enumerate(texts): # Changed 'text' to 'text_content'
tokens_or_tokenizer_for_current_text = None
if callable(tokenize_fn):
tokens_or_tokenizer_for_current_text = tokenize_fn # Pass the function itself
elif isinstance(tokenize_fn, list):
# If tokenize_fn is a list, it's assumed to be a list of pre-tokenized documents
if i < len(tokenize_fn):
tokens_or_tokenizer_for_current_text = tokenize_fn[i] # This is List[str] for the current text
else:
logger.warning(f"Pre-tokenized list `tokenize_fn` is shorter than the list of texts. Index {i} is out of bounds for `tokenize_fn` with length {len(tokenize_fn)}. Defaulting to None for this text.")
# If tokenize_fn is None or other, tokens_or_tokenizer_for_current_text remains None (get_text_embedding handles default).
try:
embedding = get_text_embedding(
text_content, # Use renamed variable
model,
tokenize_fn=tokens_or_tokenizer_for_current_text, # Pass the correctly determined function or token list
use_stopwords=use_stopwords,
stopwords_set=stopwords_set,
use_tfidf_weighting=use_tfidf_weighting,
corpus_token_freq=corpus_token_freq,
doc_freq_map=doc_freq_map,
total_docs_in_corpus=total_docs_in_corpus
)
embeddings.append(embedding)
except Exception as e:
source_module_name = "fasttext_embedding.py"
logger.error(f"Error generating FastText embeddings in {source_module_name}: {e}", exc_info=True)
# Append a zero vector or handle as per desired error strategy
embeddings.append(np.zeros(model.get_dimension()))
return np.array(embeddings)
# Cache for loaded FastText models
_fasttext_model_cache = {}
def get_model(model_id: str) -> Tuple[Optional[Any], Optional[str]]:
"""
Loads a FastText model with version tracking.
Args:
model_id (str): The identifier for the model to load.
Returns:
Tuple[Optional[Any], Optional[str]]: A tuple containing the loaded model and its type ('fasttext'),
or (None, None) if loading fails.
"""
# Include version information in cache key
model_version = MODEL_VERSIONS.get(model_id, "unknown")
cache_key = f"{model_id}@{model_version}"
if cache_key in _fasttext_model_cache:
logger.info(f"Returning cached FastText model: {model_id} (version: {model_version})")
return _fasttext_model_cache[cache_key], "fasttext"
logger.info(f"Attempting to load FastText model: {model_id} (version: {model_version})")
if model_id == "facebook-fasttext-pretrained":
try:
model = _load_facebook_official_tibetan_model()
if model:
_fasttext_model_cache[cache_key] = model
logger.info(f"FastText model '{model_id}' (version: {model_version}) loaded successfully.")
return model, "fasttext"
else:
logger.error(f"Model loading for '{model_id}' returned None.")
return None, None
except Exception as e:
logger.error(f"Failed to load FastText model '{model_id}': {e}", exc_info=True)
return None, None
# Add logic for other custom models here if needed
# elif model_id == "custom-model-name":
# ...
else:
logger.warning(f"Unsupported FastText model ID: {model_id}")
return None, None
def generate_embeddings(
texts: List[str],
model: fasttext.FastText._FastText,
tokenize_fn=None,
use_stopwords: bool = True,
use_lite_stopwords: bool = False,
corpus_token_freq=None, # Existing: For TF-IDF
doc_freq_map=None, # Added: For TF-IDF document frequency
total_docs_in_corpus=0 # Added: For TF-IDF total documents in corpus
) -> np.ndarray:
"""
Generate embeddings for a list of texts using a FastText model.
Args:
texts: List of input texts
model: FastText model
tokenize_fn: Optional tokenization function or pre-tokenized list of tokens
use_stopwords: Whether to filter out stopwords
use_lite_stopwords: Whether to use a lighter set of stopwords
corpus_token_freq: Precomputed term frequencies for the corpus (for TF-IDF).
doc_freq_map: Precomputed document frequencies for tokens (for TF-IDF).
total_docs_in_corpus: Total number of documents in the corpus (for TF-IDF).
Returns:
Array of text embedding vectors
"""
# Generate embeddings using FastText
try:
# Load stopwords if needed
stopwords_set = None
if use_stopwords:
from .tibetan_stopwords import get_stopwords
stopwords_set = get_stopwords(use_lite=use_lite_stopwords)
logger.info("Loaded %d Tibetan stopwords", len(stopwords_set))
# Generate embeddings
embeddings = get_batch_embeddings(
texts,
model,
tokenize_fn=tokenize_fn,
use_stopwords=use_stopwords,
stopwords_set=stopwords_set,
use_tfidf_weighting=True, # TF-IDF weighting enabled
corpus_token_freq=corpus_token_freq, # Pass down
doc_freq_map=doc_freq_map, # Pass down for TF-IDF
total_docs_in_corpus=total_docs_in_corpus # Pass down for TF-IDF
)
logger.info("FastText embeddings generated with shape: %s", str(embeddings.shape))
return embeddings
except Exception as e:
logger.error("Error generating FastText embeddings: %s", str(e))
# Return empty embeddings as fallback
return np.zeros((len(texts), model.get_dimension()))