safwansajad commited on
Commit
6b8ac59
·
verified ·
1 Parent(s): cf7a2e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -2
app.py CHANGED
@@ -1,5 +1,20 @@
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Load chatbot model
5
  chatbot_model = "microsoft/DialoGPT-medium"
@@ -14,13 +29,32 @@ def generate_response(user_input):
14
  input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
15
  output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
16
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
17
-
18
  # Detect emotion
19
  emotion_result = emotion_pipeline(user_input)
20
  emotion = emotion_result[0]["label"]
21
 
22
  return response, emotion
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  iface = gr.Interface(
25
  fn=generate_response,
26
  inputs=gr.Textbox(label="Enter your message"),
@@ -28,4 +62,9 @@ iface = gr.Interface(
28
  live=True
29
  )
30
 
31
- iface.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
+ from fastapi import FastAPI, Request
4
+ import uvicorn
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+
7
+ # Create FastAPI app
8
+ app = FastAPI()
9
+
10
+ # Add CORS middleware to allow requests from your React Native app
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"], # In production, specify your actual domain
14
+ allow_credentials=True,
15
+ allow_methods=["*"],
16
+ allow_headers=["*"],
17
+ )
18
 
19
  # Load chatbot model
20
  chatbot_model = "microsoft/DialoGPT-medium"
 
29
  input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
30
  output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
31
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
32
+
33
  # Detect emotion
34
  emotion_result = emotion_pipeline(user_input)
35
  emotion = emotion_result[0]["label"]
36
 
37
  return response, emotion
38
 
39
+ # Create API endpoint
40
+ @app.post("/analyze")
41
+ async def analyze_text(request: Request):
42
+ data = await request.json()
43
+ user_input = data.get("text", "")
44
+
45
+ if not user_input:
46
+ return {"error": "No text provided"}
47
+
48
+ response, emotion = generate_response(user_input)
49
+
50
+ # Return structured response
51
+ return {
52
+ "response": response,
53
+ "emotion": emotion,
54
+ "score": emotion_pipeline(user_input)[0]["score"]
55
+ }
56
+
57
+ # Create Gradio interface (optional, can keep for web testing)
58
  iface = gr.Interface(
59
  fn=generate_response,
60
  inputs=gr.Textbox(label="Enter your message"),
 
62
  live=True
63
  )
64
 
65
+ # Mount Gradio app to FastAPI
66
+ app = gr.mount_gradio_app(app, iface, path="/")
67
+
68
+ # Only needed if running directly
69
+ if __name__ == "__main__":
70
+ uvicorn.run(app, host="0.0.0.0", port=7860)