Spaces:
Build error
Build error
Upload 3 files
Browse files- README.md +1 -1
- app.py +90 -29
- tfidfreducedfiles.joblib +3 -0
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 🐿️
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: gray
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Prompt Squirrel
|
| 3 |
emoji: 🐿️
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: gray
|
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import gradio as gr
|
|
| 2 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 3 |
from scipy.sparse import csr_matrix
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
from joblib import load
|
| 6 |
import h5py
|
| 7 |
from io import BytesIO
|
|
@@ -19,6 +20,7 @@ import io
|
|
| 19 |
import os
|
| 20 |
import glob
|
| 21 |
import itertools
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
|
|
@@ -32,7 +34,7 @@ Since Stable Diffusion's initial release in 2022, users have developed a myriad
|
|
| 32 |
Some models react best when prompted with verbose scene descriptions akin to DALL-E, while others fine-tuned on images scraped from popular image boards understand those boards' tag sets.
|
| 33 |
This tool serves as a linguistic bridge to the e621 image board tag lexicon, on which many popular models such as Fluffyrock, Fluffusion, and Pony Diffusion v6 were trained.
|
| 34 |
|
| 35 |
-
When you enter a txt2img prompt and press the "submit" button,
|
| 36 |
If it finds any that are not, it recommends some valid e621 tags you can use to replace them in the "Unknown Tags" section.
|
| 37 |
Additionally, in the "Top Artists" text box, it lists the artists who would most likely draw an image having the set of tags you provided.
|
| 38 |
This is useful to align your prompt with the expected input to an e621-trained model.
|
|
@@ -114,18 +116,12 @@ See SamplePrompts.csv for the list of prompts used and their descriptions.
|
|
| 114 |
|
| 115 |
nsfw_threshold = 0.95 # Assuming the threshold value is defined here
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
#commas: double_comma | comma
|
| 124 |
-
#double_comma: comma WHITESPACE* comma
|
| 125 |
-
#WHITESPACE: /\s+/
|
| 126 |
-
#plain: /([^,\\\[\]():|]|\\.)+/
|
| 127 |
-
#%import common.SIGNED_NUMBER -> NUMBER
|
| 128 |
-
#"""
|
| 129 |
|
| 130 |
grammar=r"""
|
| 131 |
!start: (prompt | /[][():]/+)*
|
|
@@ -353,11 +349,11 @@ def geometric_mean_given_words(target_word, context_words, co_occurrence_matrix,
|
|
| 353 |
return geometric_mean
|
| 354 |
|
| 355 |
|
| 356 |
-
def create_html_tables_for_tags(
|
| 357 |
# Wrap the tag part in a <span> with styles for bold and larger font
|
| 358 |
-
html_str = f"<div style='display: inline-block; margin: 20px; vertical-align: top;'><table><thead><tr><th colspan='3' style='text-align: center; padding-bottom: 10px;'><span style='font-weight: bold; font-size: 20px;'>{
|
| 359 |
# Loop through the results and add table rows for each
|
| 360 |
-
for word, sim in
|
| 361 |
word_with_underscores = word.replace(' ', '_')
|
| 362 |
count = tag2count.get(word_with_underscores, 0) # Get the count if available, otherwise default to 0
|
| 363 |
tag_id, wiki_entry = tag2idwiki.get(word_with_underscores, (None, ''))
|
|
@@ -379,7 +375,7 @@ def create_html_tables_for_tags(tag, result, tag2count, tag2idwiki):
|
|
| 379 |
|
| 380 |
def create_top_artists_table(top_artists):
|
| 381 |
# Add a heading above the table
|
| 382 |
-
html_str = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 383 |
html_str += "<h1>Top Artists</h1>" # Heading for the table
|
| 384 |
# Start the table with increased font size and no borders between rows
|
| 385 |
html_str += "<table style='font-size: 20px; border-collapse: collapse;'>"
|
|
@@ -396,16 +392,70 @@ def create_top_artists_table(top_artists):
|
|
| 396 |
return html_str
|
| 397 |
|
| 398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
|
| 400 |
# Include a title in the same style as the top artists table heading
|
| 401 |
-
html_placeholder = f"<div style='text-align: center;'><h1>{title}</h1></div>"
|
| 402 |
# Conditionally add content if present
|
| 403 |
if content:
|
| 404 |
html_placeholder += f"<div style='text-align: center; margin-bottom: 20px;'><p>{content}</p></div>"
|
| 405 |
# Add the placeholder div with specified height and width
|
| 406 |
html_placeholder += f"<div style='height: {placeholder_height}px; width: {placeholder_width}; margin: 20px auto; background: transparent;'></div>"
|
| 407 |
return html_placeholder
|
| 408 |
-
|
| 409 |
|
| 410 |
def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
| 411 |
#Initialize stuff
|
|
@@ -425,7 +475,7 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
|
| 425 |
transformed_tags = [tag.replace(' ', '_') for tag in modified_tags]
|
| 426 |
|
| 427 |
# Find similar tags and prepare data for tables
|
| 428 |
-
html_content = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 429 |
html_content += "<h1>Unknown Tags</h1>" # Heading for the table
|
| 430 |
tags_added = False
|
| 431 |
bad_entities = []
|
|
@@ -561,14 +611,21 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
|
|
| 561 |
|
| 562 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
| 563 |
unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
|
| 564 |
-
|
|
|
|
| 565 |
bad_entities.extend(augment_bad_entities_with_regex(new_tags_string))
|
| 566 |
bad_entities.sort(key=lambda x: x['start'])
|
| 567 |
bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
|
| 568 |
|
| 569 |
-
#
|
| 570 |
-
|
| 571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data if tag_info['node_type'] == "tag"]
|
| 573 |
X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
|
| 574 |
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
|
@@ -586,12 +643,12 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
|
|
| 586 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
| 587 |
image_galleries.append(artists) # Extend the list with artist tuples
|
| 588 |
|
| 589 |
-
return (unseen_tags_data, bad_tags_illustrated_string, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
|
| 590 |
except ParseError as e:
|
| 591 |
-
return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
|
| 592 |
|
| 593 |
|
| 594 |
-
with gr.Blocks() as app:
|
| 595 |
with gr.Group():
|
| 596 |
with gr.Row():
|
| 597 |
with gr.Column(scale=3):
|
|
@@ -609,7 +666,11 @@ with gr.Blocks() as app:
|
|
| 609 |
with gr.Row():
|
| 610 |
similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
|
| 611 |
allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
|
| 612 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
with gr.Column(scale=1):
|
| 614 |
with gr.Group():
|
| 615 |
num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
|
@@ -626,7 +687,7 @@ with gr.Blocks() as app:
|
|
| 626 |
submit_button.click(
|
| 627 |
find_similar_artists,
|
| 628 |
inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
|
| 629 |
-
outputs=[unseen_tags, bad_tags_illustrated_string, top_artists, dynamic_prompts] + galleries
|
| 630 |
)
|
| 631 |
|
| 632 |
gr.Markdown(faq_content)
|
|
|
|
| 2 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 3 |
from scipy.sparse import csr_matrix
|
| 4 |
import numpy as np
|
| 5 |
+
import joblib
|
| 6 |
from joblib import load
|
| 7 |
import h5py
|
| 8 |
from io import BytesIO
|
|
|
|
| 20 |
import os
|
| 21 |
import glob
|
| 22 |
import itertools
|
| 23 |
+
from itertools import islice
|
| 24 |
|
| 25 |
|
| 26 |
|
|
|
|
| 34 |
Some models react best when prompted with verbose scene descriptions akin to DALL-E, while others fine-tuned on images scraped from popular image boards understand those boards' tag sets.
|
| 35 |
This tool serves as a linguistic bridge to the e621 image board tag lexicon, on which many popular models such as Fluffyrock, Fluffusion, and Pony Diffusion v6 were trained.
|
| 36 |
|
| 37 |
+
When you enter a txt2img prompt and press the "submit" button, Prompt Squirrel parses your prompt and checks that all your tags are valid e621 tags.
|
| 38 |
If it finds any that are not, it recommends some valid e621 tags you can use to replace them in the "Unknown Tags" section.
|
| 39 |
Additionally, in the "Top Artists" text box, it lists the artists who would most likely draw an image having the set of tags you provided.
|
| 40 |
This is useful to align your prompt with the expected input to an e621-trained model.
|
|
|
|
| 116 |
|
| 117 |
nsfw_threshold = 0.95 # Assuming the threshold value is defined here
|
| 118 |
|
| 119 |
+
css = """
|
| 120 |
+
.scrollable-content {
|
| 121 |
+
max-height: 500px;
|
| 122 |
+
overflow-y: auto;
|
| 123 |
+
}
|
| 124 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
grammar=r"""
|
| 127 |
!start: (prompt | /[][():]/+)*
|
|
|
|
| 349 |
return geometric_mean
|
| 350 |
|
| 351 |
|
| 352 |
+
def create_html_tables_for_tags(subtable_heading, word_similarity_tuples, tag2count, tag2idwiki):
|
| 353 |
# Wrap the tag part in a <span> with styles for bold and larger font
|
| 354 |
+
html_str = f"<div style='display: inline-block; margin: 20px; vertical-align: top;'><table><thead><tr><th colspan='3' style='text-align: center; padding-bottom: 10px;'><span style='font-weight: bold; font-size: 20px;'>{subtable_heading}</span></th></tr></thead><tbody><tr style='border-bottom: 1px solid #000;'><th>Corrected Tag</th><th>Similarity</th><th>Count</th></tr>"
|
| 355 |
# Loop through the results and add table rows for each
|
| 356 |
+
for word, sim in word_similarity_tuples:
|
| 357 |
word_with_underscores = word.replace(' ', '_')
|
| 358 |
count = tag2count.get(word_with_underscores, 0) # Get the count if available, otherwise default to 0
|
| 359 |
tag_id, wiki_entry = tag2idwiki.get(word_with_underscores, (None, ''))
|
|
|
|
| 375 |
|
| 376 |
def create_top_artists_table(top_artists):
|
| 377 |
# Add a heading above the table
|
| 378 |
+
html_str = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 379 |
html_str += "<h1>Top Artists</h1>" # Heading for the table
|
| 380 |
# Start the table with increased font size and no borders between rows
|
| 381 |
html_str += "<table style='font-size: 20px; border-collapse: collapse;'>"
|
|
|
|
| 392 |
return html_str
|
| 393 |
|
| 394 |
|
| 395 |
+
def construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded):
|
| 396 |
+
# Initialize a vector of zeros with the length of the term_to_index mapping
|
| 397 |
+
pseudo_vector = np.zeros(len(tag_to_row_loaded))
|
| 398 |
+
|
| 399 |
+
# Fill in the vector for terms in the pseudo document
|
| 400 |
+
for term in pseudo_doc_terms:
|
| 401 |
+
if term in tag_to_row_loaded:
|
| 402 |
+
index = tag_to_row_loaded[term]
|
| 403 |
+
pseudo_vector[index] = idf_loaded.get(term, 0)
|
| 404 |
+
|
| 405 |
+
# Return the vector as a 2D array for compatibility with SVD transform
|
| 406 |
+
return pseudo_vector.reshape(1, -1)
|
| 407 |
+
|
| 408 |
+
def get_top_indices(reduced_pseudo_vector, reduced_matrix):
|
| 409 |
+
# Compute cosine similarities
|
| 410 |
+
similarities = cosine_similarity(reduced_pseudo_vector, reduced_matrix).flatten()
|
| 411 |
+
|
| 412 |
+
# Get sorted tag indices based on similarities, in descending order
|
| 413 |
+
sorted_indices = np.argsort(-similarities)
|
| 414 |
+
|
| 415 |
+
# Return the top N indices
|
| 416 |
+
return sorted_indices
|
| 417 |
+
|
| 418 |
+
def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags):
|
| 419 |
+
# Check and load components if not already loaded
|
| 420 |
+
if not hasattr(get_tfidf_reduced_similar_tags, "components"):
|
| 421 |
+
get_tfidf_reduced_similar_tags.components = joblib.load('tfidfreducedfiles.joblib')
|
| 422 |
+
|
| 423 |
+
# Access components
|
| 424 |
+
components = get_tfidf_reduced_similar_tags.components
|
| 425 |
+
idf_loaded = components['idf']
|
| 426 |
+
tag_to_row_loaded = components['tag_to_row']
|
| 427 |
+
reduced_matrix_loaded = components['reduced_matrix']
|
| 428 |
+
svd_loaded = components['svd_model']
|
| 429 |
+
|
| 430 |
+
# Remaining part of the function
|
| 431 |
+
pseudo_vector = construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded)
|
| 432 |
+
reduced_pseudo_vector = svd_loaded.transform(pseudo_vector)
|
| 433 |
+
# Compute cosine similarities
|
| 434 |
+
similarities = cosine_similarity(reduced_pseudo_vector, reduced_matrix_loaded).flatten()
|
| 435 |
+
|
| 436 |
+
# Get top N indices based on similarities
|
| 437 |
+
top_indices_reduced = get_top_indices(reduced_pseudo_vector, reduced_matrix_loaded)
|
| 438 |
+
|
| 439 |
+
# Create the initial tag_similarity_dict
|
| 440 |
+
tag_similarity_dict = {list(tag_to_row_loaded.keys())[i]: similarities[i] for i in top_indices_reduced}
|
| 441 |
+
if not allow_nsfw_tags:
|
| 442 |
+
tag_similarity_dict = {tag: similarity for tag, similarity in tag_similarity_dict.items() if tag.replace(' ', '_') not in nsfw_tags}
|
| 443 |
+
|
| 444 |
+
sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
|
| 445 |
+
|
| 446 |
+
return sorted_tag_similarity_dict
|
| 447 |
+
|
| 448 |
+
|
| 449 |
def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
|
| 450 |
# Include a title in the same style as the top artists table heading
|
| 451 |
+
html_placeholder = f"<div class=\"scrollable-content\" style='text-align: center;'><h1>{title}</h1></div>"
|
| 452 |
# Conditionally add content if present
|
| 453 |
if content:
|
| 454 |
html_placeholder += f"<div style='text-align: center; margin-bottom: 20px;'><p>{content}</p></div>"
|
| 455 |
# Add the placeholder div with specified height and width
|
| 456 |
html_placeholder += f"<div style='height: {placeholder_height}px; width: {placeholder_width}; margin: 20px auto; background: transparent;'></div>"
|
| 457 |
return html_placeholder
|
| 458 |
+
|
| 459 |
|
| 460 |
def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
| 461 |
#Initialize stuff
|
|
|
|
| 475 |
transformed_tags = [tag.replace(' ', '_') for tag in modified_tags]
|
| 476 |
|
| 477 |
# Find similar tags and prepare data for tables
|
| 478 |
+
html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 479 |
html_content += "<h1>Unknown Tags</h1>" # Heading for the table
|
| 480 |
tags_added = False
|
| 481 |
bad_entities = []
|
|
|
|
| 611 |
|
| 612 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
| 613 |
unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
|
| 614 |
+
|
| 615 |
+
#Bad tags stuff
|
| 616 |
bad_entities.extend(augment_bad_entities_with_regex(new_tags_string))
|
| 617 |
bad_entities.sort(key=lambda x: x['start'])
|
| 618 |
bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
|
| 619 |
|
| 620 |
+
#Suggested tags stuff
|
| 621 |
+
suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 622 |
+
|
| 623 |
+
suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
|
| 624 |
+
suggested_tags = get_tfidf_reduced_similar_tags([item["artist_matrix_tag"] for item in tag_data], allow_nsfw_tags)
|
| 625 |
+
topnsuggestions = list(islice(suggested_tags.items(), 100))
|
| 626 |
+
suggested_tags_html_content += create_html_tables_for_tags("Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
| 627 |
+
|
| 628 |
+
#Artist stuff
|
| 629 |
artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data if tag_info['node_type'] == "tag"]
|
| 630 |
X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
|
| 631 |
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
|
|
|
| 643 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
| 644 |
image_galleries.append(artists) # Extend the list with artist tuples
|
| 645 |
|
| 646 |
+
return (unseen_tags_data, bad_tags_illustrated_string, suggested_tags_html_content, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
|
| 647 |
except ParseError as e:
|
| 648 |
+
return [], "Parse Error: Check for mismatched parentheses or something", "", "", None, None
|
| 649 |
|
| 650 |
|
| 651 |
+
with gr.Blocks(css=css) as app:
|
| 652 |
with gr.Group():
|
| 653 |
with gr.Row():
|
| 654 |
with gr.Column(scale=3):
|
|
|
|
| 666 |
with gr.Row():
|
| 667 |
similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
|
| 668 |
allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
|
| 669 |
+
with gr.Row():
|
| 670 |
+
with gr.Column(scale=2):
|
| 671 |
+
unseen_tags = gr.HTML(label="Unknown Tags", value=create_html_placeholder(title="Unknown Tags"))
|
| 672 |
+
with gr.Column(scale=1):
|
| 673 |
+
suggested_tags = gr.HTML(label="Suggested Tags", value=create_html_placeholder(title="Suggested Tags"))
|
| 674 |
with gr.Column(scale=1):
|
| 675 |
with gr.Group():
|
| 676 |
num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
|
|
|
| 687 |
submit_button.click(
|
| 688 |
find_similar_artists,
|
| 689 |
inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
|
| 690 |
+
outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
|
| 691 |
)
|
| 692 |
|
| 693 |
gr.Markdown(faq_content)
|
tfidfreducedfiles.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a325f75a94c8a6c47034fba0e96a89039a3550463f916690b74c16d139f32504
|
| 3 |
+
size 68245886
|