import gradio as gr
import os
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from threading import Thread
import torch
import time
# Set environment variables
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Apollo system prompt
SYSTEM_PROMPT = "You are Apollo, a multilingual medical model. You communicate with people and assist them."
LICENSE = """
@misc{wang2024apollo,
title={Apollo: Lightweight Multilingual Medical LLMs towards Democratizing Medical AI to 6B People},
author={Xidong Wang and Nuo Chen and Junyin Chen and Yan Hu and Yidong Wang and Xiangbo Wu and Anningzhe Gao and Xiang Wan and Haizhou Li and Benyou Wang},
year={2024},
eprint={2403.03640},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{zheng2024efficientlydemocratizingmedicalllms,
title={Efficiently Democratizing Medical LLMs for 50 Languages via a Mixture of Language Family Experts},
author={Guorui Zheng and Xidong Wang and Juhao Liang and Nuo Chen and Yuping Zheng and Benyou Wang},
year={2024},
eprint={2410.10626},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2410.10626},
}
"""
# Apollo model options
APOLLO_MODELS = {
"Apollo": [
"FreedomIntelligence/Apollo-7B",
"FreedomIntelligence/Apollo-6B",
"FreedomIntelligence/Apollo-2B",
"FreedomIntelligence/Apollo-0.5B",
],
"Apollo2": [
"FreedomIntelligence/Apollo2-7B",
"FreedomIntelligence/Apollo2-3.8B",
"FreedomIntelligence/Apollo2-2B",
],
"Apollo-MoE": [
"FreedomIntelligence/Apollo-MoE-7B",
"FreedomIntelligence/Apollo-MoE-1.5B",
"FreedomIntelligence/Apollo-MoE-0.5B",
]
}
# CSS styles
css = """
h1 {
text-align: center;
display: block;
}
.gradio-container {
max-width: 1200px;
margin: auto;
}
"""
# Global variables to store currently loaded model and tokenizer
current_model = None
current_tokenizer = None
current_model_path = None
@spaces.GPU(duration=120)
def load_model(model_path, progress=gr.Progress()):
"""Load the selected model and tokenizer"""
global current_model, current_tokenizer, current_model_path
# If the same model is already loaded, don't reload it
if current_model_path == model_path and current_model is not None:
return "Model already loaded, no need to reload."
# Clean up previously loaded model (if any)
if current_model is not None:
del current_model
del current_tokenizer
torch.cuda.empty_cache()
progress(0.1, desc=f"Starting to load model {model_path}...")
try:
progress(0.3, desc="Loading tokenizer...")
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if 'MoE' in model_path:
from configuration_upcycling_qwen2_moe import UpcyclingQwen2MoeConfig
config = UpcyclingQwen2MoeConfig.from_pretrained(model_path, trust_remote_code=True)
# config_moe.auto_map["AutoConfig"] = "./configuration_upcycling_qwen2_moe.UpcyclingQwen2MoeConfig"
# config_moe.auto_map["AutoModelForCausalLM"] = "./modeling_upcycling_qwen2_moe.UpcyclingQwen2MoeForCausalLM"
current_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False,trust_remote_code=True)
progress(0.5, desc="Loading model...")
if 'MoE' in model_path:
from modeling_upcycling_qwen2_moe import UpcyclingQwen2MoeForCausalLM
current_model = UpcyclingQwen2MoeForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.float16,
config=config,
trust_remote_code=True
)
else:
current_model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.float16,
config=config,
trust_remote_code=True
)
current_model_path = model_path
progress(1.0, desc="Model loading complete!")
return f"Model {model_path} successfully loaded."
except Exception as e:
progress(1.0, desc="Model loading failed!")
return f"Model loading failed: {str(e)}"
@spaces.GPU(duration=120)
def generate_response_non_streaming(instruction, model_name, temperature=0.7, max_tokens=1024):
"""Generate a response from the Apollo model (non-streaming)"""
global current_model, current_tokenizer, current_model_path
print("instruction:",instruction)
# If model is not yet loaded, load it first
if current_model_path != model_name or current_model is None:
load_message = load_model(model_name)
if "failed" in load_message.lower():
return load_message
try:
# 直接使用简单的提示格式,不使用模型的聊天模板
prompt = f"User:{instruction}\nAssistant:"
print("prompt:",prompt)
chat_input = current_tokenizer.encode(prompt, return_tensors="pt").to(current_model.device)
# 生成响应
output = current_model.generate(
input_ids=chat_input,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=(temperature > 0),
eos_token_id=current_tokenizer.eos_token_id # 使用<|endoftext|>作为停止标记
)
# 解码并返回生成的文本
generated_text = current_tokenizer.decode(output[0][len(chat_input[0]):], skip_special_tokens=True)
print("generated_text:",generated_text)
return generated_text
except Exception as e:
return f"生成响应时出错: {str(e)}"
# try:
# # 检查模型是否有聊天模板
# if hasattr(current_tokenizer, 'chat_template') and current_tokenizer.chat_template:
# # 使用模型的聊天模板
# messages = [
# {"role": "system", "content": SYSTEM_PROMPT},
# {"role": "user", "content": instruction}
# ]
# # 使用模型的聊天模板格式化输入
# chat_input = current_tokenizer.apply_chat_template(
# messages,
# tokenize=True,
# return_tensors="pt"
# ).to(current_model.device)
# else:
# # 使用指定的提示格式
# prompt = f"User:{instruction}\nAssistant:"
# chat_input = current_tokenizer.encode(prompt, return_tensors="pt").to(current_model.device)
# # 获取<|endoftext|>的token id,用于停止生成
# eos_token_id = current_tokenizer.eos_token_id
# # 生成响应
# output = current_model.generate(
# input_ids=chat_input,
# max_new_tokens=max_tokens,
# temperature=temperature,
# do_sample=(temperature > 0),
# eos_token_id=current_tokenizer.eos_token_id # 使用<|endoftext|>作为停止标记
# )
# # 解码并返回生成的文本
# generated_text = current_tokenizer.decode(output[0][len(chat_input[0]):], skip_special_tokens=True)
# return generated_text
# except Exception as e:
# return f"生成响应时出错: {str(e)}"
def update_chat_with_response(chatbot, instruction, model_name, temperature, max_tokens):
"""Updates the chatbot with non-streaming response"""
global current_model, current_tokenizer, current_model_path
# If model is not yet loaded, load it first
if current_model_path != model_name or current_model is None:
load_result = load_model(model_name)
if "failed" in load_result.lower():
new_chat = list(chatbot)
new_chat[-1] = (instruction, load_result)
return new_chat
# Generate response using the non-streaming function
response = generate_response_non_streaming(instruction, model_name, temperature, max_tokens)
# Create a copy of the current chatbot and add the response
new_chat = list(chatbot)
new_chat[-1] = (instruction, response)
return new_chat
def on_model_series_change(model_series):
"""Update available model list based on selected model series"""
if model_series in APOLLO_MODELS:
return gr.update(choices=APOLLO_MODELS[model_series], value=APOLLO_MODELS[model_series][0])
return gr.update(choices=[], value=None)
def process_message(message, chat_history, model_series_value, model_name_value, temperature_value, max_tokens_value):
"""Process user message and generate response"""
if message.strip() == "":
return "", chat_history
# 打印用户提交的消息,用于调试
print("instruction:", message)
# Add user message to chat history
chat_history = list(chat_history)
chat_history.append((message, None))
# 自动加载模型(如果需要)
global current_model, current_tokenizer, current_model_path
if current_model_path != model_name_value or current_model is None:
try:
load_result = load_model(model_name_value)
if "failed" in load_result.lower():
chat_history[-1] = (message, f"模型加载失败: {load_result}")
return "", chat_history
except Exception as e:
chat_history[-1] = (message, f"模型加载出错: {str(e)}")
return "", chat_history
# Generate response
try:
response = generate_response_non_streaming(message, model_name_value, temperature_value, max_tokens_value)
# Add response to chat history
chat_history[-1] = (message, response)
except Exception as e:
chat_history[-1] = (message, f"生成响应时出错: {str(e)}")
return "", chat_history
# Create Gradio interface
with gr.Blocks(css=css) as demo:
# Title and description
favicon = "🩺"
gr.Markdown(
f"""# {favicon} Apollo Playground
This is a demo of the multilingual medical model series **[Apollo](https://github.com/FreedomIntelligence/Apollo)** made by **[FreedomIntelligence](https://huggingface.co/FreedomIntelligence)**.
[Apollo1](https://arxiv.org/abs/2403.03640) supports 6 languages. [Apollo2](https://arxiv.org/abs/2410.10626) and [Apollo-MoE](https://arxiv.org/abs/2410.10626) supports 50 languages.
"""
)
with gr.Row():
with gr.Column(scale=1):
# Model selection controls
model_series = gr.Dropdown(
choices=list(APOLLO_MODELS.keys()),
value="Apollo",
label="Select Model Series",
info="First choose Apollo, Apollo2 or Apollo-MoE"
)
model_name = gr.Dropdown(
choices=APOLLO_MODELS["Apollo"],
value=APOLLO_MODELS["Apollo"][0],
label="Select Model Size",
info="Select the specific model size based on the chosen model series"
)
# Parameter settings
with gr.Accordion("Generation Parameters", open=False):
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.05,
label="Temperature"
)
max_tokens = gr.Slider(
minimum=128,
maximum=2048,
value=1024,
step=32,
label="Maximum Tokens"
)
# 移除Load Model按钮和状态显示
# load_button = gr.Button("Load Model")
# model_status = gr.Textbox(label="Model Status", value="No model loaded yet")
with gr.Column(scale=2):
# Chat interface
chatbot = gr.Chatbot(label="Conversation", height=500, value=[]) # Initialize with empty list
user_input = gr.Textbox(
label="Input Medical Question",
placeholder="Example: What are the symptoms of hypertension? 高血压有哪些症状?",
lines=3
)
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear Chat")
# Event handling
# Update model selection when model series changes
model_series.change(
fn=on_model_series_change,
inputs=model_series,
outputs=model_name
)
# 修改提交事件绑定
submit_event = user_input.submit(
fn=process_message,
inputs=[user_input, chatbot, model_series, model_name, temperature, max_tokens],
outputs=[user_input, chatbot]
)
submit_button.click(
fn=process_message,
inputs=[user_input, chatbot, model_series, model_name, temperature, max_tokens],
outputs=[user_input, chatbot]
)
# Clear chat
clear_button.click(
fn=lambda: [],
outputs=chatbot
)
# # Handle message submission
# def user_message_submitted(message, chat_history):
# """Handle user submitted message"""
# # Ensure chat_history is a list
# if chat_history is None:
# chat_history = []
# if message.strip() == "":
# return "", chat_history
# # Add user message to chat history
# chat_history = list(chat_history)
# chat_history.append((message, None))
# return "", chat_history
# # Bind message submission
# submit_event = user_input.submit(
# fn=user_message_submitted,
# inputs=[user_input, chatbot],
# outputs=[user_input, chatbot]
# ).then(
# fn=update_chat_with_response,
# inputs=[chatbot, user_input, model_name, temperature, max_tokens],
# outputs=chatbot
# )
# submit_button.click(
# fn=user_message_submitted,
# inputs=[user_input, chatbot],
# outputs=[user_input, chatbot]
# ).then(
# fn=update_chat_with_response,
# inputs=[chatbot, user_input, model_name, temperature, max_tokens],
# outputs=chatbot
# )
# # Clear chat
# clear_button.click(
# fn=lambda: [],
# outputs=chatbot
# )
examples = [
["Últimamente tengo la tensión un poco alta, ¿cómo debo adaptar mis hábitos?"],
["What are the common side effects of metformin?"],
["中医和西医在治疗高血压方面有什么不同的观点?"],
["मेरा सिर दर्द कर रहा है, मुझे क्या करना चाहिए? "],
["Comment savoir si je suis diabétique ?"],
["ما الدواء الذي يمكنني تناوله إذا لم أستطع النوم ليلاً؟"],
["针对一名28岁女性患者,她左小腿挫伤12小时,伤口有分泌物,骨折端外露,小腿成角畸形,描述她的最佳处理方法。"]
]
gr.Examples(
examples=examples,
inputs=user_input
)
gr.HTML(LICENSE)
if __name__ == "__main__":
demo.launch()