Manasa1 commited on
Commit
8503786
·
verified ·
1 Parent(s): 4639b02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -3,11 +3,25 @@ from langchain import PromptTemplate
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  from langchain.chains import RetrievalQA
 
 
6
  import gradio as gr
7
- from huggingface_hub import hf_hub_download
8
 
9
  DB_FAISS_PATH = "vectorstores/db_faiss"
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def load_llm():
12
  """
13
  Load the GPT-2 model for the language model.
@@ -18,10 +32,10 @@ def load_llm():
18
  model = GPT2LMHeadModel.from_pretrained(model_name)
19
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
20
  print("Model and tokenizer successfully loaded!")
21
- return model, tokenizer
22
  except Exception as e:
23
  print(f"An error occurred while loading the model: {e}")
24
- return None, None
25
 
26
  def set_custom_prompt():
27
  """
@@ -39,29 +53,16 @@ Helpful answer:
39
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
40
  return prompt
41
 
42
- def generate_answer(prompt_text, model, tokenizer):
43
- """
44
- Generate an answer using the GPT-2 model and tokenizer.
45
- """
46
- inputs = tokenizer.encode(prompt_text, return_tensors='pt')
47
- outputs = model.generate(inputs, max_length=512, temperature=0.5)
48
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
49
-
50
- def retrieval_QA_chain(model, tokenizer, prompt, db):
51
  """
52
  Create a RetrievalQA chain with the specified LLM, prompt, and vector store.
53
  """
54
- def generate_answer_fn(query):
55
- # Format the query with the prompt
56
- formatted_prompt = prompt.format(context="Some context here", question=query)
57
- return generate_answer(formatted_prompt, model, tokenizer)
58
-
59
  qachain = RetrievalQA.from_chain_type(
60
- llm=generate_answer_fn,
61
  chain_type="stuff",
62
  retriever=db.as_retriever(search_kwargs={'k': 2}),
63
- return_source_documents=True,
64
- chain_type_kwargs={'prompt': prompt}
65
  )
66
  return qachain
67
 
@@ -71,10 +72,10 @@ def qa_bot():
71
  """
72
  embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
73
  db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
74
- model, tokenizer = load_llm()
75
  qa_prompt = set_custom_prompt()
76
- if model and tokenizer:
77
- qa = retrieval_QA_chain(model, tokenizer, qa_prompt, db)
78
  else:
79
  qa = None
80
  return qa
 
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  from langchain.chains import RetrievalQA
6
+ from langchain.chains.llm import LLMChain
7
+ from langchain.chains.question_answering import load_qa_chain
8
  import gradio as gr
 
9
 
10
  DB_FAISS_PATH = "vectorstores/db_faiss"
11
 
12
+ class GPT2LLM:
13
+ """
14
+ A custom class to wrap the GPT-2 model and tokenizer to be used with LangChain.
15
+ """
16
+ def __init__(self, model, tokenizer):
17
+ self.model = model
18
+ self.tokenizer = tokenizer
19
+
20
+ def __call__(self, prompt_text, max_length=512):
21
+ inputs = self.tokenizer.encode(prompt_text, return_tensors='pt')
22
+ outputs = self.model.generate(inputs, max_length=max_length, temperature=0.5)
23
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
24
+
25
  def load_llm():
26
  """
27
  Load the GPT-2 model for the language model.
 
32
  model = GPT2LMHeadModel.from_pretrained(model_name)
33
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
34
  print("Model and tokenizer successfully loaded!")
35
+ return GPT2LLM(model, tokenizer)
36
  except Exception as e:
37
  print(f"An error occurred while loading the model: {e}")
38
+ return None
39
 
40
  def set_custom_prompt():
41
  """
 
53
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
54
  return prompt
55
 
56
+ def retrieval_QA_chain(llm, prompt, db):
 
 
 
 
 
 
 
 
57
  """
58
  Create a RetrievalQA chain with the specified LLM, prompt, and vector store.
59
  """
60
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
 
 
 
 
61
  qachain = RetrievalQA.from_chain_type(
62
+ llm_chain=llm_chain,
63
  chain_type="stuff",
64
  retriever=db.as_retriever(search_kwargs={'k': 2}),
65
+ return_source_documents=True
 
66
  )
67
  return qachain
68
 
 
72
  """
73
  embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
74
  db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
75
+ llm = load_llm()
76
  qa_prompt = set_custom_prompt()
77
+ if llm:
78
+ qa = retrieval_QA_chain(llm, qa_prompt, db)
79
  else:
80
  qa = None
81
  return qa