Spaces:
Running
Running
# Author : Justin | |
# Program : Vectorizer for Hybrid Search | |
# Instructions : Check README.md | |
import torch | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
from qdrant_client import models | |
import logging | |
import json | |
# --- Setup Logging --- | |
# Configure logging to be more descriptive | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
) | |
logger = logging.getLogger(__name__) | |
# --- Configuration --- | |
# Local models for vector generation | |
DENSE_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2' | |
# Use the corresponding QUERY encoder for SPLADE, which is optimized for search queries | |
SPLADE_QUERY_MODEL_ID = 'naver/efficient-splade-VI-BT-large-query' | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# --- Global Variables for Models --- | |
# These will be loaded once when the application starts | |
dense_model = None | |
splade_tokenizer = None | |
splade_model = None | |
# --- FastAPI Application --- | |
app = FastAPI( | |
title="Hybrid Vector Generation API", | |
description="An API to generate dense and sparse vectors for a given text query.", | |
version="1.2.0" | |
) | |
# --- Pydantic Models for API --- | |
class QueryRequest(BaseModel): | |
"""Request model for the API, expecting a single text query.""" | |
query_text: str | |
class SparseVectorResponse(BaseModel): | |
"""Response model for the sparse vector.""" | |
indices: list[int] | |
values: list[float] | |
class VectorResponse(BaseModel): | |
"""Final JSON response model containing both vectors.""" | |
dense_vector: list[float] | |
sparse_vector: SparseVectorResponse | |
async def load_models(): | |
""" | |
Asynchronous event to load ML models on application startup. | |
This ensures models are loaded only once. | |
""" | |
global dense_model, splade_tokenizer, splade_model | |
logger.info("Server is starting up... Time to load the ML models.") | |
logger.info(f"I'll be using the '{DEVICE}' for processing.") | |
try: | |
dense_model = SentenceTransformer(DENSE_MODEL_ID, device=DEVICE) | |
splade_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_MODEL_ID) | |
splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_MODEL_ID).to(DEVICE) | |
logger.info("YAaay! All models have been loaded successfully.") | |
except Exception as e: | |
logger.critical(f"Oh no, a critical error occurred while loading models: {e}", exc_info=True) | |
# In a real-world scenario, you might want the app to fail startup if models don't load | |
raise e | |
def compute_splade_vector(text: str) -> models.SparseVector: | |
""" | |
Computes a SPLADE sparse vector from a given text query. | |
Args: | |
text: The input text string. | |
Returns: | |
A Qdrant SparseVector object. | |
""" | |
tokens = splade_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
tokens = {key: val.to(DEVICE) for key, val in tokens.items()} # Move tensors to the correct device | |
with torch.no_grad(): | |
output = splade_model(**tokens) | |
logits, attention_mask = output.logits, tokens['attention_mask'] | |
relu_log = torch.log(1 + torch.relu(logits)) | |
weighted_log = relu_log * attention_mask.unsqueeze(-1) | |
max_val, _ = torch.max(weighted_log, dim=1) | |
vec = max_val.squeeze() | |
indices = vec.nonzero().squeeze().cpu().tolist() | |
values = vec[indices].cpu().tolist() | |
# Ensure indices and values are always lists, even for a single-element tensor | |
if not isinstance(indices, list): | |
indices = [indices] | |
values = [values] | |
return models.SparseVector(indices=indices, values=values) | |
async def vectorize_query(request: QueryRequest): | |
""" | |
API endpoint to generate and return dense and sparse vectors for a text query. | |
Args: | |
request: A QueryRequest object containing the 'query_text'. | |
Returns: | |
A JSON response containing the dense and sparse vectors. | |
""" | |
# --- n8n Logging --- | |
logger.info("=========================================================") | |
logger.info("A new request just arrived! Let's see what we've got.") | |
logger.info(f"The incoming search query from n8n is: '{request.query_text}'") | |
# 1. Generate Dense Vector | |
logger.info("First, generating the dense vector for semantic meaning...") | |
dense_query_vector = dense_model.encode(request.query_text).tolist() | |
logger.info("Done with the dense vector. It has %d dimensions.", len(dense_query_vector)) | |
logger.info("Here's a small sample of the dense vector: %s...", str(dense_query_vector[:4])) | |
# 2. Generate Sparse Vector | |
logger.info("Next up, creating the sparse vector for keyword matching...") | |
sparse_query_vector = compute_splade_vector(request.query_text) | |
logger.info("Sparse vector is ready. It contains %d important terms.", len(sparse_query_vector.indices)) | |
logger.info("Here's a sample of the sparse vector indices: %s...", str(sparse_query_vector.indices[:4])) | |
# 3. Construct and return the response | |
logger.info("Everything looks good. Packaging up the vectors to send back.") | |
logger.info("-----------------------------------------------------------------") | |
final_response = VectorResponse( | |
dense_vector=dense_query_vector, | |
sparse_vector=SparseVectorResponse( | |
indices=sparse_query_vector.indices, | |
values=sparse_query_vector.values | |
) | |
) | |
return final_response | |
async def root(): | |
return {"message": "Vector Generation API is running. -- VERSION 2 --"} | |