File size: 2,495 Bytes
17db205
 
 
 
 
 
 
 
b6fd9d1
17db205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137f262
17db205
 
 
137f262
17db205
 
 
137f262
17db205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a066f9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
from flask import Flask, request, jsonify, render_template, send_from_directory
from huggingface_hub import InferenceClient
from datasets import load_dataset
import markdown2

os.environ["HF_HOME"] = "/app/.cache"

app = Flask(__name__, static_folder='static', template_folder='templates')

hf_token = os.getenv("HF_TOKEN")

chat_doctor_dataset = load_dataset("avaliev/chat_doctor")
mental_health_dataset = load_dataset("Amod/mental_health_counseling_conversations")

client = InferenceClient(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    token=hf_token,
)

def select_relevant_context(user_input):
    mental_health_keywords = [
        "anxious", "depressed", "stress", "mental health", "counseling", 
        "therapy", "feelings", "worthless", "suicidal", "panic", "anxiety"
    ]
    medical_keywords = [
        "symptoms", "diagnosis", "treatment", "doctor", "prescription", "medication",
        "pain", "illness", "disease", "infection", "surgery"
    ]

    if any(keyword in user_input.lower() for keyword in mental_health_keywords):
        example = mental_health_dataset['train'][0]
        context = f"Counselor: {example['Response']}\nUser: {example['Context']}"
    elif any(keyword in user_input.lower() for keyword in medical_keywords):
        example = chat_doctor_dataset['train'][0]
        context = f"Doctor: {example['input']}\nPatient: {example['output']}"
    else:
        context = "You are a general assistant. Respond to the user's query in a helpful manner."

    return context

def create_prompt(context, user_input):
    prompt = (
        f"{context}\n\n"
        f"User: {user_input}\nAssistant:"
    )
    return prompt

def render_markdown(text):
    return markdown2.markdown(text)

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/static/<path:path>')
def send_static(path):
    return send_from_directory('static', path)

@app.route('/chat', methods=['POST'])
def chat():
    user_input = request.json['message']
    context = select_relevant_context(user_input)
    prompt = create_prompt(context, user_input)
    
    response = ""
    for message in client.chat_completion(
        messages=[{"role": "user", "content": prompt}],
        max_tokens=500,
        stream=True,
    ):
        response += message.choices[0].delta.content
    
    formatted_response = render_markdown(response)
    
    return jsonify({"response": formatted_response})

if __name__ == '__main__':
    app.run(debug=False)