Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -15,12 +15,10 @@ import torch
|
|
15 |
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
16 |
|
17 |
# Configuration
|
18 |
-
# Use this MODEL_ID, adjust if you have a local path instead
|
19 |
MODEL_ID = os.getenv("GEMMA_MODEL_PATH", "tabularisai/german-gemma-3-1b-it")
|
20 |
-
|
21 |
-
HF_TOKEN = os.getenv("Tokentest")
|
22 |
|
23 |
-
# Load tokenizer and model
|
24 |
print(f"Loading model {MODEL_ID}...")
|
25 |
tokenizer = AutoTokenizer.from_pretrained(
|
26 |
MODEL_ID,
|
@@ -35,7 +33,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
35 |
device_map="auto"
|
36 |
).eval()
|
37 |
|
38 |
-
#
|
39 |
PAD = tokenizer.pad_token_id or tokenizer.eos_token_id
|
40 |
EOT = tokenizer.convert_tokens_to_ids('<end_of_turn>')
|
41 |
|
@@ -68,10 +66,9 @@ class PodcastGenerator:
|
|
68 |
|
69 |
full_prompt = system_prompt + "\n\n" + user_prompt
|
70 |
|
71 |
-
# sync generation in executor
|
72 |
def gen_sync():
|
73 |
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
|
74 |
-
# add stopping criteria
|
75 |
stop_crit = StoppingCriteriaList([StoppingCriteria(max_length=512)])
|
76 |
outputs = model.generate(
|
77 |
**inputs,
|
@@ -139,4 +136,4 @@ def run_app():
|
|
139 |
demo.launch()
|
140 |
|
141 |
if __name__ == '__main__':
|
142 |
-
run_app()
|
|
|
15 |
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
16 |
|
17 |
# Configuration
|
|
|
18 |
MODEL_ID = os.getenv("GEMMA_MODEL_PATH", "tabularisai/german-gemma-3-1b-it")
|
19 |
+
HF_TOKEN = os.getenv("Tokentest") # Optional
|
|
|
20 |
|
21 |
+
# Load tokenizer and model using external snippet
|
22 |
print(f"Loading model {MODEL_ID}...")
|
23 |
tokenizer = AutoTokenizer.from_pretrained(
|
24 |
MODEL_ID,
|
|
|
33 |
device_map="auto"
|
34 |
).eval()
|
35 |
|
36 |
+
# Stopping criteria tokens
|
37 |
PAD = tokenizer.pad_token_id or tokenizer.eos_token_id
|
38 |
EOT = tokenizer.convert_tokens_to_ids('<end_of_turn>')
|
39 |
|
|
|
66 |
|
67 |
full_prompt = system_prompt + "\n\n" + user_prompt
|
68 |
|
69 |
+
# sync generation in executor using model.generate
|
70 |
def gen_sync():
|
71 |
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
|
|
|
72 |
stop_crit = StoppingCriteriaList([StoppingCriteria(max_length=512)])
|
73 |
outputs = model.generate(
|
74 |
**inputs,
|
|
|
136 |
demo.launch()
|
137 |
|
138 |
if __name__ == '__main__':
|
139 |
+
run_app()
|