Spaces:
Runtime error
Runtime error
# rag_dspy.py | |
import dspy | |
from dspy_qdrant import QdrantRM | |
from qdrant_client import QdrantClient, models | |
from dotenv import load_dotenv | |
import os | |
load_dotenv() | |
# DSPy setup | |
lm = dspy.LM("gpt-4", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY")) | |
client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY")) | |
collection_name = "medical_chat_bot" | |
rm = QdrantRM( | |
qdrant_collection_name=collection_name, | |
qdrant_client=client, | |
vector_name="dense", # <-- MATCHES your vector field in upsert | |
document_field="passage_text", # <-- MATCHES your payload field in upsert | |
k=20) | |
dspy.settings.configure(lm=lm, rm=rm) | |
# Manual reranker using ColBERT multivector field | |
# Manual reranker using Qdrant’s native prefetch + ColBERT query | |
def rerank_with_colbert(query_text): | |
from fastembed import TextEmbedding, LateInteractionTextEmbedding | |
# Encode query once with both models | |
dense_model = TextEmbedding("BAAI/bge-small-en") | |
colbert_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0") | |
dense_query = list(dense_model.embed(query_text))[0] | |
colbert_query = list(colbert_model.embed(query_text))[0] | |
# Combined query: retrieve with dense, rerank with ColBERT | |
results = client.query_points( | |
collection_name=collection_name, | |
prefetch=models.Prefetch( | |
query=dense_query, | |
using="dense" | |
), | |
query=colbert_query, | |
using="colbert", | |
limit=5, | |
with_payload=True | |
) | |
points = results.points | |
docs = [] | |
for point in points: | |
docs.append(point.payload['passage_text']) | |
return docs | |
# DSPy Signature and Module | |
class MedicalAnswer(dspy.Signature): | |
question = dspy.InputField(desc="The medical question to answer") | |
context = dspy.OutputField(desc="The answer to the medical question") | |
final_answer = dspy.OutputField(desc="The answer to the medical question") | |
class MedicalRAG(dspy.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, question): | |
reranked_docs = rerank_with_colbert(question) | |
context_str = "\n".join(reranked_docs) | |
return dspy.ChainOfThought(MedicalAnswer)( | |
question=question, | |
context=context_str | |
) | |