medicalchatbot / rag_dspy.py
mwitiderrick's picture
Update rag_dspy.py
223bbe7 verified
# rag_dspy.py
import dspy
from dspy_qdrant import QdrantRM
from qdrant_client import QdrantClient, models
from qdrant_client.models import Filter, FieldCondition, MatchValue
from dotenv import load_dotenv
import os
load_dotenv()
# DSPy setup
lm = dspy.LM("gpt-4o-mini-2024-07-18", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
collection_name = "miriad"
rm = QdrantRM(
qdrant_collection_name=collection_name,
qdrant_client=client,
vector_name="dense", # <-- MATCHES your vector field in upsert
document_field="passage_text", # <-- MATCHES your payload field in upsert
k=20)
dspy.settings.configure(lm=lm, rm=rm)
# Manual reranker using ColBERT multivector field
# Manual reranker using Qdrant’s native prefetch + ColBERT query
def rerank_with_colbert(query_text, min_year, max_year, specialty):
from fastembed import TextEmbedding, LateInteractionTextEmbedding
# Encode query once with both models
dense_model = TextEmbedding("BAAI/bge-small-en")
colbert_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0")
dense_query = list(dense_model.embed(query_text))[0]
colbert_query = list(colbert_model.embed(query_text))[0]
# Combined query: retrieve with dense, rerank with ColBERT
results = client.query_points(
collection_name=collection_name,
prefetch=models.Prefetch(
query=dense_query,
using="dense"
),
query=colbert_query,
using="colbert",
limit=5,
with_payload=True,
query_filter=Filter(
must=[
FieldCondition(key="specialty", match=MatchValue(value=specialty)),
FieldCondition(key="year",range=models.Range(gt=None,gte=min_year,lt=None,lte=max_year))
]
)
)
points = results.points
docs = []
for point in points:
docs.append(point.payload['passage_text'])
return docs
# DSPy Signature and Module
class MedicalAnswer(dspy.Signature):
question = dspy.InputField(desc="The medical question to answer")
is_medical = dspy.OutputField(desc="Answer 'Yes' if the question is medical, otherwise 'No'")
min_year = dspy.InputField(desc="The minimum year of the medical paper")
max_year = dspy.InputField(desc="The maximum year of the medical paper")
specialty = dspy.InputField(desc="The specialty of the medical paper")
context = dspy.OutputField(desc="The answer to the medical question")
final_answer = dspy.OutputField(desc="The answer to the medical question")
class MedicalGuardrail(dspy.Module):
def forward(self, question):
prompt = (
"Is the following question a medical question? Answer with 'Yes' or 'No'.\n"
f"Question: {question}\n"
"Answer:"
)
response = dspy.settings.lm(prompt)
answer = response[0].strip().lower()
return answer.startswith("yes")
class MedicalRAG(dspy.Module):
def __init__(self):
super().__init__()
self.guardrail = MedicalGuardrail()
def forward(self, question, min_year, max_year, specialty):
if not self.guardrail.forward(question):
class DummyResult:
final_answer = "Sorry, I can only answer medical questions. Please ask a question related to medicine or healthcare."
return DummyResult()
reranked_docs = rerank_with_colbert(question, min_year, max_year, specialty)
context_str = "\n".join(reranked_docs)
return dspy.ChainOfThought(MedicalAnswer)(
question=question,
min_year=min_year,
max_year=max_year,
specialty=specialty,
context=context_str
)