TestPalmyraMed / app.py
khanhbdang's picture
Update app.py
b44e38c verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Define the model and tokenizer
model_id = "Writer/Palmyra-Med-70B-32k"
@st.cache(allow_output_mutation=True)
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="flash_attention_2",
)
return tokenizer, model
tokenizer, model = load_model()
# Define Streamlit app
st.title("Medical Query Model")
st.write(
"You are interacting with a highly knowledgeable medical model. Enter your medical question below:"
)
user_input = st.text_area("Your Question")
if st.button("Get Response"):
if user_input:
# Prepare input for the model
messages = [
{
"role": "system",
"content": "You are a highly knowledgeable and experienced expert in the healthcare and biomedical field, possessing extensive medical knowledge and practical expertise.",
},
{
"role": "user",
"content": user_input,
},
]
input_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
gen_conf = {
"max_new_tokens": 256,
"eos_token_id": [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("")],
"temperature": 0.0,
"top_p": 0.9,
}
# Generate response
with torch.no_grad():
output_id = model.generate(input_ids, **gen_conf)
output_text = tokenizer.decode(output_id[0][input_ids.shape[1]:], skip_special_tokens=True)
st.write("Response:")
st.write(output_text)
else:
st.warning("Please enter a question.")