Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| import xgboost as xgb | |
| import pickle | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
| import nltk | |
| from nltk.tokenize import word_tokenize | |
| from nltk.corpus import stopwords | |
| import re | |
| # ๐น Download stopwords only when needed | |
| nltk.download("stopwords") | |
| nltk.download("punkt") | |
| nltk.download('punkt_tab') | |
| # Load English stopwords | |
| stop_words = set(stopwords.words("english")) | |
| # ============================ | |
| # ๐น 1. Load Pretrained Medical Q&A Model | |
| # ============================ | |
| # qa_model_name = "deepset/roberta-base-squad2" # Better model for medical Q&A | |
| # tokenizer = AutoTokenizer.from_pretrained(qa_model_name) | |
| # qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name) | |
| model_name = "dmis-lab/biobert-large-cased-v1.1-squad" # โ Updated Model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| qa_model = AutoModelForQuestionAnswering.from_pretrained(model_name) | |
| # ============================ | |
| # ๐น 2. Load Symptom Checker Model | |
| # ============================ | |
| model = xgb.XGBClassifier() | |
| model.load_model("symptom_disease_model.json") # Load trained model | |
| label_encoder = pickle.load(open("label_encoder.pkl", "rb")) # Load label encoder | |
| X_train = pd.read_csv("X_train.csv") # Load symptoms | |
| symptom_list = X_train.columns.tolist() | |
| # ============================ | |
| # ๐น 3. Load Precaution Data | |
| # ============================ | |
| precaution_df = pd.read_csv("Disease precaution.csv") | |
| precaution_dict = { | |
| row["Disease"].strip().lower(): [row[f"Precaution_{i}"] for i in range(1, 5) if pd.notna(row[f"Precaution_{i}"])] | |
| for _, row in precaution_df.iterrows() | |
| } | |
| # ============================ | |
| # ๐น 4. Load Medical Context | |
| # ============================ | |
| def load_medical_context(): | |
| with open("medical_context.txt", "r", encoding="utf-8") as file: | |
| return file.read() | |
| medical_context = load_medical_context() | |
| # ============================ | |
| # ๐น 5. Doctor Database | |
| # ============================ | |
| doctor_database = { | |
| "malaria": [{"name": "Dr. Rajesh Kumar", "specialty": "Infectious Diseases", "location": "Apollo Hospital", "contact": "9876543210"}], | |
| "diabetes": [{"name": "Dr. Anil Mehta", "specialty": "Endocrinologist", "location": "AIIMS Delhi", "contact": "9876543233"}], | |
| "heart attack": [{"name": "Dr. Vikram Singh", "specialty": "Cardiologist", "location": "Medanta Hospital", "contact": "9876543255"}], | |
| } | |
| # ============================ | |
| # ๐น 6. Predict Disease from Symptoms | |
| # ============================ | |
| def predict_disease(user_symptoms): | |
| """Predicts disease based on user symptoms using the trained XGBoost model.""" | |
| input_vector = np.zeros(len(symptom_list)) | |
| for symptom in user_symptoms: | |
| if symptom in symptom_list: | |
| input_vector[symptom_list.index(symptom)] = 1 | |
| input_vector = input_vector.reshape(1, -1) # Reshape for model input | |
| predicted_class = model.predict(input_vector)[0] # Predict disease | |
| predicted_disease = label_encoder.inverse_transform([predicted_class])[0] | |
| return predicted_disease | |
| # ============================ | |
| # ๐น 7. Get Precautions for a Disease | |
| # ============================ | |
| def get_precautions(disease): | |
| """Returns the precautions for a given disease.""" | |
| return precaution_dict.get(disease.lower(), ["No precautions available"]) | |
| # ============================ | |
| # ๐น 8. Answer Medical Questions (Q&A Model) | |
| # ============================ | |
| def get_medical_answer(question): | |
| """Uses the pre-trained Q&A model to answer general medical questions.""" | |
| inputs = tokenizer(question, medical_context, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = qa_model(**inputs) | |
| answer_start = torch.argmax(outputs.start_logits) | |
| answer_end = torch.argmax(outputs.end_logits) + 1 | |
| answer = tokenizer.convert_tokens_to_string( | |
| tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]) | |
| ) | |
| if answer.strip() in ["", "[CLS]", "<s>"]: | |
| return "I'm not sure. Please consult a medical professional." | |
| return answer | |
| # ============================ | |
| # ๐น 9. Book a Doctor's Appointment | |
| # ============================ | |
| def book_appointment(disease): | |
| """Finds a doctor for the given disease and returns appointment details.""" | |
| disease = disease.lower().strip() | |
| doctors = doctor_database.get(disease, []) | |
| if not doctors: | |
| return f"Sorry, no available doctors found for {disease}." | |
| doctor = doctors[0] | |
| return f"Appointment booked with **{doctor['name']}** ({doctor['specialty']}) at **{doctor['location']}**.\nContact: {doctor['contact']}" | |
| # ============================ | |
| # ๐น 10. Handle User Queries | |
| # ============================ | |
| def handle_user_query(user_query): | |
| """Handles user queries related to symptoms, diseases, and doctor appointments.""" | |
| user_query = user_query.lower().strip() | |
| # Check if query is about symptoms | |
| if "symptoms" in user_query or "signs" in user_query: | |
| disease = user_query.replace("symptoms", "").replace("signs", "").strip() | |
| return get_medical_answer(f"What are the symptoms of {disease}?") | |
| # Check if query is about treatment | |
| elif "treatment" in user_query or "treat" in user_query: | |
| disease = user_query.replace("treatment", "").replace("treat", "").strip() | |
| return get_medical_answer(f"What is the treatment for {disease}?") | |
| # Check for doctor recommendation | |
| elif "who should i see" in user_query: | |
| disease = user_query.replace("who should i see for", "").strip() | |
| return book_appointment(disease) | |
| # Check for appointment booking | |
| elif "book appointment" in user_query: | |
| disease = user_query.replace("book appointment for", "").strip() | |
| return book_appointment(disease) | |
| # Default case: general medical question | |
| else: | |
| return get_medical_answer(user_query) | |