ThinkFlow-llama / app.py
openfree's picture
Create app.py
42f4126 verified
raw
history blame
8.38 kB
import re
import threading
import gradio as gr
import spaces
import transformers
from transformers import pipeline
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
model_name = "Qwen/Qwen2-1.5B-Instruct"
if gr.NO_RELOAD:
pipe = pipeline(
"text-generation",
model=model_name,
device_map="auto",
torch_dtype="auto",
)
# ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
ANSWER_MARKER = "**๋‹ต๋ณ€**"
# ๋‹จ๊ณ„๋ณ„ ์ถ”๋ก ์„ ์‹œ์ž‘ํ•˜๋Š” ๋ฌธ์žฅ๋“ค
rethink_prepends = [
"์ž, ์ด์ œ ๋‹ค์Œ์„ ํŒŒ์•…ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค ",
"์ œ ์ƒ๊ฐ์—๋Š” ",
"์ž ์‹œ๋งŒ์š”, ์ œ ์ƒ๊ฐ์—๋Š” ",
"๋‹ค์Œ ์‚ฌํ•ญ์ด ๋งž๋Š”์ง€ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค ",
"๋˜ํ•œ ๊ธฐ์–ตํ•ด์•ผ ํ•  ๊ฒƒ์€ ",
"๋˜ ๋‹ค๋ฅธ ์ฃผ๋ชฉํ•  ์ ์€ ",
"๊ทธ๋ฆฌ๊ณ  ์ €๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์‚ฌ์‹ค๋„ ๊ธฐ์–ตํ•ฉ๋‹ˆ๋‹ค ",
"์ด์ œ ์ถฉ๋ถ„ํžˆ ์ดํ•ดํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค ",
"์ง€๊ธˆ๊นŒ์ง€์˜ ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ, ์›๋ž˜ ์งˆ๋ฌธ์— ์‚ฌ์šฉ๋œ ์–ธ์–ด๋กœ ๋‹ต๋ณ€ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:"
"\n{question}\n"
f"\n{ANSWER_MARKER}\n",
]
# ์ˆ˜์‹ ํ‘œ์‹œ ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ์„ค์ •
latex_delimiters = [
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
]
def reformat_math(text):
"""Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •.
์ด๊ฒƒ์€ Gradio์—์„œ ์ˆ˜ํ•™ ๊ณต์‹์„ ํ‘œ์‹œํ•˜๊ธฐ ์œ„ํ•œ ์ž„์‹œ ํ•ด๊ฒฐ์ฑ…์ž…๋‹ˆ๋‹ค. ํ˜„์žฌ๋กœ์„œ๋Š”
๋‹ค๋ฅธ latex_delimiters๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ƒ๋Œ€๋กœ ์ž‘๋™ํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค...
"""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
def user_input(message, history: list):
"""์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
return "", history + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
]
def rebuild_messages(history: list):
"""์ค‘๊ฐ„ ์ƒ๊ฐ ๊ณผ์ • ์—†์ด ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•  ํžˆ์Šคํ† ๋ฆฌ์—์„œ ๋ฉ”์‹œ์ง€ ์žฌ๊ตฌ์„ฑ"""
messages = []
for h in history:
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False):
messages.append(h)
elif (
isinstance(h, gr.ChatMessage)
and h.metadata.get("title")
and isinstance(h.content, str)
):
messages.append({"role": h.role, "content": h.content})
return messages
@spaces.GPU
def bot(
history: list,
max_num_tokens: int,
final_num_tokens: int,
do_sample: bool,
temperature: float,
):
"""๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ"""
# ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer, # pyright: ignore
skip_special_tokens=True,
skip_prompt=True,
)
# ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์ถ”๋ก ์— ์งˆ๋ฌธ์„ ๋‹ค์‹œ ์‚ฝ์ž…ํ•˜๊ธฐ ์œ„ํ•จ
question = history[-1]["content"]
# ๋ณด์กฐ์ž ๋ฉ”์‹œ์ง€ ์ค€๋น„
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
metadata={"title": "๐Ÿง  ์ƒ๊ฐ ์ค‘...", "status": "pending"},
)
)
# ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
messages = rebuild_messages(history)
for i, prepend in enumerate(rethink_prepends):
if i > 0:
messages[-1]["content"] += "\n\n"
messages[-1]["content"] += prepend.format(question=question)
num_tokens = int(
max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens
)
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
# ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
history[-1].content += prepend.format(question=question)
if ANSWER_MARKER in prepend:
history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
# ์ƒ๊ฐ ์ข…๋ฃŒ, ์ด์ œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค (์ค‘๊ฐ„ ๋‹จ๊ณ„์— ๋Œ€ํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์—†์Œ)
history.append(gr.ChatMessage(role="assistant", content=""))
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
yield history
t.join()
yield history
with gr.Blocks(fill_height=True, title="๋ชจ๋“  LLM ๋ชจ๋ธ์— ์ถ”๋ก  ๋Šฅ๋ ฅ ๋ถ€์—ฌํ•˜๊ธฐ") as demo:
with gr.Row(scale=1):
with gr.Column(scale=5):
gr.Markdown(f"""
# ๋ชจ๋“  LLM์— ์ถ”๋ก  ๋Šฅ๋ ฅ ๊ฐ•์ œํ•˜๊ธฐ
์ด๊ฒƒ์€ ๋ชจ๋“  LLM(๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ)์ด ์‘๋‹ต ์ „์— ์ถ”๋ก ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ๊ฐœ๋… ์ฆ๋ช…์ž…๋‹ˆ๋‹ค.
์ด ์ธํ„ฐํŽ˜์ด์Šค๋Š” *{model_name}* ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋Š”๋ฐ, **์ด๋Š” ์ถ”๋ก  ๋ชจ๋ธ์ด ์•„๋‹™๋‹ˆ๋‹ค**. ์‚ฌ์šฉ๋œ ๋ฐฉ๋ฒ•์€
๋‹จ์ง€ ์ ‘๋‘์‚ฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์ด ๋‹ต๋ณ€์„ ํ–ฅ์ƒ์‹œํ‚ค๋Š” ๋ฐ ๋„์›€์ด ๋˜๋Š” "์ถ”๋ก " ๋‹จ๊ณ„๋ฅผ ๊ฐ•์ œํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
๊ด€๋ จ ๊ธฐ์‚ฌ๋Š” ๋‹ค์Œ์—์„œ ํ™•์ธํ•˜์„ธ์š”: [๋ชจ๋“  ๋ชจ๋ธ์— ์ถ”๋ก  ๋Šฅ๋ ฅ ๋ถ€์—ฌํ•˜๊ธฐ](https://huggingface.co/blog/Metal3d/making-any-model-reasoning)
""")
chatbot = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
)
msg = gr.Textbox(
submit_btn=True,
label="",
show_label=False,
placeholder="์—ฌ๊ธฐ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
autofocus=True,
)
with gr.Column(scale=1):
gr.Markdown("""## ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •""")
num_tokens = gr.Slider(
50,
1024,
100,
step=1,
label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
interactive=True,
)
final_num_tokens = gr.Slider(
50,
1024,
512,
step=1,
label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
interactive=True,
)
do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
gr.Markdown("""
์ถ”๋ก  ๋‹จ๊ณ„์—์„œ ๋” ์ ์€ ์ˆ˜์˜ ํ† ํฐ์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ชจ๋ธ์ด
๋” ๋นจ๋ฆฌ ๋‹ต๋ณ€ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, ์ถฉ๋ถ„ํžˆ ๊นŠ๊ฒŒ ์ถ”๋ก ํ•˜์ง€ ๋ชปํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ ์ ˆํ•œ ๊ฐ’์€ 100์—์„œ 512์ž…๋‹ˆ๋‹ค.
์ตœ์ข… ๋‹ต๋ณ€์— ๋” ์ ์€ ์ˆ˜์˜ ํ† ํฐ์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ชจ๋ธ์˜
์‘๋‹ต์ด ๋œ ์žฅํ™ฉํ•ด์ง€์ง€๋งŒ, ์™„์ „ํ•œ ๋‹ต๋ณ€์„ ์ œ๊ณตํ•˜์ง€ ๋ชปํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ ์ ˆํ•œ ๊ฐ’์€ 512์—์„œ 1024์ž…๋‹ˆ๋‹ค.
**์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ**์€ ๋‹ต๋ณ€์„ ์™„์„ฑํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์Œ ํ† ํฐ์„ ์„ ํƒํ•˜๋Š” ๋‹ค๋ฅธ ์ „๋žต์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
์ผ๋ฐ˜์ ์œผ๋กœ ์ด ์˜ต์…˜์„ ์ฒดํฌํ•ด ๋‘๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.
**์˜จ๋„**๋Š” ๋ชจ๋ธ์ด ์–ผ๋งˆ๋‚˜ "์ฐฝ์˜์ "์ผ ์ˆ˜ ์žˆ๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. 0.7์ด ์ผ๋ฐ˜์ ์ธ ๊ฐ’์ž…๋‹ˆ๋‹ค.
๋„ˆ๋ฌด ๋†’์€ ๊ฐ’(์˜ˆ: 1.0)์„ ์„ค์ •ํ•˜๋ฉด ๋ชจ๋ธ์ด ์ผ๊ด€์„ฑ์ด ์—†์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‚ฎ์€ ๊ฐ’(์˜ˆ: 0.3)์œผ๋กœ
์„ค์ •ํ•˜๋ฉด ๋ชจ๋ธ์€ ๋งค์šฐ ์˜ˆ์ธก ๊ฐ€๋Šฅํ•œ ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
""")
gr.Markdown("""
์ด ์ธํ„ฐํŽ˜์ด์Šค๋Š” 6GB VRAM์„ ๊ฐ€์ง„ ๊ฐœ์ธ ์ปดํ“จํ„ฐ์—์„œ ์ž‘๋™ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(์˜ˆ: ๋…ธํŠธ๋ถ์˜ NVidia 3050/3060).
์ž์œ ๋กญ๊ฒŒ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ํฌํฌํ•˜์—ฌ ๋‹ค๋ฅธ instruct ๋ชจ๋ธ์„ ์‹œ๋„ํ•ด ๋ณด์„ธ์š”.
""")
# ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋ด‡์ด ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
msg.submit(
user_input,
[msg, chatbot], # ์ž…๋ ฅ
[msg, chatbot], # ์ถœ๋ ฅ
).then(
bot,
[
chatbot,
num_tokens,
final_num_tokens,
do_sample,
temperature,
], # ์‹ค์ œ๋กœ๋Š” "history" ์ž…๋ ฅ
chatbot, # ์ถœ๋ ฅ์—์„œ ์ƒˆ ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
)
if __name__ == "__main__":
demo.queue().launch()