Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
from threading import Thread | |
import torch | |
import time | |
# Set environment variables | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
# Apollo system prompt | |
SYSTEM_PROMPT = "You are Apollo, a multilingual medical model. You communicate with people and assist them." | |
LICENSE = """ | |
<div style="font-family: monospace; white-space: pre; margin-top: 20px; line-height: 1.2;"> | |
@misc{wang2024apollo, | |
title={Apollo: Lightweight Multilingual Medical LLMs towards Democratizing Medical AI to 6B People}, | |
author={Xidong Wang and Nuo Chen and Junyin Chen and Yan Hu and Yidong Wang and Xiangbo Wu and Anningzhe Gao and Xiang Wan and Haizhou Li and Benyou Wang}, | |
year={2024}, | |
eprint={2403.03640}, | |
archivePrefix={arXiv}, | |
primaryClass={cs.CL} | |
} | |
@misc{zheng2024efficientlydemocratizingmedicalllms, | |
title={Efficiently Democratizing Medical LLMs for 50 Languages via a Mixture of Language Family Experts}, | |
author={Guorui Zheng and Xidong Wang and Juhao Liang and Nuo Chen and Yuping Zheng and Benyou Wang}, | |
year={2024}, | |
eprint={2410.10626}, | |
archivePrefix={arXiv}, | |
primaryClass={cs.CL}, | |
url={https://arxiv.org/abs/2410.10626}, | |
} | |
</div> | |
""" | |
# Apollo model options | |
APOLLO_MODELS = { | |
"Apollo": [ | |
"FreedomIntelligence/Apollo-7B", | |
"FreedomIntelligence/Apollo-6B", | |
"FreedomIntelligence/Apollo-2B", | |
"FreedomIntelligence/Apollo-0.5B", | |
], | |
"Apollo2": [ | |
"FreedomIntelligence/Apollo2-7B", | |
"FreedomIntelligence/Apollo2-3.8B", | |
"FreedomIntelligence/Apollo2-2B", | |
], | |
"Apollo-MoE": [ | |
"FreedomIntelligence/Apollo-MoE-7B", | |
"FreedomIntelligence/Apollo-MoE-1.5B", | |
"FreedomIntelligence/Apollo-MoE-0.5B", | |
] | |
} | |
# CSS styles | |
css = """ | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
.gradio-container { | |
max-width: 1200px; | |
margin: auto; | |
} | |
""" | |
# Global variables to store currently loaded model and tokenizer | |
current_model = None | |
current_tokenizer = None | |
current_model_path = None | |
def load_model(model_path, progress=gr.Progress()): | |
"""Load the selected model and tokenizer""" | |
global current_model, current_tokenizer, current_model_path | |
# If the same model is already loaded, don't reload it | |
if current_model_path == model_path and current_model is not None: | |
return "Model already loaded, no need to reload." | |
# Clean up previously loaded model (if any) | |
if current_model is not None: | |
del current_model | |
del current_tokenizer | |
torch.cuda.empty_cache() | |
progress(0.1, desc=f"Starting to load model {model_path}...") | |
try: | |
progress(0.3, desc="Loading tokenizer...") | |
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
if 'MoE' in model_path: | |
from configuration_upcycling_qwen2_moe import UpcyclingQwen2MoeConfig | |
config = UpcyclingQwen2MoeConfig.from_pretrained(model_path, trust_remote_code=True) | |
# config_moe.auto_map["AutoConfig"] = "./configuration_upcycling_qwen2_moe.UpcyclingQwen2MoeConfig" | |
# config_moe.auto_map["AutoModelForCausalLM"] = "./modeling_upcycling_qwen2_moe.UpcyclingQwen2MoeForCausalLM" | |
current_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False,trust_remote_code=True) | |
progress(0.5, desc="Loading model...") | |
if 'MoE' in model_path: | |
from modeling_upcycling_qwen2_moe import UpcyclingQwen2MoeForCausalLM | |
current_model = UpcyclingQwen2MoeForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
config=config, | |
trust_remote_code=True | |
) | |
else: | |
current_model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
config=config, | |
trust_remote_code=True | |
) | |
current_model_path = model_path | |
progress(1.0, desc="Model loading complete!") | |
return f"Model {model_path} successfully loaded." | |
except Exception as e: | |
progress(1.0, desc="Model loading failed!") | |
return f"Model loading failed: {str(e)}" | |
def generate_response_non_streaming(instruction, model_name, temperature=0.7, max_tokens=1024): | |
"""Generate a response from the Apollo model (non-streaming)""" | |
global current_model, current_tokenizer, current_model_path | |
print("instruction:",instruction) | |
# If model is not yet loaded, load it first | |
if current_model_path != model_name or current_model is None: | |
load_message = load_model(model_name) | |
if "failed" in load_message.lower(): | |
return load_message | |
try: | |
# 直接使用简单的提示格式,不使用模型的聊天模板 | |
prompt = f"User:{instruction}\nAssistant:" | |
print("prompt:",prompt) | |
chat_input = current_tokenizer.encode(prompt, return_tensors="pt").to(current_model.device) | |
# 生成响应 | |
output = current_model.generate( | |
input_ids=chat_input, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=(temperature > 0), | |
eos_token_id=current_tokenizer.eos_token_id # 使用<|endoftext|>作为停止标记 | |
) | |
# 解码并返回生成的文本 | |
generated_text = current_tokenizer.decode(output[0][len(chat_input[0]):], skip_special_tokens=True) | |
print("generated_text:",generated_text) | |
return generated_text | |
except Exception as e: | |
return f"生成响应时出错: {str(e)}" | |
# try: | |
# # 检查模型是否有聊天模板 | |
# if hasattr(current_tokenizer, 'chat_template') and current_tokenizer.chat_template: | |
# # 使用模型的聊天模板 | |
# messages = [ | |
# {"role": "system", "content": SYSTEM_PROMPT}, | |
# {"role": "user", "content": instruction} | |
# ] | |
# # 使用模型的聊天模板格式化输入 | |
# chat_input = current_tokenizer.apply_chat_template( | |
# messages, | |
# tokenize=True, | |
# return_tensors="pt" | |
# ).to(current_model.device) | |
# else: | |
# # 使用指定的提示格式 | |
# prompt = f"User:{instruction}\nAssistant:" | |
# chat_input = current_tokenizer.encode(prompt, return_tensors="pt").to(current_model.device) | |
# # 获取<|endoftext|>的token id,用于停止生成 | |
# eos_token_id = current_tokenizer.eos_token_id | |
# # 生成响应 | |
# output = current_model.generate( | |
# input_ids=chat_input, | |
# max_new_tokens=max_tokens, | |
# temperature=temperature, | |
# do_sample=(temperature > 0), | |
# eos_token_id=current_tokenizer.eos_token_id # 使用<|endoftext|>作为停止标记 | |
# ) | |
# # 解码并返回生成的文本 | |
# generated_text = current_tokenizer.decode(output[0][len(chat_input[0]):], skip_special_tokens=True) | |
# return generated_text | |
# except Exception as e: | |
# return f"生成响应时出错: {str(e)}" | |
def update_chat_with_response(chatbot, instruction, model_name, temperature, max_tokens): | |
"""Updates the chatbot with non-streaming response""" | |
global current_model, current_tokenizer, current_model_path | |
# If model is not yet loaded, load it first | |
if current_model_path != model_name or current_model is None: | |
load_result = load_model(model_name) | |
if "failed" in load_result.lower(): | |
new_chat = list(chatbot) | |
new_chat[-1] = (instruction, load_result) | |
return new_chat | |
# Generate response using the non-streaming function | |
response = generate_response_non_streaming(instruction, model_name, temperature, max_tokens) | |
# Create a copy of the current chatbot and add the response | |
new_chat = list(chatbot) | |
new_chat[-1] = (instruction, response) | |
return new_chat | |
def on_model_series_change(model_series): | |
"""Update available model list based on selected model series""" | |
if model_series in APOLLO_MODELS: | |
return gr.update(choices=APOLLO_MODELS[model_series], value=APOLLO_MODELS[model_series][0]) | |
return gr.update(choices=[], value=None) | |
def process_message(message, chat_history, model_series_value, model_name_value, temperature_value, max_tokens_value): | |
"""Process user message and generate response""" | |
if message.strip() == "": | |
return "", chat_history | |
# 打印用户提交的消息,用于调试 | |
print("instruction:", message) | |
# Add user message to chat history | |
chat_history = list(chat_history) | |
chat_history.append((message, None)) | |
# 自动加载模型(如果需要) | |
global current_model, current_tokenizer, current_model_path | |
if current_model_path != model_name_value or current_model is None: | |
try: | |
load_result = load_model(model_name_value) | |
if "failed" in load_result.lower(): | |
chat_history[-1] = (message, f"模型加载失败: {load_result}") | |
return "", chat_history | |
except Exception as e: | |
chat_history[-1] = (message, f"模型加载出错: {str(e)}") | |
return "", chat_history | |
# Generate response | |
try: | |
response = generate_response_non_streaming(message, model_name_value, temperature_value, max_tokens_value) | |
# Add response to chat history | |
chat_history[-1] = (message, response) | |
except Exception as e: | |
chat_history[-1] = (message, f"生成响应时出错: {str(e)}") | |
return "", chat_history | |
# Create Gradio interface | |
with gr.Blocks(css=css) as demo: | |
# Title and description | |
favicon = "🩺" | |
gr.Markdown( | |
f"""# {favicon} Apollo Playground | |
This is a demo of the multilingual medical model series **[Apollo](https://github.com/FreedomIntelligence/Apollo)** made by **[FreedomIntelligence](https://huggingface.co/FreedomIntelligence)**. | |
[Apollo1](https://arxiv.org/abs/2403.03640) supports 6 languages. [Apollo2](https://arxiv.org/abs/2410.10626) and [Apollo-MoE](https://arxiv.org/abs/2410.10626) supports 50 languages. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Model selection controls | |
model_series = gr.Dropdown( | |
choices=list(APOLLO_MODELS.keys()), | |
value="Apollo", | |
label="Select Model Series", | |
info="First choose Apollo, Apollo2 or Apollo-MoE" | |
) | |
model_name = gr.Dropdown( | |
choices=APOLLO_MODELS["Apollo"], | |
value=APOLLO_MODELS["Apollo"][0], | |
label="Select Model Size", | |
info="Select the specific model size based on the chosen model series" | |
) | |
# Parameter settings | |
with gr.Accordion("Generation Parameters", open=False): | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.7, | |
step=0.05, | |
label="Temperature" | |
) | |
max_tokens = gr.Slider( | |
minimum=128, | |
maximum=2048, | |
value=1024, | |
step=32, | |
label="Maximum Tokens" | |
) | |
# 移除Load Model按钮和状态显示 | |
# load_button = gr.Button("Load Model") | |
# model_status = gr.Textbox(label="Model Status", value="No model loaded yet") | |
with gr.Column(scale=2): | |
# Chat interface | |
chatbot = gr.Chatbot(label="Conversation", height=500, value=[]) # Initialize with empty list | |
user_input = gr.Textbox( | |
label="Input Medical Question", | |
placeholder="Example: What are the symptoms of hypertension? 高血压有哪些症状?", | |
lines=3 | |
) | |
submit_button = gr.Button("Submit") | |
clear_button = gr.Button("Clear Chat") | |
# Event handling | |
# Update model selection when model series changes | |
model_series.change( | |
fn=on_model_series_change, | |
inputs=model_series, | |
outputs=model_name | |
) | |
# 修改提交事件绑定 | |
submit_event = user_input.submit( | |
fn=process_message, | |
inputs=[user_input, chatbot, model_series, model_name, temperature, max_tokens], | |
outputs=[user_input, chatbot] | |
) | |
submit_button.click( | |
fn=process_message, | |
inputs=[user_input, chatbot, model_series, model_name, temperature, max_tokens], | |
outputs=[user_input, chatbot] | |
) | |
# Clear chat | |
clear_button.click( | |
fn=lambda: [], | |
outputs=chatbot | |
) | |
# # Handle message submission | |
# def user_message_submitted(message, chat_history): | |
# """Handle user submitted message""" | |
# # Ensure chat_history is a list | |
# if chat_history is None: | |
# chat_history = [] | |
# if message.strip() == "": | |
# return "", chat_history | |
# # Add user message to chat history | |
# chat_history = list(chat_history) | |
# chat_history.append((message, None)) | |
# return "", chat_history | |
# # Bind message submission | |
# submit_event = user_input.submit( | |
# fn=user_message_submitted, | |
# inputs=[user_input, chatbot], | |
# outputs=[user_input, chatbot] | |
# ).then( | |
# fn=update_chat_with_response, | |
# inputs=[chatbot, user_input, model_name, temperature, max_tokens], | |
# outputs=chatbot | |
# ) | |
# submit_button.click( | |
# fn=user_message_submitted, | |
# inputs=[user_input, chatbot], | |
# outputs=[user_input, chatbot] | |
# ).then( | |
# fn=update_chat_with_response, | |
# inputs=[chatbot, user_input, model_name, temperature, max_tokens], | |
# outputs=chatbot | |
# ) | |
# # Clear chat | |
# clear_button.click( | |
# fn=lambda: [], | |
# outputs=chatbot | |
# ) | |
examples = [ | |
["Últimamente tengo la tensión un poco alta, ¿cómo debo adaptar mis hábitos?"], | |
["What are the common side effects of metformin?"], | |
["中医和西医在治疗高血压方面有什么不同的观点?"], | |
["मेरा सिर दर्द कर रहा है, मुझे क्या करना चाहिए? "], | |
["Comment savoir si je suis diabétique ?"], | |
["ما الدواء الذي يمكنني تناوله إذا لم أستطع النوم ليلاً؟"], | |
["针对一名28岁女性患者,她左小腿挫伤12小时,伤口有分泌物,骨折端外露,小腿成角畸形,描述她的最佳处理方法。"] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=user_input | |
) | |
gr.HTML(LICENSE) | |
if __name__ == "__main__": | |
demo.launch() | |