Spaces:
Sleeping
Sleeping
Update rag_dspy.py
Browse files- rag_dspy.py +18 -5
rag_dspy.py
CHANGED
@@ -3,6 +3,8 @@
|
|
3 |
import dspy
|
4 |
from dspy_qdrant import QdrantRM
|
5 |
from qdrant_client import QdrantClient, models
|
|
|
|
|
6 |
from dotenv import load_dotenv
|
7 |
import os
|
8 |
|
@@ -10,7 +12,7 @@ load_dotenv()
|
|
10 |
# DSPy setup
|
11 |
lm = dspy.LM("gpt-4", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
|
12 |
client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
|
13 |
-
collection_name = "
|
14 |
rm = QdrantRM(
|
15 |
qdrant_collection_name=collection_name,
|
16 |
qdrant_client=client,
|
@@ -22,7 +24,7 @@ dspy.settings.configure(lm=lm, rm=rm)
|
|
22 |
|
23 |
# Manual reranker using ColBERT multivector field
|
24 |
# Manual reranker using Qdrant’s native prefetch + ColBERT query
|
25 |
-
def rerank_with_colbert(query_text):
|
26 |
from fastembed import TextEmbedding, LateInteractionTextEmbedding
|
27 |
|
28 |
# Encode query once with both models
|
@@ -42,7 +44,14 @@ def rerank_with_colbert(query_text):
|
|
42 |
query=colbert_query,
|
43 |
using="colbert",
|
44 |
limit=5,
|
45 |
-
with_payload=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
)
|
47 |
|
48 |
points = results.points
|
@@ -56,6 +65,8 @@ def rerank_with_colbert(query_text):
|
|
56 |
# DSPy Signature and Module
|
57 |
class MedicalAnswer(dspy.Signature):
|
58 |
question = dspy.InputField(desc="The medical question to answer")
|
|
|
|
|
59 |
context = dspy.OutputField(desc="The answer to the medical question")
|
60 |
final_answer = dspy.OutputField(desc="The answer to the medical question")
|
61 |
|
@@ -63,12 +74,14 @@ class MedicalRAG(dspy.Module):
|
|
63 |
def __init__(self):
|
64 |
super().__init__()
|
65 |
|
66 |
-
def forward(self, question):
|
67 |
-
reranked_docs = rerank_with_colbert(question)
|
68 |
|
69 |
context_str = "\n".join(reranked_docs)
|
70 |
|
71 |
return dspy.ChainOfThought(MedicalAnswer)(
|
72 |
question=question,
|
|
|
|
|
73 |
context=context_str
|
74 |
)
|
|
|
3 |
import dspy
|
4 |
from dspy_qdrant import QdrantRM
|
5 |
from qdrant_client import QdrantClient, models
|
6 |
+
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
7 |
+
|
8 |
from dotenv import load_dotenv
|
9 |
import os
|
10 |
|
|
|
12 |
# DSPy setup
|
13 |
lm = dspy.LM("gpt-4", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
|
14 |
client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
|
15 |
+
collection_name = "indexed_medical_chat_bot"
|
16 |
rm = QdrantRM(
|
17 |
qdrant_collection_name=collection_name,
|
18 |
qdrant_client=client,
|
|
|
24 |
|
25 |
# Manual reranker using ColBERT multivector field
|
26 |
# Manual reranker using Qdrant’s native prefetch + ColBERT query
|
27 |
+
def rerank_with_colbert(query_text, year, specialty):
|
28 |
from fastembed import TextEmbedding, LateInteractionTextEmbedding
|
29 |
|
30 |
# Encode query once with both models
|
|
|
44 |
query=colbert_query,
|
45 |
using="colbert",
|
46 |
limit=5,
|
47 |
+
with_payload=True,
|
48 |
+
query_filter=Filter(
|
49 |
+
must=[
|
50 |
+
FieldCondition(key="specialty", match=MatchValue(value=specialty)),
|
51 |
+
FieldCondition(key="year", match=MatchValue(value=year))
|
52 |
+
]
|
53 |
+
|
54 |
+
)
|
55 |
)
|
56 |
|
57 |
points = results.points
|
|
|
65 |
# DSPy Signature and Module
|
66 |
class MedicalAnswer(dspy.Signature):
|
67 |
question = dspy.InputField(desc="The medical question to answer")
|
68 |
+
year = dspy.InputField(desc="The year of the medical paper")
|
69 |
+
specialty = dspy.InputField(desc="The specialty of the medical paper")
|
70 |
context = dspy.OutputField(desc="The answer to the medical question")
|
71 |
final_answer = dspy.OutputField(desc="The answer to the medical question")
|
72 |
|
|
|
74 |
def __init__(self):
|
75 |
super().__init__()
|
76 |
|
77 |
+
def forward(self, question, year, specialty):
|
78 |
+
reranked_docs = rerank_with_colbert(question, year, specialty)
|
79 |
|
80 |
context_str = "\n".join(reranked_docs)
|
81 |
|
82 |
return dspy.ChainOfThought(MedicalAnswer)(
|
83 |
question=question,
|
84 |
+
year=year,
|
85 |
+
specialty=specialty,
|
86 |
context=context_str
|
87 |
)
|