kimhyunwoo commited on
Commit
3b77cfa
ยท
verified ยท
1 Parent(s): afe67d3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer
5
+ from optimum.onnxruntime import ORTModelForCausalLM
6
+
7
+ # --- Configuration ---
8
+ # ์‚ฌ์šฉ์ž๋‹˜์ด ์ง€์ •ํ•œ ONNX ๋ชจ๋ธ ID
9
+ MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA"
10
+ # ์–‘์žํ™”๋œ ๋ชจ๋ธ ํŒŒ์ผ ์ด๋ฆ„ (์ €์žฅ์†Œ ๊ตฌ์กฐ ํ™•์ธ ํ•„์š”, ์—†์„ ๊ฒฝ์šฐ ์ผ๋ฐ˜ ๋ชจ๋ธ ์‹œ๋„)
11
+ # Q4 ๋ชจ๋ธ ํŒŒ์ผ์ด 'onnx/model_q4.onnx' ํ˜•ํƒœ์ผ ์ˆ˜ ์žˆ์Œ -> optimum ์ด ์ž๋™ ๊ฐ์ง€ ์‹œ๋„
12
+ # ์šฐ์„  ๋ช…์‹œ์  ํŒŒ์ผ ์ง€์ • ์—†์ด ๋กœ๋“œ ์‹œ๋„
13
+ ONNX_FILE_NAME = None # e.g., "onnx/model_q4.onnx" if needed and present
14
+
15
+ # Hugging Face Hub ํ† ํฐ (ํ•„์š”์‹œ - Gemma ๋ชจ๋ธ์€ Gated์ผ ์ˆ˜ ์žˆ์œผ๋‚˜ ONNX ์ปค๋ฎค๋‹ˆํ‹ฐ ๋ฒ„์ „์€ ์•„๋‹ ์ˆ˜ ์žˆ์Œ)
16
+ # HF_TOKEN = os.getenv("HF_TOKEN") # Space secrets ์—์„œ ์„ค์ •
17
+
18
+ # --- Device Selection ---
19
+ try:
20
+ if torch.cuda.is_available():
21
+ device = "cuda:0"
22
+ provider = "CUDAExecutionProvider"
23
+ print("Using GPU (CUDA).")
24
+ # Mps (Apple Silicon) - Gradio Spaces ์—์„œ๋Š” ์‚ฌ์šฉ ๋ถˆ๊ฐ€ ๊ฐ€๋Šฅ์„ฑ ๋†’์Œ
25
+ # elif torch.backends.mps.is_available():
26
+ # device = "mps"
27
+ # provider = "CoreMLExecutionProvider" # Needs check
28
+ # print("Using MPS (Apple Silicon).")
29
+ else:
30
+ device = "cpu"
31
+ provider = "CPUExecutionProvider"
32
+ print("Using CPU.")
33
+ except Exception as e:
34
+ print(f"Device detection error: {e}. Defaulting to CPU.")
35
+ device = "cpu"
36
+ provider = "CPUExecutionProvider"
37
+
38
+ # --- Model and Tokenizer Loading ---
39
+ print(f"Attempting to load model: {MODEL_ID}")
40
+ print(f"Using device: {device}, Execution Provider: {provider}")
41
+
42
+ try:
43
+ # ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
44
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) #, token=HF_TOKEN)
45
+ print("Tokenizer loaded successfully.")
46
+
47
+ # ONNX ๋ชจ๋ธ ๋กœ๋“œ (Optimum ์‚ฌ์šฉ)
48
+ # provider_options ์„ค์ • (ํ•„์š”์‹œ ์ถ”๊ฐ€ ์ตœ์ ํ™” ๊ฐ€๋Šฅ)
49
+ model = ORTModelForCausalLM.from_pretrained(
50
+ MODEL_ID,
51
+ # filename=ONNX_FILE_NAME, # ํŒŒ์ผ๋ช… ๋ช…์‹œ๊ฐ€ ํ•„์š” ์—†์„ ์ˆ˜ ์žˆ์Œ (์ž๋™ ๊ฐ์ง€)
52
+ provider=provider,
53
+ # use_auth_token=HF_TOKEN, # Gated ๋ชจ๋ธ์ผ ๊ฒฝ์šฐ ํ•„์š”
54
+ use_cache=True, # KV ์บ์‹œ ์‚ฌ์šฉ
55
+ # provider_options={'enable_skip_layer_norm_strict_mode': True} # ์˜ˆ์‹œ ์˜ต์…˜
56
+ )
57
+ # ๋ชจ๋ธ์„ ์ง€์ •๋œ ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™ (ORTModel ์€ ๋‚ด๋ถ€์ ์œผ๋กœ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์œผ๋‚˜ ๋ช…์‹œ ๊ฐ€๋Šฅ)
58
+ # model.to(device) # ORTModel ์—์„œ๋Š” .to() ๊ฐ€ ์—†์„ ์ˆ˜ ์žˆ์Œ, provider ์ง€์ •์œผ๋กœ ์ฒ˜๋ฆฌ
59
+ print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
60
+ model_loaded_successfully = True
61
+
62
+ except Exception as e:
63
+ print(f"!!!!!!!!!!!!!! Error loading model {MODEL_ID} !!!!!!!!!!!!!!")
64
+ print(f"Error type: {type(e).__name__}")
65
+ print(f"Error message: {e}")
66
+ print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
67
+ model_loaded_successfully = False
68
+ # ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ ์‹œ Gradio ์•ฑ ์‹คํ–‰ ์ค‘๋‹จ ๋˜๋Š” ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ํ‘œ์‹œ
69
+ # raise gr.Error(f"CRITICAL: Failed to load model '{MODEL_ID}'. Check logs. Error: {e}")
70
+
71
+ # --- Chat Function ---
72
+ def chat_function(message: str, history: list):
73
+ if not model_loaded_successfully:
74
+ return "Error: The AI model failed to load. Cannot generate response."
75
+
76
+ # Gemma Instruct ํ˜•์‹์— ๋งž๊ฒŒ history ์™€ message ๋ฅผ ํ”„๋กฌํ”„ํŠธ๋กœ ๋ณ€ํ™˜
77
+ # AutoTokenizer ์— chat_template ์ด ์ •์˜๋˜์–ด ์žˆ์œผ๋ฉด ์‚ฌ์šฉ ๊ถŒ์žฅ
78
+ try:
79
+ # [[user_msg1, model_msg1], ...] -> [{"role": "user", "content": ...}, ...]
80
+ chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
81
+ for user_msg, model_msg in history:
82
+ chat_messages.append({"role": "user", "content": user_msg})
83
+ chat_messages.append({"role": "model", "content": model_msg})
84
+ chat_messages.append({"role": "user", "content": message})
85
+
86
+ # ํ† ํฌ๋‚˜์ด์ €์˜ apply_chat_template ์‚ฌ์šฉ (Gemma ์ง€์› ํ™•์ธ ํ•„์š”)
87
+ try:
88
+ prompt = tokenizer.apply_chat_template(
89
+ chat_messages,
90
+ tokenize=False,
91
+ add_generation_prompt=True # ๋ชจ๋ธ์ด ์‘๋‹ต์„ ์‹œ์ž‘ํ•˜๋„๋ก ํ”„๋กฌํ”„ํŠธ ์ถ”๊ฐ€
92
+ )
93
+ except Exception as template_error:
94
+ # ํ…œํ”Œ๋ฆฟ ์ ์šฉ ์‹คํŒจ ์‹œ ์ˆ˜๋™ ๊ตฌ์„ฑ (์ด์ „ JS ๋ฒ„์ „ ๋ฐฉ์‹)
95
+ print(f"Warning: Failed to apply chat template ({template_error}). Falling back to manual prompt construction.")
96
+ prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
97
+ for user_msg, model_msg in history:
98
+ prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
99
+ prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
100
+ prompt_parts.append(f"<start_of_turn>user\n{message}<end_of_turn>")
101
+ prompt_parts.append("<start_of_turn>model")
102
+ prompt = "\n".join(prompt_parts)
103
+
104
+
105
+ # print("\n--- Prompt ---")
106
+ # print(prompt)
107
+ # print("--------------\n")
108
+
109
+ # ์ž…๋ ฅ ํ† ํฐํ™”
110
+ inputs = tokenizer(prompt, return_tensors="pt").to(device) # ๋ชจ๋ธ๊ณผ ๊ฐ™์€ ๋””๋ฐ”์ด์Šค๋กœ
111
+
112
+ # ์‘๋‹ต ์ƒ์„ฑ
113
+ print("Generating response...")
114
+ outputs = model.generate(
115
+ **inputs,
116
+ max_new_tokens=512,
117
+ do_sample=True,
118
+ temperature=0.7,
119
+ top_k=50,
120
+ top_p=0.9,
121
+ # pad_token_id=tokenizer.eos_token_id # ํŒจ๋”ฉ ์„ค์ • ํ•„์š”์‹œ
122
+ )
123
+ print("Generation complete.")
124
+
125
+ # ์ƒ์„ฑ๋œ ํ† ํฐ ๋””์ฝ”๋”ฉ (์ž…๋ ฅ ๋ถ€๋ถ„ ์ œ์™ธ)
126
+ # inputs[0] ๋Œ€์‹  inputs['input_ids'][0] ์‚ฌ์šฉํ•ด์•ผ ํ•  ์ˆ˜ ์žˆ์Œ
127
+ input_token_len = inputs['input_ids'].shape[1]
128
+ generated_tokens = outputs[0][input_token_len:]
129
+ response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
130
+
131
+ # ์ข…๋ฃŒ ํ† ํฐ ๋˜๋Š” ๋ถˆํ•„์š”ํ•œ ํ›„ํ–‰ ํ…์ŠคํŠธ ์ œ๊ฑฐ
132
+ response = response.replace("<end_of_turn>", "").strip()
133
+
134
+ # print("\n--- Response ---")
135
+ # print(response)
136
+ # print("--------------\n")
137
+
138
+ # history.append((message, response)) # history ๋Š” Gradio ๊ฐ€ ๊ด€๋ฆฌ
139
+ return response
140
+
141
+ except Exception as e:
142
+ print(f"Error during generation: {e}")
143
+ # ์‚ฌ์šฉ์ž์—๊ฒŒ ํ‘œ์‹œ๋  ์ˆ˜ ์žˆ๋Š” ์•ˆ์ „ํ•œ ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ๋ฐ˜ํ™˜
144
+ return f"Sorry, an error occurred during response generation. Please check the application logs for details."
145
+
146
+
147
+ # --- Gradio Interface ---
148
+ print("Creating Gradio Interface...")
149
+ iface = gr.ChatInterface(
150
+ fn=chat_function if model_loaded_successfully else lambda msg, hist: "Model not loaded.", # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ์‹œ ๋Œ€์ฒด ํ•จ์ˆ˜
151
+ title="AI Assistant (Gemma 3 1B ONNX)",
152
+ description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
153
+ chatbot=gr.Chatbot(height=600),
154
+ textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
155
+ submit_btn="Send",
156
+ retry_btn="Retry",
157
+ undo_btn="Undo",
158
+ clear_btn="Clear",
159
+ theme=gr.themes.Soft(), # ํ…Œ๋งˆ ์ ์šฉ
160
+ examples=[["Hello!"], ["Write a poem about the internet."]]
161
+ )
162
+
163
+ # --- Launch App ---
164
+ if __name__ == "__main__":
165
+ print("Launching Gradio App...")
166
+ # share=True ๋กœ ์„ค์ •ํ•˜๋ฉด ์™ธ๋ถ€์—์„œ ์ ‘๊ทผ ๊ฐ€๋Šฅํ•œ ๋งํฌ ์ƒ์„ฑ (๋ณด์•ˆ ์ฃผ์˜)
167
+ iface.launch()#share=True)