Spaces:
Running
Running
| import logging | |
| import os | |
| import time | |
| from typing import TypedDict | |
| import numpy as np | |
| import pandas as pd | |
| import redis | |
| import torch | |
| from chem_mrl.molecular_fingerprinter import MorganFingerprinter | |
| from dotenv import load_dotenv | |
| from rdkit import Chem, RDLogger | |
| from redis.commands.search.field import TextField, VectorField | |
| from redis.commands.search.indexDefinition import IndexDefinition, IndexType | |
| from redis.commands.search.query import Query | |
| from sentence_transformers import SentenceTransformer | |
| from constants import ( | |
| EMBEDDING_DIMENSION, | |
| HNSW_K, | |
| HNSW_PARAMETERS, | |
| MODEL_NAME, | |
| SUPPORTED_EMBEDDING_DIMENSIONS, | |
| USE_HALF_PRECISION, | |
| ) | |
| from data import ISOMER_DESIGN_DATASET | |
| def setup_logger(clear_handler=False): | |
| if clear_handler: | |
| for handler in logging.root.handlers[:]: | |
| logging.root.removeHandler(handler) # issue with sentence-transformer's logging handler | |
| RDLogger.DisableLog("rdApp.*") # type: ignore - DisableLog is an exported function | |
| logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| return logger | |
| load_dotenv("../.env") | |
| logger = setup_logger(clear_handler=True) | |
| class SimilarMolecule(TypedDict): | |
| smiles: str | |
| name: str | |
| properties: str | |
| score: float | |
| class MolecularEmbeddingService: | |
| def __init__(self): | |
| self.model_name = MODEL_NAME | |
| self.index_name = "molecule_embeddings" | |
| self.model_embed_dim = EMBEDDING_DIMENSION | |
| self.model = self._initialize_model() | |
| self.redis_client = self._initialize_redis() | |
| self._initialize_datastore() | |
| def _initialize_model(self): | |
| """Initialize the Hugging Face transformers model""" | |
| try: | |
| model = SentenceTransformer( | |
| self.model_name, | |
| model_kwargs={ | |
| "torch_dtype": torch.float16 if USE_HALF_PRECISION else torch.float32, | |
| }, | |
| ) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def _initialize_redis(self): | |
| """Initialize Redis connection""" | |
| try: | |
| redis_host = os.getenv("REDIS_HOST", "localhost") | |
| redis_port = int(os.getenv("REDIS_PORT", 6379)) | |
| redis_password = os.getenv("REDIS_PASSWORD", None) | |
| logger.info( | |
| f"Connecting to Redis at {redis_host}:{redis_port} with password: {'***' if redis_password else 'None'}" | |
| ) | |
| redis_client = redis.Redis( | |
| host=redis_host, | |
| port=redis_port, | |
| password=redis_password, | |
| decode_responses=True, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to connect to Redis: {e}") | |
| raise | |
| while True: | |
| try: | |
| redis_client.ping() | |
| break | |
| except redis.exceptions.BusyLoadingError: | |
| time_out = 5 | |
| logger.warning(f"Redis is loading the dataset in memory. Retrying in {time_out} seconds...") | |
| time.sleep(time_out) | |
| return redis_client | |
| def _initialize_datastore(self): | |
| self.__create_hnsw_index() | |
| self.__populate_sample_data(ISOMER_DESIGN_DATASET) | |
| def __create_hnsw_index(self): | |
| """Create HNSW index for molecular embeddings""" | |
| try: | |
| self.redis_client.ft(self.index_name).info() | |
| logger.info(f"Index {self.index_name} already exists") | |
| return | |
| except redis.exceptions.ResponseError: | |
| pass | |
| try: | |
| schema: list[TextField | VectorField] = [ | |
| VectorField( | |
| self.embedding_field_name(dim), | |
| "HNSW", | |
| { | |
| **HNSW_PARAMETERS, | |
| "DIM": dim, | |
| }, | |
| ) | |
| for dim in SUPPORTED_EMBEDDING_DIMENSIONS | |
| ] | |
| schema.insert(0, TextField("smiles")) | |
| self.redis_client.ft(self.index_name).create_index( | |
| schema, | |
| definition=IndexDefinition(prefix=[self.molecule_index_prefix("")], index_type=IndexType.HASH), | |
| ) | |
| logger.info(f"Created HNSW index: {self.index_name}") | |
| except Exception as e: | |
| logger.error(f"Failed to create HNSW index: {e}") | |
| raise | |
| def __populate_sample_data(self, df: pd.DataFrame): | |
| """Populate Redis with sample molecular data""" | |
| logger.info("Populating Redis with sample molecular data...") | |
| for _, row in df.iterrows(): | |
| try: | |
| key = self.molecule_index_prefix(row["smiles"]) | |
| if self.redis_client.exists(key): | |
| continue | |
| embedding_cache: np.ndarray = self.get_molecular_embedding(row["smiles"], EMBEDDING_DIMENSION) | |
| mapping: dict[str, bytes | str] = { | |
| self.embedding_field_name(embed_dim): self._truncate_and_normalize_embedding( | |
| embedding_cache.copy(), embed_dim | |
| ).tobytes() | |
| for embed_dim in SUPPORTED_EMBEDDING_DIMENSIONS | |
| } | |
| mapping = {**mapping, **row.to_dict()} | |
| self.redis_client.hset( | |
| key, | |
| mapping=mapping, # type: ignore | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to process molecule {row}: {e}") | |
| continue | |
| logger.info(f"Populated {len(df)} sample molecules") | |
| def get_molecular_embedding(self, smiles: str, embed_dim: int) -> np.ndarray: | |
| """Generate molecular embedding using ChemMRL""" | |
| try: | |
| if embed_dim <= 0: | |
| raise ValueError("embed_dim must be positive") | |
| # Preprocess smiles similarly as training data for optimal performance | |
| smiles = MorganFingerprinter.canonicalize_smiles(smiles) or smiles | |
| embedding: np.ndarray = self.model.encode( | |
| [smiles], | |
| show_progress_bar=False, | |
| convert_to_numpy=True, | |
| )[0] | |
| return self._truncate_and_normalize_embedding(embedding, embed_dim) | |
| except Exception as e: | |
| logger.error(f"Failed to generate embedding for {smiles}: {e}") | |
| raise | |
| def _truncate_and_normalize_embedding(self, embedding: np.ndarray, embed_dim: int) -> np.ndarray: | |
| """Truncate and normalize embedding""" | |
| if embed_dim < len(embedding): | |
| embedding = embedding[:embed_dim] | |
| norms = np.linalg.norm(embedding, ord=2, keepdims=True) | |
| return embedding / np.where(norms == 0, 1, norms) | |
| def find_similar_molecules( | |
| self, query_embedding: np.ndarray, embed_dim: int, k: int = HNSW_K | |
| ) -> list[SimilarMolecule]: | |
| """Find k most similar molecules using HNSW""" | |
| try: | |
| query_vector = query_embedding.tobytes() | |
| query = ( | |
| Query(f"*=>[KNN {k} @{self.embedding_field_name(embed_dim)} $vec AS score]") | |
| .sort_by("score") | |
| .return_fields("smiles", "name", "properties", "score") | |
| .dialect(2) | |
| ) | |
| results = self.redis_client.ft(self.index_name).search( | |
| query, | |
| query_params={ | |
| "vec": query_vector, # type: ignore | |
| }, | |
| ) | |
| neighbors: list[SimilarMolecule] = [ | |
| {"smiles": doc.smiles, "name": doc.name, "properties": doc.properties, "score": float(doc.score)} | |
| for doc in results.docs | |
| ] | |
| return neighbors | |
| except Exception as e: | |
| logger.error(f"Failed to find similar molecules: {e}") | |
| return [] | |
| def get_canonical_smiles(smiles: str | None) -> str: | |
| """Convert SMILES to canonical SMILES representation""" | |
| if not smiles or smiles.strip() == "": | |
| return "" | |
| canonical = MorganFingerprinter.canonicalize_smiles(smiles.strip()) | |
| if canonical is None: | |
| return smiles.strip() | |
| return canonical | |
| def get_smiles_from_mol_file(mol_file: str) -> str: | |
| """Convert SMILES to canonical SMILES representation""" | |
| if not mol_file or mol_file.strip() == "": | |
| return "" | |
| mol = Chem.rdmolfiles.MolFromMolBlock(mol_file) | |
| if mol is None: | |
| return "" | |
| return Chem.MolToSmiles(mol, canonical=True) | |
| def embedding_field_name(dim: int) -> str: | |
| return f"embedding_{dim}" | |
| def molecule_index_prefix(smiles: str) -> str: | |
| return f"mol:{smiles}" | |