|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import gc |
|
import os |
|
import datetime |
|
import time |
|
import spaces |
|
|
|
|
|
MODEL_ID = "XiaomiMiMo/MiMo-7B-RL" |
|
MAX_NEW_TOKENS = 512 |
|
CPU_THREAD_COUNT = 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
print("--- νκ²½ μ€μ ---") |
|
print(f"PyTorch λ²μ : {torch.__version__}") |
|
print(f"μ€ν μ₯μΉ: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}") |
|
print(f"Torch μ€λ λ: {torch.get_num_threads()}") |
|
|
|
|
|
print(f"--- λͺ¨λΈ λ‘λ© μ€: {MODEL_ID} ---") |
|
print("첫 μ€ν μ λͺ λΆ μ λ μμλ μ μμ΅λλ€...") |
|
|
|
model = None |
|
tokenizer = None |
|
load_successful = False |
|
stop_token_ids_list = [] |
|
|
|
try: |
|
start_load_time = time.time() |
|
|
|
device_map = "auto" if torch.cuda.is_available() else "cpu" |
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_ID, |
|
trust_remote_code=True |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=dtype, |
|
device_map=device_map, |
|
trust_remote_code=True |
|
) |
|
|
|
model.eval() |
|
load_time = time.time() - start_load_time |
|
print(f"--- λͺ¨λΈ λ° ν ν¬λμ΄μ λ‘λ© μλ£: {load_time:.2f}μ΄ μμ ---") |
|
load_successful = True |
|
|
|
|
|
stop_token_strings = ["</s>", "<|endoftext|>"] |
|
temp_stop_ids = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings] |
|
|
|
if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in temp_stop_ids: |
|
temp_stop_ids.append(tokenizer.eos_token_id) |
|
elif tokenizer.eos_token_id is None: |
|
print("κ²½κ³ : tokenizer.eos_token_idκ° Noneμ
λλ€. μ€μ§ ν ν°μ μΆκ°ν μ μμ΅λλ€.") |
|
|
|
stop_token_ids_list = [tid for tid in temp_stop_ids if tid is not None] |
|
|
|
if not stop_token_ids_list: |
|
print("κ²½κ³ : μ€μ§ ν ν° IDλ₯Ό μ°Ύμ μ μμ΅λλ€. κ°λ₯νλ©΄ κΈ°λ³Έ EOSλ₯Ό μ¬μ©νκ³ , κ·Έλ μ§ μμΌλ©΄ μμ±μ΄ μ¬λ°λ₯΄κ² μ€μ§λμ§ μμ μ μμ΅λλ€.") |
|
if tokenizer.eos_token_id is not None: |
|
stop_token_ids_list = [tokenizer.eos_token_id] |
|
else: |
|
print("μ€λ₯: κΈ°λ³Έ EOSλ₯Ό ν¬ν¨νμ¬ μ€μ§ ν ν°μ μ°Ύμ μ μμ΅λλ€. μμ±μ΄ 무νμ μ€νλ μ μμ΅λλ€.") |
|
|
|
print(f"μ¬μ©ν μ€μ§ ν ν° ID: {stop_token_ids_list}") |
|
|
|
except Exception as e: |
|
print(f"!!! λͺ¨λΈ λ‘λ© μ€λ₯: {e}") |
|
if 'model' in locals() and model is not None: del model |
|
if 'tokenizer' in locals() and tokenizer is not None: del tokenizer |
|
gc.collect() |
|
raise gr.Error(f"λͺ¨λΈ {MODEL_ID} λ‘λ©μ μ€ν¨νμ΅λλ€. μ ν리μΌμ΄μ
μ μμν μ μμ΅λλ€. μ€λ₯: {e}") |
|
|
|
|
|
def get_system_prompt(): |
|
current_date = datetime.datetime.now().strftime("%Y-%m-%d (%A)") |
|
return ( |
|
f"- AI μΈμ΄λͺ¨λΈμ μ΄λ¦μ \"MiMo\"μ΄λ©° XiaomiMiMoμμ λ§λ€μμ΅λλ€.\n" |
|
f"- μ€λμ {current_date}μ
λλ€.\n" |
|
f"- μ¬μ©μμ μ§λ¬Έμ λν΄ μΉμ νκ³ μμΈνκ² νκ΅μ΄λ‘ λ΅λ³ν΄μΌ ν©λλ€." |
|
) |
|
|
|
|
|
def warmup_model(): |
|
if not load_successful or model is None or tokenizer is None: |
|
print("μμ
건λλ°κΈ°: λͺ¨λΈμ΄ μ±κ³΅μ μΌλ‘ λ‘λλμ§ μμμ΅λλ€.") |
|
return |
|
|
|
print("--- λͺ¨λΈ μμ
μμ ---") |
|
try: |
|
start_warmup_time = time.time() |
|
warmup_message = "μλ
νμΈμ" |
|
|
|
|
|
system_prompt = get_system_prompt() |
|
|
|
|
|
prompt = f"Human: {warmup_message}\nAssistant:" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
gen_kwargs = { |
|
"max_new_tokens": 10, |
|
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id, |
|
"do_sample": False |
|
} |
|
|
|
if stop_token_ids_list: |
|
gen_kwargs["eos_token_id"] = stop_token_ids_list |
|
else: |
|
print("μμ
κ²½κ³ : μμ±μ μ μλ μ€μ§ ν ν°μ΄ μμ΅λλ€.") |
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate(**inputs, **gen_kwargs) |
|
|
|
del inputs |
|
del output_ids |
|
gc.collect() |
|
warmup_time = time.time() - start_warmup_time |
|
print(f"--- λͺ¨λΈ μμ
μλ£: {warmup_time:.2f}μ΄ μμ ---") |
|
|
|
except Exception as e: |
|
print(f"!!! λͺ¨λΈ μμ
μ€ μ€λ₯ λ°μ: {e}") |
|
finally: |
|
gc.collect() |
|
|
|
|
|
@spaces.GPU() |
|
def predict(message, history): |
|
""" |
|
XiaomiMiMo/MiMo-7B-RL λͺ¨λΈμ μ¬μ©νμ¬ μλ΅μ μμ±ν©λλ€. |
|
'history'λ Gradio 'messages' νμμ κ°μ ν©λλ€: List[Dict]. |
|
""" |
|
if model is None or tokenizer is None: |
|
return "μ€λ₯: λͺ¨λΈμ΄ λ‘λλμ§ μμμ΅λλ€." |
|
|
|
|
|
history_text = "" |
|
if isinstance(history, list): |
|
for turn in history: |
|
if isinstance(turn, tuple) and len(turn) == 2: |
|
history_text += f"Human: {turn[0]}\nAssistant: {turn[1]}\n" |
|
|
|
|
|
prompt = f"{history_text}Human: {message}\nAssistant:" |
|
|
|
inputs = None |
|
output_ids = None |
|
|
|
try: |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
input_length = inputs.input_ids.shape[1] |
|
print(f"\nμ
λ ₯ ν ν° μ: {input_length}") |
|
|
|
except Exception as e: |
|
print(f"!!! μ
λ ₯ μ²λ¦¬ μ€ μ€λ₯ λ°μ: {e}") |
|
return f"μ€λ₯: μ
λ ₯ νμμ μ²λ¦¬νλ μ€ λ¬Έμ κ° λ°μνμ΅λλ€. ({e})" |
|
|
|
try: |
|
print("μλ΅ μμ± μ€...") |
|
generation_start_time = time.time() |
|
|
|
|
|
gen_kwargs = { |
|
"max_new_tokens": MAX_NEW_TOKENS, |
|
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id, |
|
"do_sample": True, |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"repetition_penalty": 1.1 |
|
} |
|
|
|
if stop_token_ids_list: |
|
gen_kwargs["eos_token_id"] = stop_token_ids_list |
|
else: |
|
print("μμ± κ²½κ³ : μ μλ μ€μ§ ν ν°μ΄ μμ΅λλ€.") |
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate(**inputs, **gen_kwargs) |
|
|
|
generation_time = time.time() - generation_start_time |
|
print(f"μμ± μλ£: {generation_time:.2f}μ΄ μμ.") |
|
|
|
except Exception as e: |
|
print(f"!!! λͺ¨λΈ μμ± μ€ μ€λ₯ λ°μ: {e}") |
|
if inputs is not None: del inputs |
|
if output_ids is not None: del output_ids |
|
gc.collect() |
|
return f"μ€λ₯: μλ΅μ μμ±νλ μ€ λ¬Έμ κ° λ°μνμ΅λλ€. ({e})" |
|
|
|
|
|
response = "μ€λ₯: μλ΅ μμ±μ μ€ν¨νμ΅λλ€." |
|
if output_ids is not None: |
|
try: |
|
new_tokens = output_ids[0, input_length:] |
|
response = tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
print(f"μΆλ ₯ ν ν° μ: {len(new_tokens)}") |
|
del new_tokens |
|
except Exception as e: |
|
print(f"!!! μλ΅ λμ½λ© μ€ μ€λ₯ λ°μ: {e}") |
|
response = "μ€λ₯: μλ΅μ λμ½λ©νλ μ€ λ¬Έμ κ° λ°μνμ΅λλ€." |
|
|
|
|
|
if inputs is not None: del inputs |
|
if output_ids is not None: del output_ids |
|
gc.collect() |
|
print("λ©λͺ¨λ¦¬ μ 리 μλ£.") |
|
|
|
return response.strip() |
|
|
|
|
|
print("--- Gradio μΈν°νμ΄μ€ μ€μ μ€ ---") |
|
|
|
examples = [ |
|
["μλ
νμΈμ! μκΈ°μκ° μ’ ν΄μ£ΌμΈμ."], |
|
["μΈκ³΅μ§λ₯κ³Ό λ¨Έμ λ¬λμ μ°¨μ΄μ μ 무μμΈκ°μ?"], |
|
["λ₯λ¬λ λͺ¨λΈ νμ΅ κ³Όμ μ λ¨κ³λ³λ‘ μλ €μ£ΌμΈμ."], |
|
["μ μ£Όλ μ¬ν κ³νμ μΈμ°κ³ μλλ°, 3λ° 4μΌ μΆμ² μ½μ€ μ’ μλ €μ£ΌμΈμ."], |
|
] |
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=predict, |
|
title="π€ XiaomiMiMo/MiMo-7B-RL νκ΅μ΄ λ°λͺ¨", |
|
description=( |
|
f"**λͺ¨λΈ:** {MODEL_ID}\n" |
|
f"**νκ²½:** {'GPU' if torch.cuda.is_available() else 'CPU'}\n" |
|
f"**μ£Όμ:** {'GPUμμ μ€ν μ€μ
λλ€.' if torch.cuda.is_available() else 'CPUμμ μ€νλλ―λ‘ μλ΅ μμ±μ λ€μ μκ°μ΄ 걸릴 μ μμ΅λλ€.'}\n" |
|
f"μ΅λ μμ± ν ν° μλ {MAX_NEW_TOKENS}κ°λ‘ μ νλ©λλ€." |
|
), |
|
examples=examples, |
|
cache_examples=False, |
|
theme=gr.themes.Soft(), |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
if load_successful: |
|
warmup_model() |
|
else: |
|
print("λͺ¨λΈ λ‘λ©μ μ€ν¨νμ¬ μμ
μ 건λλλλ€.") |
|
|
|
print("--- Gradio μ± μ€ν μ€ ---") |
|
demo.queue().launch( |
|
|
|
|
|
) |