Sanchit2207 commited on
Commit
06bb798
·
verified ·
1 Parent(s): 6003605

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -1,32 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
1
  def chat(user_input, history=[]):
2
  global chat_history_ids
3
 
4
- # Encode user input
5
  new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
6
 
7
- # Append to chat history or start new
8
  bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
9
 
10
- # Generate response (LIMIT max_length and num_return_sequences)
11
  chat_history_ids = model.generate(
12
  bot_input_ids,
13
- max_length=1000, # Can reduce to 500 if needed
14
  pad_token_id=tokenizer.eos_token_id,
15
  do_sample=True,
16
  top_k=50,
17
  top_p=0.95,
18
- temperature=0.7,
19
- num_return_sequences=1
20
  )
21
 
22
- # Decode only the new response part
23
  response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
24
 
25
- # Truncate response if it’s too long (hard limit)
26
- if len(response) > 1000:
27
- response = response[:1000] + "..."
28
-
29
  history.append((user_input, response))
30
  return history, history
31
 
 
 
 
 
 
 
32
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
+ import gradio as gr
4
+
5
+ # Load pre-trained model
6
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
7
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
8
+
9
+ # Global chat history
10
+ chat_history_ids = None
11
+
12
  def chat(user_input, history=[]):
13
  global chat_history_ids
14
 
15
+ # Tokenize user input
16
  new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
17
 
18
+ # Append to chat history
19
  bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
20
 
21
+ # Generate response with controlled output
22
  chat_history_ids = model.generate(
23
  bot_input_ids,
24
+ max_length=500, # shorter for safety
25
  pad_token_id=tokenizer.eos_token_id,
26
  do_sample=True,
27
  top_k=50,
28
  top_p=0.95,
29
+ temperature=0.7
 
30
  )
31
 
32
+ # Decode model output
33
  response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
34
 
35
+ # Append to chat history
 
 
 
36
  history.append((user_input, response))
37
  return history, history
38
 
39
+ # Create a Gradio ChatInterface
40
+ chatbot_ui = gr.ChatInterface(
41
+ fn=chat,
42
+ title="Teen Mental Health Chatbot 🤖💬",
43
+ description="Talk to a supportive AI. Not a replacement for professional help.",
44
+ )
45
 
46
+ # Launch the app (required!)
47
+ if __name__ == "__main__":
48
+ chatbot_ui.launch()