Sanchit2207 commited on
Commit
6003605
·
verified ·
1 Parent(s): 2eaca46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -1,33 +1,32 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
- import torch
3
- import gradio as gr
4
-
5
- # Load model and tokenizer
6
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
7
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
8
-
9
- chat_history_ids = None
10
-
11
  def chat(user_input, history=[]):
12
  global chat_history_ids
13
 
14
- # Encode input + add to chat history
15
  new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
 
 
16
  bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
17
 
18
- # Generate reply
19
- chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
 
 
 
 
 
 
 
 
 
 
 
20
  response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
21
 
 
 
 
 
22
  history.append((user_input, response))
23
  return history, history
24
 
25
- # Chat UI
26
- chatbot_ui = gr.ChatInterface(
27
- fn=chat,
28
- title="Teen Mental Health Chatbot 🤖💬",
29
- description="Talk to a supportive AI. Not a replacement for professional help.",
30
- )
31
-
32
- chatbot_ui.launch()
33
 
 
 
 
 
 
 
 
 
 
 
 
1
  def chat(user_input, history=[]):
2
  global chat_history_ids
3
 
4
+ # Encode user input
5
  new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
6
+
7
+ # Append to chat history or start new
8
  bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
9
 
10
+ # Generate response (LIMIT max_length and num_return_sequences)
11
+ chat_history_ids = model.generate(
12
+ bot_input_ids,
13
+ max_length=1000, # Can reduce to 500 if needed
14
+ pad_token_id=tokenizer.eos_token_id,
15
+ do_sample=True,
16
+ top_k=50,
17
+ top_p=0.95,
18
+ temperature=0.7,
19
+ num_return_sequences=1
20
+ )
21
+
22
+ # Decode only the new response part
23
  response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
24
 
25
+ # Truncate response if it’s too long (hard limit)
26
+ if len(response) > 1000:
27
+ response = response[:1000] + "..."
28
+
29
  history.append((user_input, response))
30
  return history, history
31
 
 
 
 
 
 
 
 
 
32