Manasa1 commited on
Commit
82385e8
·
verified ·
1 Parent(s): 36c509c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -23
app.py CHANGED
@@ -1,7 +1,8 @@
 
1
  from langchain import PromptTemplate
2
  from langchain.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.vectorstores import FAISS
4
- from langchain_community.llms import CTransformers
5
  from langchain.chains import RetrievalQA
6
  import gradio as gr
7
  from huggingface_hub import hf_hub_download
@@ -10,17 +11,18 @@ DB_FAISS_PATH = "vectorstores/db_faiss"
10
 
11
  def load_llm():
12
  """
13
- Load the LLaMA model for the language model.
14
  """
15
- model_name = 'TheBloke/Llama-2-7B-Chat-GGML'
16
- model_path = hf_hub_download(repo_id=model_name, filename='llama-2-7b-chat.ggmlv3.q8_0.bin', cache_dir='./models')
17
- llm = CTransformers(
18
- model=model_path,
19
- model_type="llama",
20
- max_new_tokens=512,
21
- temperature=0.5
22
- )
23
- return llm
 
24
 
25
  def set_custom_prompt():
26
  """
@@ -38,12 +40,19 @@ Helpful answer:
38
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
39
  return prompt
40
 
41
- def retrieval_QA_chain(llm, prompt, db):
42
  """
43
  Create a RetrievalQA chain with the specified LLM, prompt, and vector store.
44
  """
 
 
 
 
 
 
 
45
  qachain = RetrievalQA.from_chain_type(
46
- llm=llm,
47
  chain_type="stuff",
48
  retriever=db.as_retriever(search_kwargs={'k': 2}),
49
  return_source_documents=True,
@@ -57,9 +66,12 @@ def qa_bot():
57
  """
58
  embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
59
  db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
60
- llm = load_llm()
61
  qa_prompt = set_custom_prompt()
62
- qa = retrieval_QA_chain(llm, qa_prompt, db)
 
 
 
63
  return qa
64
 
65
  bot = qa_bot()
@@ -69,14 +81,17 @@ def chatbot_response(message, history):
69
  Generate a response from the chatbot based on the user input and conversation history.
70
  """
71
  try:
72
- response = bot({'query': message})
73
- answer = response["result"]
74
- sources = response["source_documents"]
75
- if sources:
76
- answer += f"\nSources: {sources}"
 
 
 
 
77
  else:
78
- answer += "\nNo sources found"
79
- history.append((message, answer))
80
  except Exception as e:
81
  history.append((message, f"An error occurred: {str(e)}"))
82
  return history, history
@@ -97,4 +112,4 @@ demo = gr.Interface(
97
  )
98
 
99
  if __name__ == "__main__":
100
- demo.launch()
 
1
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
  from langchain import PromptTemplate
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
+ from langchain_community.llms import CTransformers # You might need to change this if GPT-2 isn't directly supported
6
  from langchain.chains import RetrievalQA
7
  import gradio as gr
8
  from huggingface_hub import hf_hub_download
 
11
 
12
  def load_llm():
13
  """
14
+ Load the GPT-2 model for the language model.
15
  """
16
+ try:
17
+ print("Downloading or loading the GPT-2 model and tokenizer...")
18
+ model_name = 'gpt2'
19
+ model = GPT2LMHeadModel.from_pretrained(model_name)
20
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
21
+ print("Model and tokenizer successfully loaded!")
22
+ return model, tokenizer
23
+ except Exception as e:
24
+ print(f"An error occurred while loading the model: {e}")
25
+ return None, None
26
 
27
  def set_custom_prompt():
28
  """
 
40
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
41
  return prompt
42
 
43
+ def retrieval_QA_chain(llm, tokenizer, prompt, db):
44
  """
45
  Create a RetrievalQA chain with the specified LLM, prompt, and vector store.
46
  """
47
+ def generate_answer(query):
48
+ # Tokenize the input query
49
+ inputs = tokenizer.encode(query, return_tensors='pt')
50
+ # Generate response
51
+ outputs = llm.generate(inputs, max_length=512, temperature=0.5)
52
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+
54
  qachain = RetrievalQA.from_chain_type(
55
+ llm=generate_answer,
56
  chain_type="stuff",
57
  retriever=db.as_retriever(search_kwargs={'k': 2}),
58
  return_source_documents=True,
 
66
  """
67
  embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
68
  db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
69
+ model, tokenizer = load_llm()
70
  qa_prompt = set_custom_prompt()
71
+ if model and tokenizer:
72
+ qa = retrieval_QA_chain(model, tokenizer, qa_prompt, db)
73
+ else:
74
+ qa = None
75
  return qa
76
 
77
  bot = qa_bot()
 
81
  Generate a response from the chatbot based on the user input and conversation history.
82
  """
83
  try:
84
+ if bot:
85
+ response = bot({'query': message})
86
+ answer = response["result"]
87
+ sources = response.get("source_documents", [])
88
+ if sources:
89
+ answer += f"\nSources: {sources}"
90
+ else:
91
+ answer += "\nNo sources found"
92
+ history.append((message, answer))
93
  else:
94
+ history.append((message, "Model is not loaded properly."))
 
95
  except Exception as e:
96
  history.append((message, f"An error occurred: {str(e)}"))
97
  return history, history
 
112
  )
113
 
114
  if __name__ == "__main__":
115
+ demo.launch()