Spaces:
Sleeping
Sleeping
import plotly.graph_objects as go | |
import textwrap | |
import re | |
from collections import defaultdict | |
def generate_subplot1(paraphrased_sentence, masked_sentences, strategies, highlight_info, common_grams): | |
""" | |
Generates a subplot visualizing paraphrased and masked sentences in a tree structure. | |
Highlights common words with specific colors and applies Longest Common Subsequence (LCS) numbering. | |
Args: | |
paraphrased_sentence (str): The paraphrased sentence to be visualized. | |
masked_sentences (list of str): A list of masked sentences to be visualized. | |
strategies (list of str, optional): List of strategies used for each masked sentence. | |
highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting. | |
common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering. | |
Returns: | |
plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges. | |
""" | |
# Combine nodes into one list with appropriate labels | |
if isinstance(masked_sentences, str): | |
masked_sentences = [masked_sentences] | |
nodes = [paraphrased_sentence] + masked_sentences | |
nodes[0] += ' L0' # Paraphrased sentence is level 0 | |
if len(nodes) < 2: | |
print("[ERROR] Insufficient nodes for visualization") | |
return go.Figure() | |
for i in range(1, len(nodes)): | |
nodes[i] += ' L1' # masked sentences are level 1 | |
def apply_lcs_numbering(sentence, common_grams): | |
""" | |
Applies LCS numbering to the sentence based on the common_grams. | |
Args: | |
sentence (str): The sentence to which the LCS numbering should be applied. | |
common_grams (list of tuples): A list of common grams to be replaced with LCS numbers. | |
Returns: | |
str: The sentence with LCS numbering applied. | |
""" | |
for idx, lcs in common_grams: | |
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) | |
return sentence | |
# Apply LCS numbering | |
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] | |
def highlight_words(sentence, color_map): | |
""" | |
Highlights words in the sentence based on the color_map. | |
Args: | |
sentence (str): The sentence where the words will be highlighted. | |
color_map (dict): A dictionary mapping words to their colors. | |
Returns: | |
str: The sentence with highlighted words. | |
""" | |
for word, color in color_map.items(): | |
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) | |
return sentence | |
# Clean and wrap nodes, and highlight specified words globally | |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] | |
global_color_map = dict(highlight_info) | |
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] | |
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes] | |
def get_levels_and_edges(nodes, strategies=None): | |
""" | |
Determines tree levels and creates edges dynamically. | |
Args: | |
nodes (list of str): The nodes representing the sentences. | |
strategies (list of str, optional): The strategies used for each edge. | |
Returns: | |
tuple: A tuple containing two dictionaries: | |
- levels: A dictionary mapping node indices to their levels. | |
- edges: A list of edges where each edge is represented by a tuple of node indices. | |
""" | |
levels = {} | |
edges = [] | |
for i, node in enumerate(nodes): | |
level = int(node.split()[-1][1]) | |
levels[i] = level | |
# Add edges from L0 to all L1 nodes | |
root_node = next((i for i, level in levels.items() if level == 0), 0) | |
for i, level in levels.items(): | |
if level == 1: | |
edges.append((root_node, i)) | |
return levels, edges | |
# Get levels and dynamic edges | |
levels, edges = get_levels_and_edges(nodes, strategies) | |
max_level = max(levels.values(), default=0) | |
# Calculate positions | |
positions = {} | |
level_heights = defaultdict(int) | |
for node, level in levels.items(): | |
level_heights[level] += 1 | |
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} | |
x_gap = 2 | |
l1_y_gap = 10 | |
for node, level in levels.items(): | |
if level == 1: | |
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) | |
else: | |
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) | |
y_offsets[level] += 1 | |
def color_highlighted_words(node, color_map): | |
""" | |
Colors the highlighted words in the node text. | |
Args: | |
node (str): The node text to be highlighted. | |
color_map (dict): A dictionary mapping words to their colors. | |
Returns: | |
str: The node text with highlighted words. | |
""" | |
parts = re.split(r'(\{\{.*?\}\})', node) | |
colored_parts = [] | |
for part in parts: | |
match = re.match(r'\{\{(.*?)\}\}', part) | |
if match: | |
word = match.group(1) | |
color = color_map.get(word, 'black') | |
colored_parts.append(f"<span style='color: {color};'>{word}</span>") | |
else: | |
colored_parts.append(part) | |
return ''.join(colored_parts) | |
# Define the text for each edge | |
default_edge_texts = [ | |
"Highest Entropy Masking", "Pseudo-random Masking", "Random Masking", | |
"Greedy Sampling", "Tournament Sampling", "Temperature Sampling", "Exponential Minimum Sampling", | |
"Inverse Transform Sampling", "Greedy Sampling", "Tournament Sampling", "Temperature Sampling", | |
"Exponential Minimum Sampling", "Inverse Transform Sampling", "Greedy Sampling", "Tournament Sampling", | |
"Temperature Sampling", "Exponential Minimum Sampling", "Inverse Transform Sampling" | |
] | |
if len(nodes) < 2: | |
print("[ERROR] Insufficient nodes for visualization") | |
return go.Figure() | |
# Create figure | |
fig1 = go.Figure() | |
# Add nodes to the figure | |
for i, node in enumerate(wrapped_nodes): | |
colored_node = color_highlighted_words(node, global_color_map) | |
x, y = positions[i] | |
fig1.add_trace(go.Scatter( | |
x=[-x], # Reflect the x coordinate | |
y=[y], | |
mode='markers', | |
marker=dict(size=20, color='blue', line=dict(color='black', width=2)), | |
hoverinfo='none' | |
)) | |
fig1.add_annotation( | |
x=-x, # Reflect the x coordinate | |
y=y, | |
text=colored_node, | |
showarrow=False, | |
xshift=15, | |
align="center", | |
font=dict(size=12), | |
bordercolor='black', | |
borderwidth=2, | |
borderpad=4, | |
bgcolor='white', | |
width=400, | |
height=100 | |
) | |
# Add edges and text above each edge | |
for i, edge in enumerate(edges): | |
x0, y0 = positions[edge[0]] | |
x1, y1 = positions[edge[1]] | |
# Use strategy if available, otherwise use default edge text | |
if strategies and i < len(strategies): | |
edge_text = strategies[i] | |
else: | |
edge_text = default_edge_texts[i % len(default_edge_texts)] | |
fig1.add_trace(go.Scatter( | |
x=[-x0, -x1], # Reflect the x coordinates | |
y=[y0, y1], | |
mode='lines', | |
line=dict(color='black', width=1) | |
)) | |
# Calculate the midpoint of the edge | |
mid_x = (-x0 + -x1) / 2 | |
mid_y = (y0 + y1) / 2 | |
# Adjust y position to shift text upwards | |
text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards | |
# Add text annotation above the edge | |
fig1.add_annotation( | |
x=mid_x, | |
y=text_y_position, | |
text=edge_text, # Use the text specific to this edge | |
showarrow=False, | |
font=dict(size=12), | |
align="center" | |
) | |
fig1.update_layout( | |
showlegend=False, | |
margin=dict(t=50, b=50, l=50, r=50), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
width=800 + max_level * 200, # Adjusted width to accommodate more levels | |
height=300 + len(nodes) * 100, # Adjusted height to accommodate more levels | |
plot_bgcolor='rgba(240,240,240,0.2)', | |
paper_bgcolor='white' | |
) | |
return fig1 | |
def generate_subplot2(masked_sentences, sampled_sentences, highlight_info, common_grams): | |
""" | |
Generates a subplot visualizing multiple masked sentences and their sampled variants in a tree structure. | |
Each masked sentence will have multiple sampled sentences derived from it using different sampling techniques. | |
Args: | |
masked_sentences (list of str): A list of masked sentences to be visualized as root nodes. | |
sampled_sentences (list of str): A list of sampled sentences derived from masked sentences. | |
highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting. | |
common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering. | |
Returns: | |
plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges. | |
""" | |
# Define sampling techniques | |
sampling_techniques = [ | |
"Inverse Transform Sampling", | |
"Exponential Minimum Sampling", | |
"Temperature Sampling", | |
"Greedy Sampling", | |
"Tournament Sampling", | |
] | |
# Calculate total number of nodes | |
num_masked = len(masked_sentences) | |
num_sampled_per_masked = len(sampling_techniques) | |
total_nodes = num_masked + (num_masked * num_sampled_per_masked) | |
# Combine all sentences into nodes list with appropriate labels | |
nodes = [] | |
# Level 0: masked sentences (root nodes) | |
nodes.extend([s + ' L0' for s in masked_sentences]) | |
# Level 1: sampled sentences (branch nodes) | |
# For each masked sentence, we should have samples from each technique | |
sampled_nodes = [] | |
# Validate if we have the expected number of sampled sentences | |
expected_sampled_count = num_masked * num_sampled_per_masked | |
if len(sampled_sentences) < expected_sampled_count: | |
# If insufficient samples provided, pad with placeholder sentences | |
print(f"Warning: Expected {expected_sampled_count} sampled sentences, but got {len(sampled_sentences)}") | |
while len(sampled_sentences) < expected_sampled_count: | |
sampled_sentences.append(f"Placeholder sampled sentence {len(sampled_sentences) + 1}") | |
# Add all sampled sentences with level information | |
for s in sampled_sentences[:expected_sampled_count]: | |
sampled_nodes.append(s + ' L1') | |
nodes.extend(sampled_nodes) | |
def apply_lcs_numbering(sentence, common_grams): | |
""" | |
Applies LCS numbering to the sentence based on the common_grams. | |
""" | |
for idx, lcs in common_grams: | |
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) | |
return sentence | |
# Apply LCS numbering | |
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] | |
def highlight_words(sentence, color_map): | |
""" | |
Highlights words in the sentence based on the color_map. | |
""" | |
for word, color in color_map.items(): | |
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) | |
return sentence | |
# Helper function to color highlighted words | |
def color_highlighted_words(node, color_map): | |
""" | |
Colors the highlighted words in the node text. | |
""" | |
parts = re.split(r'(\{\{.*?\}\})', node) | |
colored_parts = [] | |
for part in parts: | |
match = re.match(r'\{\{(.*?)\}\}', part) | |
if match: | |
word = match.group(1) | |
color = color_map.get(word, 'black') | |
colored_parts.append(f"<span style='color: {color};'>{word}</span>") | |
else: | |
colored_parts.append(part) | |
return ''.join(colored_parts) | |
# Clean nodes, highlight words, and wrap text | |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] | |
global_color_map = dict(highlight_info) | |
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] | |
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes] | |
# Generate edges based on the tree structure | |
def get_levels_and_edges(nodes): | |
levels = {} | |
edges = [] | |
# Extract level info from node labels | |
for i, node in enumerate(nodes): | |
level = int(node.split()[-1][1]) | |
levels[i] = level | |
# Create edges from masked sentences to their sampled variants | |
for masked_idx in range(num_masked): | |
# For each masked sentence, create edges to its sampled variants | |
for technique_idx in range(num_sampled_per_masked): | |
sampled_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx | |
if sampled_idx < len(nodes): | |
edges.append((masked_idx, sampled_idx)) | |
return levels, edges | |
levels, edges = get_levels_and_edges(nodes) | |
# Calculate positions with improved spacing | |
positions = {} | |
# Calculate horizontal spacing for the root nodes (masked sentences) | |
root_x_spacing = 0 # All root nodes at x=0 | |
root_y_spacing = 8.0 # Vertical spacing between root nodes | |
# Calculate positions for sampled nodes | |
sampled_x = 3 # X position for all sampled nodes | |
# Calculate y positions for root nodes (masked sentences) | |
root_y_start = -(num_masked - 1) * root_y_spacing / 2 | |
for i in range(num_masked): | |
positions[i] = (root_x_spacing, root_y_start + i * root_y_spacing) | |
# Calculate y positions for sampled nodes | |
for masked_idx in range(num_masked): | |
root_y = positions[masked_idx][1] # Y position of parent masked sentence | |
# Calculate y-spacing for children of this root | |
children_y_spacing = 1.5 # Vertical spacing between children of the same root | |
children_y_start = root_y - (num_sampled_per_masked - 1) * children_y_spacing / 2 | |
# Position each child | |
for technique_idx in range(num_sampled_per_masked): | |
child_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx | |
child_y = children_y_start + technique_idx * children_y_spacing | |
positions[child_idx] = (sampled_x, child_y) | |
# Create figure | |
fig2 = go.Figure() | |
# Add nodes | |
for i, node in enumerate(wrapped_nodes): | |
x, y = positions[i] | |
# Define node color based on level | |
node_color = 'blue' if levels[i] == 0 else 'green' | |
# Add the node marker | |
fig2.add_trace(go.Scatter( | |
x=[x], | |
y=[y], | |
mode='markers', | |
marker=dict(size=20, color=node_color, line=dict(color='black', width=2)), | |
hoverinfo='none' | |
)) | |
# Add node label with highlighting | |
colored_node = color_highlighted_words(node, global_color_map) | |
fig2.add_annotation( | |
x=x, | |
y=y, | |
text=colored_node, | |
showarrow=False, | |
xshift=15, | |
align="left", | |
font=dict(size=12), | |
bordercolor='black', | |
borderwidth=2, | |
borderpad=4, | |
bgcolor='white', | |
width=450, | |
height=100 | |
) | |
# Add edges with labels | |
for i, (src, dst) in enumerate(edges): | |
x0, y0 = positions[src] | |
x1, y1 = positions[dst] | |
# Draw the edge | |
fig2.add_trace(go.Scatter( | |
x=[x0, x1], | |
y=[y0, y1], | |
mode='lines', | |
line=dict(color='black', width=1) | |
)) | |
# Add sampling technique label | |
# Determine which sampling technique this is | |
parent_idx = src | |
technique_count = sum(1 for k, (s, _) in enumerate(edges) if s == parent_idx and k < i) | |
technique_label = sampling_techniques[technique_count % len(sampling_techniques)] | |
# Calculate midpoint for the label | |
mid_x = (x0 + x1) / 2 | |
mid_y = (y0 + y1) / 2 | |
# Add slight offset to avoid overlap | |
label_offset = 0.1 | |
fig2.add_annotation( | |
x=mid_x, | |
y=mid_y + label_offset, | |
text=technique_label, | |
showarrow=False, | |
font=dict(size=8), | |
align="center" | |
) | |
# Update layout | |
fig2.update_layout( | |
showlegend=False, | |
margin=dict(t=20, b=20, l=20, r=20), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
width=1200, # Adjusted width to accommodate more levels | |
height=2000, # Adjusted height to accommodate more levels | |
plot_bgcolor='rgba(240,240,240,0.2)', | |
paper_bgcolor='white' | |
) | |
return fig2 | |
if __name__ == "__main__": | |
paraphrased_sentence = "The quick brown fox jumps over the lazy dog." | |
masked_sentences = [ | |
"A fast brown fox leaps over the lazy dog.", | |
"A quick brown fox hops over a lazy dog." | |
] | |
highlight_info = [ | |
("quick", "red"), | |
("brown", "green"), | |
("fox", "blue"), | |
("lazy", "purple") | |
] | |
common_grams = [ | |
(1, "quick brown fox"), | |
(2, "lazy dog") | |
] | |
fig1 = generate_subplot1(paraphrased_sentence, masked_sentences, highlight_info, common_grams) | |
fig1.show() | |
sampled_sentence = ["A fast brown fox jumps over a lazy dog."] | |
fig2 = generate_subplot2(masked_sentences, sampled_sentence, highlight_info, common_grams) | |
fig2.show() |