Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from optimum.onnxruntime import ORTModelForCausalLM
|
6 |
+
|
7 |
+
# --- Configuration ---
|
8 |
+
# ์ฌ์ฉ์๋์ด ์ง์ ํ ONNX ๋ชจ๋ธ ID
|
9 |
+
MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA"
|
10 |
+
# ์์ํ๋ ๋ชจ๋ธ ํ์ผ ์ด๋ฆ (์ ์ฅ์ ๊ตฌ์กฐ ํ์ธ ํ์, ์์ ๊ฒฝ์ฐ ์ผ๋ฐ ๋ชจ๋ธ ์๋)
|
11 |
+
# Q4 ๋ชจ๋ธ ํ์ผ์ด 'onnx/model_q4.onnx' ํํ์ผ ์ ์์ -> optimum ์ด ์๋ ๊ฐ์ง ์๋
|
12 |
+
# ์ฐ์ ๋ช
์์ ํ์ผ ์ง์ ์์ด ๋ก๋ ์๋
|
13 |
+
ONNX_FILE_NAME = None # e.g., "onnx/model_q4.onnx" if needed and present
|
14 |
+
|
15 |
+
# Hugging Face Hub ํ ํฐ (ํ์์ - Gemma ๋ชจ๋ธ์ Gated์ผ ์ ์์ผ๋ ONNX ์ปค๋ฎค๋ํฐ ๋ฒ์ ์ ์๋ ์ ์์)
|
16 |
+
# HF_TOKEN = os.getenv("HF_TOKEN") # Space secrets ์์ ์ค์
|
17 |
+
|
18 |
+
# --- Device Selection ---
|
19 |
+
try:
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
device = "cuda:0"
|
22 |
+
provider = "CUDAExecutionProvider"
|
23 |
+
print("Using GPU (CUDA).")
|
24 |
+
# Mps (Apple Silicon) - Gradio Spaces ์์๋ ์ฌ์ฉ ๋ถ๊ฐ ๊ฐ๋ฅ์ฑ ๋์
|
25 |
+
# elif torch.backends.mps.is_available():
|
26 |
+
# device = "mps"
|
27 |
+
# provider = "CoreMLExecutionProvider" # Needs check
|
28 |
+
# print("Using MPS (Apple Silicon).")
|
29 |
+
else:
|
30 |
+
device = "cpu"
|
31 |
+
provider = "CPUExecutionProvider"
|
32 |
+
print("Using CPU.")
|
33 |
+
except Exception as e:
|
34 |
+
print(f"Device detection error: {e}. Defaulting to CPU.")
|
35 |
+
device = "cpu"
|
36 |
+
provider = "CPUExecutionProvider"
|
37 |
+
|
38 |
+
# --- Model and Tokenizer Loading ---
|
39 |
+
print(f"Attempting to load model: {MODEL_ID}")
|
40 |
+
print(f"Using device: {device}, Execution Provider: {provider}")
|
41 |
+
|
42 |
+
try:
|
43 |
+
# ํ ํฌ๋์ด์ ๋ก๋
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) #, token=HF_TOKEN)
|
45 |
+
print("Tokenizer loaded successfully.")
|
46 |
+
|
47 |
+
# ONNX ๋ชจ๋ธ ๋ก๋ (Optimum ์ฌ์ฉ)
|
48 |
+
# provider_options ์ค์ (ํ์์ ์ถ๊ฐ ์ต์ ํ ๊ฐ๋ฅ)
|
49 |
+
model = ORTModelForCausalLM.from_pretrained(
|
50 |
+
MODEL_ID,
|
51 |
+
# filename=ONNX_FILE_NAME, # ํ์ผ๋ช
๋ช
์๊ฐ ํ์ ์์ ์ ์์ (์๋ ๊ฐ์ง)
|
52 |
+
provider=provider,
|
53 |
+
# use_auth_token=HF_TOKEN, # Gated ๋ชจ๋ธ์ผ ๊ฒฝ์ฐ ํ์
|
54 |
+
use_cache=True, # KV ์บ์ ์ฌ์ฉ
|
55 |
+
# provider_options={'enable_skip_layer_norm_strict_mode': True} # ์์ ์ต์
|
56 |
+
)
|
57 |
+
# ๋ชจ๋ธ์ ์ง์ ๋ ๋๋ฐ์ด์ค๋ก ์ด๋ (ORTModel ์ ๋ด๋ถ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ผ๋ ๋ช
์ ๊ฐ๋ฅ)
|
58 |
+
# model.to(device) # ORTModel ์์๋ .to() ๊ฐ ์์ ์ ์์, provider ์ง์ ์ผ๋ก ์ฒ๋ฆฌ
|
59 |
+
print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
|
60 |
+
model_loaded_successfully = True
|
61 |
+
|
62 |
+
except Exception as e:
|
63 |
+
print(f"!!!!!!!!!!!!!! Error loading model {MODEL_ID} !!!!!!!!!!!!!!")
|
64 |
+
print(f"Error type: {type(e).__name__}")
|
65 |
+
print(f"Error message: {e}")
|
66 |
+
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
67 |
+
model_loaded_successfully = False
|
68 |
+
# ๋ชจ๋ธ ๋ก๋ฉ ์คํจ ์ Gradio ์ฑ ์คํ ์ค๋จ ๋๋ ์ค๋ฅ ๋ฉ์์ง ํ์
|
69 |
+
# raise gr.Error(f"CRITICAL: Failed to load model '{MODEL_ID}'. Check logs. Error: {e}")
|
70 |
+
|
71 |
+
# --- Chat Function ---
|
72 |
+
def chat_function(message: str, history: list):
|
73 |
+
if not model_loaded_successfully:
|
74 |
+
return "Error: The AI model failed to load. Cannot generate response."
|
75 |
+
|
76 |
+
# Gemma Instruct ํ์์ ๋ง๊ฒ history ์ message ๋ฅผ ํ๋กฌํํธ๋ก ๋ณํ
|
77 |
+
# AutoTokenizer ์ chat_template ์ด ์ ์๋์ด ์์ผ๋ฉด ์ฌ์ฉ ๊ถ์ฅ
|
78 |
+
try:
|
79 |
+
# [[user_msg1, model_msg1], ...] -> [{"role": "user", "content": ...}, ...]
|
80 |
+
chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
|
81 |
+
for user_msg, model_msg in history:
|
82 |
+
chat_messages.append({"role": "user", "content": user_msg})
|
83 |
+
chat_messages.append({"role": "model", "content": model_msg})
|
84 |
+
chat_messages.append({"role": "user", "content": message})
|
85 |
+
|
86 |
+
# ํ ํฌ๋์ด์ ์ apply_chat_template ์ฌ์ฉ (Gemma ์ง์ ํ์ธ ํ์)
|
87 |
+
try:
|
88 |
+
prompt = tokenizer.apply_chat_template(
|
89 |
+
chat_messages,
|
90 |
+
tokenize=False,
|
91 |
+
add_generation_prompt=True # ๋ชจ๋ธ์ด ์๋ต์ ์์ํ๋๋ก ํ๋กฌํํธ ์ถ๊ฐ
|
92 |
+
)
|
93 |
+
except Exception as template_error:
|
94 |
+
# ํ
ํ๋ฆฟ ์ ์ฉ ์คํจ ์ ์๋ ๊ตฌ์ฑ (์ด์ JS ๋ฒ์ ๋ฐฉ์)
|
95 |
+
print(f"Warning: Failed to apply chat template ({template_error}). Falling back to manual prompt construction.")
|
96 |
+
prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
|
97 |
+
for user_msg, model_msg in history:
|
98 |
+
prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
|
99 |
+
prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
|
100 |
+
prompt_parts.append(f"<start_of_turn>user\n{message}<end_of_turn>")
|
101 |
+
prompt_parts.append("<start_of_turn>model")
|
102 |
+
prompt = "\n".join(prompt_parts)
|
103 |
+
|
104 |
+
|
105 |
+
# print("\n--- Prompt ---")
|
106 |
+
# print(prompt)
|
107 |
+
# print("--------------\n")
|
108 |
+
|
109 |
+
# ์
๋ ฅ ํ ํฐํ
|
110 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device) # ๋ชจ๋ธ๊ณผ ๊ฐ์ ๋๋ฐ์ด์ค๋ก
|
111 |
+
|
112 |
+
# ์๋ต ์์ฑ
|
113 |
+
print("Generating response...")
|
114 |
+
outputs = model.generate(
|
115 |
+
**inputs,
|
116 |
+
max_new_tokens=512,
|
117 |
+
do_sample=True,
|
118 |
+
temperature=0.7,
|
119 |
+
top_k=50,
|
120 |
+
top_p=0.9,
|
121 |
+
# pad_token_id=tokenizer.eos_token_id # ํจ๋ฉ ์ค์ ํ์์
|
122 |
+
)
|
123 |
+
print("Generation complete.")
|
124 |
+
|
125 |
+
# ์์ฑ๋ ํ ํฐ ๋์ฝ๋ฉ (์
๋ ฅ ๋ถ๋ถ ์ ์ธ)
|
126 |
+
# inputs[0] ๋์ inputs['input_ids'][0] ์ฌ์ฉํด์ผ ํ ์ ์์
|
127 |
+
input_token_len = inputs['input_ids'].shape[1]
|
128 |
+
generated_tokens = outputs[0][input_token_len:]
|
129 |
+
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
130 |
+
|
131 |
+
# ์ข
๋ฃ ํ ํฐ ๋๋ ๋ถํ์ํ ํํ ํ
์คํธ ์ ๊ฑฐ
|
132 |
+
response = response.replace("<end_of_turn>", "").strip()
|
133 |
+
|
134 |
+
# print("\n--- Response ---")
|
135 |
+
# print(response)
|
136 |
+
# print("--------------\n")
|
137 |
+
|
138 |
+
# history.append((message, response)) # history ๋ Gradio ๊ฐ ๊ด๋ฆฌ
|
139 |
+
return response
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
print(f"Error during generation: {e}")
|
143 |
+
# ์ฌ์ฉ์์๊ฒ ํ์๋ ์ ์๋ ์์ ํ ์ค๋ฅ ๋ฉ์์ง ๋ฐํ
|
144 |
+
return f"Sorry, an error occurred during response generation. Please check the application logs for details."
|
145 |
+
|
146 |
+
|
147 |
+
# --- Gradio Interface ---
|
148 |
+
print("Creating Gradio Interface...")
|
149 |
+
iface = gr.ChatInterface(
|
150 |
+
fn=chat_function if model_loaded_successfully else lambda msg, hist: "Model not loaded.", # ๋ชจ๋ธ ๋ก๋ ์คํจ์ ๋์ฒด ํจ์
|
151 |
+
title="AI Assistant (Gemma 3 1B ONNX)",
|
152 |
+
description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
|
153 |
+
chatbot=gr.Chatbot(height=600),
|
154 |
+
textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
|
155 |
+
submit_btn="Send",
|
156 |
+
retry_btn="Retry",
|
157 |
+
undo_btn="Undo",
|
158 |
+
clear_btn="Clear",
|
159 |
+
theme=gr.themes.Soft(), # ํ
๋ง ์ ์ฉ
|
160 |
+
examples=[["Hello!"], ["Write a poem about the internet."]]
|
161 |
+
)
|
162 |
+
|
163 |
+
# --- Launch App ---
|
164 |
+
if __name__ == "__main__":
|
165 |
+
print("Launching Gradio App...")
|
166 |
+
# share=True ๋ก ์ค์ ํ๋ฉด ์ธ๋ถ์์ ์ ๊ทผ ๊ฐ๋ฅํ ๋งํฌ ์์ฑ (๋ณด์ ์ฃผ์)
|
167 |
+
iface.launch()#share=True)
|