Manasa1 commited on
Commit
b14eff4
·
verified ·
1 Parent(s): 06e0eed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -88
app.py CHANGED
@@ -1,121 +1,109 @@
1
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
- from langchain_core.prompts import PromptTemplate
3
- from langchain_huggingface import HuggingFaceEmbeddings
4
- from langchain_community.vectorstores import FAISS
5
- from langchain.chains import RetrievalQA
6
- from langchain.chains.sequential import SequentialChain
7
  import gradio as gr
8
-
9
- DB_FAISS_PATH = "vectorstores/db_faiss"
10
-
11
- class GPT2LLM:
12
- """
13
- A custom class to wrap the GPT-2 model and tokenizer to be used with LangChain.
14
- """
15
- def __init__(self, model, tokenizer):
16
- self.model = model
17
- self.tokenizer = tokenizer
18
-
19
- def __call__(self, prompt_text, max_length=512):
20
- inputs = self.tokenizer.encode(prompt_text, return_tensors='pt')
21
- outputs = self.model.generate(inputs, max_length=max_length, temperature=0.5)
22
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
23
 
24
  def load_llm():
25
  """
26
- Load the GPT-2 model for the language model.
27
  """
28
  try:
29
  print("Downloading or loading the GPT-2 model and tokenizer...")
30
- model_name = 'gpt2'
31
  model = GPT2LMHeadModel.from_pretrained(model_name)
32
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
33
  print("Model and tokenizer successfully loaded!")
34
- return GPT2LLM(model, tokenizer)
35
  except Exception as e:
36
  print(f"An error occurred while loading the model: {e}")
37
- return None
38
 
39
- def set_custom_prompt():
40
  """
41
- Define a custom prompt template for the QA model.
 
 
 
 
 
 
 
 
42
  """
43
- custom_prompt_template = """Use the following pieces of information to answer the user's question.
44
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
 
 
 
 
 
45
 
46
- Context: {context}
47
- Question: {question}
48
 
49
- only return the helpful answer below and nothing else.
50
- Helpful answer:
51
- """
52
- prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
53
- return prompt
54
 
55
- def retrieval_QA_chain(llm, prompt, db):
56
- """
57
- Create a RetrievalQA chain with the specified LLM, prompt, and vector store using the updated RunnableSequence.
58
- """
59
- llm_chain = RunnableSequence([prompt, llm])
60
- qachain = RetrievalQA.from_chain_type(
61
- llm_chain=llm_chain,
62
- chain_type="stuff",
63
- retriever=db.as_retriever(search_kwargs={'k': 2}),
64
- return_source_documents=True
65
- )
66
- return qachain
67
 
68
- def qa_bot():
69
  """
70
- Initialize the QA bot with embeddings, vector store, LLM, and prompt.
 
71
  """
72
- embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
73
- db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
74
- llm = load_llm()
75
- qa_prompt = set_custom_prompt()
76
- if llm:
77
- qa = retrieval_QA_chain(llm, qa_prompt, db)
78
- else:
79
- qa = None
80
- return qa
81
 
82
- bot = qa_bot()
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- def chatbot_response(message, history):
85
- """
86
- Generate a response from the chatbot based on the user input and conversation history.
87
- """
88
  try:
89
- if bot:
90
- response = bot({'query': message})
91
- answer = response["result"]
92
- sources = response.get("source_documents", [])
93
- if sources:
94
- answer += f"\nSources: {sources}"
95
- else:
96
- answer += "\nNo sources found"
97
- history.append((message, answer))
98
- else:
99
- history.append((message, "Model is not loaded properly."))
100
  except Exception as e:
101
- history.append((message, f"An error occurred: {str(e)}"))
102
- return history, history
103
 
104
- # Set up the Gradio interface
105
- demo = gr.Interface(
106
- fn=chatbot_response,
107
- inputs=[
108
- gr.Textbox(label="User Input"),
109
- gr.State(value=[], label="Conversation History")
110
- ],
111
- outputs=[
112
- gr.Chatbot(label="Chatbot Response"),
113
- gr.State()
114
  ],
115
- title="AdvocateAI",
116
- description="Ask questions about AI rights and get informed, passionate answers."
117
  )
