import logging import os import time from typing import TypedDict import numpy as np import pandas as pd import redis import torch from chem_mrl.molecular_fingerprinter import MorganFingerprinter from dotenv import load_dotenv from rdkit import Chem, RDLogger from redis.commands.search.field import TextField, VectorField from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import Query from sentence_transformers import SentenceTransformer from constants import ( EMBEDDING_DIMENSION, HNSW_K, HNSW_PARAMETERS, MODEL_NAME, SUPPORTED_EMBEDDING_DIMENSIONS, USE_HALF_PRECISION, ) from data import ISOMER_DESIGN_DATASET def setup_logger(clear_handler=False): if clear_handler: for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # issue with sentence-transformer's logging handler RDLogger.DisableLog("rdApp.*") # type: ignore - DisableLog is an exported function logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) logger = logging.getLogger(__name__) return logger load_dotenv("../.env") logger = setup_logger(clear_handler=True) class SimilarMolecule(TypedDict): smiles: str name: str properties: str score: float class MolecularEmbeddingService: def __init__(self): self.model_name = MODEL_NAME self.index_name = "molecule_embeddings" self.model_embed_dim = EMBEDDING_DIMENSION self.model = self._initialize_model() self.redis_client = self._initialize_redis() self._initialize_datastore() def _initialize_model(self): """Initialize the Hugging Face transformers model""" try: model = SentenceTransformer( self.model_name, model_kwargs={ "torch_dtype": torch.float16 if USE_HALF_PRECISION else torch.float32, }, ) model.eval() return model except Exception as e: logger.error(f"Failed to load model: {e}") raise def _initialize_redis(self): """Initialize Redis connection""" try: redis_host = os.getenv("REDIS_HOST", "localhost") redis_port = int(os.getenv("REDIS_PORT", 6379)) redis_password = os.getenv("REDIS_PASSWORD", None) logger.info( f"Connecting to Redis at {redis_host}:{redis_port} with password: {'***' if redis_password else 'None'}" ) redis_client = redis.Redis( host=redis_host, port=redis_port, password=redis_password, decode_responses=True, ) except Exception as e: logger.error(f"Failed to connect to Redis: {e}") raise while True: try: redis_client.ping() break except redis.exceptions.BusyLoadingError: time_out = 5 logger.warning(f"Redis is loading the dataset in memory. Retrying in {time_out} seconds...") time.sleep(time_out) return redis_client def _initialize_datastore(self): self.__create_hnsw_index() self.__populate_sample_data(ISOMER_DESIGN_DATASET) def __create_hnsw_index(self): """Create HNSW index for molecular embeddings""" try: self.redis_client.ft(self.index_name).info() logger.info(f"Index {self.index_name} already exists") return except redis.exceptions.ResponseError: pass try: schema: list[TextField | VectorField] = [ VectorField( self.embedding_field_name(dim), "HNSW", { **HNSW_PARAMETERS, "DIM": dim, }, ) for dim in SUPPORTED_EMBEDDING_DIMENSIONS ] schema.insert(0, TextField("smiles")) self.redis_client.ft(self.index_name).create_index( schema, definition=IndexDefinition(prefix=[self.molecule_index_prefix("")], index_type=IndexType.HASH), ) logger.info(f"Created HNSW index: {self.index_name}") except Exception as e: logger.error(f"Failed to create HNSW index: {e}") raise def __populate_sample_data(self, df: pd.DataFrame): """Populate Redis with sample molecular data""" logger.info("Populating Redis with sample molecular data...") for _, row in df.iterrows(): try: key = self.molecule_index_prefix(row["smiles"]) if self.redis_client.exists(key): continue embedding_cache: np.ndarray = self.get_molecular_embedding(row["smiles"], EMBEDDING_DIMENSION) mapping: dict[str, bytes | str] = { self.embedding_field_name(embed_dim): self._truncate_and_normalize_embedding( embedding_cache.copy(), embed_dim ).tobytes() for embed_dim in SUPPORTED_EMBEDDING_DIMENSIONS } mapping = {**mapping, **row.to_dict()} self.redis_client.hset( key, mapping=mapping, # type: ignore ) except Exception as e: logger.error(f"Failed to process molecule {row}: {e}") continue logger.info(f"Populated {len(df)} sample molecules") def get_molecular_embedding(self, smiles: str, embed_dim: int) -> np.ndarray: """Generate molecular embedding using ChemMRL""" try: if embed_dim <= 0: raise ValueError("embed_dim must be positive") # Preprocess smiles similarly as training data for optimal performance smiles = MorganFingerprinter.canonicalize_smiles(smiles) or smiles embedding: np.ndarray = self.model.encode( [smiles], show_progress_bar=False, convert_to_numpy=True, )[0] return self._truncate_and_normalize_embedding(embedding, embed_dim) except Exception as e: logger.error(f"Failed to generate embedding for {smiles}: {e}") raise def _truncate_and_normalize_embedding(self, embedding: np.ndarray, embed_dim: int) -> np.ndarray: """Truncate and normalize embedding""" if embed_dim < len(embedding): embedding = embedding[:embed_dim] norms = np.linalg.norm(embedding, ord=2, keepdims=True) return embedding / np.where(norms == 0, 1, norms) def find_similar_molecules( self, query_embedding: np.ndarray, embed_dim: int, k: int = HNSW_K ) -> list[SimilarMolecule]: """Find k most similar molecules using HNSW""" try: query_vector = query_embedding.tobytes() query = ( Query(f"*=>[KNN {k} @{self.embedding_field_name(embed_dim)} $vec AS score]") .sort_by("score") .return_fields("smiles", "name", "properties", "score") .dialect(2) ) results = self.redis_client.ft(self.index_name).search( query, query_params={ "vec": query_vector, # type: ignore }, ) neighbors: list[SimilarMolecule] = [ {"smiles": doc.smiles, "name": doc.name, "properties": doc.properties, "score": float(doc.score)} for doc in results.docs ] return neighbors except Exception as e: logger.error(f"Failed to find similar molecules: {e}") return [] @staticmethod def get_canonical_smiles(smiles: str | None) -> str: """Convert SMILES to canonical SMILES representation""" if not smiles or smiles.strip() == "": return "" canonical = MorganFingerprinter.canonicalize_smiles(smiles.strip()) if canonical is None: return smiles.strip() return canonical @staticmethod def get_smiles_from_mol_file(mol_file: str) -> str: """Convert SMILES to canonical SMILES representation""" if not mol_file or mol_file.strip() == "": return "" mol = Chem.rdmolfiles.MolFromMolBlock(mol_file) if mol is None: return "" return Chem.MolToSmiles(mol, canonical=True) @staticmethod def embedding_field_name(dim: int) -> str: return f"embedding_{dim}" @staticmethod def molecule_index_prefix(smiles: str) -> str: return f"mol:{smiles}"