Rausda6 commited on
Commit
84117d5
·
verified ·
1 Parent(s): 4211f84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -15,12 +15,10 @@ import torch
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
16
 
17
  # Configuration
18
- # Use this MODEL_ID, adjust if you have a local path instead
19
  MODEL_ID = os.getenv("GEMMA_MODEL_PATH", "tabularisai/german-gemma-3-1b-it")
20
- # Hugging Face token secret (optional, for gated/private models)
21
- HF_TOKEN = os.getenv("Tokentest")
22
 
23
- # Load tokenizer and model
24
  print(f"Loading model {MODEL_ID}...")
25
  tokenizer = AutoTokenizer.from_pretrained(
26
  MODEL_ID,
@@ -35,7 +33,7 @@ model = AutoModelForCausalLM.from_pretrained(
35
  device_map="auto"
36
  ).eval()
37
 
38
- # Optional: set up a simple stopping criteria on <end_of_turn> token
39
  PAD = tokenizer.pad_token_id or tokenizer.eos_token_id
40
  EOT = tokenizer.convert_tokens_to_ids('<end_of_turn>')
41
 
@@ -68,10 +66,9 @@ class PodcastGenerator:
68
 
69
  full_prompt = system_prompt + "\n\n" + user_prompt
70
 
71
- # sync generation in executor
72
  def gen_sync():
73
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
74
- # add stopping criteria
75
  stop_crit = StoppingCriteriaList([StoppingCriteria(max_length=512)])
76
  outputs = model.generate(
77
  **inputs,
@@ -139,4 +136,4 @@ def run_app():
139
  demo.launch()
140
 
141
  if __name__ == '__main__':
142
- run_app()
 
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
16
 
17
  # Configuration
 
18
  MODEL_ID = os.getenv("GEMMA_MODEL_PATH", "tabularisai/german-gemma-3-1b-it")
19
+ HF_TOKEN = os.getenv("Tokentest") # Optional
 
20
 
21
+ # Load tokenizer and model using external snippet
22
  print(f"Loading model {MODEL_ID}...")
23
  tokenizer = AutoTokenizer.from_pretrained(
24
  MODEL_ID,
 
33
  device_map="auto"
34
  ).eval()
35
 
36
+ # Stopping criteria tokens
37
  PAD = tokenizer.pad_token_id or tokenizer.eos_token_id
38
  EOT = tokenizer.convert_tokens_to_ids('<end_of_turn>')
39
 
 
66
 
67
  full_prompt = system_prompt + "\n\n" + user_prompt
68
 
69
+ # sync generation in executor using model.generate
70
  def gen_sync():
71
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
 
72
  stop_crit = StoppingCriteriaList([StoppingCriteria(max_length=512)])
73
  outputs = model.generate(
74
  **inputs,
 
136
  demo.launch()
137
 
138
  if __name__ == '__main__':
139
+ run_app()