Manasa1 commited on
Commit
4639b02
·
verified ·
1 Parent(s): 82385e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -2,7 +2,6 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
  from langchain import PromptTemplate
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
- from langchain_community.llms import CTransformers # You might need to change this if GPT-2 isn't directly supported
6
  from langchain.chains import RetrievalQA
7
  import gradio as gr
8
  from huggingface_hub import hf_hub_download
@@ -40,19 +39,25 @@ Helpful answer:
40
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
41
  return prompt
42
 
43
- def retrieval_QA_chain(llm, tokenizer, prompt, db):
44
  """
45
- Create a RetrievalQA chain with the specified LLM, prompt, and vector store.
46
  """
47
- def generate_answer(query):
48
- # Tokenize the input query
49
- inputs = tokenizer.encode(query, return_tensors='pt')
50
- # Generate response
51
- outputs = llm.generate(inputs, max_length=512, temperature=0.5)
52
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
53
 
 
 
 
 
 
 
 
 
 
54
  qachain = RetrievalQA.from_chain_type(
55
- llm=generate_answer,
56
  chain_type="stuff",
57
  retriever=db.as_retriever(search_kwargs={'k': 2}),
58
  return_source_documents=True,
@@ -113,3 +118,4 @@ demo = gr.Interface(
113
 
114
  if __name__ == "__main__":
115
  demo.launch()
 
 
2
  from langchain import PromptTemplate
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
 
5
  from langchain.chains import RetrievalQA
6
  import gradio as gr
7
  from huggingface_hub import hf_hub_download
 
39
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
40
  return prompt
41
 
42
+ def generate_answer(prompt_text, model, tokenizer):
43
  """
44
+ Generate an answer using the GPT-2 model and tokenizer.
45
  """
46
+ inputs = tokenizer.encode(prompt_text, return_tensors='pt')
47
+ outputs = model.generate(inputs, max_length=512, temperature=0.5)
48
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
49
 
50
+ def retrieval_QA_chain(model, tokenizer, prompt, db):
51
+ """
52
+ Create a RetrievalQA chain with the specified LLM, prompt, and vector store.
53
+ """
54
+ def generate_answer_fn(query):
55
+ # Format the query with the prompt
56
+ formatted_prompt = prompt.format(context="Some context here", question=query)
57
+ return generate_answer(formatted_prompt, model, tokenizer)
58
+
59
  qachain = RetrievalQA.from_chain_type(
60
+ llm=generate_answer_fn,
61
  chain_type="stuff",
62
  retriever=db.as_retriever(search_kwargs={'k': 2}),
63
  return_source_documents=True,
 
118
 
119
  if __name__ == "__main__":
120
  demo.launch()
121
+