safwansajad commited on
Commit
bc974ff
·
verified ·
1 Parent(s): f3aaa59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -14
app.py CHANGED
@@ -1,37 +1,119 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Load chatbot model
 
5
  chatbot_model = "microsoft/DialoGPT-medium"
6
  tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
7
  model = AutoModelForCausalLM.from_pretrained(chatbot_model)
8
 
9
  # Load emotion detection model
10
- emotion_pipeline = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def generate_response(user_input):
13
- # Generate chatbot response
14
- input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
15
- output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
16
- response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Detect emotion
19
- emotion_result = emotion_pipeline(user_input)
20
- emotion = emotion_result[0]["label"]
21
- score = emotion_result[0]["score"]
 
 
 
 
 
22
 
23
  return response, emotion, score
24
 
25
  # Gradio Interface
 
 
 
 
 
26
  iface = gr.Interface(
27
- fn=generate_response,
28
- inputs=gr.Textbox(label="Enter your message"),
29
  outputs=[
30
  gr.Textbox(label="Chatbot Response"),
31
  gr.Textbox(label="Emotion Detected"),
32
  gr.Textbox(label="Emotion Score")
33
  ],
34
- live=False
 
 
35
  )
36
 
37
- iface.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from transformers import pipeline
4
+ import torch
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ import json
8
+
9
+ # Initialize FastAPI
10
+ app = FastAPI()
11
+
12
+ # Add CORS middleware to allow requests from your React Native app
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # In production, restrict this to your app's domain
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
 
21
  # Load chatbot model
22
+ print("Loading chatbot model...")
23
  chatbot_model = "microsoft/DialoGPT-medium"
24
  tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
25
  model = AutoModelForCausalLM.from_pretrained(chatbot_model)
26
 
27
  # Load emotion detection model
28
+ print("Loading emotion detection model...")
29
+ emotion_pipeline = pipeline(
30
+ "text-classification",
31
+ model="bhadresh-savani/distilbert-base-uncased-emotion",
32
+ return_all_scores=False
33
+ )
34
+
35
+ # Store conversation history
36
+ chat_history = {}
37
+
38
+ # Define API endpoint for chatbot
39
+ @app.post("/api/chat")
40
+ async def chat_endpoint(request: Request):
41
+ data = await request.json()
42
+ user_input = data.get("message", "")
43
+ session_id = data.get("session_id", "default")
44
+
45
+ # Process input and get response
46
+ response, emotion, score = generate_response(user_input, session_id)
47
+
48
+ return {
49
+ "response": response,
50
+ "emotion": emotion,
51
+ "score": float(score),
52
+ "session_id": session_id
53
+ }
54
 
55
+ def generate_response(user_input, session_id="default"):
56
+ # Initialize chat history for new sessions
57
+ if session_id not in chat_history:
58
+ chat_history[session_id] = []
59
+
60
+ # Format the input with chat history
61
+ bot_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
62
+
63
+ # Append user input to chat history
64
+ chat_history[session_id].append(bot_input_ids)
65
+
66
+ # Generate a response
67
+ with torch.no_grad():
68
+ chat_history_ids = model.generate(
69
+ bot_input_ids,
70
+ max_length=200,
71
+ pad_token_id=tokenizer.eos_token_id,
72
+ no_repeat_ngram_size=3,
73
+ do_sample=True,
74
+ top_k=50,
75
+ top_p=0.95,
76
+ temperature=0.7
77
+ )
78
+
79
+ # Decode the response
80
+ response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
81
 
82
  # Detect emotion
83
+ try:
84
+ emotion_result = emotion_pipeline(user_input)[0]
85
+ emotion = emotion_result["label"]
86
+ score = emotion_result["score"]
87
+ except Exception as e:
88
+ print(f"Error detecting emotion: {e}")
89
+ emotion = "unknown"
90
+ score = 0.0
91
 
92
  return response, emotion, score
93
 
94
  # Gradio Interface
95
+ def gradio_generate(user_input):
96
+ response, emotion, score = generate_response(user_input)
97
+ return response, emotion, f"{score:.4f}"
98
+
99
+ # Create Gradio interface
100
  iface = gr.Interface(
101
+ fn=gradio_generate,
102
+ inputs=gr.Textbox(label="Enter your message", placeholder="Type your message here..."),
103
  outputs=[
104
  gr.Textbox(label="Chatbot Response"),
105
  gr.Textbox(label="Emotion Detected"),
106
  gr.Textbox(label="Emotion Score")
107
  ],
108
+ title="Mental Health Chatbot",
109
+ description="A simple mental health chatbot with emotion detection",
110
+ allow_flagging="never"
111
  )
112
 
113
+ # Mount the Gradio app to FastAPI
114
+ app = gr.mount_gradio_app(app, iface, path="/")
115
+
116
+ # Run the app
117
+ if __name__ == "__main__":
118
+ import uvicorn
119
+ uvicorn.run(app, host="0.0.0.0", port=7860)