Spaces:
Sleeping
Sleeping
Dan Mo
commited on
Commit
·
cfb0d15
1
Parent(s):
cf957e4
Add script to generate and save embeddings for models
Browse files- Implemented `generate_embeddings.py` to load embedding models and generate embeddings for emotion and event dictionaries.
- Added functionality to save generated embeddings as pickle files in the 'embeddings' directory.
- Included error handling and logging for better debugging and tracking of the embedding generation process.
- .gitignore +58 -0
- app.py +136 -17
- config.py +19 -0
- embeddings/BAAI_bge-large-en-v1.5_emotion.pkl +3 -0
- embeddings/BAAI_bge-large-en-v1.5_event.pkl +3 -0
- embeddings/all-mpnet-base-v2_emotion.pkl +3 -0
- embeddings/all-mpnet-base-v2_event.pkl +3 -0
- embeddings/thenlper_gte-large_emotion.pkl +3 -0
- embeddings/thenlper_gte-large_event.pkl +3 -0
- emoji_processor.py +106 -8
- generate_embeddings.py +99 -0
- utils.py +59 -1
.gitignore
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python cache files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
|
| 8 |
+
# Distribution / packaging
|
| 9 |
+
dist/
|
| 10 |
+
build/
|
| 11 |
+
*.egg-info/
|
| 12 |
+
|
| 13 |
+
# Virtual environments
|
| 14 |
+
venv/
|
| 15 |
+
env/
|
| 16 |
+
ENV/
|
| 17 |
+
|
| 18 |
+
# Jupyter Notebook
|
| 19 |
+
.ipynb_checkpoints
|
| 20 |
+
|
| 21 |
+
# VS Code
|
| 22 |
+
.vscode/
|
| 23 |
+
*.code-workspace
|
| 24 |
+
|
| 25 |
+
# PyCharm
|
| 26 |
+
.idea/
|
| 27 |
+
|
| 28 |
+
# Logs
|
| 29 |
+
*.log
|
| 30 |
+
logs/
|
| 31 |
+
|
| 32 |
+
# OS specific files
|
| 33 |
+
.DS_Store
|
| 34 |
+
Thumbs.db
|
| 35 |
+
desktop.ini
|
| 36 |
+
|
| 37 |
+
# Environment variables
|
| 38 |
+
.env
|
| 39 |
+
.env.local
|
| 40 |
+
|
| 41 |
+
# Temporary files
|
| 42 |
+
*.swp
|
| 43 |
+
*.swo
|
| 44 |
+
*~
|
| 45 |
+
.temp/
|
| 46 |
+
|
| 47 |
+
# NOTE: We're keeping the embeddings/*.pkl files since they're pre-generated
|
| 48 |
+
# for faster startup. They're managed by Git LFS as specified in .gitattributes.
|
| 49 |
+
|
| 50 |
+
# Gradio specific
|
| 51 |
+
gradio_cached_examples/
|
| 52 |
+
flagged/
|
| 53 |
+
|
| 54 |
+
# Local development files
|
| 55 |
+
.jupyter/
|
| 56 |
+
.local/
|
| 57 |
+
.bash_history
|
| 58 |
+
.python_history
|
app.py
CHANGED
|
@@ -6,36 +6,155 @@ This module handles the Gradio interface and application setup.
|
|
| 6 |
import gradio as gr
|
| 7 |
from utils import logger
|
| 8 |
from emoji_processor import EmojiProcessor
|
|
|
|
| 9 |
|
| 10 |
class EmojiMashupApp:
|
| 11 |
def __init__(self):
|
| 12 |
"""Initialize the Gradio application."""
|
| 13 |
logger.info("Initializing Emoji Mashup App")
|
| 14 |
-
self.processor = EmojiProcessor()
|
| 15 |
self.processor.load_emoji_dictionaries()
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def create_interface(self):
|
| 18 |
"""Create and configure the Gradio interface.
|
| 19 |
|
| 20 |
Returns:
|
| 21 |
Gradio Interface object
|
| 22 |
"""
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
gr.
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def run(self, share=True):
|
| 41 |
"""Launch the Gradio application.
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
from utils import logger
|
| 8 |
from emoji_processor import EmojiProcessor
|
| 9 |
+
from config import EMBEDDING_MODELS
|
| 10 |
|
| 11 |
class EmojiMashupApp:
|
| 12 |
def __init__(self):
|
| 13 |
"""Initialize the Gradio application."""
|
| 14 |
logger.info("Initializing Emoji Mashup App")
|
| 15 |
+
self.processor = EmojiProcessor(model_key="mpnet", use_cached_embeddings=True) # Default to mpnet
|
| 16 |
self.processor.load_emoji_dictionaries()
|
| 17 |
|
| 18 |
+
def create_model_dropdown_choices(self):
|
| 19 |
+
"""Create formatted choices for the model dropdown.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
List of formatted model choices
|
| 23 |
+
"""
|
| 24 |
+
return [
|
| 25 |
+
f"{key} ({info['size']}) - {info['notes']}"
|
| 26 |
+
for key, info in EMBEDDING_MODELS.items()
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
def handle_model_change(self, dropdown_value, use_cached_embeddings):
|
| 30 |
+
"""Handle model selection change from dropdown.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
dropdown_value: Selected value from dropdown
|
| 34 |
+
use_cached_embeddings: Whether to use cached embeddings
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Status message about model change
|
| 38 |
+
"""
|
| 39 |
+
# Extract model key from dropdown value (first word before space)
|
| 40 |
+
model_key = dropdown_value.split()[0] if dropdown_value else "mpnet"
|
| 41 |
+
|
| 42 |
+
# Update processor cache setting
|
| 43 |
+
self.processor.use_cached_embeddings = use_cached_embeddings
|
| 44 |
+
|
| 45 |
+
if model_key in EMBEDDING_MODELS:
|
| 46 |
+
success = self.processor.switch_model(model_key)
|
| 47 |
+
if success:
|
| 48 |
+
cache_status = "using cached embeddings" if use_cached_embeddings else "computing fresh embeddings"
|
| 49 |
+
return f"Switched to {model_key} model ({cache_status}): {EMBEDDING_MODELS[model_key]['notes']}"
|
| 50 |
+
else:
|
| 51 |
+
return f"Failed to switch to {model_key} model"
|
| 52 |
+
else:
|
| 53 |
+
return f"Unknown model: {model_key}"
|
| 54 |
+
|
| 55 |
+
def process_with_model(self, model_selection, text, use_cached_embeddings):
|
| 56 |
+
"""Process text with selected model.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
model_selection: Selected model from dropdown
|
| 60 |
+
text: User input text
|
| 61 |
+
use_cached_embeddings: Whether to use cached embeddings
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Tuple of (emotion emoji, event emoji, mashup image)
|
| 65 |
+
"""
|
| 66 |
+
# Extract model key from dropdown value (first word before space)
|
| 67 |
+
model_key = model_selection.split()[0] if model_selection else "mpnet"
|
| 68 |
+
|
| 69 |
+
# Update processor cache setting
|
| 70 |
+
self.processor.use_cached_embeddings = use_cached_embeddings
|
| 71 |
+
|
| 72 |
+
if model_key in EMBEDDING_MODELS:
|
| 73 |
+
self.processor.switch_model(model_key)
|
| 74 |
+
|
| 75 |
+
# Process text with current model
|
| 76 |
+
return self.processor.sentence_to_emojis(text)
|
| 77 |
+
|
| 78 |
def create_interface(self):
|
| 79 |
"""Create and configure the Gradio interface.
|
| 80 |
|
| 81 |
Returns:
|
| 82 |
Gradio Interface object
|
| 83 |
"""
|
| 84 |
+
with gr.Blocks(title="Sentence → Emoji Mashup") as interface:
|
| 85 |
+
gr.Markdown("# Sentence → Emoji Mashup")
|
| 86 |
+
gr.Markdown("Get the top emotion and event emoji from your sentence, and view the mashup!")
|
| 87 |
+
|
| 88 |
+
with gr.Row():
|
| 89 |
+
with gr.Column(scale=3):
|
| 90 |
+
# Model selection dropdown
|
| 91 |
+
model_dropdown = gr.Dropdown(
|
| 92 |
+
choices=self.create_model_dropdown_choices(),
|
| 93 |
+
value=self.create_model_dropdown_choices()[0], # Default to first model (mpnet)
|
| 94 |
+
label="Embedding Model",
|
| 95 |
+
info="Select the model used for text-emoji matching"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Cache toggle
|
| 99 |
+
cache_toggle = gr.Checkbox(
|
| 100 |
+
label="Use cached embeddings",
|
| 101 |
+
value=True,
|
| 102 |
+
info="When enabled, embeddings will be saved to and loaded from disk"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Text input
|
| 106 |
+
text_input = gr.Textbox(
|
| 107 |
+
lines=2,
|
| 108 |
+
placeholder="Type a sentence...",
|
| 109 |
+
label="Your message"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Process button
|
| 113 |
+
submit_btn = gr.Button("Generate Emoji Mashup", variant="primary")
|
| 114 |
+
|
| 115 |
+
with gr.Column(scale=2):
|
| 116 |
+
# Model info display
|
| 117 |
+
model_info = gr.Textbox(
|
| 118 |
+
value=f"Using mpnet model (using cached embeddings): {EMBEDDING_MODELS['mpnet']['notes']}",
|
| 119 |
+
label="Model Info",
|
| 120 |
+
interactive=False
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Output displays
|
| 124 |
+
emotion_out = gr.Text(label="Top Emotion Emoji")
|
| 125 |
+
event_out = gr.Text(label="Top Event Emoji")
|
| 126 |
+
mashup_out = gr.Image(label="Mashup Emoji")
|
| 127 |
+
|
| 128 |
+
# Set up event handlers
|
| 129 |
+
model_dropdown.change(
|
| 130 |
+
fn=self.handle_model_change,
|
| 131 |
+
inputs=[model_dropdown, cache_toggle],
|
| 132 |
+
outputs=[model_info]
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
cache_toggle.change(
|
| 136 |
+
fn=self.handle_model_change,
|
| 137 |
+
inputs=[model_dropdown, cache_toggle],
|
| 138 |
+
outputs=[model_info]
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
submit_btn.click(
|
| 142 |
+
fn=self.process_with_model,
|
| 143 |
+
inputs=[model_dropdown, text_input, cache_toggle],
|
| 144 |
+
outputs=[emotion_out, event_out, mashup_out]
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Examples
|
| 148 |
+
gr.Examples(
|
| 149 |
+
examples=[
|
| 150 |
+
["I feel so happy today!"],
|
| 151 |
+
["I'm really angry right now"],
|
| 152 |
+
["Feeling tired after a long day"]
|
| 153 |
+
],
|
| 154 |
+
inputs=text_input
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
return interface
|
| 158 |
|
| 159 |
def run(self, share=True):
|
| 160 |
"""Launch the Gradio application.
|
config.py
CHANGED
|
@@ -9,4 +9,23 @@ CONFIG = {
|
|
| 9 |
"item_file": "google-emoji-kitchen-item.txt",
|
| 10 |
"emoji_kitchen_url": "https://emojik.vercel.app/s/{emoji1}_{emoji2}",
|
| 11 |
"default_size": 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
}
|
|
|
|
| 9 |
"item_file": "google-emoji-kitchen-item.txt",
|
| 10 |
"emoji_kitchen_url": "https://emojik.vercel.app/s/{emoji1}_{emoji2}",
|
| 11 |
"default_size": 256
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
# Available embedding models
|
| 15 |
+
EMBEDDING_MODELS = {
|
| 16 |
+
"mpnet": {
|
| 17 |
+
"id": "all-mpnet-base-v2",
|
| 18 |
+
"size": "110M",
|
| 19 |
+
"notes": "Balanced, great general-purpose model"
|
| 20 |
+
},
|
| 21 |
+
"gte": {
|
| 22 |
+
"id": "thenlper/gte-large",
|
| 23 |
+
"size": "335M",
|
| 24 |
+
"notes": "Context-rich, good for emotion & nuance"
|
| 25 |
+
},
|
| 26 |
+
"bge": {
|
| 27 |
+
"id": "BAAI/bge-large-en-v1.5",
|
| 28 |
+
"size": "350M",
|
| 29 |
+
"notes": "Tuned for ranking & high-precision similarity"
|
| 30 |
+
}
|
| 31 |
}
|
embeddings/BAAI_bge-large-en-v1.5_emotion.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5456af7ceaa04bdc28b9b125e317eaebf503c60b6937f006b54c595850c3830a
|
| 3 |
+
size 463549
|
embeddings/BAAI_bge-large-en-v1.5_event.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c32a321359bd0a197e906c85731003294a835e7590c655043e6e9ebdfa607de9
|
| 3 |
+
size 2238733
|
embeddings/all-mpnet-base-v2_emotion.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6db3183f80970f30c7dee0cf846c832b5505890a071c1af8009f6ff452083f7c
|
| 3 |
+
size 348852
|
embeddings/all-mpnet-base-v2_event.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:65d434cb2cdd1034e494a87d354345e67bfd25a90a44247cfa3406dc100334c0
|
| 3 |
+
size 1684668
|
embeddings/thenlper_gte-large_emotion.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e8b5472bf5008613f76ac06738fa55c91ff2fd6ae7472c9a1f739d210b5f2f0e
|
| 3 |
+
size 463549
|
embeddings/thenlper_gte-large_event.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:81961955fad517b578deb1969c8e84594fe92c8eed32d6b43f85e804f5214b82
|
| 3 |
+
size 2238733
|
emoji_processor.py
CHANGED
|
@@ -7,23 +7,36 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
| 7 |
import requests
|
| 8 |
from PIL import Image
|
| 9 |
from io import BytesIO
|
|
|
|
| 10 |
|
| 11 |
-
from config import CONFIG
|
| 12 |
-
from utils import logger, kitchen_txt_to_dict
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class EmojiProcessor:
|
| 15 |
-
def __init__(self, model_name=
|
| 16 |
"""Initialize the emoji processor with the specified model.
|
| 17 |
|
| 18 |
Args:
|
| 19 |
-
model_name:
|
|
|
|
|
|
|
| 20 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
logger.info(f"Loading model: {model_name}")
|
| 22 |
self.model = SentenceTransformer(model_name)
|
|
|
|
| 23 |
self.emotion_dict = {}
|
| 24 |
self.event_dict = {}
|
| 25 |
self.emotion_embeddings = {}
|
| 26 |
self.event_embeddings = {}
|
|
|
|
| 27 |
|
| 28 |
def load_emoji_dictionaries(self, emotion_file=CONFIG["emotion_file"], item_file=CONFIG["item_file"]):
|
| 29 |
"""Load emoji dictionaries from text files.
|
|
@@ -36,10 +49,95 @@ class EmojiProcessor:
|
|
| 36 |
self.emotion_dict = kitchen_txt_to_dict(emotion_file)
|
| 37 |
self.event_dict = kitchen_txt_to_dict(item_file)
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
def find_top_emojis(self, embedding, emoji_embeddings, top_n=1):
|
| 45 |
"""Find top matching emojis based on cosine similarity.
|
|
|
|
| 7 |
import requests
|
| 8 |
from PIL import Image
|
| 9 |
from io import BytesIO
|
| 10 |
+
import os
|
| 11 |
|
| 12 |
+
from config import CONFIG, EMBEDDING_MODELS
|
| 13 |
+
from utils import (logger, kitchen_txt_to_dict,
|
| 14 |
+
save_embeddings_to_pickle, load_embeddings_from_pickle,
|
| 15 |
+
get_embeddings_pickle_path)
|
| 16 |
|
| 17 |
class EmojiProcessor:
|
| 18 |
+
def __init__(self, model_name=None, model_key=None, use_cached_embeddings=True):
|
| 19 |
"""Initialize the emoji processor with the specified model.
|
| 20 |
|
| 21 |
Args:
|
| 22 |
+
model_name: Direct name of the sentence transformer model to use
|
| 23 |
+
model_key: Key from EMBEDDING_MODELS to use (takes precedence over model_name)
|
| 24 |
+
use_cached_embeddings: Whether to use cached embeddings from pickle files
|
| 25 |
"""
|
| 26 |
+
# Get model name from the key if provided
|
| 27 |
+
if model_key and model_key in EMBEDDING_MODELS:
|
| 28 |
+
model_name = EMBEDDING_MODELS[model_key]['id']
|
| 29 |
+
elif not model_name:
|
| 30 |
+
model_name = CONFIG["model_name"]
|
| 31 |
+
|
| 32 |
logger.info(f"Loading model: {model_name}")
|
| 33 |
self.model = SentenceTransformer(model_name)
|
| 34 |
+
self.current_model_name = model_name
|
| 35 |
self.emotion_dict = {}
|
| 36 |
self.event_dict = {}
|
| 37 |
self.emotion_embeddings = {}
|
| 38 |
self.event_embeddings = {}
|
| 39 |
+
self.use_cached_embeddings = use_cached_embeddings
|
| 40 |
|
| 41 |
def load_emoji_dictionaries(self, emotion_file=CONFIG["emotion_file"], item_file=CONFIG["item_file"]):
|
| 42 |
"""Load emoji dictionaries from text files.
|
|
|
|
| 49 |
self.emotion_dict = kitchen_txt_to_dict(emotion_file)
|
| 50 |
self.event_dict = kitchen_txt_to_dict(item_file)
|
| 51 |
|
| 52 |
+
# Load or compute embeddings
|
| 53 |
+
self._load_or_compute_embeddings()
|
| 54 |
+
|
| 55 |
+
def _load_or_compute_embeddings(self):
|
| 56 |
+
"""Load embeddings from pickle files if available, otherwise compute them."""
|
| 57 |
+
if self.use_cached_embeddings:
|
| 58 |
+
# Try to load emotion embeddings
|
| 59 |
+
emotion_pickle_path = get_embeddings_pickle_path(self.current_model_name, "emotion")
|
| 60 |
+
loaded_emotion_embeddings = load_embeddings_from_pickle(emotion_pickle_path)
|
| 61 |
+
|
| 62 |
+
# Try to load event embeddings
|
| 63 |
+
event_pickle_path = get_embeddings_pickle_path(self.current_model_name, "event")
|
| 64 |
+
loaded_event_embeddings = load_embeddings_from_pickle(event_pickle_path)
|
| 65 |
+
|
| 66 |
+
# Check if we need to compute any embeddings
|
| 67 |
+
compute_emotion = loaded_emotion_embeddings is None
|
| 68 |
+
compute_event = loaded_event_embeddings is None
|
| 69 |
+
|
| 70 |
+
if not compute_emotion:
|
| 71 |
+
# Verify all emoji keys are present in loaded embeddings
|
| 72 |
+
for emoji in self.emotion_dict.keys():
|
| 73 |
+
if emoji not in loaded_emotion_embeddings:
|
| 74 |
+
logger.info(f"Cached emotion embeddings missing emoji: {emoji}, will recompute")
|
| 75 |
+
compute_emotion = True
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
if not compute_emotion:
|
| 79 |
+
self.emotion_embeddings = loaded_emotion_embeddings
|
| 80 |
+
|
| 81 |
+
if not compute_event:
|
| 82 |
+
# Verify all emoji keys are present in loaded embeddings
|
| 83 |
+
for emoji in self.event_dict.keys():
|
| 84 |
+
if emoji not in loaded_event_embeddings:
|
| 85 |
+
logger.info(f"Cached event embeddings missing emoji: {emoji}, will recompute")
|
| 86 |
+
compute_event = True
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
if not compute_event:
|
| 90 |
+
self.event_embeddings = loaded_event_embeddings
|
| 91 |
+
|
| 92 |
+
# Compute any missing embeddings
|
| 93 |
+
if compute_emotion:
|
| 94 |
+
logger.info(f"Computing emotion embeddings for model: {self.current_model_name}")
|
| 95 |
+
self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()}
|
| 96 |
+
# Save for future use
|
| 97 |
+
save_embeddings_to_pickle(self.emotion_embeddings, emotion_pickle_path)
|
| 98 |
+
|
| 99 |
+
if compute_event:
|
| 100 |
+
logger.info(f"Computing event embeddings for model: {self.current_model_name}")
|
| 101 |
+
self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()}
|
| 102 |
+
# Save for future use
|
| 103 |
+
save_embeddings_to_pickle(self.event_embeddings, event_pickle_path)
|
| 104 |
+
else:
|
| 105 |
+
# Compute embeddings without caching
|
| 106 |
+
logger.info("Computing embeddings for emoji dictionaries (no caching)")
|
| 107 |
+
self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()}
|
| 108 |
+
self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()}
|
| 109 |
+
|
| 110 |
+
def switch_model(self, model_key):
|
| 111 |
+
"""Switch to a different embedding model.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
model_key: Key from EMBEDDING_MODELS to use
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
True if model was switched successfully, False otherwise
|
| 118 |
+
"""
|
| 119 |
+
if model_key not in EMBEDDING_MODELS:
|
| 120 |
+
logger.error(f"Unknown model key: {model_key}")
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
model_name = EMBEDDING_MODELS[model_key]['id']
|
| 124 |
+
if model_name == self.current_model_name:
|
| 125 |
+
logger.info(f"Model {model_key} is already loaded")
|
| 126 |
+
return True
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
logger.info(f"Switching to model: {model_name}")
|
| 130 |
+
self.model = SentenceTransformer(model_name)
|
| 131 |
+
self.current_model_name = model_name
|
| 132 |
+
|
| 133 |
+
# Load or recompute embeddings with new model
|
| 134 |
+
if self.emotion_dict and self.event_dict:
|
| 135 |
+
self._load_or_compute_embeddings()
|
| 136 |
+
|
| 137 |
+
return True
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"Error switching model: {e}")
|
| 140 |
+
return False
|
| 141 |
|
| 142 |
def find_top_emojis(self, embedding, emoji_embeddings, top_n=1):
|
| 143 |
"""Find top matching emojis based on cosine similarity.
|
generate_embeddings.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility script to pre-generate embedding pickle files for all models.
|
| 3 |
+
|
| 4 |
+
This script will:
|
| 5 |
+
1. Load each embedding model
|
| 6 |
+
2. Generate embeddings for both emotion and event dictionaries
|
| 7 |
+
3. Save the embeddings as pickle files in the 'embeddings' directory
|
| 8 |
+
|
| 9 |
+
Run this script once locally to create all pickle files before uploading to the repository.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
from sentence_transformers import SentenceTransformer
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from config import CONFIG, EMBEDDING_MODELS
|
| 17 |
+
from utils import (logger, kitchen_txt_to_dict,
|
| 18 |
+
save_embeddings_to_pickle, get_embeddings_pickle_path)
|
| 19 |
+
|
| 20 |
+
def generate_embeddings_for_model(model_key, model_info):
|
| 21 |
+
"""Generate and save embeddings for a specific model.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
model_key: Key of the model in EMBEDDING_MODELS
|
| 25 |
+
model_info: Model information dictionary
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Tuple of (success_emotion, success_event)
|
| 29 |
+
"""
|
| 30 |
+
model_id = model_info['id']
|
| 31 |
+
print(f"\nProcessing model: {model_key} ({model_id}) - {model_info['size']}")
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
# Load the model
|
| 35 |
+
print(f"Loading {model_key} model...")
|
| 36 |
+
model = SentenceTransformer(model_id)
|
| 37 |
+
|
| 38 |
+
# Load emoji dictionaries
|
| 39 |
+
print("Loading emoji dictionaries...")
|
| 40 |
+
emotion_dict = kitchen_txt_to_dict(CONFIG["emotion_file"])
|
| 41 |
+
event_dict = kitchen_txt_to_dict(CONFIG["item_file"])
|
| 42 |
+
|
| 43 |
+
if not emotion_dict or not event_dict:
|
| 44 |
+
print("Error: Failed to load emoji dictionaries")
|
| 45 |
+
return False, False
|
| 46 |
+
|
| 47 |
+
# Generate emotion embeddings
|
| 48 |
+
print(f"Generating {len(emotion_dict)} emotion embeddings...")
|
| 49 |
+
emotion_embeddings = {}
|
| 50 |
+
for emoji, desc in tqdm(emotion_dict.items()):
|
| 51 |
+
emotion_embeddings[emoji] = model.encode(desc)
|
| 52 |
+
|
| 53 |
+
# Generate event embeddings
|
| 54 |
+
print(f"Generating {len(event_dict)} event embeddings...")
|
| 55 |
+
event_embeddings = {}
|
| 56 |
+
for emoji, desc in tqdm(event_dict.items()):
|
| 57 |
+
event_embeddings[emoji] = model.encode(desc)
|
| 58 |
+
|
| 59 |
+
# Save embeddings
|
| 60 |
+
emotion_pickle_path = get_embeddings_pickle_path(model_id, "emotion")
|
| 61 |
+
event_pickle_path = get_embeddings_pickle_path(model_id, "event")
|
| 62 |
+
|
| 63 |
+
success_emotion = save_embeddings_to_pickle(emotion_embeddings, emotion_pickle_path)
|
| 64 |
+
success_event = save_embeddings_to_pickle(event_embeddings, event_pickle_path)
|
| 65 |
+
|
| 66 |
+
return success_emotion, success_event
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"Error generating embeddings for model {model_key}: {e}")
|
| 69 |
+
return False, False
|
| 70 |
+
|
| 71 |
+
def main():
|
| 72 |
+
"""Main function to generate embeddings for all models."""
|
| 73 |
+
# Create embeddings directory if it doesn't exist
|
| 74 |
+
os.makedirs('embeddings', exist_ok=True)
|
| 75 |
+
|
| 76 |
+
print(f"Generating embeddings for {len(EMBEDDING_MODELS)} models...")
|
| 77 |
+
|
| 78 |
+
results = {}
|
| 79 |
+
|
| 80 |
+
# Generate embeddings for each model
|
| 81 |
+
for model_key, model_info in EMBEDDING_MODELS.items():
|
| 82 |
+
success_emotion, success_event = generate_embeddings_for_model(model_key, model_info)
|
| 83 |
+
results[model_key] = {
|
| 84 |
+
'emotion': success_emotion,
|
| 85 |
+
'event': success_event
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
# Print summary
|
| 89 |
+
print("\n=== Embedding Generation Summary ===")
|
| 90 |
+
for model_key, result in results.items():
|
| 91 |
+
status_emotion = "✓ Success" if result['emotion'] else "✗ Failed"
|
| 92 |
+
status_event = "✓ Success" if result['event'] else "✗ Failed"
|
| 93 |
+
print(f"{model_key:<10}: Emotion: {status_emotion}, Event: {status_event}")
|
| 94 |
+
|
| 95 |
+
print("\nDone! Embedding pickle files are stored in the 'embeddings' directory.")
|
| 96 |
+
print("You can now upload these files to your repository.")
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
main()
|
utils.py
CHANGED
|
@@ -3,6 +3,8 @@ Utility functions for the Emoji Mashup application.
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import logging
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# Configure logging
|
| 8 |
def setup_logging():
|
|
@@ -36,4 +38,60 @@ def kitchen_txt_to_dict(filepath):
|
|
| 36 |
return emoji_dict
|
| 37 |
except Exception as e:
|
| 38 |
logger.error(f"Error loading emoji dictionary from {filepath}: {e}")
|
| 39 |
-
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import logging
|
| 6 |
+
import os
|
| 7 |
+
import pickle
|
| 8 |
|
| 9 |
# Configure logging
|
| 10 |
def setup_logging():
|
|
|
|
| 38 |
return emoji_dict
|
| 39 |
except Exception as e:
|
| 40 |
logger.error(f"Error loading emoji dictionary from {filepath}: {e}")
|
| 41 |
+
return {}
|
| 42 |
+
|
| 43 |
+
def save_embeddings_to_pickle(embeddings, filepath):
|
| 44 |
+
"""Save embeddings dictionary to a pickle file.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
embeddings: Dictionary of embeddings to save
|
| 48 |
+
filepath: Path to save the pickle file to
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
True if successful, False otherwise
|
| 52 |
+
"""
|
| 53 |
+
try:
|
| 54 |
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
| 55 |
+
with open(filepath, 'wb') as f:
|
| 56 |
+
pickle.dump(embeddings, f)
|
| 57 |
+
logger.info(f"Saved embeddings to {filepath}")
|
| 58 |
+
return True
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Error saving embeddings to {filepath}: {e}")
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
def load_embeddings_from_pickle(filepath):
|
| 64 |
+
"""Load embeddings dictionary from a pickle file.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
filepath: Path to load the pickle file from
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Dictionary of embeddings if successful, None otherwise
|
| 71 |
+
"""
|
| 72 |
+
if not os.path.exists(filepath):
|
| 73 |
+
logger.info(f"Pickle file {filepath} does not exist")
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
with open(filepath, 'rb') as f:
|
| 78 |
+
embeddings = pickle.load(f)
|
| 79 |
+
logger.info(f"Loaded embeddings from {filepath}")
|
| 80 |
+
return embeddings
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.error(f"Error loading embeddings from {filepath}: {e}")
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
def get_embeddings_pickle_path(model_id, emoji_type):
|
| 86 |
+
"""Generate the path for an embeddings pickle file.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
model_id: ID of the embedding model
|
| 90 |
+
emoji_type: Type of emoji ('emotion' or 'event')
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Path to the embeddings pickle file
|
| 94 |
+
"""
|
| 95 |
+
# Create a safe filename from the model ID
|
| 96 |
+
safe_model_id = model_id.replace('/', '_').replace('\\', '_')
|
| 97 |
+
return os.path.join('embeddings', f"{safe_model_id}_{emoji_type}.pkl")
|