File size: 3,270 Bytes
17db205
27472b6
 
 
 
17db205
 
 
 
27472b6
17db205
 
27472b6
 
17db205
27472b6
 
 
 
 
17db205
 
27472b6
17db205
 
 
27472b6
17db205
 
 
 
 
27472b6
 
17db205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27472b6
17db205
 
27472b6
 
 
17db205
27472b6
 
17db205
 
27472b6
 
 
 
17db205
27472b6
 
 
 
 
 
17db205
27472b6
 
 
 
 
 
 
 
 
 
 
 
 
 
17db205
27472b6
 
17db205
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huggingface_hub import InferenceClient
from datasets import load_dataset
import markdown2

# Set up Hugging Face cache
os.environ["HF_HOME"] = "/app/.cache"

# Initialize FastAPI application
app = FastAPI()

# Set up templates and static file serving
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

# Hugging Face API token
hf_token = os.getenv("HF_TOKEN")

# Load datasets
chat_doctor_dataset = load_dataset("avaliev/chat_doctor")
mental_health_dataset = load_dataset("Amod/mental_health_counseling_conversations")

# Set up Hugging Face Inference Client
client = InferenceClient(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    token=hf_token,
)

def select_relevant_context(user_input: str) -> str:
    """Select relevant context from the datasets based on user input keywords."""
    mental_health_keywords = [
        "anxious", "depressed", "stress", "mental health", "counseling", 
        "therapy", "feelings", "worthless", "suicidal", "panic", "anxiety"
    ]
    medical_keywords = [
        "symptoms", "diagnosis", "treatment", "doctor", "prescription", "medication",
        "pain", "illness", "disease", "infection", "surgery"
    ]

    if any(keyword in user_input.lower() for keyword in mental_health_keywords):
        example = mental_health_dataset['train'][0]
        context = f"Counselor: {example['Response']}\nUser: {example['Context']}"
    elif any(keyword in user_input.lower() for keyword in medical_keywords):
        example = chat_doctor_dataset['train'][0]
        context = f"Doctor: {example['input']}\nPatient: {example['output']}"
    else:
        context = "You are a general assistant. Respond to the user's query in a helpful manner."
    
    return context

def create_prompt(context: str, user_input: str) -> str:
    """Create the final prompt based on the context and user input."""
    return f"{context}\n\nUser: {user_input}\nAssistant:"

def render_markdown(text: str) -> str:
    """Render Markdown into HTML."""
    return markdown2.markdown(text)

@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    """Render the homepage."""
    return templates.TemplateResponse("index.html", {"request": request})

@app.post("/chat")
async def chat(request: Request):
    """Handle the chat route and process user input."""
    try:
        data = await request.json()
        user_input = data["message"]

        context = select_relevant_context(user_input)
        prompt = create_prompt(context, user_input)

        response = ""
        for message in client.chat_completion(
            messages=[{"role": "user", "content": prompt}],
            max_tokens=500,
            stream=True,
        ):
            response += message.choices[0].delta.content
        
        formatted_response = render_markdown(response)

        return JSONResponse({"response": formatted_response})
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}")