Spaces:
Running
Running
Upload app.py
Browse files
app.py
CHANGED
|
@@ -130,16 +130,17 @@ parser = Lark(grammar, start='start')
|
|
| 130 |
|
| 131 |
# Function to extract tags
|
| 132 |
def extract_tags(tree):
|
| 133 |
-
|
| 134 |
def _traverse(node):
|
| 135 |
if isinstance(node, Token) and node.type == '__ANON_1':
|
| 136 |
-
|
|
|
|
|
|
|
| 137 |
elif not isinstance(node, Token):
|
| 138 |
for child in node.children:
|
| 139 |
_traverse(child)
|
| 140 |
-
|
| 141 |
_traverse(tree)
|
| 142 |
-
return
|
| 143 |
|
| 144 |
|
| 145 |
special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
|
|
@@ -341,7 +342,7 @@ def geometric_mean_given_words(target_word, context_words, co_occurrence_matrix,
|
|
| 341 |
|
| 342 |
def create_html_tables_for_tags(tag, result, tag2count, tag2idwiki):
|
| 343 |
# Wrap the tag part in a <span> with styles for bold and larger font
|
| 344 |
-
html_str = f"<div style='display: inline-block; margin:
|
| 345 |
# Loop through the results and add table rows for each
|
| 346 |
for word, sim in result:
|
| 347 |
word_with_underscores = word.replace(' ', '_')
|
|
@@ -404,24 +405,35 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
|
| 404 |
if not hasattr(find_similar_tags, "tag2idwiki"):
|
| 405 |
find_similar_tags.tag2idwiki = build_tag_id_wiki_dict()
|
| 406 |
|
| 407 |
-
|
|
|
|
| 408 |
|
| 409 |
# Find similar tags and prepare data for tables
|
| 410 |
html_content = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 411 |
html_content += "<h1>Unknown Tags</h1>" # Heading for the table
|
| 412 |
tags_added = False
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
continue
|
| 416 |
|
| 417 |
-
modified_tag_for_search =
|
| 418 |
similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
|
| 419 |
result, seen = [], set(transformed_tags)
|
| 420 |
|
| 421 |
if modified_tag_for_search in find_similar_tags.tag2aliases:
|
| 422 |
-
if
|
| 423 |
result.append(modified_tag_for_search.replace('_',' '), 1)
|
| 424 |
-
seen.add(
|
| 425 |
else: #The user correctly did not put underscores in their tag
|
| 426 |
continue
|
| 427 |
else:
|
|
@@ -444,36 +456,60 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
|
| 444 |
#Adjust score based on context
|
| 445 |
for i in range(len(result)):
|
| 446 |
word, score = result[i] # Unpack the tuple
|
| 447 |
-
geometric_mean = geometric_mean_given_words(word.replace(' ','_'), [context_tag for context_tag in transformed_tags if context_tag != word and context_tag !=
|
| 448 |
adjusted_score = (similarity_weight * geometric_mean) + ((1-similarity_weight)*score) # Apply the adjustment function
|
| 449 |
result[i] = (word, adjusted_score) # Update the tuple with the adjusted score
|
| 450 |
#print(word, score, geometric_mean, adjusted_score)
|
| 451 |
|
| 452 |
result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
|
| 453 |
-
html_content += create_html_tables_for_tags(
|
|
|
|
|
|
|
|
|
|
| 454 |
tags_added=True
|
| 455 |
# If no tags were processed, add a message
|
| 456 |
if not tags_added:
|
| 457 |
html_content = create_html_placeholder(title="Unknown Tags")
|
| 458 |
|
| 459 |
-
return html_content # Return list of lists for Dataframe
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
-
def find_similar_artists(
|
| 463 |
try:
|
| 464 |
-
new_tags_string =
|
| 465 |
new_tags_string, removed_tags = remove_special_tags(new_tags_string)
|
| 466 |
|
| 467 |
# Parse the prompt
|
| 468 |
parsed = parser.parse(new_tags_string)
|
| 469 |
# Extract tags from the parsed tree
|
| 470 |
new_image_tags = extract_tags(parsed)
|
| 471 |
-
|
| 472 |
|
| 473 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
| 474 |
-
unseen_tags_data = find_similar_tags(
|
|
|
|
|
|
|
| 475 |
|
| 476 |
-
|
|
|
|
| 477 |
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
| 478 |
|
| 479 |
top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
|
|
@@ -490,7 +526,7 @@ def find_similar_artists(new_tags_string, top_n, similarity_weight, allow_nsfw_t
|
|
| 490 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
| 491 |
image_galleries.append(artists) # Extend the list with artist tuples
|
| 492 |
|
| 493 |
-
return (unseen_tags_data, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries) #image_galleries[0], image_galleries[1] DOES work. Find a generic alternative.
|
| 494 |
except ParseError as e:
|
| 495 |
return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
|
| 496 |
|
|
@@ -504,6 +540,8 @@ with gr.Blocks() as app:
|
|
| 504 |
similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
|
| 505 |
num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
| 506 |
allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
|
|
|
|
|
|
|
| 507 |
with gr.Row():
|
| 508 |
with gr.Column(scale=1):
|
| 509 |
top_artists = gr.HTML(label="Top Artists", value=create_html_placeholder(title="Top Artists"))
|
|
@@ -521,7 +559,7 @@ with gr.Blocks() as app:
|
|
| 521 |
submit_button.click(
|
| 522 |
find_similar_artists,
|
| 523 |
inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
|
| 524 |
-
outputs=[unseen_tags, top_artists, dynamic_prompts] + galleries
|
| 525 |
)
|
| 526 |
|
| 527 |
gr.Markdown(faq_content)
|
|
|
|
| 130 |
|
| 131 |
# Function to extract tags
|
| 132 |
def extract_tags(tree):
|
| 133 |
+
tags_with_positions = []
|
| 134 |
def _traverse(node):
|
| 135 |
if isinstance(node, Token) and node.type == '__ANON_1':
|
| 136 |
+
tag_position = node.start_pos
|
| 137 |
+
tag_text = node.value.strip()
|
| 138 |
+
tags_with_positions.append((tag_text, tag_position))
|
| 139 |
elif not isinstance(node, Token):
|
| 140 |
for child in node.children:
|
| 141 |
_traverse(child)
|
|
|
|
| 142 |
_traverse(tree)
|
| 143 |
+
return tags_with_positions
|
| 144 |
|
| 145 |
|
| 146 |
special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
|
|
|
|
| 342 |
|
| 343 |
def create_html_tables_for_tags(tag, result, tag2count, tag2idwiki):
|
| 344 |
# Wrap the tag part in a <span> with styles for bold and larger font
|
| 345 |
+
html_str = f"<div style='display: inline-block; margin: 20px; vertical-align: top;'><table><thead><tr><th colspan='3' style='text-align: center; padding-bottom: 10px;'><span style='font-weight: bold; font-size: 20px;'>{tag}</span></th></tr></thead><tbody><tr style='border-bottom: 1px solid #000;'><th>Corrected Tag</th><th>Similarity</th><th>Count</th></tr>"
|
| 346 |
# Loop through the results and add table rows for each
|
| 347 |
for word, sim in result:
|
| 348 |
word_with_underscores = word.replace(' ', '_')
|
|
|
|
| 405 |
if not hasattr(find_similar_tags, "tag2idwiki"):
|
| 406 |
find_similar_tags.tag2idwiki = build_tag_id_wiki_dict()
|
| 407 |
|
| 408 |
+
modified_tags = [tag_info['modified_tag'] for tag_info in test_tags]
|
| 409 |
+
transformed_tags = [tag.replace(' ', '_') for tag in modified_tags]
|
| 410 |
|
| 411 |
# Find similar tags and prepare data for tables
|
| 412 |
html_content = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 413 |
html_content += "<h1>Unknown Tags</h1>" # Heading for the table
|
| 414 |
tags_added = False
|
| 415 |
+
bad_entities = []
|
| 416 |
+
for tag_info in test_tags:
|
| 417 |
+
original_tag = tag_info['original_tag']
|
| 418 |
+
modified_tag = tag_info['modified_tag']
|
| 419 |
+
start_pos = tag_info['start_pos']
|
| 420 |
+
end_pos = tag_info['end_pos']
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
print(original_tag, modified_tag, start_pos, end_pos)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
if modified_tag in special_tags:
|
| 427 |
continue
|
| 428 |
|
| 429 |
+
modified_tag_for_search = modified_tag.replace(' ','_')
|
| 430 |
similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
|
| 431 |
result, seen = [], set(transformed_tags)
|
| 432 |
|
| 433 |
if modified_tag_for_search in find_similar_tags.tag2aliases:
|
| 434 |
+
if modified_tag in find_similar_tags.tag2aliases and "_" in modified_tag: #Implicitly tell the user that they should get rid of the underscore
|
| 435 |
result.append(modified_tag_for_search.replace('_',' '), 1)
|
| 436 |
+
seen.add(modified_tag)
|
| 437 |
else: #The user correctly did not put underscores in their tag
|
| 438 |
continue
|
| 439 |
else:
|
|
|
|
| 456 |
#Adjust score based on context
|
| 457 |
for i in range(len(result)):
|
| 458 |
word, score = result[i] # Unpack the tuple
|
| 459 |
+
geometric_mean = geometric_mean_given_words(word.replace(' ','_'), [context_tag for context_tag in transformed_tags if context_tag != word and context_tag != modified_tag], conditional_co_occurrence_matrix, conditional_vocabulary, conditional_doc_count, smoothing_value=conditional_smoothing)
|
| 460 |
adjusted_score = (similarity_weight * geometric_mean) + ((1-similarity_weight)*score) # Apply the adjustment function
|
| 461 |
result[i] = (word, adjusted_score) # Update the tuple with the adjusted score
|
| 462 |
#print(word, score, geometric_mean, adjusted_score)
|
| 463 |
|
| 464 |
result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
|
| 465 |
+
html_content += create_html_tables_for_tags(modified_tag, result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
| 466 |
+
|
| 467 |
+
bad_entities.append({"entity":"UNKNOWN", "start":start_pos, "end":end_pos})
|
| 468 |
+
|
| 469 |
tags_added=True
|
| 470 |
# If no tags were processed, add a message
|
| 471 |
if not tags_added:
|
| 472 |
html_content = create_html_placeholder(title="Unknown Tags")
|
| 473 |
|
| 474 |
+
return html_content, bad_entities # Return list of lists for Dataframe
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def build_tag_offsets_dicts(new_image_tags_with_positions):
|
| 478 |
+
# Structure the data for HighlightedText
|
| 479 |
+
tag_data = []
|
| 480 |
+
for tag_text, start_pos in new_image_tags_with_positions:
|
| 481 |
+
# Modify the tag
|
| 482 |
+
modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
|
| 483 |
+
# Calculate the end position based on the original tag length
|
| 484 |
+
end_pos = start_pos + len(tag_text)
|
| 485 |
+
# Append the structured data for each tag
|
| 486 |
+
tag_data.append({
|
| 487 |
+
"original_tag": tag_text,
|
| 488 |
+
"start_pos": start_pos,
|
| 489 |
+
"end_pos": end_pos,
|
| 490 |
+
"modified_tag": modified_tag
|
| 491 |
+
})
|
| 492 |
+
return tag_data
|
| 493 |
+
|
| 494 |
|
| 495 |
+
def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_nsfw_tags):
|
| 496 |
try:
|
| 497 |
+
new_tags_string = original_tags_string.lower()
|
| 498 |
new_tags_string, removed_tags = remove_special_tags(new_tags_string)
|
| 499 |
|
| 500 |
# Parse the prompt
|
| 501 |
parsed = parser.parse(new_tags_string)
|
| 502 |
# Extract tags from the parsed tree
|
| 503 |
new_image_tags = extract_tags(parsed)
|
| 504 |
+
tag_data = build_tag_offsets_dicts(new_image_tags)
|
| 505 |
|
| 506 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
| 507 |
+
unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
|
| 508 |
+
|
| 509 |
+
bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
|
| 510 |
|
| 511 |
+
modified_tags = [tag_info['modified_tag'] for tag_info in tag_data]
|
| 512 |
+
X_new_image = vectorizer.transform([','.join(modified_tags + removed_tags)])
|
| 513 |
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
| 514 |
|
| 515 |
top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
|
|
|
|
| 526 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
| 527 |
image_galleries.append(artists) # Extend the list with artist tuples
|
| 528 |
|
| 529 |
+
return (unseen_tags_data, bad_tags_illustrated_string, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries) #image_galleries[0], image_galleries[1] DOES work. Find a generic alternative.
|
| 530 |
except ParseError as e:
|
| 531 |
return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
|
| 532 |
|
|
|
|
| 540 |
similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
|
| 541 |
num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
| 542 |
allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
|
| 543 |
+
with gr.Row():
|
| 544 |
+
bad_tags_illustrated_string = gr.HighlightedText()
|
| 545 |
with gr.Row():
|
| 546 |
with gr.Column(scale=1):
|
| 547 |
top_artists = gr.HTML(label="Top Artists", value=create_html_placeholder(title="Top Artists"))
|
|
|
|
| 559 |
submit_button.click(
|
| 560 |
find_similar_artists,
|
| 561 |
inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
|
| 562 |
+
outputs=[unseen_tags, bad_tags_illustrated_string, top_artists, dynamic_prompts] + galleries
|
| 563 |
)
|
| 564 |
|
| 565 |
gr.Markdown(faq_content)
|