Spaces:
Running
on
Zero
Running
on
Zero
import re | |
import threading | |
import gradio as gr | |
import spaces | |
import transformers | |
from transformers import pipeline | |
# ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋ฉ | |
model_name = "Qwen/Qwen2-1.5B-Instruct" | |
if gr.NO_RELOAD: | |
pipe = pipeline( | |
"text-generation", | |
model=model_name, | |
device_map="auto", | |
torch_dtype="auto", | |
) | |
# ์ต์ข ๋ต๋ณ์ ๊ฐ์งํ๊ธฐ ์ํ ๋ง์ปค | |
ANSWER_MARKER = "**๋ต๋ณ**" | |
# ๋จ๊ณ๋ณ ์ถ๋ก ์ ์์ํ๋ ๋ฌธ์ฅ๋ค | |
rethink_prepends = [ | |
"์, ์ด์ ๋ค์์ ํ์ ํด์ผ ํฉ๋๋ค ", | |
"์ ์๊ฐ์๋ ", | |
"์ ์๋ง์, ์ ์๊ฐ์๋ ", | |
"๋ค์ ์ฌํญ์ด ๋ง๋์ง ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค ", | |
"๋ํ ๊ธฐ์ตํด์ผ ํ ๊ฒ์ ", | |
"๋ ๋ค๋ฅธ ์ฃผ๋ชฉํ ์ ์ ", | |
"๊ทธ๋ฆฌ๊ณ ์ ๋ ๋ค์๊ณผ ๊ฐ์ ์ฌ์ค๋ ๊ธฐ์ตํฉ๋๋ค ", | |
"์ด์ ์ถฉ๋ถํ ์ดํดํ๋ค๊ณ ์๊ฐํฉ๋๋ค ", | |
"์ง๊ธ๊น์ง์ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก, ์๋ ์ง๋ฌธ์ ์ฌ์ฉ๋ ์ธ์ด๋ก ๋ต๋ณํ๊ฒ ์ต๋๋ค:" | |
"\n{question}\n" | |
f"\n{ANSWER_MARKER}\n", | |
] | |
# ์์ ํ์ ๋ฌธ์ ํด๊ฒฐ์ ์ํ ์ค์ | |
latex_delimiters = [ | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "$", "right": "$", "display": False}, | |
] | |
def reformat_math(text): | |
"""Gradio ๊ตฌ๋ฌธ(Katex)์ ์ฌ์ฉํ๋๋ก MathJax ๊ตฌ๋ถ ๊ธฐํธ ์์ . | |
์ด๊ฒ์ Gradio์์ ์ํ ๊ณต์์ ํ์ํ๊ธฐ ์ํ ์์ ํด๊ฒฐ์ฑ ์ ๋๋ค. ํ์ฌ๋ก์๋ | |
๋ค๋ฅธ latex_delimiters๋ฅผ ์ฌ์ฉํ์ฌ ์์๋๋ก ์๋ํ๊ฒ ํ๋ ๋ฐฉ๋ฒ์ ์ฐพ์ง ๋ชปํ์ต๋๋ค... | |
""" | |
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL) | |
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL) | |
return text | |
def user_input(message, history: list): | |
"""์ฌ์ฉ์ ์ ๋ ฅ์ ํ์คํ ๋ฆฌ์ ์ถ๊ฐํ๊ณ ์ ๋ ฅ ํ ์คํธ ์์ ๋น์ฐ๊ธฐ""" | |
return "", history + [ | |
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, "")) | |
] | |
def rebuild_messages(history: list): | |
"""์ค๊ฐ ์๊ฐ ๊ณผ์ ์์ด ๋ชจ๋ธ์ด ์ฌ์ฉํ ํ์คํ ๋ฆฌ์์ ๋ฉ์์ง ์ฌ๊ตฌ์ฑ""" | |
messages = [] | |
for h in history: | |
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False): | |
messages.append(h) | |
elif ( | |
isinstance(h, gr.ChatMessage) | |
and h.metadata.get("title") | |
and isinstance(h.content, str) | |
): | |
messages.append({"role": h.role, "content": h.content}) | |
return messages | |
def bot( | |
history: list, | |
max_num_tokens: int, | |
final_num_tokens: int, | |
do_sample: bool, | |
temperature: float, | |
): | |
"""๋ชจ๋ธ์ด ์ง๋ฌธ์ ๋ต๋ณํ๋๋ก ํ๊ธฐ""" | |
# ๋์ค์ ์ค๋ ๋์์ ํ ํฐ์ ์คํธ๋ฆผ์ผ๋ก ๊ฐ์ ธ์ค๊ธฐ ์ํจ | |
streamer = transformers.TextIteratorStreamer( | |
pipe.tokenizer, # pyright: ignore | |
skip_special_tokens=True, | |
skip_prompt=True, | |
) | |
# ํ์ํ ๊ฒฝ์ฐ ์ถ๋ก ์ ์ง๋ฌธ์ ๋ค์ ์ฝ์ ํ๊ธฐ ์ํจ | |
question = history[-1]["content"] | |
# ๋ณด์กฐ์ ๋ฉ์์ง ์ค๋น | |
history.append( | |
gr.ChatMessage( | |
role="assistant", | |
content=str(""), | |
metadata={"title": "๐ง ์๊ฐ ์ค...", "status": "pending"}, | |
) | |
) | |
# ํ์ฌ ์ฑํ ์ ํ์๋ ์ถ๋ก ๊ณผ์ | |
messages = rebuild_messages(history) | |
for i, prepend in enumerate(rethink_prepends): | |
if i > 0: | |
messages[-1]["content"] += "\n\n" | |
messages[-1]["content"] += prepend.format(question=question) | |
num_tokens = int( | |
max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens | |
) | |
t = threading.Thread( | |
target=pipe, | |
args=(messages,), | |
kwargs=dict( | |
max_new_tokens=num_tokens, | |
streamer=streamer, | |
do_sample=do_sample, | |
temperature=temperature, | |
), | |
) | |
t.start() | |
# ์ ๋ด์ฉ์ผ๋ก ํ์คํ ๋ฆฌ ์ฌ๊ตฌ์ฑ | |
history[-1].content += prepend.format(question=question) | |
if ANSWER_MARKER in prepend: | |
history[-1].metadata = {"title": "๐ญ ์ฌ๊ณ ๊ณผ์ ", "status": "done"} | |
# ์๊ฐ ์ข ๋ฃ, ์ด์ ๋ต๋ณ์ ๋๋ค (์ค๊ฐ ๋จ๊ณ์ ๋ํ ๋ฉํ๋ฐ์ดํฐ ์์) | |
history.append(gr.ChatMessage(role="assistant", content="")) | |
for token in streamer: | |
history[-1].content += token | |
history[-1].content = reformat_math(history[-1].content) | |
yield history | |
t.join() | |
yield history | |
with gr.Blocks(fill_height=True, title="๋ชจ๋ LLM ๋ชจ๋ธ์ ์ถ๋ก ๋ฅ๋ ฅ ๋ถ์ฌํ๊ธฐ") as demo: | |
with gr.Row(scale=1): | |
with gr.Column(scale=5): | |
gr.Markdown(f""" | |
# ๋ชจ๋ LLM์ ์ถ๋ก ๋ฅ๋ ฅ ๊ฐ์ ํ๊ธฐ | |
์ด๊ฒ์ ๋ชจ๋ LLM(๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ)์ด ์๋ต ์ ์ ์ถ๋ก ํ ์ ์๋๋ก ํ๋ ๊ฐ๋จํ ๊ฐ๋ ์ฆ๋ช ์ ๋๋ค. | |
์ด ์ธํฐํ์ด์ค๋ *{model_name}* ๋ชจ๋ธ์ ์ฌ์ฉํ๋๋ฐ, **์ด๋ ์ถ๋ก ๋ชจ๋ธ์ด ์๋๋๋ค**. ์ฌ์ฉ๋ ๋ฐฉ๋ฒ์ | |
๋จ์ง ์ ๋์ฌ๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ด ๋ต๋ณ์ ํฅ์์ํค๋ ๋ฐ ๋์์ด ๋๋ "์ถ๋ก " ๋จ๊ณ๋ฅผ ๊ฐ์ ํ๋ ๊ฒ์ ๋๋ค. | |
๊ด๋ จ ๊ธฐ์ฌ๋ ๋ค์์์ ํ์ธํ์ธ์: [๋ชจ๋ ๋ชจ๋ธ์ ์ถ๋ก ๋ฅ๋ ฅ ๋ถ์ฌํ๊ธฐ](https://huggingface.co/blog/Metal3d/making-any-model-reasoning) | |
""") | |
chatbot = gr.Chatbot( | |
scale=1, | |
type="messages", | |
latex_delimiters=latex_delimiters, | |
) | |
msg = gr.Textbox( | |
submit_btn=True, | |
label="", | |
show_label=False, | |
placeholder="์ฌ๊ธฐ์ ์ง๋ฌธ์ ์ ๋ ฅํ์ธ์.", | |
autofocus=True, | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("""## ๋งค๊ฐ๋ณ์ ์กฐ์ """) | |
num_tokens = gr.Slider( | |
50, | |
1024, | |
100, | |
step=1, | |
label="์ถ๋ก ๋จ๊ณ๋น ์ต๋ ํ ํฐ ์", | |
interactive=True, | |
) | |
final_num_tokens = gr.Slider( | |
50, | |
1024, | |
512, | |
step=1, | |
label="์ต์ข ๋ต๋ณ์ ์ต๋ ํ ํฐ ์", | |
interactive=True, | |
) | |
do_sample = gr.Checkbox(True, label="์ํ๋ง ์ฌ์ฉ") | |
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์จ๋") | |
gr.Markdown(""" | |
์ถ๋ก ๋จ๊ณ์์ ๋ ์ ์ ์์ ํ ํฐ์ ์ฌ์ฉํ๋ฉด ๋ชจ๋ธ์ด | |
๋ ๋นจ๋ฆฌ ๋ต๋ณํ ์ ์์ง๋ง, ์ถฉ๋ถํ ๊น๊ฒ ์ถ๋ก ํ์ง ๋ชปํ ์ ์์ต๋๋ค. | |
์ ์ ํ ๊ฐ์ 100์์ 512์ ๋๋ค. | |
์ต์ข ๋ต๋ณ์ ๋ ์ ์ ์์ ํ ํฐ์ ์ฌ์ฉํ๋ฉด ๋ชจ๋ธ์ | |
์๋ต์ด ๋ ์ฅํฉํด์ง์ง๋ง, ์์ ํ ๋ต๋ณ์ ์ ๊ณตํ์ง ๋ชปํ ์ ์์ต๋๋ค. | |
์ ์ ํ ๊ฐ์ 512์์ 1024์ ๋๋ค. | |
**์ํ๋ง ์ฌ์ฉ**์ ๋ต๋ณ์ ์์ฑํ๊ธฐ ์ํด ๋ค์ ํ ํฐ์ ์ ํํ๋ ๋ค๋ฅธ ์ ๋ต์ ์ฌ์ฉํฉ๋๋ค. | |
์ผ๋ฐ์ ์ผ๋ก ์ด ์ต์ ์ ์ฒดํฌํด ๋๋ ๊ฒ์ด ์ข์ต๋๋ค. | |
**์จ๋**๋ ๋ชจ๋ธ์ด ์ผ๋ง๋ "์ฐฝ์์ "์ผ ์ ์๋์ง๋ฅผ ๋ํ๋ ๋๋ค. 0.7์ด ์ผ๋ฐ์ ์ธ ๊ฐ์ ๋๋ค. | |
๋๋ฌด ๋์ ๊ฐ(์: 1.0)์ ์ค์ ํ๋ฉด ๋ชจ๋ธ์ด ์ผ๊ด์ฑ์ด ์์ ์ ์์ต๋๋ค. ๋ฎ์ ๊ฐ(์: 0.3)์ผ๋ก | |
์ค์ ํ๋ฉด ๋ชจ๋ธ์ ๋งค์ฐ ์์ธก ๊ฐ๋ฅํ ๋ต๋ณ์ ์์ฑํฉ๋๋ค. | |
""") | |
gr.Markdown(""" | |
์ด ์ธํฐํ์ด์ค๋ 6GB VRAM์ ๊ฐ์ง ๊ฐ์ธ ์ปดํจํฐ์์ ์๋ํ ์ ์์ต๋๋ค(์: ๋ ธํธ๋ถ์ NVidia 3050/3060). | |
์์ ๋กญ๊ฒ ์ ํ๋ฆฌ์ผ์ด์ ์ ํฌํฌํ์ฌ ๋ค๋ฅธ instruct ๋ชจ๋ธ์ ์๋ํด ๋ณด์ธ์. | |
""") | |
# ์ฌ์ฉ์๊ฐ ๋ฉ์์ง๋ฅผ ์ ์ถํ๋ฉด ๋ด์ด ์๋ตํฉ๋๋ค | |
msg.submit( | |
user_input, | |
[msg, chatbot], # ์ ๋ ฅ | |
[msg, chatbot], # ์ถ๋ ฅ | |
).then( | |
bot, | |
[ | |
chatbot, | |
num_tokens, | |
final_num_tokens, | |
do_sample, | |
temperature, | |
], # ์ค์ ๋ก๋ "history" ์ ๋ ฅ | |
chatbot, # ์ถ๋ ฅ์์ ์ ํ์คํ ๋ฆฌ ์ ์ฅ | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |