Spaces:
Sleeping
Sleeping
| import re | |
| import plotly | |
| from bertopic import BERTopic | |
| from collections import Counter | |
| from src.utils.data_utils import tokeniser | |
| from src.modelling.embed import DalaEmbedder | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| from src.utils.plotting import custom_topic_barchart, custom_umap_plot | |
| from typing import Dict, List, Tuple | |
| class TopicModeller: | |
| """ | |
| Wrapper for topic modelling with BERTopic. | |
| """ | |
| def __init__(self): | |
| # Custom vectoriser with stopword filtering | |
| self.vectorizer_model = None | |
| self.model = None | |
| def _extract_dalat5_stopwords(self, texts: List[str], top_k: int = 75) -> List[str]: | |
| """ | |
| Identify frequent tokens using DalaT5's tokeniser as proxy stopwords. | |
| """ | |
| token_counter = Counter() | |
| for text in texts: | |
| token_ids = tokeniser.encode(text, add_special_tokens=False) | |
| token_counter.update(token_ids) | |
| most_common = token_counter.most_common(top_k) | |
| stop_tokens = [tokeniser.decode([tok_id]).strip() for tok_id, _ in most_common] | |
| return stop_tokens | |
| def _preprocess_texts(self, texts: List[str]) -> List[str]: | |
| """ | |
| Lowercase and remove digits/symbols from texts. | |
| """ | |
| return [ | |
| re.sub(r"\d+|\s+", " ", t.lower()).strip() | |
| for t in texts | |
| ] | |
| def fit( | |
| self, | |
| texts: List[str], | |
| embeddings: List[List[float]] | |
| ) -> Tuple[List[str], plotly.graph_objs.Figure, Dict[int, str], plotly.graph_objs.Figure]: | |
| """ | |
| Fit BERTopic on preprocessed texts and given embeddings. | |
| Returns topics and an interactive plot. | |
| """ | |
| clean_texts = self._preprocess_texts(texts) | |
| # Leverage DalaT5's tokeniser for stopword acquisition | |
| stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75) | |
| # Define vectoriser and model | |
| self.vectoriser_model = CountVectorizer( | |
| stop_words = stopwords, | |
| token_pattern = r"\b[a-zA-Z]+(?:-[a-zA-Z]+)?\b" | |
| ) | |
| self.model = BERTopic( | |
| language = "multilingual", | |
| vectorizer_model = self.vectoriser_model, | |
| embedding_model = DalaEmbedder().get_model() | |
| ) | |
| topics, _ = self.model.fit_transform(clean_texts, embeddings) | |
| # Generate labels | |
| topic_info = self.model.get_topic_info() | |
| topic_labels = {} | |
| for topic_id in topic_info.Topic.values: | |
| if topic_id == -1: | |
| topic_labels[topic_id] = '-' | |
| continue | |
| words = [word for word, _ in self.model.get_topic(topic_id)[:4]] | |
| label = "_".join(words) | |
| topic_labels[topic_id] = f"{topic_id}_{label}" | |
| fig = custom_topic_barchart(self.model, topic_labels) | |
| umap_fig = custom_umap_plot(embeddings, topics, topic_labels) | |
| labeled_topics = [topic_labels[t] for t in topics] | |
| return labeled_topics, fig, topic_labels, umap_fig | |