Spaces:
Sleeping
Sleeping
import gradio as gr | |
from utils.watermark import Watermarker | |
from utils.config import load_config | |
from renderers.highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html | |
from renderers.tree import generate_subplot1, generate_subplot2 | |
from pathlib import Path | |
import time | |
from typing import Dict, List, Tuple, Any | |
import plotly.graph_objects as go | |
class WatermarkerInterface: | |
def __init__(self, config): | |
self.pipeline = Watermarker(config) | |
self.common_grams = {} | |
self.highlight_info = [] | |
self.masked_sentences = [] | |
# Add tracking dictionaries for indexing | |
self.masked_sentence_indices = {} # Maps original sentences to masked indices | |
self.sampled_sentence_indices = {} # Maps masked sentences to sampling indices | |
self.reparaphrased_indices = {} # Maps sampled sentences to reparaphrased indices | |
def handle_paraphrase(self, prompt: str) -> Tuple[str, str, str, str]: | |
"""Wrapper for paraphrasing that includes highlighting""" | |
start_time = time.time() | |
# Run paraphrasing | |
self.pipeline.Paraphrase(prompt) | |
# Step 1: Process the original sentence first | |
seen_ngrams = {} # Stores first occurrence index of each n-gram | |
original_indexed_ngrams = [] # Final indexed list for original | |
original_sentence = self.pipeline.user_prompt | |
original_ngrams = self.pipeline.common_grams.get(original_sentence, {}) | |
# Step 1.1: Extract n-grams and their first occurrence index | |
ngram_occurrences = [ | |
(min(indices, key=lambda x: x[0])[0], gram) # Get first index | |
for gram, indices in original_ngrams.items() | |
] | |
# Step 1.2: Sort n-grams based on their first occurrence | |
ngram_occurrences.sort() | |
# Step 1.3: Assign sequential indices | |
for idx, (position, gram) in enumerate(ngram_occurrences, start=1): | |
seen_ngrams[gram] = idx # Assign sequential index | |
original_indexed_ngrams.append((idx, gram)) | |
print("Original Indexed N-grams:", original_indexed_ngrams) | |
#generate highlight_info | |
colors = ["red", "blue", "purple", "green", "orange"] | |
highlight_info = [ | |
(ngram, colors[i % len(colors)]) | |
for i, (index, ngram) in enumerate(original_indexed_ngrams) | |
] | |
common_grams = original_indexed_ngrams | |
self.highlight_info = highlight_info | |
self.common_grams = common_grams | |
# Step 2: Process paraphrased sentences and match indices | |
paraphrase_indexed_ngrams = {} | |
for sentence in self.pipeline.paraphrased_sentences: | |
sentence_ngrams = [] # Stores n-grams for this sentence | |
sentence_ngrams_dict = self.pipeline.common_grams.get(sentence, {}) | |
for gram, indices in sentence_ngrams_dict.items(): | |
first_occurrence = min(indices, key=lambda x: x[0])[0] | |
# Use the original's index if exists, otherwise assign a new one | |
if gram in seen_ngrams: | |
index = seen_ngrams[gram] # Use the same index as original | |
else: | |
index = len(seen_ngrams) + 1 # Assign new index | |
seen_ngrams[gram] = index # Store it | |
sentence_ngrams.append((index, gram)) | |
sentence_ngrams.sort() | |
paraphrase_indexed_ngrams[sentence] = sentence_ngrams | |
print("Paraphrase Indexed N-grams:", paraphrase_indexed_ngrams) | |
# Step 3: Generate highlighted versions using the renderer | |
highlighted_prompt = highlight_common_words( | |
common_grams, | |
[self.pipeline.user_prompt], | |
"Original Prompt with Highlighted Common Sequences" | |
) | |
highlighted_accepted = highlight_common_words_dict( | |
common_grams, | |
self.pipeline.selected_sentences, | |
"Accepted Paraphrased Sentences with Entailment Scores" | |
) | |
highlighted_discarded = highlight_common_words_dict( | |
common_grams, | |
self.pipeline.discarded_sentences, | |
"Discarded Paraphrased Sentences with Entailment Scores" | |
) | |
execution_time = f"<div class='execution-time'>Step 1 completed in {time.time() - start_time:.2f} seconds</div>" | |
return highlighted_prompt, highlighted_accepted, highlighted_discarded, execution_time | |
def handle_masking(self): | |
start_time = time.time() | |
masking_results = self.pipeline.Masking() | |
trees = [] | |
highlight_info = self.highlight_info | |
common_grams = self.common_grams | |
sentence_to_masked = {} | |
self.masked_sentence_indices = {} | |
for strategy, sentence_dict in masking_results.items(): | |
for sent, data in sentence_dict.items(): | |
if sent not in sentence_to_masked: | |
sentence_to_masked[sent] = [] | |
masked_sentence = data.get("masked_sentence", "") | |
if masked_sentence: | |
sentence_to_masked[sent].append((masked_sentence, strategy)) | |
plot_idx = 1 | |
for original_sentence, masked_sentences_data in sentence_to_masked.items(): | |
if not masked_sentences_data: | |
continue | |
masked_idx = 1 | |
for masked_sentence, strategy in masked_sentences_data: | |
index = f"{plot_idx}{masked_idx}" | |
if original_sentence not in self.masked_sentence_indices: | |
self.masked_sentence_indices[original_sentence] = {} | |
key = f"{strategy}_{masked_sentence}" | |
self.masked_sentence_indices[original_sentence][key] = { | |
'index': index, | |
'strategy': strategy, | |
'masked_sentence': masked_sentence | |
} | |
masked_idx += 1 | |
masked_sentences = [ms[0] for ms in masked_sentences_data] | |
indexed_masked_sentences = [] | |
verified_strategies = [] | |
for masked_sentence, strategy in masked_sentences_data: | |
key = f"{strategy}_{masked_sentence}" | |
entry = self.masked_sentence_indices[original_sentence][key] | |
idx = entry['index'] | |
indexed_masked_sentences.append(f"[{idx}] {masked_sentence}") | |
verified_strategies.append(entry['strategy']) | |
try: | |
fig = generate_subplot1( | |
original_sentence, | |
indexed_masked_sentences, | |
verified_strategies, | |
highlight_info, | |
common_grams | |
) | |
trees.append(fig) | |
except Exception as e: | |
print(f"Error generating plot: {e}") | |
trees.append(go.Figure()) | |
plot_idx += 1 | |
while len(trees) < 10: | |
trees.append(go.Figure()) | |
execution_time = f"<div class='execution-time'>Step 2 completed in {time.time() - start_time:.2f} seconds</div>" | |
return trees[:10] + [execution_time] | |
def handle_sampling(self) -> Tuple[List[go.Figure], str]: | |
start_time = time.time() | |
sampling_results = self.pipeline.Sampling() | |
trees = [] | |
self.sampled_sentence_indices = {} | |
organized_results = {} | |
for sampling_strategy, masking_dict in sampling_results.items(): | |
for masking_strategy, sentences in masking_dict.items(): | |
for original_sentence, data in sentences.items(): | |
if original_sentence not in organized_results: | |
organized_results[original_sentence] = {} | |
if masking_strategy not in organized_results[original_sentence]: | |
organized_results[original_sentence][masking_strategy] = { | |
"masked_sentence": data.get("masked_sentence", ""), | |
"sampled_sentences": {} | |
} | |
organized_results[original_sentence][masking_strategy]["sampled_sentences"][sampling_strategy] = data.get("sampled_sentence", "") | |
plot_idx = 1 | |
for original_sentence, data in organized_results.items(): | |
masked_sentences = [] | |
all_sampled_sentences = [] | |
indexed_sampled_sentences = [] | |
masked_indices = self.masked_sentence_indices.get(original_sentence, {}) | |
for masking_strategy, masking_data in list(data.items())[:3]: | |
masked_sentence = masking_data.get("masked_sentence", "") | |
if masked_sentence: | |
masked_sentences.append(masked_sentence) | |
masked_idx = None | |
for ms_key, ms_data in masked_indices.items(): | |
if ms_key == f"{masking_strategy}_{masked_sentence}": | |
masked_idx = ms_data['index'] | |
break | |
if not masked_idx: | |
print(f"Warning: No index found for masked sentence: {masked_sentence}") | |
continue | |
sample_count = 1 | |
for sampling_strategy, sampled_sentence in masking_data.get("sampled_sentences", {}).items(): | |
if sampled_sentence: | |
sample_idx = f"{masked_idx}.{sample_count}" | |
if masked_sentence not in self.sampled_sentence_indices: | |
self.sampled_sentence_indices[masked_sentence] = {} | |
self.sampled_sentence_indices[masked_sentence][sampled_sentence] = { | |
'index': sample_idx, | |
'strategy': sampling_strategy | |
} | |
indexed_sampled_sentences.append(f"[{sample_idx}] {sampled_sentence}") | |
all_sampled_sentences.append(sampled_sentence) | |
sample_count += 1 | |
if masked_sentences: | |
indexed_masked_sentences = [] | |
for ms in masked_sentences: | |
idx = "" | |
for ms_key, ms_data in masked_indices.items(): | |
if ms_key.endswith(f"_{ms}"): | |
idx = ms_data['index'] | |
break | |
indexed_masked_sentences.append(f"[{idx}] {ms}") | |
try: | |
fig = generate_subplot2( | |
indexed_masked_sentences, | |
indexed_sampled_sentences, | |
self.highlight_info, | |
self.common_grams | |
) | |
trees.append(fig) | |
except Exception as e: | |
print(f"Error generating subplot for {original_sentence}: {e}") | |
trees.append(go.Figure()) | |
plot_idx += 1 | |
print("Sampled sentence indices:", self.sampled_sentence_indices) | |
while len(trees) < 10: | |
trees.append(go.Figure()) | |
execution_time = f"<div class='execution-time'>Step 3 completed in {time.time() - start_time:.2f} seconds</div>" | |
return trees[:10] + [execution_time] | |
def handle_reparaphrasing(self) -> Tuple[List[str], str]: | |
start_time = time.time() | |
results = self.pipeline.re_paraphrasing() | |
html_outputs = [] | |
self.reparaphrased_indices = {} | |
tab_count = 1 | |
for sampling_strategy, masking_dict in results.items(): | |
for masking_strategy, sentences in masking_dict.items(): | |
for original_sent, data in sentences.items(): | |
sampled_sentence = data.get("sampled_sentence", "") | |
if not sampled_sentence or not data["re_paraphrased_sentences"]: | |
continue | |
sampled_index = None | |
for masked_sent, sampled_dict in self.sampled_sentence_indices.items(): | |
if sampled_sentence in sampled_dict: | |
sampled_index = sampled_dict[sampled_sentence]['index'] | |
break | |
if not sampled_index: | |
sampled_index = "unknown" | |
indexed_reparaphrased = [] | |
for i, rp_sent in enumerate(data["re_paraphrased_sentences"], 1): | |
rp_idx = f"{tab_count}.({sampled_index}).{i}" | |
if sampled_sentence not in self.reparaphrased_indices: | |
self.reparaphrased_indices[sampled_sentence] = {} | |
self.reparaphrased_indices[sampled_sentence][rp_sent] = rp_idx | |
indexed_reparaphrased.append(f"[{rp_idx}] {rp_sent}") | |
print(f"Reparaphrasing {tab_count}.({sampled_index}): {' '.join(sampled_sentence.split()[:5])}...") | |
html = reparaphrased_sentences_html(indexed_reparaphrased) | |
html_outputs.append(html) | |
tab_count += 1 | |
print("Reparaphrased indices:", self.reparaphrased_indices) | |
while len(html_outputs) < 150: | |
html_outputs.append("") | |
execution_time = f"<div class='execution-time'>Step 4 completed in {time.time() - start_time:.2f} seconds</div>" | |
return html_outputs[:150] + [execution_time] | |
def create_gradio_interface(config): | |
"""Creates the Gradio interface with the updated pipeline""" | |
interface = WatermarkerInterface(config) | |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
#CSS to enable scrolling for reparaphrased sentences and sampling plots | |
demo.css = """ | |
/* Set fixed height for the reparaphrased tabs container only */ | |
.gradio-container .tabs[id="reparaphrased-tabs"], | |
.gradio-container .tabs[id="sampling-tabs"] { | |
overflow-x: hidden; | |
white-space: normal; | |
border-radius: 8px; | |
max-height: 600px; /* Set fixed height for the entire tabs component */ | |
overflow-y: auto; /* Enable vertical scrolling inside the container */ | |
} | |
/* Tab content styling for reparaphrased and sampling tabs */ | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tabitem, | |
.gradio-container .tabs[id="sampling-tabs"] .tabitem { | |
overflow-x: hidden; | |
white-space: normal; | |
display: block; | |
border-radius: 8px; | |
} | |
/* Make the tab navigation fixed at the top for scrollable tabs */ | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav { | |
display: flex; | |
overflow-x: auto; | |
white-space: nowrap; | |
scrollbar-width: thin; | |
border-radius: 8px; | |
scrollbar-color: #888 #f1f1f1; | |
position: sticky; | |
top: 0; | |
background: white; | |
z-index: 100; | |
} | |
/* Dropdown menu for scrollable tabs styling */ | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown { | |
position: relative; | |
display: inline-block; | |
} | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content { | |
display: none; | |
position: absolute; | |
background-color: #f9f9f9; | |
min-width: 160px; | |
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); | |
z-index: 1; | |
max-height: 300px; | |
overflow-y: auto; | |
} | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content { | |
display: block; | |
} | |
/* Scrollbar styling for scrollable tabs */ | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar { | |
height: 8px; | |
border-radius: 8px; | |
} | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-track, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-track { | |
background: #f1f1f1; | |
border-radius: 8px; | |
} | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-thumb, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-thumb { | |
background: #888; | |
border-radius: 8px; | |
} | |
/* Tab button styling for scrollable tabs */ | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-item, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-item { | |
flex: 0 0 auto; | |
border-radius: 8px; | |
} | |
/* Plot container styling specifically for sampling tabs */ | |
.gradio-container .tabs[id="sampling-tabs"] .plot-container { | |
min-height: 600px; | |
max-height: 1800px; | |
overflow-y: auto; | |
} | |
/* Ensure text wraps in HTML components */ | |
.gradio-container .prose { | |
white-space: normal; | |
word-wrap: break-word; | |
overflow-wrap: break-word; | |
} | |
/* Dropdown button styling for scrollable tabs */ | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button { | |
background-color: #f0f0f0; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
padding: 5px 10px; | |
cursor: pointer; | |
margin: 2px; | |
} | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button:hover, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button:hover { | |
background-color: #e0e0e0; | |
} | |
/* Style dropdown content items for scrollable tabs */ | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div { | |
padding: 8px 12px; | |
cursor: pointer; | |
} | |
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div:hover, | |
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div:hover { | |
background-color: #e0e0e0; | |
} | |
/* Custom styling for execution time display */ | |
.execution-time { | |
text-align: right; | |
padding: 8px 16px; | |
font-family: inherit; | |
color: #555; | |
font-size: 0.9rem; | |
font-style: italic; | |
margin-left: auto; | |
width: 100%; | |
border-top: 1px solid #eee; | |
margin-top: 8px; | |
} | |
/* Layout for section headers with execution time */ | |
.section-header { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
width: 100%; | |
margin-bottom: 12px; | |
} | |
.section-header h3 { | |
margin: 0; | |
} | |
""" | |
gr.Markdown("# **AIISC Watermarking Model**") | |
with gr.Column(): | |
gr.Markdown("## Input Prompt") | |
user_input = gr.Textbox( | |
label="Enter Your Prompt", | |
placeholder="Type your text here..." | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis") | |
with gr.Column(scale=1): | |
step1_time = gr.HTML() | |
paraphrase_button = gr.Button("Generate Paraphrases") | |
highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt") | |
with gr.Tabs(): | |
with gr.TabItem("Accepted Paraphrased Sentences"): | |
highlighted_accepted_sentences = gr.HTML() | |
with gr.TabItem("Discarded Paraphrased Sentences"): | |
highlighted_discarded_sentences = gr.HTML() | |
with gr.Row(): | |
with gr.Column(scale=3): | |
gr.Markdown("## Step 2: Where to Mask?") | |
with gr.Column(scale=1): | |
step2_time = gr.HTML() | |
masking_button = gr.Button("Apply Masking") | |
gr.Markdown("### Masked Sentence Trees") | |
tree1_plots = [] | |
with gr.Tabs() as tree1_tabs: | |
for i in range(10): | |
with gr.TabItem(f"Masked Sentence {i+1}"): | |
tree1 = gr.Plot() | |
tree1_plots.append(tree1) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
gr.Markdown("## Step 3: How to Mask?") | |
with gr.Column(scale=1): | |
step3_time = gr.HTML() | |
sampling_button = gr.Button("Sample Words") | |
gr.Markdown("### Sampled Sentence Trees") | |
tree2_plots = [] | |
# Add elem_id to make this tab container scrollable | |
with gr.Tabs(elem_id="sampling-tabs") as tree2_tabs: | |
for i in range(10): | |
with gr.TabItem(f"Sampled Sentence {i+1}"): | |
# Add a custom class to the container to enable proper styling | |
with gr.Column(elem_classes=["plot-container"]): | |
tree2 = gr.Plot() | |
tree2_plots.append(tree2) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
gr.Markdown("## Step 4: Re-paraphrasing") | |
with gr.Column(scale=1): | |
step4_time = gr.HTML() | |
reparaphrase_button = gr.Button("Re-paraphrase") | |
gr.Markdown("### Reparaphrased Sentences") | |
reparaphrased_sentences_tabs = [] | |
with gr.Tabs(elem_id="reparaphrased-tabs") as reparaphrased_tabs: | |
for i in range(150): | |
with gr.TabItem(f"Reparaphrased Batch {i+1}"): | |
reparaphrased_sent_html = gr.HTML() | |
reparaphrased_sentences_tabs.append(reparaphrased_sent_html) | |
# Connect the interface functions to the buttons | |
paraphrase_button.click( | |
interface.handle_paraphrase, | |
inputs=user_input, | |
outputs=[ | |
highlighted_user_prompt, | |
highlighted_accepted_sentences, | |
highlighted_discarded_sentences, | |
step1_time | |
] | |
) | |
masking_button.click( | |
interface.handle_masking, | |
inputs=None, | |
outputs=tree1_plots + [step2_time] | |
) | |
sampling_button.click( | |
interface.handle_sampling, | |
inputs=None, | |
outputs=tree2_plots + [step3_time] | |
) | |
reparaphrase_button.click( | |
interface.handle_reparaphrasing, | |
inputs=None, | |
outputs=reparaphrased_sentences_tabs + [step4_time] | |
) | |
return demo | |
if __name__ == "__main__": | |
project_root = Path(__file__).parent.parent | |
config_path = project_root / "utils" / "config.yaml" | |
config = load_config(config_path)['PECCAVI_TEXT'] | |
create_gradio_interface(config).launch() |