mwitiderrick commited on
Commit
875c838
·
verified ·
1 Parent(s): 38fec79

Update rag_dspy.py

Browse files
Files changed (1) hide show
  1. rag_dspy.py +18 -5
rag_dspy.py CHANGED
@@ -3,6 +3,8 @@
3
  import dspy
4
  from dspy_qdrant import QdrantRM
5
  from qdrant_client import QdrantClient, models
 
 
6
  from dotenv import load_dotenv
7
  import os
8
 
@@ -10,7 +12,7 @@ load_dotenv()
10
  # DSPy setup
11
  lm = dspy.LM("gpt-4", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
12
  client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
13
- collection_name = "medical_chat_bot"
14
  rm = QdrantRM(
15
  qdrant_collection_name=collection_name,
16
  qdrant_client=client,
@@ -22,7 +24,7 @@ dspy.settings.configure(lm=lm, rm=rm)
22
 
23
  # Manual reranker using ColBERT multivector field
24
  # Manual reranker using Qdrant’s native prefetch + ColBERT query
25
- def rerank_with_colbert(query_text):
26
  from fastembed import TextEmbedding, LateInteractionTextEmbedding
27
 
28
  # Encode query once with both models
@@ -42,7 +44,14 @@ def rerank_with_colbert(query_text):
42
  query=colbert_query,
43
  using="colbert",
44
  limit=5,
45
- with_payload=True
 
 
 
 
 
 
 
46
  )
47
 
48
  points = results.points
@@ -56,6 +65,8 @@ def rerank_with_colbert(query_text):
56
  # DSPy Signature and Module
57
  class MedicalAnswer(dspy.Signature):
58
  question = dspy.InputField(desc="The medical question to answer")
 
 
59
  context = dspy.OutputField(desc="The answer to the medical question")
60
  final_answer = dspy.OutputField(desc="The answer to the medical question")
61
 
@@ -63,12 +74,14 @@ class MedicalRAG(dspy.Module):
63
  def __init__(self):
64
  super().__init__()
65
 
66
- def forward(self, question):
67
- reranked_docs = rerank_with_colbert(question)
68
 
69
  context_str = "\n".join(reranked_docs)
70
 
71
  return dspy.ChainOfThought(MedicalAnswer)(
72
  question=question,
 
 
73
  context=context_str
74
  )
 
3
  import dspy
4
  from dspy_qdrant import QdrantRM
5
  from qdrant_client import QdrantClient, models
6
+ from qdrant_client.models import Filter, FieldCondition, MatchValue
7
+
8
  from dotenv import load_dotenv
9
  import os
10
 
 
12
  # DSPy setup
13
  lm = dspy.LM("gpt-4", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
14
  client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
15
+ collection_name = "indexed_medical_chat_bot"
16
  rm = QdrantRM(
17
  qdrant_collection_name=collection_name,
18
  qdrant_client=client,
 
24
 
25
  # Manual reranker using ColBERT multivector field
26
  # Manual reranker using Qdrant’s native prefetch + ColBERT query
27
+ def rerank_with_colbert(query_text, year, specialty):
28
  from fastembed import TextEmbedding, LateInteractionTextEmbedding
29
 
30
  # Encode query once with both models
 
44
  query=colbert_query,
45
  using="colbert",
46
  limit=5,
47
+ with_payload=True,
48
+ query_filter=Filter(
49
+ must=[
50
+ FieldCondition(key="specialty", match=MatchValue(value=specialty)),
51
+ FieldCondition(key="year", match=MatchValue(value=year))
52
+ ]
53
+
54
+ )
55
  )
56
 
57
  points = results.points
 
65
  # DSPy Signature and Module
66
  class MedicalAnswer(dspy.Signature):
67
  question = dspy.InputField(desc="The medical question to answer")
68
+ year = dspy.InputField(desc="The year of the medical paper")
69
+ specialty = dspy.InputField(desc="The specialty of the medical paper")
70
  context = dspy.OutputField(desc="The answer to the medical question")
71
  final_answer = dspy.OutputField(desc="The answer to the medical question")
72
 
 
74
  def __init__(self):
75
  super().__init__()
76
 
77
+ def forward(self, question, year, specialty):
78
+ reranked_docs = rerank_with_colbert(question, year, specialty)
79
 
80
  context_str = "\n".join(reranked_docs)
81
 
82
  return dspy.ChainOfThought(MedicalAnswer)(
83
  question=question,
84
+ year=year,
85
+ specialty=specialty,
86
  context=context_str
87
  )