Update generator/llm_inference.py
Browse files- generator/llm_inference.py +30 -27
generator/llm_inference.py
CHANGED
@@ -1,27 +1,30 @@
|
|
1 |
-
from transformers import pipeline
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
"""
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
import spaces
|
3 |
+
|
4 |
+
# 1. ๋ชจ๋ธ ๋ก๋ (์ต์ด 1๋ฒ๋ง ๋ก๋๋จ)
|
5 |
+
generator = pipeline(
|
6 |
+
"text-generation",
|
7 |
+
model="dasomaru/gemma-3-4bit-it-demo", # ๋ค๊ฐ ์
๋ก๋ํ ๋ชจ๋ธ ์ด๋ฆ
|
8 |
+
tokenizer="dasomaru/gemma-3-4bit-it-demo",
|
9 |
+
device=0, # CUDA:0 ์ฌ์ฉ (GPU). CPU๋ง ์์ผ๋ฉด device=-1
|
10 |
+
max_new_tokens=512,
|
11 |
+
temperature=0.7,
|
12 |
+
top_p=0.9,
|
13 |
+
repetition_penalty=1.1
|
14 |
+
)
|
15 |
+
|
16 |
+
# 2. ๋ต๋ณ ์์ฑ ํจ์
|
17 |
+
@spaces.GPU(duration=300)
|
18 |
+
def generate_answer(prompt: str) -> str:
|
19 |
+
"""
|
20 |
+
์
๋ ฅ๋ฐ์ ํ๋กฌํํธ๋ก๋ถํฐ ๋ชจ๋ธ์ด ๋ต๋ณ์ ์์ฑํ๋ค.
|
21 |
+
"""
|
22 |
+
print(f"๐ต Prompt Length: {len(prompt)} characters") # ์ถ๊ฐ!
|
23 |
+
outputs = generator(
|
24 |
+
prompt,
|
25 |
+
do_sample=True,
|
26 |
+
top_k=50,
|
27 |
+
num_return_sequences=1
|
28 |
+
)
|
29 |
+
return outputs[0]["generated_text"].strip()
|
30 |
+
|