Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import spaces | |
# Model name | |
model_name = "medalpaca/medalpaca-7b" | |
# Load tokenizer and model globally for efficiency | |
print(f"CUDA available: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
print(f"GPU device count: {torch.cuda.device_count()}") | |
print(f"GPU device name: {torch.cuda.get_device_name(0)}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto", # Use GPU if available | |
load_in_8bit=torch.cuda.is_available() # 8-bit quantization for GPU | |
) | |
def format_prompt(message, chat_history): | |
prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" | |
if chat_history: | |
prompt += "Previous conversation:\n" | |
for turn in chat_history: | |
user_message, assistant_message = turn | |
prompt += f"Human: {user_message}\nAssistant: {assistant_message}\n\n" | |
prompt += f"Human: {message}\nAssistant:" | |
return prompt | |
# <--- This is REQUIRED for ZeroGPU! | |
def generate_response(message, chat_history): | |
prompt = format_prompt(message, chat_history) | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
generation_output = model.generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
) | |
full_output = tokenizer.decode(generation_output[0], skip_special_tokens=True) | |
response = full_output.split("Assistant:")[-1].strip() | |
chat_history.append((message, response)) | |
return "", chat_history | |
with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
gr.Markdown("# MedAlpaca Medical Chatbot") | |
gr.Markdown("A specialized medical chatbot powered by MedAlpaca-7B.") | |
gr.Markdown("Ask medical questions and get responses from a model trained on medical data.") | |
chatbot = gr.Chatbot(type="messages") | |
msg = gr.Textbox(placeholder="Type your medical question here...") | |
clear = gr.Button("Clear") | |
msg.submit(generate_response, [msg, chatbot], [msg, chatbot]) # Pass GPU-decorated function! | |
clear.click(lambda: None, None, chatbot, queue=False) | |
if __name__ == "__main__": | |
print("Starting Gradio app...") | |
demo.launch(server_name="0.0.0.0") | |