File size: 3,863 Bytes
949d5bc
 
 
 
 
875c838
 
949d5bc
 
 
 
 
1d95dad
949d5bc
223bbe7
949d5bc
 
 
 
 
 
 
 
 
 
 
3d2faa5
949d5bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
875c838
 
 
 
3d2faa5
 
875c838
 
949d5bc
 
 
 
 
 
 
 
 
 
 
 
 
1235a73
3d2faa5
 
875c838
949d5bc
 
 
1235a73
 
 
 
 
 
 
 
 
 
 
949d5bc
 
 
1235a73
949d5bc
3d2faa5
1235a73
 
 
 
3d2faa5
949d5bc
 
 
3d2faa5
 
875c838
949d5bc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# rag_dspy.py

import dspy
from dspy_qdrant import QdrantRM
from qdrant_client import QdrantClient, models
from qdrant_client.models import Filter, FieldCondition, MatchValue

from dotenv import load_dotenv
import os

load_dotenv() 
# DSPy setup
lm = dspy.LM("gpt-4o-mini-2024-07-18", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
collection_name = "miriad"
rm = QdrantRM(
    qdrant_collection_name=collection_name, 
    qdrant_client=client, 
    vector_name="dense",                 # <-- MATCHES your vector field in upsert
    document_field="passage_text",        # <-- MATCHES your payload field in upsert
    k=20)

dspy.settings.configure(lm=lm, rm=rm)

# Manual reranker using ColBERT multivector field
# Manual reranker using Qdrant’s native prefetch + ColBERT query
def rerank_with_colbert(query_text, min_year, max_year, specialty):
    from fastembed import TextEmbedding, LateInteractionTextEmbedding

    # Encode query once with both models
    dense_model = TextEmbedding("BAAI/bge-small-en")
    colbert_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0")

    dense_query = list(dense_model.embed(query_text))[0]
    colbert_query = list(colbert_model.embed(query_text))[0]

    # Combined query: retrieve with dense, rerank with ColBERT
    results = client.query_points(
        collection_name=collection_name,
        prefetch=models.Prefetch(
            query=dense_query,
            using="dense"
        ),
        query=colbert_query,
        using="colbert",
        limit=5,
        with_payload=True,
        query_filter=Filter(
            must=[
                FieldCondition(key="specialty", match=MatchValue(value=specialty)),
                FieldCondition(key="year",range=models.Range(gt=None,gte=min_year,lt=None,lte=max_year))
            ]

        )
    )
    
    points = results.points
    docs = []

    for point in points:
        docs.append(point.payload['passage_text'])

    return docs

# DSPy Signature and Module
class MedicalAnswer(dspy.Signature):
    question = dspy.InputField(desc="The medical question to answer")
    is_medical = dspy.OutputField(desc="Answer 'Yes' if the question is medical, otherwise 'No'")
    min_year = dspy.InputField(desc="The minimum year of the medical paper")
    max_year = dspy.InputField(desc="The maximum year of the medical paper")
    specialty = dspy.InputField(desc="The specialty of the medical paper")
    context = dspy.OutputField(desc="The answer to the medical question")
    final_answer = dspy.OutputField(desc="The answer to the medical question")

class MedicalGuardrail(dspy.Module):
    def forward(self, question):
        prompt = (
            "Is the following question a medical question? Answer with 'Yes' or 'No'.\n"
            f"Question: {question}\n"
            "Answer:"
        )
        response = dspy.settings.lm(prompt)
        answer = response[0].strip().lower()
        return answer.startswith("yes")
        
class MedicalRAG(dspy.Module):
    def __init__(self):
        super().__init__()
        self.guardrail = MedicalGuardrail()

    def forward(self, question, min_year, max_year, specialty):
        if not self.guardrail.forward(question):
            class DummyResult:
                final_answer = "Sorry, I can only answer medical questions. Please ask a question related to medicine or healthcare."
            return DummyResult()
        reranked_docs = rerank_with_colbert(question, min_year, max_year, specialty)
        context_str = "\n".join(reranked_docs)
        return dspy.ChainOfThought(MedicalAnswer)(
            question=question,
            min_year=min_year,
            max_year=max_year,
            specialty=specialty,
            context=context_str
        )