jwkirchenbauer commited on
Commit
34e65b0
·
1 Parent(s): 5ac7b67

limit to one small model that fits in 24gb vram

Browse files
Files changed (2) hide show
  1. app.py +9 -9
  2. demo_watermark.py +25 -17
app.py CHANGED
@@ -24,17 +24,17 @@ arg_dict = {
24
  # 'model_name_or_path': 'facebook/opt-2.7b', # historical
25
  # 'model_name_or_path': 'facebook/opt-6.7b', # historical
26
  # 'model_name_or_path': 'meta-llama/Llama-2-7b-hf', # historical
27
- 'model_name_or_path': 'meta-llama/Llama-3.1-8B',
28
  'all_models':[
29
- "meta-llama/Llama-3.1-8B",
30
  "meta-llama/Llama-3.2-3B",
31
- "meta-llama/Llama-3.2-1B",
32
- "Qwen/Qwen3-8B",
33
- "Qwen/Qwen3-4B",
34
- "Qwen/Qwen3-1.7B",
35
- "Qwen/Qwen3-0.6B",
36
- "Qwen/Qwen3-4B-Instruct-2507",
37
- "Qwen/Qwen3-4B-Thinking-2507",
38
  ],
39
  # 'load_fp16' : True,
40
  'load_fp16' : False,
 
24
  # 'model_name_or_path': 'facebook/opt-2.7b', # historical
25
  # 'model_name_or_path': 'facebook/opt-6.7b', # historical
26
  # 'model_name_or_path': 'meta-llama/Llama-2-7b-hf', # historical
27
+ 'model_name_or_path': 'meta-llama/Llama-3.2-3B',
28
  'all_models':[
29
+ # "meta-llama/Llama-3.1-8B", # too big for the A10G 24GB
30
  "meta-llama/Llama-3.2-3B",
31
+ # "meta-llama/Llama-3.2-1B",
32
+ # "Qwen/Qwen3-8B", # too big for the A10G 24GB
33
+ # "Qwen/Qwen3-4B",
34
+ # "Qwen/Qwen3-1.7B",
35
+ # "Qwen/Qwen3-0.6B",
36
+ # "Qwen/Qwen3-4B-Instruct-2507",
37
+ # "Qwen/Qwen3-4B-Thinking-2507",
38
  ],
39
  # 'load_fp16' : True,
40
  'load_fp16' : False,
demo_watermark.py CHANGED
@@ -19,6 +19,8 @@ import argparse
19
  from pprint import pprint
20
  from functools import partial
21
 
 
 
22
  import numpy # for gradio hot reload
23
  import gradio as gr
24
 
@@ -206,9 +208,11 @@ def load_model(args):
206
  model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
207
  elif args.is_decoder_only_model:
208
  if args.load_fp16:
209
- model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
 
210
  elif args.load_bf16:
211
- model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16, device_map='auto')
 
212
  else:
213
  model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
214
  else:
@@ -216,12 +220,18 @@ def load_model(args):
216
 
217
  if args.use_gpu:
218
  device = "cuda" if torch.cuda.is_available() else "cpu"
219
- if args.load_fp16 or args.load_bf16:
220
- pass
221
- else:
222
- model = model.to(device)
223
  else:
224
  device = "cpu"
 
 
 
 
 
 
225
  model.eval()
226
 
227
  tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
@@ -268,7 +278,7 @@ def generate_with_api(prompt, args):
268
  yield all_without_words, all_with_words
269
 
270
 
271
- def check_prompt(prompt, args, tokenizer, model, device=None):
272
 
273
  # This applies to both the local and API model scenarios
274
  if args.model_name_or_path in API_MODEL_MAP:
@@ -288,7 +298,7 @@ def check_prompt(prompt, args, tokenizer, model, device=None):
288
 
289
 
290
 
291
- def generate(prompt, args, tokenizer, model, device=None):
292
  """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
293
  and generate watermarked text by passing it to the generate method of the model
294
  as a logits processor. """
@@ -486,11 +496,10 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
486
  default_prompt = args.__dict__.pop("default_prompt")
487
  session_args = gr.State(value=args)
488
  # note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
