awacke1 commited on
Commit
8057c09
·
1 Parent(s): 552817b

Create new file

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer, util
2
+ from huggingface_hub import hf_hub_download
3
+ import os
4
+ import pickle
5
+ import pandas as pd
6
+ import gradio as gr
7
+
8
+ pd.options.mode.chained_assignment = None # Turn off SettingWithCopyWarning
9
+
10
+ auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
11
+ pickled = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="clean-large_embeddings_msmarco-MiniLM-L-6-v3.pkl"), "rb"))
12
+ songs = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="songs_new.csv"))
13
+ verses = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator-private", repo_type="dataset", filename="verses.csv", use_auth_token=auth_token))
14
+ lyrics = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator-private", repo_type="dataset", filename="lyrics_new.csv", use_auth_token=auth_token))
15
+
16
+ embedder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
17
+
18
+ song_ids = pickled["song_ids"]
19
+ corpus_embeddings = pickled["embeddings"]
20
+
21
+
22
+ def generate_playlist(prompt):
23
+ prompt_embedding = embedder.encode(prompt, convert_to_tensor=True)
24
+ hits = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=20)
25
+ hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])
26
+
27
+ verse_match = verses.iloc[hits['corpus_id']]
28
+ verse_match = verse_match.drop_duplicates(subset=["song_id"])
29
+ song_match = songs[songs["song_id"].isin(verse_match["song_id"].values)]
30
+ song_match.song_id = pd.Categorical(song_match.song_id, categories=verse_match["song_id"].values)
31
+ song_match = song_match.sort_values("song_id")
32
+ song_match = song_match[0:9] # Only grab the top 9
33
+
34
+ song_names = list(song_match["full_title"])
35
+ song_art = list(song_match["art"].fillna("https://i.imgur.com/bgCDfT1.jpg"))
36
+ images = [gr.Image.update(value=art, visible=True) for art in song_art]
37
+
38
+ return (
39
+ gr.Radio.update(label="Songs", interactive=True, choices=song_names),
40
+ *images
41
+ )
42
+
43
+
44
+ def set_lyrics(full_title):
45
+ lyrics_text = lyrics[lyrics["song_id"].isin(songs[songs["full_title"] == full_title]["song_id"])]["text"].iloc[0]
46
+ return gr.Textbox.update(value=lyrics_text)
47
+
48
+
49
+ def set_example_prompt(example):
50
+ return gr.TextArea.update(value=example[0])
51
+
52
+
53
+ demo = gr.Blocks()
54
+
55
+ with demo:
56
+ gr.Markdown(
57
+ """
58
+ # Playlist Generator 📻 🎵
59
+ """)
60
+
61
+ with gr.Row():
62
+ with gr.Column():
63
+ gr.Markdown(
64
+ """
65
+ Enter a prompt and generate a playlist based on ✨semantic similarity✨
66
+ This was built using Sentence Transformers and Gradio – blog post coming soon!
67
+ """)
68
+
69
+ song_prompt = gr.TextArea(
70
+ value="Running wild and free",
71
+ placeholder="Enter a song prompt, or choose an example"
72
+ )
73
+ example_prompts = gr.Dataset(
74
+ components=[song_prompt],
75
+ samples=[
76
+ ["I feel nostalgic for the past"],
77
+ ["Running wild and free"],
78
+ ["I'm deeply in love with someone I just met!"],
79
+ ["My friends mean the world to me"],
80
+ ["Sometimes I feel like no one understands"],
81
+ ]
82
+ )
83
+
84
+ with gr.Column():
85
+ fetch_songs = gr.Button(value="Generate Your Playlist 🧑🏽‍🎤").style(full_width=True)
86
+
87
+ with gr.Row():
88
+ tile1 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
89
+ tile2 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
90
+ tile3 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
91
+ with gr.Row():
92
+ tile4 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
93
+ tile5 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
94
+ tile6 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
95
+ with gr.Row():
96
+ tile7 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
97
+ tile8 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
98
+ tile9 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
99
+
100
+ # Workaround because of the Gallery issues
101
+ tiles = [tile1, tile2, tile3, tile4, tile5, tile6, tile7, tile8, tile9]
102
+
103
+ song_option = gr.Radio(label="Songs", interactive=True, choices=None, type="value")
104
+
105
+ with gr.Column():
106
+ verse = gr.Textbox(label="Verse", placeholder="Select a song to see its lyrics")
107
+
108
+ fetch_songs.click(
109
+ fn=generate_playlist,
110
+ inputs=[song_prompt],
111
+ outputs=[song_option, *tiles],
112
+ )
113
+
114
+ example_prompts.click(
115
+ fn=set_example_prompt,
116
+ inputs=example_prompts,
117
+ outputs=example_prompts.components,
118
+ )
119
+
120
+ song_option.change(
121
+ fn=set_lyrics,
122
+ inputs=[song_option],
123
+ outputs=[verse]
124
+ )
125
+
126
+ demo.launch()