Shabi23 commited on
Commit
27472b6
·
verified ·
1 Parent(s): da1ddb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -35
app.py CHANGED
@@ -1,24 +1,37 @@
1
  import os
2
- from flask import Flask, request, jsonify, render_template, send_from_directory
 
 
 
3
  from huggingface_hub import InferenceClient
4
  from datasets import load_dataset
5
  import markdown2
6
 
 
7
  os.environ["HF_HOME"] = "/app/.cache"
8
 
9
- app = Flask(__name__, static_folder='static', template_folder='templates')
 
10
 
 
 
 
 
 
11
  hf_token = os.getenv("HF_TOKEN")
12
 
 
13
  chat_doctor_dataset = load_dataset("avaliev/chat_doctor")
14
  mental_health_dataset = load_dataset("Amod/mental_health_counseling_conversations")
15
 
 
16
  client = InferenceClient(
17
  "meta-llama/Meta-Llama-3-8B-Instruct",
18
  token=hf_token,
19
  )
20
 
21
- def select_relevant_context(user_input):
 
22
  mental_health_keywords = [
23
  "anxious", "depressed", "stress", "mental health", "counseling",
24
  "therapy", "feelings", "worthless", "suicidal", "panic", "anxiety"
@@ -36,44 +49,44 @@ def select_relevant_context(user_input):
36
  context = f"Doctor: {example['input']}\nPatient: {example['output']}"
37
  else:
38
  context = "You are a general assistant. Respond to the user's query in a helpful manner."
39
-
40
  return context
41
 
42
- def create_prompt(context, user_input):
43
- prompt = (
44
- f"{context}\n\n"
45
- f"User: {user_input}\nAssistant:"
46
- )
47
- return prompt
48
 
49
- def render_markdown(text):
 
50
  return markdown2.markdown(text)
51
 
52
- @app.route('/')
53
- def index():
54
- return render_template('index.html')
 
55
 
56
- @app.route('/static/<path:path>')
57
- def send_static(path):
58
- return send_from_directory('static', path)
 
 
 
59
 
60
- @app.route('/chat', methods=['POST'])
61
- def chat():
62
- user_input = request.json['message']
63
- context = select_relevant_context(user_input)
64
- prompt = create_prompt(context, user_input)
65
-
66
- response = ""
67
- for message in client.chat_completion(
68
- messages=[{"role": "user", "content": prompt}],
69
- max_tokens=500,
70
- stream=True,
71
- ):
72
- response += message.choices[0].delta.content
73
-
74
- formatted_response = render_markdown(response)
75
 
76
- return jsonify({"response": formatted_response})
 
77
 
78
- if __name__ == '__main__':
79
- app.run(debug=False)
 
1
  import os
2
+ from fastapi import FastAPI, Request, HTTPException
3
+ from fastapi.responses import JSONResponse, HTMLResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.templating import Jinja2Templates
6
  from huggingface_hub import InferenceClient
7
  from datasets import load_dataset
8
  import markdown2
9
 
10
+ # Set up Hugging Face cache
11
  os.environ["HF_HOME"] = "/app/.cache"
12
 
13
+ # Initialize FastAPI application
14
+ app = FastAPI()
15
 
16
+ # Set up templates and static file serving
17
+ app.mount("/static", StaticFiles(directory="static"), name="static")
18
+ templates = Jinja2Templates(directory="templates")
19
+
20
+ # Hugging Face API token
21
  hf_token = os.getenv("HF_TOKEN")
22
 
23
+ # Load datasets
24
  chat_doctor_dataset = load_dataset("avaliev/chat_doctor")
25
  mental_health_dataset = load_dataset("Amod/mental_health_counseling_conversations")
26
 
27
+ # Set up Hugging Face Inference Client
28
  client = InferenceClient(
29
  "meta-llama/Meta-Llama-3-8B-Instruct",
30
  token=hf_token,
31
  )
32
 
33
+ def select_relevant_context(user_input: str) -> str:
34
+ """Select relevant context from the datasets based on user input keywords."""
35
  mental_health_keywords = [
36
  "anxious", "depressed", "stress", "mental health", "counseling",
37
  "therapy", "feelings", "worthless", "suicidal", "panic", "anxiety"
 
49
  context = f"Doctor: {example['input']}\nPatient: {example['output']}"
50
  else:
51
  context = "You are a general assistant. Respond to the user's query in a helpful manner."
52
+
53
  return context
54
 
55
+ def create_prompt(context: str, user_input: str) -> str:
56
+ """Create the final prompt based on the context and user input."""
57
+ return f"{context}\n\nUser: {user_input}\nAssistant:"
 
 
 
58
 
59
+ def render_markdown(text: str) -> str:
60
+ """Render Markdown into HTML."""
61
  return markdown2.markdown(text)
62
 
63
+ @app.get("/", response_class=HTMLResponse)
64
+ async def index(request: Request):
65
+ """Render the homepage."""
66
+ return templates.TemplateResponse("index.html", {"request": request})
67
 
68
+ @app.post("/chat")
69
+ async def chat(request: Request):
70
+ """Handle the chat route and process user input."""
71
+ try:
72
+ data = await request.json()
73
+ user_input = data["message"]
74
 
75
+ context = select_relevant_context(user_input)
76
+ prompt = create_prompt(context, user_input)
77
+
78
+ response = ""
79
+ for message in client.chat_completion(
80
+ messages=[{"role": "user", "content": prompt}],
81
+ max_tokens=500,
82
+ stream=True,
83
+ ):
84
+ response += message.choices[0].delta.content
85
+
86
+ formatted_response = render_markdown(response)
87
+
88
+ return JSONResponse({"response": formatted_response})
 
89
 
90
+ except Exception as e:
91
+ raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}")
92