489
- session_tokenizer = gr.State(value=lambda : tokenizer)
490
- session_model = gr.State(value=lambda : model)
491
 
492
- check_prompt_partial = partial(check_prompt, device=device)
493
- generate_partial = partial(generate, device=device)
494
  detect_partial = partial(detect, device=device)
495
 
496
  with gr.Tab("Welcome"):
@@ -704,8 +713,8 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
704
  """)
705
 
706
  # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
707
- generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer, session_model], outputs=[redecoded_input, truncation_warning, session_args]).success(
708
- fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer,session_model], outputs=[output_without_watermark, output_with_watermark]).success(
709
  fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success(
710
  fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
711
  # Show truncated version of prompt if truncation occurred
@@ -781,6 +790,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
781
  def update_model(state, old_model):
782
  del old_model
783
  torch.cuda.empty_cache()
 
784
  model, _, _ = load_model(state)
785
  return model
786
 
@@ -803,8 +813,6 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
803
  update_model_state,inputs=[session_args, model_selector], outputs=[session_args]
804
  ).then(
805
  update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]
806
- ).then(
807
- update_model,inputs=[session_args, session_model], outputs=[session_model]
808
  ).then(
809
  lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
810
  )
@@ -852,7 +860,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
852
  select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
853
  select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
854
 
855
- # demo.queue(concurrency_count=3)
856
  demo.queue()
857
 
858
  if args.demo_public:
 
19
  from pprint import pprint
20
  from functools import partial
21
 
22
+ import gc
23
+
24
  import numpy # for gradio hot reload
25
  import gradio as gr
26
 
 
208
  model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
209
  elif args.is_decoder_only_model:
210
  if args.load_fp16:
211
+ # model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
212
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16)
213
  elif args.load_bf16:
214
+ # model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16, device_map='auto')
215
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16)
216
  else:
217
  model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
218
  else:
 
220
 
221
  if args.use_gpu:
222
  device = "cuda" if torch.cuda.is_available() else "cpu"
223
+ # if args.load_fp16 or args.load_bf16:
224
+ # pass
225
+ # else:
226
+ model = model.to(device)
227
  else:
228
  device = "cpu"
229
+
230
+ if args.load_bf16:
231
+ model = model.to(torch.bfloat16)
232
+ if args.load_fp16:
233
+ model = model.to(torch.float16)
234
+
235
  model.eval()
236
 
237
  tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
 
278
  yield all_without_words, all_with_words
279
 
280
 
281
+ def check_prompt(prompt, args, tokenizer, model=None, device=None):
282
 
283
  # This applies to both the local and API model scenarios
284
  if args.model_name_or_path in API_MODEL_MAP:
 
298
 
299
 
300
 
301
+ def generate(prompt, args, tokenizer, model=None, device=None):
302
  """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
303
  and generate watermarked text by passing it to the generate method of the model
304
  as a logits processor. """
 
496
  default_prompt = args.__dict__.pop("default_prompt")
497
  session_args = gr.State(value=args)
498
  # note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
499
+ session_tokenizer = gr.State(value=lambda : tokenizer)
 
500
 
501
+ check_prompt_partial = partial(check_prompt, model=model, device=device)
502
+ generate_partial = partial(generate, model=model, device=device)
503
  detect_partial = partial(detect, device=device)
504
 
505
  with gr.Tab("Welcome"):
 
713
  """)
714
 
715
  # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
716
+ generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer], outputs=[redecoded_input, truncation_warning, session_args]).success(
717
+ fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer], outputs=[output_without_watermark, output_with_watermark]).success(
718
  fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success(
719
  fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
720
  # Show truncated version of prompt if truncation occurred
 
790
  def update_model(state, old_model):
791
  del old_model
792
  torch.cuda.empty_cache()
793
+ gc.collect()
794
  model, _, _ = load_model(state)
795
  return model
796
 
 
813
  update_model_state,inputs=[session_args, model_selector], outputs=[session_args]
814
  ).then(
815
  update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]
 
 
816
  ).then(
817
  lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
818
  )
 
860
  select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
861
  select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
862
 
863
+
864
  demo.queue()
865
 
866
  if args.demo_public: