Spaces:
Paused
Paused
import spaces | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import warnings | |
warnings.filterwarnings("ignore") | |
""" | |
ELYZA(8B)モデルを使用したGradioチャットボット | |
Hugging Face Transformersライブラリを使用してローカルでモデルを実行 | |
""" | |
# モデルとトークナイザーの初期化 | |
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1" # Sarashina2 3B | |
# MODEL_NAME = "sbintuitions/sarashina2-7b" # Sarashina2 7B | |
# MODEL_NAME = "sbintuitions/sarashina2-13b" # Sarashina2 13B | |
# MODEL_NAME = "sbintuitions/sarashina2-70b" # Sarashina2 70B | |
# MODEL_NAME = "sbintuitions/sarashina1-65b" # Sarashina1 65B | |
# MODEL_NAME = "elyza/Llama-3-ELYZA-JP-8B" # ELYZA-JP-8B | |
# MODEL_NAME = "lightblue/ao-karasu-72B" # ao-karasu-72B | |
print("モデルを読み込み中〜...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
print("モデルの読み込みが完了しました〜。") | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
# True | |
print("あ") | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
print("い") | |
# Tesla T4 | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
""" | |
チャットボットの応答を生成する関数 | |
Gradio ChatInterfaceの標準形式に対応 | |
""" | |
try: | |
# システムメッセージと会話履歴を含むプロンプトを構築 | |
conversation = "" | |
if system_message.strip(): | |
conversation += f"システム: {system_message}\n" | |
# 会話履歴を追加 | |
for user_msg, bot_msg in history: | |
if user_msg: | |
conversation += f"ユーザー: {user_msg}\n" | |
if bot_msg: | |
conversation += f"アシスタント: {bot_msg}\n" | |
# 現在のメッセージを追加 | |
conversation += f"ユーザー: {message}\nアシスタント: " | |
# トークン化 | |
inputs = tokenizer.encode(conversation, return_tensors="pt") | |
# GPU使用時はCUDAに移動 | |
if torch.cuda.is_available(): | |
inputs = inputs.cuda() | |
# 応答生成(ストリーミング対応) | |
response = "" | |
with torch.no_grad(): | |
# 一度に生成してからストリーミング風に出力 | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
repetition_penalty=1.1 | |
) | |
# 生成されたテキストをデコード | |
generated = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# 変換できるかテスト用!! | |
# import json | |
# # レスポンス用の辞書を作るときに | |
# return json.dumps({"result": generated}, ensure_ascii=False) | |
# 応答部分のみを抽出 | |
full_response = generated[len(conversation):].strip() | |
# 不要な部分を除去 | |
if "ユーザー:" in full_response: | |
full_response = full_response.split("ユーザー:")[0].strip() | |
# ストリーミング風の出力 | |
#for i in range(len(full_response)): | |
# response = full_response[:i+1] | |
# yield response | |
#response = full_response[:len(full_response)] #追加 | |
#yield response #追加 | |
#yield full_response #追加 | |
return full_response #追加 | |
except Exception as e: | |
#yield f"エラーが発生しました: {str(e)}" | |
return f"エラーが発生しました: {str(e)}" #追加 | |
""" | |
Gradio ChatInterfaceを使用したシンプルなチャットボット | |
カスタマイズ可能なパラメータを含む | |
""" | |
demo = gr.ChatInterface( | |
respond, | |
title="🤖 ELYZA Chatbot", | |
description="ELYZA-JP-8B モデルを使用した日本語チャットボットです。", | |
additional_inputs=[ | |
gr.Textbox( | |
value="あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。", | |
label="システムメッセージ", | |
lines=3 | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=8192, | |
value=4096, | |
step=1, | |
label="最大新規トークン数" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature (創造性)" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (多様性制御)", | |
), | |
], | |
theme=gr.themes.Soft(), | |
examples=[ | |
["こんにちは!今日はどんなことを話しましょうか?"], | |
["日本の文化について教えてください。"], | |
["簡単なレシピを教えてもらえますか?"], | |
["プログラミングについて質問があります。"], | |
], | |
cache_examples=False, | |
#streaming=False # 追加 ← これで return のみ受け付ける同期モードに | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_api=True, # API documentation を表示 | |
debug=True | |
) |