Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,31 +1,23 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
4 |
-
from transformers import AutoTokenizer
|
5 |
-
from optimum.onnxruntime import ORTModelForCausalLM
|
6 |
|
7 |
# --- Configuration ---
|
8 |
-
#
|
9 |
-
|
10 |
-
# ์์ํ๋ ๋ชจ๋ธ ํ์ผ ์ด๋ฆ (์ ์ฅ์ ๊ตฌ์กฐ ํ์ธ ํ์, ์์ ๊ฒฝ์ฐ ์ผ๋ฐ ๋ชจ๋ธ ์๋)
|
11 |
-
# Q4 ๋ชจ๋ธ ํ์ผ์ด 'onnx/model_q4.onnx' ํํ์ผ ์ ์์ -> optimum ์ด ์๋ ๊ฐ์ง ์๋
|
12 |
-
# ์ฐ์ ๋ช
์์ ํ์ผ ์ง์ ์์ด ๋ก๋ ์๋
|
13 |
-
ONNX_FILE_NAME = None # e.g., "onnx/model_q4.onnx" if needed and present
|
14 |
|
15 |
-
|
16 |
-
|
|
|
17 |
|
18 |
# --- Device Selection ---
|
19 |
try:
|
20 |
if torch.cuda.is_available():
|
21 |
device = "cuda:0"
|
22 |
provider = "CUDAExecutionProvider"
|
23 |
-
print("
|
24 |
-
# Mps (Apple Silicon) - Gradio Spaces ์์๋ ์ฌ์ฉ ๋ถ๊ฐ ๊ฐ๋ฅ์ฑ ๋์
|
25 |
-
# elif torch.backends.mps.is_available():
|
26 |
-
# device = "mps"
|
27 |
-
# provider = "CoreMLExecutionProvider" # Needs check
|
28 |
-
# print("Using MPS (Apple Silicon).")
|
29 |
else:
|
30 |
device = "cpu"
|
31 |
provider = "CPUExecutionProvider"
|
@@ -36,132 +28,150 @@ except Exception as e:
|
|
36 |
provider = "CPUExecutionProvider"
|
37 |
|
38 |
# --- Model and Tokenizer Loading ---
|
|
|
|
|
|
|
|
|
39 |
print(f"Attempting to load model: {MODEL_ID}")
|
40 |
print(f"Using device: {device}, Execution Provider: {provider}")
|
41 |
|
42 |
try:
|
43 |
-
|
44 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) #, token=HF_TOKEN)
|
45 |
print("Tokenizer loaded successfully.")
|
46 |
|
47 |
-
# ONNX ๋ชจ๋ธ ๋ก๋
|
48 |
-
# provider_options ์ค์ (ํ์์ ์ถ๊ฐ ์ต์ ํ ๊ฐ๋ฅ)
|
49 |
model = ORTModelForCausalLM.from_pretrained(
|
50 |
MODEL_ID,
|
51 |
-
# filename=ONNX_FILE_NAME, # ํ์ผ๋ช
๋ช
์๊ฐ ํ์ ์์ ์ ์์ (์๋ ๊ฐ์ง)
|
52 |
provider=provider,
|
53 |
-
# use_auth_token=HF_TOKEN, # Gated ๋ชจ๋ธ์ผ ๊ฒฝ์ฐ ํ์
|
54 |
use_cache=True, # KV ์บ์ ์ฌ์ฉ
|
55 |
-
#
|
56 |
)
|
57 |
-
# ๋ชจ๋ธ์ ์ง์ ๋ ๋๋ฐ์ด์ค๋ก ์ด๋ (ORTModel ์ ๋ด๋ถ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ผ๋ ๋ช
์ ๊ฐ๋ฅ)
|
58 |
-
# model.to(device) # ORTModel ์์๋ .to() ๊ฐ ์์ ์ ์์, provider ์ง์ ์ผ๋ก ์ฒ๋ฆฌ
|
59 |
print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
|
60 |
model_loaded_successfully = True
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
except Exception as e:
|
63 |
-
|
|
|
|
|
64 |
print(f"Error type: {type(e).__name__}")
|
65 |
print(f"Error message: {e}")
|
66 |
-
print("
|
|
|
67 |
model_loaded_successfully = False
|
68 |
-
# ๋ชจ๋ธ ๋ก๋ฉ ์คํจ ์ Gradio ์ฑ ์คํ ์ค๋จ ๋๋ ์ค๋ฅ ๋ฉ์์ง ํ์
|
69 |
-
# raise gr.Error(f"CRITICAL: Failed to load model '{MODEL_ID}'. Check logs. Error: {e}")
|
70 |
|
71 |
# --- Chat Function ---
|
72 |
def chat_function(message: str, history: list):
|
73 |
-
if not model_loaded_successfully:
|
74 |
-
|
|
|
75 |
|
76 |
-
# Gemma Instruct ํ์์ ๋ง๊ฒ history ์ message ๋ฅผ ํ๋กฌํํธ๋ก ๋ณํ
|
77 |
-
# AutoTokenizer ์ chat_template ์ด ์ ์๋์ด ์์ผ๋ฉด ์ฌ์ฉ ๊ถ์ฅ
|
78 |
try:
|
79 |
-
#
|
80 |
chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
|
81 |
for user_msg, model_msg in history:
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
87 |
try:
|
88 |
prompt = tokenizer.apply_chat_template(
|
89 |
chat_messages,
|
90 |
tokenize=False,
|
91 |
-
add_generation_prompt=True
|
92 |
)
|
93 |
except Exception as template_error:
|
94 |
-
|
95 |
-
print(f"Warning: Failed to apply chat template ({template_error}). Falling back to manual prompt construction.")
|
96 |
prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
|
|
|
97 |
for user_msg, model_msg in history:
|
98 |
-
prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
|
99 |
-
prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
|
100 |
-
prompt_parts.append(f"<start_of_turn>user\n{message}<end_of_turn>")
|
101 |
prompt_parts.append("<start_of_turn>model")
|
102 |
prompt = "\n".join(prompt_parts)
|
103 |
|
|
|
104 |
|
105 |
-
#
|
106 |
-
|
107 |
-
# print("--------------\n")
|
108 |
-
|
109 |
-
# ์
๋ ฅ ํ ํฐํ
|
110 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(device) # ๋ชจ๋ธ๊ณผ ๊ฐ์ ๋๋ฐ์ด์ค๋ก
|
111 |
|
112 |
# ์๋ต ์์ฑ
|
113 |
print("Generating response...")
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
123 |
print("Generation complete.")
|
124 |
|
125 |
-
#
|
126 |
-
# inputs[0] ๋์ inputs['input_ids'][0] ์ฌ์ฉํด์ผ ํ ์ ์์
|
127 |
input_token_len = inputs['input_ids'].shape[1]
|
128 |
generated_tokens = outputs[0][input_token_len:]
|
129 |
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
130 |
|
131 |
-
#
|
132 |
response = response.replace("<end_of_turn>", "").strip()
|
133 |
|
134 |
-
# print("
|
135 |
-
|
136 |
-
#
|
|
|
|
|
|
|
137 |
|
138 |
-
# history.append((message, response)) # history ๋ Gradio ๊ฐ ๊ด๋ฆฌ
|
139 |
return response
|
140 |
|
141 |
except Exception as e:
|
142 |
-
print(f"Error during generation
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
145 |
|
146 |
|
147 |
-
# --- Gradio Interface ---
|
148 |
print("Creating Gradio Interface...")
|
149 |
iface = gr.ChatInterface(
|
150 |
-
fn=chat_function
|
151 |
-
title="AI Assistant (Gemma 3 1B ONNX)",
|
152 |
description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
|
153 |
-
chatbot=
|
154 |
-
|
155 |
-
|
156 |
-
retry_btn=
|
157 |
-
undo_btn=
|
158 |
-
clear_btn=
|
159 |
-
|
|
|
160 |
examples=[["Hello!"], ["Write a poem about the internet."]]
|
161 |
)
|
162 |
|
163 |
# --- Launch App ---
|
164 |
if __name__ == "__main__":
|
165 |
print("Launching Gradio App...")
|
166 |
-
#
|
167 |
-
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
4 |
+
from transformers import AutoTokenizer, __version__ as transformers_version # ๋ฒ์ ํ์ธ์ฉ import ์ถ๊ฐ
|
5 |
+
from optimum.onnxruntime import ORTModelForCausalLM, __version__ as optimum_version # ๋ฒ์ ํ์ธ์ฉ import ์ถ๊ฐ
|
6 |
|
7 |
# --- Configuration ---
|
8 |
+
MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA" # ์ฌ์ฉ์๊ฐ ์ง์ ํ GQA ๋ชจ๋ธ
|
9 |
+
ONNX_FILE_NAME = None # ํ์ผ๋ช
์๋ ๊ฐ์ง ์๋
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
print(f"Using Transformers version: {transformers_version}")
|
12 |
+
print(f"Using Optimum version: {optimum_version}")
|
13 |
+
print(f"Using Gradio version: {gr.__version__}") # Gradio ๋ฒ์ ๋ก๊น
|
14 |
|
15 |
# --- Device Selection ---
|
16 |
try:
|
17 |
if torch.cuda.is_available():
|
18 |
device = "cuda:0"
|
19 |
provider = "CUDAExecutionProvider"
|
20 |
+
print("Attempting to use GPU (CUDA).")
|
|
|
|
|
|
|
|
|
|
|
21 |
else:
|
22 |
device = "cpu"
|
23 |
provider = "CPUExecutionProvider"
|
|
|
28 |
provider = "CPUExecutionProvider"
|
29 |
|
30 |
# --- Model and Tokenizer Loading ---
|
31 |
+
model = None
|
32 |
+
tokenizer = None
|
33 |
+
model_loaded_successfully = False
|
34 |
+
|
35 |
print(f"Attempting to load model: {MODEL_ID}")
|
36 |
print(f"Using device: {device}, Execution Provider: {provider}")
|
37 |
|
38 |
try:
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
|
40 |
print("Tokenizer loaded successfully.")
|
41 |
|
42 |
+
# ONNX ๋ชจ๋ธ ๋ก๋ ์๋
|
|
|
43 |
model = ORTModelForCausalLM.from_pretrained(
|
44 |
MODEL_ID,
|
|
|
45 |
provider=provider,
|
|
|
46 |
use_cache=True, # KV ์บ์ ์ฌ์ฉ
|
47 |
+
# use_io_binding=False # GPU ์ฌ์ฉ ์ ๋ฌธ์ ๋ฐ์ํ๋ฉด False ๋ก ์๋
|
48 |
)
|
|
|
|
|
49 |
print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
|
50 |
model_loaded_successfully = True
|
51 |
|
52 |
+
except ValueError as ve:
|
53 |
+
# ValueError ๋ ๋ชจ๋ธ ํ์
๋ฏธ์ง์ ์ค๋ฅ์ผ ๊ฐ๋ฅ์ฑ์ด ๋์
|
54 |
+
print(f"!!!!!!!!!!!!!! CRITICAL MODEL LOADING ERROR (ValueError) !!!!!!!!!!!!!!")
|
55 |
+
print(f"Model: {MODEL_ID}")
|
56 |
+
print(f"Error message: {ve}")
|
57 |
+
print("This likely means the installed 'transformers' library version does NOT support the 'gemma3_text' architecture.")
|
58 |
+
print("Ensure 'requirements.txt' specifies a recent version (e.g., transformers>=4.41.0) and the Space has been rebuilt/restarted.")
|
59 |
+
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
60 |
+
# ๋ชจ๋ธ ๋ก๋ฉ ์คํจ ์ ์ฌ์ฉ์์๊ฒ ๋ช
ํํ ์๋ฆผ
|
61 |
+
model_loaded_successfully = False
|
62 |
+
|
63 |
except Exception as e:
|
64 |
+
# ๋ค๋ฅธ ์ข
๋ฅ์ ์์ธ ์ฒ๋ฆฌ (๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ, ๋คํธ์ํฌ ๋ฑ)
|
65 |
+
print(f"!!!!!!!!!!!!!! UNEXPECTED MODEL LOADING ERROR !!!!!!!!!!!!!!")
|
66 |
+
print(f"Model: {MODEL_ID}")
|
67 |
print(f"Error type: {type(e).__name__}")
|
68 |
print(f"Error message: {e}")
|
69 |
+
print("Check Space resources (memory limits), network connection, or other dependencies.")
|
70 |
+
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
71 |
model_loaded_successfully = False
|
|
|
|
|
72 |
|
73 |
# --- Chat Function ---
|
74 |
def chat_function(message: str, history: list):
|
75 |
+
if not model_loaded_successfully or model is None or tokenizer is None:
|
76 |
+
# ๋ชจ๋ธ ๋ก๋ ์คํจ ์ ์ค๋ฅ ๋ฉ์์ง ๋ฐํ
|
77 |
+
return "Error: The AI model is not loaded. Please check the application logs."
|
78 |
|
|
|
|
|
79 |
try:
|
80 |
+
# ์ฑํ
๊ธฐ๋ก์ messages ํ์์ผ๋ก ๋ณํ
|
81 |
chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
|
82 |
for user_msg, model_msg in history:
|
83 |
+
# None ๊ฐ ์ฒดํฌ ์ถ๊ฐ (Gradio ์ด๊ธฐ ์ํ ๋ฑ์์ ๋ฐ์ ๊ฐ๋ฅ)
|
84 |
+
if user_msg:
|
85 |
+
chat_messages.append({"role": "user", "content": user_msg})
|
86 |
+
if model_msg:
|
87 |
+
chat_messages.append({"role": "model", "content": model_msg})
|
88 |
+
if message: # ํ์ฌ ์ฌ์ฉ์ ๋ฉ์์ง ์ถ๊ฐ
|
89 |
+
chat_messages.append({"role": "user", "content": message})
|
90 |
+
|
91 |
+
# ํ๋กฌํํธ ์์ฑ (apply_chat_template ์๋, ์คํจ ์ ์๋)
|
92 |
+
prompt = ""
|
93 |
try:
|
94 |
prompt = tokenizer.apply_chat_template(
|
95 |
chat_messages,
|
96 |
tokenize=False,
|
97 |
+
add_generation_prompt=True
|
98 |
)
|
99 |
except Exception as template_error:
|
100 |
+
print(f"Warning: Failed to apply chat template ({template_error}). Using manual prompt construction.")
|
|
|
101 |
prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
|
102 |
+
# history ์์ model ๋ฉ์์ง๊ฐ None ์ผ ์ ์์์ ์ ์
|
103 |
for user_msg, model_msg in history:
|
104 |
+
if user_msg: prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
|
105 |
+
if model_msg: prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
|
106 |
+
if message: prompt_parts.append(f"<start_of_turn>user\n{message}<end_of_turn>")
|
107 |
prompt_parts.append("<start_of_turn>model")
|
108 |
prompt = "\n".join(prompt_parts)
|
109 |
|
110 |
+
# print(f"--- PROMPT --- \n{prompt}\n--------------")
|
111 |
|
112 |
+
# ์
๋ ฅ ํ ํฐํ ๋ฐ ๋๋ฐ์ด์ค ์ด๋
|
113 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
|
|
|
|
|
|
114 |
|
115 |
# ์๋ต ์์ฑ
|
116 |
print("Generating response...")
|
117 |
+
with torch.no_grad(): # ์ถ๋ก ์ ๊ทธ๋๋์ธํธ ๊ณ์ฐ ๋นํ์ฑํ
|
118 |
+
outputs = model.generate(
|
119 |
+
**inputs,
|
120 |
+
max_new_tokens=512,
|
121 |
+
do_sample=True,
|
122 |
+
temperature=0.7,
|
123 |
+
top_k=50,
|
124 |
+
top_p=0.9,
|
125 |
+
pad_token_id=tokenizer.eos_token_id # EOS ํ ํฐ์ ํจ๋ฉ ํ ํฐ์ผ๋ก ์ฌ์ฉ
|
126 |
+
)
|
127 |
print("Generation complete.")
|
128 |
|
129 |
+
# ๋์ฝ๋ฉ (์
๋ ฅ ๋ถ๋ถ ์ ์ธ)
|
|
|
130 |
input_token_len = inputs['input_ids'].shape[1]
|
131 |
generated_tokens = outputs[0][input_token_len:]
|
132 |
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
133 |
|
134 |
+
# ํ์ฒ๋ฆฌ
|
135 |
response = response.replace("<end_of_turn>", "").strip()
|
136 |
|
137 |
+
# print(f"--- RESPONSE --- \n{response}\n--------------")
|
138 |
+
|
139 |
+
# ๋น ์๋ต ์ฒ๋ฆฌ
|
140 |
+
if not response:
|
141 |
+
print("Warning: Generated empty response.")
|
142 |
+
response = "Sorry, I couldn't generate a response for that."
|
143 |
|
|
|
144 |
return response
|
145 |
|
146 |
except Exception as e:
|
147 |
+
print(f"!!!!!!!!!!!!!! Error during generation !!!!!!!!!!!!!!")
|
148 |
+
print(f"Error type: {type(e).__name__}")
|
149 |
+
print(f"Error message: {e}")
|
150 |
+
print("Input message:", message)
|
151 |
+
# traceback.print_exc() # ํ์์ ์์ธ ํธ๋ ์ด์ค๋ฐฑ ์ถ๋ ฅ
|
152 |
+
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
153 |
+
return f"Sorry, an error occurred during response generation. Please check logs."
|
154 |
|
155 |
|
156 |
+
# --- Gradio Interface (์์ ๋จ) ---
|
157 |
print("Creating Gradio Interface...")
|
158 |
iface = gr.ChatInterface(
|
159 |
+
fn=chat_function, # ๋ชจ๋ธ ๋ก๋ ์คํจ ์ chat_function ๋ด๋ถ์์ ์ฒ๋ฆฌ
|
160 |
+
title="AI Assistant (Gemma 3 1B ONNX-GQA)",
|
161 |
description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
|
162 |
+
# chatbot ์์ ฏ์ type='messages' ์ถ๊ฐ
|
163 |
+
chatbot=gr.Chatbot(height=600, type="messages", bubble_full_width=False),
|
164 |
+
# ์ง์ํ์ง ์๋ ๋ฒํผ ์ธ์ ์ ๊ฑฐ
|
165 |
+
# retry_btn=None, # ์ ๊ฑฐ
|
166 |
+
# undo_btn=None, # ์ ๊ฑฐ
|
167 |
+
# clear_btn=None, # ์ ๊ฑฐ
|
168 |
+
# submit_btn ๋์ ๊ธฐ๋ณธ ๋ฒํผ ์ฌ์ฉ
|
169 |
+
theme=gr.themes.Soft(),
|
170 |
examples=[["Hello!"], ["Write a poem about the internet."]]
|
171 |
)
|
172 |
|
173 |
# --- Launch App ---
|
174 |
if __name__ == "__main__":
|
175 |
print("Launching Gradio App...")
|
176 |
+
# ๋ชจ๋ธ ๋ก๋ฉ ์คํจ ์์๋ ์ธํฐํ์ด์ค๋ ์คํํ๋, ์ค๋ฅ ๋ฉ์์ง ํ์
|
177 |
+
iface.launch()
|