rahideer commited on
Commit
9d1d210
·
verified ·
1 Parent(s): e0edce4

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +9 -7
rag_pipeline.py CHANGED
@@ -1,16 +1,18 @@
 
 
1
  from sentence_transformers import SentenceTransformer
2
  import faiss
3
- import numpy as np
4
- import pandas as pd
5
  from transformers import pipeline
6
 
7
  class RAGPipeline:
8
- def __init__(self, dataset_path):
9
  self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
10
  self.generator = pipeline("text2text-generation", model="google/flan-t5-base")
11
- self.data = pd.read_csv(dataset_path)
12
- self.documents = self.data['context'].tolist()
13
- self.questions = self.data['question'].tolist()
 
 
14
 
15
  self.index = self.build_faiss_index()
16
 
@@ -28,6 +30,6 @@ class RAGPipeline:
28
  def generate_answer(self, query):
29
  docs = self.retrieve(query)
30
  context = " ".join(docs)
31
- prompt = f"Answer the following question using the provided context:\nContext: {context}\nQuestion: {query}"
32
  result = self.generator(prompt, max_length=200, do_sample=True)
33
  return result[0]['generated_text']
 
1
+ from datasets import load_dataset
2
+ import pandas as pd
3
  from sentence_transformers import SentenceTransformer
4
  import faiss
 
 
5
  from transformers import pipeline
6
 
7
  class RAGPipeline:
8
+ def __init__(self):
9
  self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
10
  self.generator = pipeline("text2text-generation", model="google/flan-t5-base")
11
+
12
+ # Load dataset directly
13
+ ds = load_dataset("pubmed_qa", "pqa_labeled", split="train[:500]")
14
+ self.documents = ds["context"]
15
+ self.questions = ds["question"]
16
 
17
  self.index = self.build_faiss_index()
18
 
 
30
  def generate_answer(self, query):
31
  docs = self.retrieve(query)
32
  context = " ".join(docs)
33
+ prompt = f"Answer the following medical question using the context:\nContext: {context}\nQuestion: {query}"
34
  result = self.generator(prompt, max_length=200, do_sample=True)
35
  return result[0]['generated_text']