Yokky009 commited on
Commit
917dba9
·
verified ·
1 Parent(s): 8baf7bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -8
app.py CHANGED
@@ -1,14 +1,174 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
2
  import torch
 
 
 
3
 
4
- tokenizer = AutoTokenizer.from_pretrained("lightblue/ao-karasu-72B")
5
- model = AutoModelForCausalLM.from_pretrained("lightblue/ao-karasu-72B", device_map="auto")
 
 
6
 
7
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
8
 
9
- messages = [{"role": "system", "content": "あなたはAIアシスタントです。"}]
10
- messages.append({"role": "user", "content": "イギリスの首相は誰ですか?"})
 
 
 
 
 
 
 
11
 
12
- prompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
 
 
 
 
 
13
 
14
- pipe(prompt, max_new_tokens=100, do_sample=False, temperature=0.0, return_full_text=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
 
8
+ """
9
+ ELYZA(8B)モデルを使用したGradioチャットボット
10
+ Hugging Face Transformersライブラリを使用してローカルでモデルを実行
11
+ """
12
 
13
+ # モデルとトークナイザーの初期化
14
+ MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1" # Sarashina2 3B
15
+ # MODEL_NAME = "sbintuitions/sarashina2-7b" # Sarashina2 7B
16
+ # MODEL_NAME = "sbintuitions/sarashina2-13b" # Sarashina2 13B
17
+ # MODEL_NAME = "sbintuitions/sarashina2-70b" # Sarashina2 70B
18
+ # MODEL_NAME = "sbintuitions/sarashina1-65b" # Sarashina1 65B
19
+ # MODEL_NAME = "elyza/Llama-3-ELYZA-JP-8B" # ELYZA-JP-8B
20
+ # MODEL_NAME = "lightblue/ao-karasu-72B" # ao-karasu-72B
21
 
22
+ print("モデルを読み込み中〜...")
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ MODEL_NAME,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
+ device_map="auto" if torch.cuda.is_available() else None,
28
+ trust_remote_code=True
29
+ )
30
+ print("モデルの読み込みが完了しました〜。")
31
 
32
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
33
+ # True
34
+ print("あ")
35
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
36
+ print("い")
37
+ # Tesla T4
38
 
39
+ @spaces.GPU
40
+
41
+ def respond(
42
+ message,
43
+ history: list[tuple[str, str]],
44
+ system_message,
45
+ max_tokens,
46
+ temperature,
47
+ top_p,
48
+ ):
49
+ """
50
+ チャットボットの応答を生成する関数
51
+ Gradio ChatInterfaceの標準形式に対応
52
+ """
53
+ try:
54
+ # システムメッセージと会話履歴を含むプロンプトを構築
55
+ conversation = ""
56
+ if system_message.strip():
57
+ conversation += f"システム: {system_message}\n"
58
+
59
+ # 会話履歴を追加
60
+ for user_msg, bot_msg in history:
61
+ if user_msg:
62
+ conversation += f"ユーザー: {user_msg}\n"
63
+ if bot_msg:
64
+ conversation += f"アシスタント: {bot_msg}\n"
65
+
66
+ # 現在のメッセージを追加
67
+ conversation += f"ユーザー: {message}\nアシスタント: "
68
+
69
+ # トークン化
70
+ inputs = tokenizer.encode(conversation, return_tensors="pt")
71
+
72
+ # GPU使用時はCUDAに移動
73
+ if torch.cuda.is_available():
74
+ inputs = inputs.cuda()
75
+
76
+ # 応答生成(ストリーミング対応)
77
+ response = ""
78
+ with torch.no_grad():
79
+ # 一度に生成してからストリーミング風に出力
80
+ outputs = model.generate(
81
+ inputs,
82
+ max_new_tokens=max_tokens,
83
+ temperature=temperature,
84
+ top_p=top_p,
85
+ do_sample=True,
86
+ pad_token_id=tokenizer.eos_token_id,
87
+ eos_token_id=tokenizer.eos_token_id,
88
+ repetition_penalty=1.1
89
+ )
90
+
91
+ # 生成されたテキストをデコード
92
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
93
+
94
+ # 変換できるかテスト用!!
95
+ # import json
96
+ # # レスポンス用の辞書を作るときに
97
+ # return json.dumps({"result": generated}, ensure_ascii=False)
98
+
99
+ # 応答部分のみを抽出
100
+ full_response = generated[len(conversation):].strip()
101
+
102
+ # 不要な部分を除去
103
+ if "ユーザー:" in full_response:
104
+ full_response = full_response.split("ユーザー:")[0].strip()
105
+
106
+ # ストリーミング風の出力
107
+ #for i in range(len(full_response)):
108
+ # response = full_response[:i+1]
109
+ # yield response
110
+
111
+ #response = full_response[:len(full_response)] #追加
112
+ #yield response #追加
113
+ #yield full_response #追加
114
+ return full_response #追加
115
+
116
+ except Exception as e:
117
+ #yield f"エラーが発生しました: {str(e)}"
118
+ return f"エラーが発生しました: {str(e)}" #追加
119
+
120
+ """
121
+ Gradio ChatInterfaceを使用したシンプルなチャットボット
122
+ カスタマイズ可能なパラメータを含む
123
+ """
124
+ demo = gr.ChatInterface(
125
+ respond,
126
+ title="🤖 ELYZA Chatbot",
127
+ description="ELYZA-JP-8B モデルを使用した日本語チャットボットです。",
128
+ additional_inputs=[
129
+ gr.Textbox(
130
+ value="あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。",
131
+ label="システムメッセージ",
132
+ lines=3
133
+ ),
134
+ gr.Slider(
135
+ minimum=1,
136
+ maximum=8192,
137
+ value=4096,
138
+ step=1,
139
+ label="最大新規トークン数"
140
+ ),
141
+ gr.Slider(
142
+ minimum=0.1,
143
+ maximum=2.0,
144
+ value=0.7,
145
+ step=0.1,
146
+ label="Temperature (創造性)"
147
+ ),
148
+ gr.Slider(
149
+ minimum=0.1,
150
+ maximum=1.0,
151
+ value=0.95,
152
+ step=0.05,
153
+ label="Top-p (多様性制御)",
154
+ ),
155
+ ],
156
+ theme=gr.themes.Soft(),
157
+ examples=[
158
+ ["こんにちは!今日はどんなことを話しましょうか?"],
159
+ ["日本の文化について教えてください。"],
160
+ ["簡単なレシピを教えてもらえますか?"],
161
+ ["プログラミングについて質問があります。"],
162
+ ],
163
+ cache_examples=False,
164
+ #streaming=False # 追加 ← これで return のみ受け付ける同期モードに
165
+ )
166
+
167
+ if __name__ == "__main__":
168
+ demo.launch(
169
+ server_name="0.0.0.0",
170
+ server_port=7860,
171
+ share=False,
172
+ show_api=True, # API documentation を表示
173
+ debug=True
174
+ )