HarmonyBot / app.py
Shabi23's picture
Update app.py
27472b6 verified
import os
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huggingface_hub import InferenceClient
from datasets import load_dataset
import markdown2
# Set up Hugging Face cache
os.environ["HF_HOME"] = "/app/.cache"
# Initialize FastAPI application
app = FastAPI()
# Set up templates and static file serving
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# Hugging Face API token
hf_token = os.getenv("HF_TOKEN")
# Load datasets
chat_doctor_dataset = load_dataset("avaliev/chat_doctor")
mental_health_dataset = load_dataset("Amod/mental_health_counseling_conversations")
# Set up Hugging Face Inference Client
client = InferenceClient(
"meta-llama/Meta-Llama-3-8B-Instruct",
token=hf_token,
)
def select_relevant_context(user_input: str) -> str:
"""Select relevant context from the datasets based on user input keywords."""
mental_health_keywords = [
"anxious", "depressed", "stress", "mental health", "counseling",
"therapy", "feelings", "worthless", "suicidal", "panic", "anxiety"
]
medical_keywords = [
"symptoms", "diagnosis", "treatment", "doctor", "prescription", "medication",
"pain", "illness", "disease", "infection", "surgery"
]
if any(keyword in user_input.lower() for keyword in mental_health_keywords):
example = mental_health_dataset['train'][0]
context = f"Counselor: {example['Response']}\nUser: {example['Context']}"
elif any(keyword in user_input.lower() for keyword in medical_keywords):
example = chat_doctor_dataset['train'][0]
context = f"Doctor: {example['input']}\nPatient: {example['output']}"
else:
context = "You are a general assistant. Respond to the user's query in a helpful manner."
return context
def create_prompt(context: str, user_input: str) -> str:
"""Create the final prompt based on the context and user input."""
return f"{context}\n\nUser: {user_input}\nAssistant:"
def render_markdown(text: str) -> str:
"""Render Markdown into HTML."""
return markdown2.markdown(text)
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""Render the homepage."""
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/chat")
async def chat(request: Request):
"""Handle the chat route and process user input."""
try:
data = await request.json()
user_input = data["message"]
context = select_relevant_context(user_input)
prompt = create_prompt(context, user_input)
response = ""
for message in client.chat_completion(
messages=[{"role": "user", "content": prompt}],
max_tokens=500,
stream=True,
):
response += message.choices[0].delta.content
formatted_response = render_markdown(response)
return JSONResponse({"response": formatted_response})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}")