safwansajad commited on
Commit
43bfd5c
·
verified ·
1 Parent(s): bc974ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -84
app.py CHANGED
@@ -1,15 +1,15 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from transformers import pipeline
4
- import torch
5
  from fastapi import FastAPI, Request
6
  from fastapi.middleware.cors import CORSMiddleware
7
- import json
 
 
8
 
9
- # Initialize FastAPI
10
- app = FastAPI()
 
 
11
 
12
- # Add CORS middleware to allow requests from your React Native app
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"], # In production, restrict this to your app's domain
@@ -19,101 +19,82 @@ app.add_middleware(
19
  )
20
 
21
  # Load chatbot model
22
- print("Loading chatbot model...")
23
  chatbot_model = "microsoft/DialoGPT-medium"
24
  tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
25
  model = AutoModelForCausalLM.from_pretrained(chatbot_model)
26
 
27
  # Load emotion detection model
28
  print("Loading emotion detection model...")
29
- emotion_pipeline = pipeline(
30
- "text-classification",
31
- model="bhadresh-savani/distilbert-base-uncased-emotion",
32
- return_all_scores=False
33
- )
 
 
 
 
34
 
35
- # Store conversation history
36
- chat_history = {}
 
 
 
37
 
38
- # Define API endpoint for chatbot
39
- @app.post("/api/chat")
40
- async def chat_endpoint(request: Request):
41
- data = await request.json()
42
- user_input = data.get("message", "")
43
- session_id = data.get("session_id", "default")
 
 
 
 
 
 
44
 
45
- # Process input and get response
46
- response, emotion, score = generate_response(user_input, session_id)
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return {
49
  "response": response,
50
  "emotion": emotion,
51
- "score": float(score),
52
  "session_id": session_id
53
  }
54
 
55
- def generate_response(user_input, session_id="default"):
56
- # Initialize chat history for new sessions
57
- if session_id not in chat_history:
58
- chat_history[session_id] = []
59
 
60
- # Format the input with chat history
61
- bot_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
62
-
63
- # Append user input to chat history
64
- chat_history[session_id].append(bot_input_ids)
65
-
66
- # Generate a response
67
- with torch.no_grad():
68
- chat_history_ids = model.generate(
69
- bot_input_ids,
70
- max_length=200,
71
- pad_token_id=tokenizer.eos_token_id,
72
- no_repeat_ngram_size=3,
73
- do_sample=True,
74
- top_k=50,
75
- top_p=0.95,
76
- temperature=0.7
77
- )
78
-
79
- # Decode the response
80
- response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
81
-
82
- # Detect emotion
83
- try:
84
- emotion_result = emotion_pipeline(user_input)[0]
85
- emotion = emotion_result["label"]
86
- score = emotion_result["score"]
87
- except Exception as e:
88
- print(f"Error detecting emotion: {e}")
89
- emotion = "unknown"
90
- score = 0.0
91
-
92
- return response, emotion, score
93
-
94
- # Gradio Interface
95
- def gradio_generate(user_input):
96
- response, emotion, score = generate_response(user_input)
97
- return response, emotion, f"{score:.4f}"
98
 
99
- # Create Gradio interface
100
- iface = gr.Interface(
101
- fn=gradio_generate,
102
- inputs=gr.Textbox(label="Enter your message", placeholder="Type your message here..."),
103
- outputs=[
104
- gr.Textbox(label="Chatbot Response"),
105
- gr.Textbox(label="Emotion Detected"),
106
- gr.Textbox(label="Emotion Score")
107
- ],
108
- title="Mental Health Chatbot",
109
- description="A simple mental health chatbot with emotion detection",
110
- allow_flagging="never"
111
- )
112
-
113
- # Mount the Gradio app to FastAPI
114
- app = gr.mount_gradio_app(app, iface, path="/")
115
 
116
- # Run the app
117
  if __name__ == "__main__":
118
  import uvicorn
119
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
1
  from fastapi import FastAPI, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
+ from pydantic import BaseModel
5
+ import uuid
6
 
7
+ # Initialize FastAPI app
8
+ app = FastAPI(title="Mental Health Chatbot API",
9
+ description="API for mental health chatbot with emotion detection",
10
+ version="1.0.0")
11
 
12
+ # Add CORS middleware
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"], # In production, restrict this to your app's domain
 
19
  )
20
 
21
  # Load chatbot model
22
+ print("Loading DialoGPT model...")
23
  chatbot_model = "microsoft/DialoGPT-medium"
24
  tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
25
  model = AutoModelForCausalLM.from_pretrained(chatbot_model)
26
 
27
  # Load emotion detection model
28
  print("Loading emotion detection model...")
29
+ emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
30
+
31
+ # Store chat histories by session ID
32
+ chat_histories = {}
33
+
34
+ # Request models
35
+ class ChatRequest(BaseModel):
36
+ message: str
37
+ session_id: str = None
38
 
39
+ class ChatResponse(BaseModel):
40
+ response: str
41
+ emotion: str
42
+ emotion_score: float
43
+ session_id: str
44
 
45
+ @app.get("/")
46
+ async def root():
47
+ return {"message": "Mental Health Chatbot API is running. Use /api/chat to interact with the chatbot."}
48
+
49
+ @app.post("/api/chat", response_model=ChatResponse)
50
+ async def chat(request: ChatRequest):
51
+ # Create a new session ID if not provided
52
+ session_id = request.session_id if request.session_id else str(uuid.uuid4())
53
+
54
+ # Initialize chat history for new sessions
55
+ if session_id not in chat_histories:
56
+ chat_histories[session_id] = []
57
 
58
+ user_input = request.message
 
59
 
60
+ # Generate chatbot response
61
+ input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
62
+ output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
63
+ response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
64
+
65
+ # Detect emotion
66
+ emotion_result = emotion_pipeline(user_input)[0]
67
+ emotion = emotion_result["label"]
68
+ score = emotion_result["score"]
69
+
70
+ # Store chat history
71
+ chat_histories[session_id].append({"role": "user", "content": user_input})
72
+ chat_histories[session_id].append({"role": "bot", "content": response})
73
+
74
+ # Return the response
75
  return {
76
  "response": response,
77
  "emotion": emotion,
78
+ "emotion_score": float(score),
79
  "session_id": session_id
80
  }
81
 
82
+ @app.get("/api/history/{session_id}")
83
+ async def get_chat_history(session_id: str):
84
+ if session_id not in chat_histories:
85
+ return {"error": "Session not found"}
86
 
87
+ return {"session_id": session_id, "history": chat_histories[session_id]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ @app.delete("/api/history/{session_id}")
90
+ async def clear_chat_history(session_id: str):
91
+ if session_id in chat_histories:
92
+ chat_histories.pop(session_id)
93
+ return {"message": f"Chat history for session {session_id} cleared"}
94
+
95
+ return {"error": "Session not found"}
 
 
 
 
 
 
 
 
 
96
 
97
+ # Launch the API server
98
  if __name__ == "__main__":
99
  import uvicorn
100
  uvicorn.run(app, host="0.0.0.0", port=7860)