Spaces:
Runtime error
Runtime error
import logging | |
import asyncio | |
from typing import List, Dict | |
from functools import lru_cache | |
from app.services.base import ( | |
load_spacy_model, | |
load_sentence_transformer_model, | |
ensure_nltk_resource | |
) | |
from app.core.config import ( | |
settings, | |
APP_NAME, | |
SPACY_MODEL_ID, | |
WORDNET_NLTK_ID, | |
SENTENCE_TRANSFORMER_MODEL_ID | |
) | |
from app.core.exceptions import ServiceError, ModelNotDownloadedError | |
from nltk.corpus import wordnet | |
from sentence_transformers.util import cos_sim | |
logger = logging.getLogger(f"{APP_NAME}.services.synonyms") | |
SPACY_TO_WORDNET_POS = { | |
"NOUN": wordnet.NOUN, | |
"VERB": wordnet.VERB, | |
"ADJ": wordnet.ADJ, | |
"ADV": wordnet.ADV, | |
} | |
CONTENT_POS_TAGS = {"NOUN", "VERB", "ADJ", "ADV"} | |
class SynonymSuggester: | |
def __init__(self): | |
self._sentence_model = None | |
self._nlp = None | |
def _get_sentence_model(self): | |
if self._sentence_model is None: | |
self._sentence_model = load_sentence_transformer_model( | |
SENTENCE_TRANSFORMER_MODEL_ID | |
) | |
return self._sentence_model | |
def _get_nlp(self): | |
if self._nlp is None: | |
self._nlp = load_spacy_model( | |
SPACY_MODEL_ID | |
) | |
return self._nlp | |
async def suggest(self, text: str) -> dict: | |
try: | |
text = text.strip() | |
if not text: | |
raise ServiceError(status_code=400, detail="Input text is empty for synonym suggestion.") | |
sentence_model = self._get_sentence_model() | |
nlp = self._get_nlp() | |
await asyncio.to_thread(ensure_nltk_resource, WORDNET_NLTK_ID) | |
doc = await asyncio.to_thread(nlp, text) | |
all_suggestions: Dict[str, List[str]] = {} | |
original_text_embedding = await asyncio.to_thread( | |
sentence_model.encode, text, | |
convert_to_tensor=True, | |
normalize_embeddings=True | |
) | |
candidate_data = [] | |
for token in doc: | |
if token.pos_ in CONTENT_POS_TAGS and len(token.text.strip()) > 2 and not token.is_punct and not token.is_space: | |
original_word = token.text | |
word_start = token.idx | |
word_end = token.idx + len(original_word) | |
wordnet_pos = SPACY_TO_WORDNET_POS.get(token.pos_) | |
if not wordnet_pos: | |
continue | |
wordnet_candidates = await asyncio.to_thread( | |
self._get_wordnet_synonyms_cached, original_word, wordnet_pos | |
) | |
if not wordnet_candidates: | |
continue | |
if original_word not in all_suggestions: | |
all_suggestions[original_word] = [] | |
for candidate in wordnet_candidates: | |
temp_sentence = text[:word_start] + candidate + text[word_end:] | |
candidate_data.append({ | |
"original_word": original_word, | |
"wordnet_candidate": candidate, | |
"temp_sentence": temp_sentence, | |
}) | |
if not candidate_data: | |
return {"suggestions": {}} | |
all_candidate_sentences = [c["temp_sentence"] for c in candidate_data] | |
all_candidate_embeddings = await asyncio.to_thread( | |
sentence_model.encode, | |
all_candidate_sentences, | |
batch_size=settings.SENTENCE_TRANSFORMER_BATCH_SIZE, | |
convert_to_tensor=True, | |
normalize_embeddings=True | |
) | |
if original_text_embedding.dim() == 1: | |
original_text_embedding = original_text_embedding.unsqueeze(0) | |
cosine_scores = cos_sim(original_text_embedding, all_candidate_embeddings)[0] | |
similarity_threshold = 0.65 | |
top_n = 5 | |
temp_scored: Dict[str, List[tuple]] = {word: [] for word in all_suggestions} | |
for i, data in enumerate(candidate_data): | |
word = data["original_word"] | |
candidate = data["wordnet_candidate"] | |
score = cosine_scores[i].item() | |
if score >= similarity_threshold and candidate.lower() != word.lower(): | |
temp_scored[word].append((score, candidate)) | |
final_suggestions = {} | |
for word, scored in temp_scored.items(): | |
if scored: | |
sorted_unique = [] | |
seen = set() | |
for score, candidate in sorted(scored, key=lambda x: x[0], reverse=True): | |
if candidate not in seen: | |
sorted_unique.append(candidate) | |
seen.add(candidate) | |
if len(sorted_unique) >= top_n: | |
break | |
final_suggestions[word] = sorted_unique | |
return {"suggestions": final_suggestions} | |
except Exception as e: | |
logger.error(f"Synonym suggestion error for text: '{text[:50]}...'", exc_info=True) | |
raise ServiceError(status_code=500, detail="An internal error occurred during synonym suggestion.") from e | |
def _get_wordnet_synonyms_cached(self, word: str, pos: str) -> List[str]: | |
synonyms = set() | |
for syn in wordnet.synsets(word, pos=pos): | |
for lemma in syn.lemmas(): | |
name = lemma.name().replace("_", " ").lower() | |
if name.isalpha() and len(name) > 1: | |
synonyms.add(name) | |
synonyms.discard(word.lower()) | |
return sorted(synonyms) | |