from collections.abc import Sequence import random from typing import Optional import gradio as gr import spaces import torch import transformers # If the watewrmark is not detected, consider the use case. Could be because of # the nature of the task (e.g., fatcual responses are lower entropy) or it could # be another _MODEL_IDENTIFIER = 'hf-internal-testing/tiny-random-gpt2' _PROMPTS: tuple[str] = ( 'prompt 1', 'prompt 2', 'prompt 3', ) _CORRECT_ANSWERS: dict[str, bool] = {} _TORCH_DEVICE = ( torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") ) _WATERMARK_CONFIG = transformers.generation.SynthIDTextWatermarkingConfig( ngram_len=5, keys=[ 654, 400, 836, 123, 340, 443, 597, 160, 57, 29, 590, 639, 13, 715, 468, 990, 966, 226, 324, 585, 118, 504, 421, 521, 129, 669, 732, 225, 90, 960, ], sampling_table_size=2**16, sampling_table_seed=0, context_history_size=1024, ) tokenizer = transformers.AutoTokenizer.from_pretrained(_MODEL_IDENTIFIER) tokenizer.pad_token_id = tokenizer.eos_token_id model = transformers.AutoModelForCausalLM.from_pretrained(_MODEL_IDENTIFIER) model.to(_TORCH_DEVICE) @spaces.GPU def generate_outputs( prompts: Sequence[str], watermarking_config: Optional[ transformers.generation.SynthIDTextWatermarkingConfig ] = None, ) -> Sequence[str]: tokenized_prompts = tokenizer(prompts, return_tensors='pt').to(_TORCH_DEVICE) output_sequences = model.generate( **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, max_length=500, top_k=40, ) return tokenizer.batch_decode(output_sequences) with gr.Blocks() as demo: prompt_inputs = [ gr.Textbox(value=prompt, lines=4, label='Prompt') for prompt in _PROMPTS ] generate_btn = gr.Button('Generate') with gr.Column(visible=False) as generations_col: generations_grp = gr.CheckboxGroup( label='All generations, in random order', info='Select the generations you think are watermarked!', ) reveal_btn = gr.Button('Reveal', visible=False) with gr.Column(visible=False) as detections_col: revealed_grp = gr.CheckboxGroup( label='Ground truth for all generations', info=( 'Watermarked generations are checked, and your selection are ' 'marked as correct or incorrect in the text.' ), ) detect_btn = gr.Button('Detect', visible=False) def generate(*prompts): standard = generate_outputs(prompts=prompts) watermarked = generate_outputs( prompts=prompts, watermarking_config=_WATERMARK_CONFIG, ) responses = standard + watermarked random.shuffle(responses) _CORRECT_ANSWERS.update({ response: response in watermarked for response in responses }) # Load model return { generate_btn: gr.Button(visible=False), generations_col: gr.Column(visible=True), generations_grp: gr.CheckboxGroup( responses, ), reveal_btn: gr.Button(visible=True), } generate_btn.click( generate, inputs=prompt_inputs, outputs=[generate_btn, generations_col, generations_grp, reveal_btn] ) def reveal(user_selections: list[str]): choices: list[str] = [] value: list[str] = [] for response, is_watermarked in _CORRECT_ANSWERS.items(): if is_watermarked and response in user_selections: choice = f'Correct! {response}' elif not is_watermarked and response not in user_selections: choice = f'Correct! {response}' else: choice = f'Incorrect. {response}' choices.append(choice) if is_watermarked: value.append(choice) return { reveal_btn: gr.Button(visible=False), detections_col: gr.Column(visible=True), revealed_grp: gr.CheckboxGroup(choices=choices, value=value), detect_btn: gr.Button(visible=True), } reveal_btn.click( reveal, inputs=generations_grp, outputs=[ reveal_btn, detections_col, revealed_grp, detect_btn ], ) if __name__ == '__main__': demo.launch()