|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
import pandas as pd |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model_id = "ibm-granite/granite-3.3-2b-instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained(model_id) |
|
return pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
|
|
generator = load_model() |
|
|
|
|
|
st.title("π©Ί HealthAI β Intelligent Healthcare Assistant") |
|
|
|
|
|
tab1, tab2, tab3, tab4 = st.tabs([ |
|
"π§ Patient Chat", "π§Ύ Disease Prediction", |
|
"π Treatment Plans", "π Health Analytics" |
|
]) |
|
|
|
|
|
|
|
|
|
with tab1: |
|
st.subheader("Ask any health-related question") |
|
query = st.text_area("Enter your question here") |
|
|
|
if st.button("Get Advice", key="chat"): |
|
if query.strip() == "": |
|
st.warning("Please enter a question.") |
|
else: |
|
with st.spinner("Thinking..."): |
|
response = generator(query, max_new_tokens=200)[0]["generated_text"] |
|
st.success("AI Response:") |
|
st.markdown(f"markdown\n{response}\n") |
|
|
|
|
|
|
|
|
|
with tab2: |
|
st.subheader("Enter your symptoms (comma-separated)") |
|
symptoms = st.text_input("E.g. persistent fever, fatigue, dry cough") |
|
|
|
if st.button("AI Diagnose", key="predict"): |
|
if symptoms.strip() == "": |
|
st.warning("Please enter your symptoms.") |
|
else: |
|
prompt = ( |
|
f"I am feeling unwell. My symptoms are: {symptoms}.\n" |
|
"Can you please suggest what possible conditions I might have based on this?\n" |
|
"List top 3 possible diseases with a short reason for each, and give a seriousness score out of 10." |
|
) |
|
with st.spinner("Analyzing symptoms..."): |
|
result = generator(prompt, max_new_tokens=300, do_sample=True)[0]['generated_text'] |
|
st.success("AI Prediction:") |
|
st.markdown(f"markdown\n{result}\n") |
|
|
|
|
|
|
|
|
|
with tab3: |
|
st.header("π Treatment Plan Generator") |
|
condition = st.text_input("Enter the known condition (e.g., Asthma, Diabetes)") |
|
|
|
if st.button("Get Full Treatment Plan"): |
|
if not condition.strip(): |
|
st.warning("Please enter a condition.") |
|
else: |
|
with st.spinner("Generating treatment plan..."): |
|
|
|
def get_response(prompt): |
|
return generator(prompt, max_new_tokens=1000, temperature=0.7, do_sample=True)[0]['generated_text'].strip() |
|
|
|
prompts = { |
|
"1οΈβ£ Medications": f"What medications are usually prescribed for {condition}?", |
|
"2οΈβ£ Diet": f"What diet is recommended for someone with {condition}?", |
|
"3οΈβ£ Exercise": f"What type of physical activities should a person with {condition} follow?", |
|
"4οΈβ£ Follow-Up & Monitoring": f"What follow-up steps and monitoring should be done for {condition}?", |
|
"5οΈβ£ Precautions": f"What precautions should be taken by someone with {condition}?", |
|
"6οΈβ£ Mental Health & Stress": f"How can someone with {condition} manage stress and mental health?" |
|
} |
|
|
|
for section, prompt in prompts.items(): |
|
st.subheader(section) |
|
st.markdown(f"markdown\n{get_response(prompt)}\n") |
|
|
|
|
|
|
|
|
|
with tab4: |
|
st.subheader("Track your health data over time") |
|
uploaded = st.file_uploader("Upload your CSV file (with columns like 'blood_pressure', 'heart_rate')", type=["csv"]) |
|
|
|
if uploaded: |
|
df = pd.read_csv(uploaded) |
|
st.dataframe(df) |
|
|
|
for col in df.select_dtypes(include=['float', 'int']).columns: |
|
st.line_chart(df[col]) |
|
if df[col].mean() > df[col].iloc[-1]: |
|
st.info(f"π {col} is improving.") |
|
else: |
|
st.warning(f"π {col} is rising β consider medical advice.") |