Spaces:
Running
on
A10G
Running
on
A10G
Commit
·
34e65b0
1
Parent(s):
5ac7b67
limit to one small model that fits in 24gb vram
Browse files- app.py +9 -9
- demo_watermark.py +25 -17
app.py
CHANGED
@@ -24,17 +24,17 @@ arg_dict = {
|
|
24 |
# 'model_name_or_path': 'facebook/opt-2.7b', # historical
|
25 |
# 'model_name_or_path': 'facebook/opt-6.7b', # historical
|
26 |
# 'model_name_or_path': 'meta-llama/Llama-2-7b-hf', # historical
|
27 |
-
'model_name_or_path': 'meta-llama/Llama-3.
|
28 |
'all_models':[
|
29 |
-
"meta-llama/Llama-3.1-8B",
|
30 |
"meta-llama/Llama-3.2-3B",
|
31 |
-
"meta-llama/Llama-3.2-1B",
|
32 |
-
"Qwen/Qwen3-8B",
|
33 |
-
"Qwen/Qwen3-4B",
|
34 |
-
"Qwen/Qwen3-1.7B",
|
35 |
-
"Qwen/Qwen3-0.6B",
|
36 |
-
"Qwen/Qwen3-4B-Instruct-2507",
|
37 |
-
"Qwen/Qwen3-4B-Thinking-2507",
|
38 |
],
|
39 |
# 'load_fp16' : True,
|
40 |
'load_fp16' : False,
|
|
|
24 |
# 'model_name_or_path': 'facebook/opt-2.7b', # historical
|
25 |
# 'model_name_or_path': 'facebook/opt-6.7b', # historical
|
26 |
# 'model_name_or_path': 'meta-llama/Llama-2-7b-hf', # historical
|
27 |
+
'model_name_or_path': 'meta-llama/Llama-3.2-3B',
|
28 |
'all_models':[
|
29 |
+
# "meta-llama/Llama-3.1-8B", # too big for the A10G 24GB
|
30 |
"meta-llama/Llama-3.2-3B",
|
31 |
+
# "meta-llama/Llama-3.2-1B",
|
32 |
+
# "Qwen/Qwen3-8B", # too big for the A10G 24GB
|
33 |
+
# "Qwen/Qwen3-4B",
|
34 |
+
# "Qwen/Qwen3-1.7B",
|
35 |
+
# "Qwen/Qwen3-0.6B",
|
36 |
+
# "Qwen/Qwen3-4B-Instruct-2507",
|
37 |
+
# "Qwen/Qwen3-4B-Thinking-2507",
|
38 |
],
|
39 |
# 'load_fp16' : True,
|
40 |
'load_fp16' : False,
|
demo_watermark.py
CHANGED
@@ -19,6 +19,8 @@ import argparse
|
|
19 |
from pprint import pprint
|
20 |
from functools import partial
|
21 |
|
|
|
|
|
22 |
import numpy # for gradio hot reload
|
23 |
import gradio as gr
|
24 |
|
@@ -206,9 +208,11 @@ def load_model(args):
|
|
206 |
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
|
207 |
elif args.is_decoder_only_model:
|
208 |
if args.load_fp16:
|
209 |
-
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
|
|
|
210 |
elif args.load_bf16:
|
211 |
-
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16, device_map='auto')
|
|
|
212 |
else:
|
213 |
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
|
214 |
else:
|
@@ -216,12 +220,18 @@ def load_model(args):
|
|
216 |
|
217 |
if args.use_gpu:
|
218 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
219 |
-
if args.load_fp16 or args.load_bf16:
|
220 |
-
|
221 |
-
else:
|
222 |
-
|
223 |
else:
|
224 |
device = "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
model.eval()
|
226 |
|
227 |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
@@ -268,7 +278,7 @@ def generate_with_api(prompt, args):
|
|
268 |
yield all_without_words, all_with_words
|
269 |
|
270 |
|
271 |
-
def check_prompt(prompt, args, tokenizer, model, device=None):
|
272 |
|
273 |
# This applies to both the local and API model scenarios
|
274 |
if args.model_name_or_path in API_MODEL_MAP:
|
@@ -288,7 +298,7 @@ def check_prompt(prompt, args, tokenizer, model, device=None):
|
|
288 |
|
289 |
|
290 |
|
291 |
-
def generate(prompt, args, tokenizer, model, device=None):
|
292 |
"""Instatiate the WatermarkLogitsProcessor according to the watermark parameters
|
293 |
and generate watermarked text by passing it to the generate method of the model
|
294 |
as a logits processor. """
|
@@ -486,11 +496,10 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
486 |
default_prompt = args.__dict__.pop("default_prompt")
|
487 |
session_args = gr.State(value=args)
|
488 |
# note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
|
489 |
-
session_tokenizer = gr.State(value=lambda : tokenizer)
|
490 |
-
session_model = gr.State(value=lambda : model)
|
491 |
|
492 |
-
check_prompt_partial = partial(check_prompt, device=device)
|
493 |
-
generate_partial = partial(generate, device=device)
|
494 |
detect_partial = partial(detect, device=device)
|
495 |
|
496 |
with gr.Tab("Welcome"):
|
@@ -704,8 +713,8 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
704 |
""")
|
705 |
|
706 |
# Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
|
707 |
-
generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer
|
708 |
-
fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer
|
709 |
fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success(
|
710 |
fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
|
711 |
# Show truncated version of prompt if truncation occurred
|
@@ -781,6 +790,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
781 |
def update_model(state, old_model):
|
782 |
del old_model
|
783 |
torch.cuda.empty_cache()
|
|
|
784 |
model, _, _ = load_model(state)
|
785 |
return model
|
786 |
|
@@ -803,8 +813,6 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
803 |
update_model_state,inputs=[session_args, model_selector], outputs=[session_args]
|
804 |
).then(
|
805 |
update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]
|
806 |
-
).then(
|
807 |
-
update_model,inputs=[session_args, session_model], outputs=[session_model]
|
808 |
).then(
|
809 |
lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
|
810 |
)
|
@@ -852,7 +860,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
852 |
select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
|
853 |
select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
|
854 |
|
855 |
-
|
856 |
demo.queue()
|
857 |
|
858 |
if args.demo_public:
|
|
|
19 |
from pprint import pprint
|
20 |
from functools import partial
|
21 |
|
22 |
+
import gc
|
23 |
+
|
24 |
import numpy # for gradio hot reload
|
25 |
import gradio as gr
|
26 |
|
|
|
208 |
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
|
209 |
elif args.is_decoder_only_model:
|
210 |
if args.load_fp16:
|
211 |
+
# model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
|
212 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16)
|
213 |
elif args.load_bf16:
|
214 |
+
# model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16, device_map='auto')
|
215 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16)
|
216 |
else:
|
217 |
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
|
218 |
else:
|
|
|
220 |
|
221 |
if args.use_gpu:
|
222 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
223 |
+
# if args.load_fp16 or args.load_bf16:
|
224 |
+
# pass
|
225 |
+
# else:
|
226 |
+
model = model.to(device)
|
227 |
else:
|
228 |
device = "cpu"
|
229 |
+
|
230 |
+
if args.load_bf16:
|
231 |
+
model = model.to(torch.bfloat16)
|
232 |
+
if args.load_fp16:
|
233 |
+
model = model.to(torch.float16)
|
234 |
+
|
235 |
model.eval()
|
236 |
|
237 |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
|
|
278 |
yield all_without_words, all_with_words
|
279 |
|
280 |
|
281 |
+
def check_prompt(prompt, args, tokenizer, model=None, device=None):
|
282 |
|
283 |
# This applies to both the local and API model scenarios
|
284 |
if args.model_name_or_path in API_MODEL_MAP:
|
|
|
298 |
|
299 |
|
300 |
|
301 |
+
def generate(prompt, args, tokenizer, model=None, device=None):
|
302 |
"""Instatiate the WatermarkLogitsProcessor according to the watermark parameters
|
303 |
and generate watermarked text by passing it to the generate method of the model
|
304 |
as a logits processor. """
|
|
|
496 |
default_prompt = args.__dict__.pop("default_prompt")
|
497 |
session_args = gr.State(value=args)
|
498 |
# note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
|
499 |
+
session_tokenizer = gr.State(value=lambda : tokenizer)
|
|
|
500 |
|
501 |
+
check_prompt_partial = partial(check_prompt, model=model, device=device)
|
502 |
+
generate_partial = partial(generate, model=model, device=device)
|
503 |
detect_partial = partial(detect, device=device)
|
504 |
|
505 |
with gr.Tab("Welcome"):
|
|
|
713 |
""")
|
714 |
|
715 |
# Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
|
716 |
+
generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer], outputs=[redecoded_input, truncation_warning, session_args]).success(
|
717 |
+
fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer], outputs=[output_without_watermark, output_with_watermark]).success(
|
718 |
fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success(
|
719 |
fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
|
720 |
# Show truncated version of prompt if truncation occurred
|
|
|
790 |
def update_model(state, old_model):
|
791 |
del old_model
|
792 |
torch.cuda.empty_cache()
|
793 |
+
gc.collect()
|
794 |
model, _, _ = load_model(state)
|
795 |
return model
|
796 |
|
|
|
813 |
update_model_state,inputs=[session_args, model_selector], outputs=[session_args]
|
814 |
).then(
|
815 |
update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]
|
|
|
|
|
816 |
).then(
|
817 |
lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
|
818 |
)
|
|
|
860 |
select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
|
861 |
select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
|
862 |
|
863 |
+
|
864 |
demo.queue()
|
865 |
|
866 |
if args.demo_public:
|