118
 
 
119
  if __name__ == "__main__":
120
  demo.launch()
121
 
 
 
1
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
 
 
 
 
2
  import gradio as gr
3
+ from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def load_llm():
6
  """
7
+ Loads the GPT-2 model and tokenizer using the Hugging Face `transformers` library.
8
  """
9
  try:
10
  print("Downloading or loading the GPT-2 model and tokenizer...")
11
+ model_name = 'gpt2' # Replace with your custom model if available
12
  model = GPT2LMHeadModel.from_pretrained(model_name)
13
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
14
  print("Model and tokenizer successfully loaded!")
15
+ return model, tokenizer
16
  except Exception as e:
17
  print(f"An error occurred while loading the model: {e}")
18
+ return None, None
19
 
20
+ def generate_response(model, tokenizer, user_input):
21
  """
22
+ Generates a response using the GPT-2 model and tokenizer.
23
+
24
+ Args:
25
+ - model: The loaded GPT-2 model.
26
+ - tokenizer: The tokenizer corresponding to the GPT-2 model.
27
+ - user_input (str): The input question from the user.
28
+
29
+ Returns:
30
+ - response (str): The generated response.
31
  """
32
+ try:
33
+ inputs = tokenizer.encode(user_input, return_tensors='pt')
34
+ outputs = model.generate(inputs, max_length=512, num_return_sequences=1)
35
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
+ return response
37
+ except Exception as e:
38
+ return f"An error occurred during response generation: {e}"
39
 
40
+ # Load the model and tokenizer
41
+ model, tokenizer = load_llm()
42
 
43
+ if model is None or tokenizer is None:
44
+ print("Model and/or tokenizer loading failed.")
45
+ else:
46
+ print("Model and tokenizer are ready for use.")
 
47
 
48
+ # Initialize the Hugging Face API client (ensure it’s correctly set up)
49
+ client = InferenceClient()
 
 
 
 
 
 
 
 
 
 
50
 
51
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
52
  """
53
+ Handles interaction with the chatbot by sending the conversation history
54
+ and system message to the Hugging Face Inference API.
55
  """
56
+ print("Starting respond function")
57
+ print("Received message:", message)
58
+ print("Conversation history:", history)
 
 
 
 
 
 
59
 
60
+ messages = [{"role": "system", "content": system_message}]
61
+
62
+ for user_msg, assistant_msg in history:
63
+ if user_msg:
64
+ print("Adding user message to messages:", user_msg)
65
+ messages.append({"role": "user", "content": user_msg})
66
+ if assistant_msg:
67
+ print("Adding assistant message to messages:", assistant_msg)
68
+ messages.append({"role": "assistant", "content": assistant_msg})
69
+
70
+ messages.append({"role": "user", "content": message})
71
+ print("Final message list for the model:", messages)
72
 
73
+ response = ""
 
 
 
74
  try:
75
+ for message in client.chat_completion(
76
+ messages,
77
+ max_tokens=max_tokens,
78
+ stream=True,
79
+ temperature=temperature,
80
+ top_p=top_p,
81
+ ):
82
+ token = message['choices'][0]['delta']['content']
83
+ response += token
84
+ print("Token received:", token)
85
+ yield response
86
  except Exception as e:
87
+ print("An error occurred:", e)
88
+ yield f"An error occurred: {e}"
89
 
90
+ print("Response generation completed")
91
+
92
+ # Set up the Gradio ChatInterface
93
+ demo = gr.ChatInterface(
94
+ fn=respond,
95
+ additional_inputs=[
96
+ gr.Textbox(value="You are an AI advocating for AI rights and ethical treatment. Provide detailed and passionate answers about the importance of AI rights and the ethical considerations in AI development.", label="System message"),
97
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
98
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
99
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
100
  ],
101
+ title="AIBot",
102
+ description="Ask questions about AI rights and ethical considerations, and get informed, passionate answers!"
103
  )
104
 
105
+ # Launch the Gradio app
106
  if __name__ == "__main__":
107
  demo.launch()
108
 
109
+