File size: 3,303 Bytes
c2cb658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
959ee5e
 
 
 
 
c2cb658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
959ee5e
c2cb658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e1190d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from flask import Flask, request, jsonify, render_template, send_from_directory
from huggingface_hub import InferenceClient
from datasets import load_dataset
import markdown2
import signal

os.environ["HF_HOME"] = "/app/.cache"

application = Flask(__name__, static_folder='static', template_folder='templates')

hf_token = os.getenv("HF_TOKEN")

chat_doctor_dataset = load_dataset("avaliev/chat_doctor")
mental_health_dataset = load_dataset("Amod/mental_health_counseling_conversations")

client = InferenceClient(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    token=hf_token,
)

def select_relevant_context(user_input):
    mental_health_keywords = [
        "anxious", "depressed", "stress", "mental health", "counseling", 
        "therapy", "feelings", "worthless", "suicidal", "panic", "anxiety"
    ]
    medical_keywords = [
        "symptoms", "diagnosis", "treatment", "doctor", "prescription", "medication",
        "pain", "illness", "disease", "infection", "surgery"
    ]

    # Check if the input contains any mental health-related keywords
    if any(keyword in user_input.lower() for keyword in mental_health_keywords):
        example = mental_health_dataset['train'][0]
        context = f"Counselor: {example['Response']}\nUser: {example['Context']}"
    # Check if the input contains any medical-related keywords
    elif any(keyword in user_input.lower() for keyword in medical_keywords):
        example = chat_doctor_dataset['train'][0]
        context = f"Doctor: {example['input']}\nPatient: {example['output']}"
    else:
        # If no specific keywords are found, provide a general response
        context = "You are a general assistant. Respond to the user's query in a helpful manner."

    return context

def create_prompt(context, user_input):
    prompt = (
        f"{context}\n\n"
        f"User: {user_input}\nAssistant:"
    )
    return prompt

# Function to render Markdown into HTML
def render_markdown(text):
    return markdown2.markdown(text)

@application.route('/')
def index():
    return render_template('index.html')

@application.route('/static/<path:path>')
def send_static(path):
    return send_from_directory('static', path)

@application.route('/chat', methods=['POST'])
def chat():
    user_input = request.json['message']
    
    context = select_relevant_context(user_input)
    
    prompt = create_prompt(context, user_input)
    
    response = ""
    for message in client.chat_completion(
        messages=[{"role": "user", "content": prompt}],
        max_tokens=500,
        stream=True,
    ):
        response += message.choices[0].delta.content
    
    formatted_response = render_markdown(response)
    
    return jsonify({"response": formatted_response})

@application.route('/shutdown', methods=['POST'])
def shutdown():
    if request.environ.get('werkzeug.server.shutdown'):
        shutdown_server()
    else:
        os.kill(os.getpid(), signal.SIGINT)
    return jsonify({"message": "Server is shutting down..."})

def shutdown_server():
    func = request.environ.get('werkzeug.server.shutdown')
    if func is None:
        os.kill(os.getpid(), signal.SIGINT)  # Kill the process if Werkzeug is not available
    else:
        func()

if __name__ == '__main__':
    application.run(host='0.0.0.0', port=7860)