wellsaid / app /services /synonyms.py
iamspruce
fixed the api
73a6a7e
import logging
import asyncio
from typing import List, Dict
from functools import lru_cache
from app.services.base import (
load_spacy_model,
load_sentence_transformer_model,
ensure_nltk_resource
)
from app.core.config import (
settings,
APP_NAME,
SPACY_MODEL_ID,
WORDNET_NLTK_ID,
SENTENCE_TRANSFORMER_MODEL_ID
)
from app.core.exceptions import ServiceError, ModelNotDownloadedError
from nltk.corpus import wordnet
from sentence_transformers.util import cos_sim
logger = logging.getLogger(f"{APP_NAME}.services.synonyms")
SPACY_TO_WORDNET_POS = {
"NOUN": wordnet.NOUN,
"VERB": wordnet.VERB,
"ADJ": wordnet.ADJ,
"ADV": wordnet.ADV,
}
CONTENT_POS_TAGS = {"NOUN", "VERB", "ADJ", "ADV"}
class SynonymSuggester:
def __init__(self):
self._sentence_model = None
self._nlp = None
def _get_sentence_model(self):
if self._sentence_model is None:
self._sentence_model = load_sentence_transformer_model(
SENTENCE_TRANSFORMER_MODEL_ID
)
return self._sentence_model
def _get_nlp(self):
if self._nlp is None:
self._nlp = load_spacy_model(
SPACY_MODEL_ID
)
return self._nlp
async def suggest(self, text: str) -> dict:
try:
text = text.strip()
if not text:
raise ServiceError(status_code=400, detail="Input text is empty for synonym suggestion.")
sentence_model = self._get_sentence_model()
nlp = self._get_nlp()
await asyncio.to_thread(ensure_nltk_resource, WORDNET_NLTK_ID)
doc = await asyncio.to_thread(nlp, text)
all_suggestions: Dict[str, List[str]] = {}
original_text_embedding = await asyncio.to_thread(
sentence_model.encode, text,
convert_to_tensor=True,
normalize_embeddings=True
)
candidate_data = []
for token in doc:
if token.pos_ in CONTENT_POS_TAGS and len(token.text.strip()) > 2 and not token.is_punct and not token.is_space:
original_word = token.text
word_start = token.idx
word_end = token.idx + len(original_word)
wordnet_pos = SPACY_TO_WORDNET_POS.get(token.pos_)
if not wordnet_pos:
continue
wordnet_candidates = await asyncio.to_thread(
self._get_wordnet_synonyms_cached, original_word, wordnet_pos
)
if not wordnet_candidates:
continue
if original_word not in all_suggestions:
all_suggestions[original_word] = []
for candidate in wordnet_candidates:
temp_sentence = text[:word_start] + candidate + text[word_end:]
candidate_data.append({
"original_word": original_word,
"wordnet_candidate": candidate,
"temp_sentence": temp_sentence,
})
if not candidate_data:
return {"suggestions": {}}
all_candidate_sentences = [c["temp_sentence"] for c in candidate_data]
all_candidate_embeddings = await asyncio.to_thread(
sentence_model.encode,
all_candidate_sentences,
batch_size=settings.SENTENCE_TRANSFORMER_BATCH_SIZE,
convert_to_tensor=True,
normalize_embeddings=True
)
if original_text_embedding.dim() == 1:
original_text_embedding = original_text_embedding.unsqueeze(0)
cosine_scores = cos_sim(original_text_embedding, all_candidate_embeddings)[0]
similarity_threshold = 0.65
top_n = 5
temp_scored: Dict[str, List[tuple]] = {word: [] for word in all_suggestions}
for i, data in enumerate(candidate_data):
word = data["original_word"]
candidate = data["wordnet_candidate"]
score = cosine_scores[i].item()
if score >= similarity_threshold and candidate.lower() != word.lower():
temp_scored[word].append((score, candidate))
final_suggestions = {}
for word, scored in temp_scored.items():
if scored:
sorted_unique = []
seen = set()
for score, candidate in sorted(scored, key=lambda x: x[0], reverse=True):
if candidate not in seen:
sorted_unique.append(candidate)
seen.add(candidate)
if len(sorted_unique) >= top_n:
break
final_suggestions[word] = sorted_unique
return {"suggestions": final_suggestions}
except Exception as e:
logger.error(f"Synonym suggestion error for text: '{text[:50]}...'", exc_info=True)
raise ServiceError(status_code=500, detail="An internal error occurred during synonym suggestion.") from e
@lru_cache(maxsize=5000)
def _get_wordnet_synonyms_cached(self, word: str, pos: str) -> List[str]:
synonyms = set()
for syn in wordnet.synsets(word, pos=pos):
for lemma in syn.lemmas():
name = lemma.name().replace("_", " ").lower()
if name.isalpha() and len(name) > 1:
synonyms.add(name)
synonyms.discard(word.lower())
return sorted(synonyms)