fractalz commited on
Commit
2435506
·
verified ·
1 Parent(s): 36ac795

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -11
app.py CHANGED
@@ -2,9 +2,11 @@ import gradio as gr
2
  import numpy as np
3
  from scipy.spatial.distance import cosine
4
  import pandas as pd
 
 
5
 
6
  # --- Simulate a small pre-trained Word2Vec model ---
7
- # Dummy word vectors for demonstration
8
  dummy_word_vectors = {
9
  'cat': np.array([0.9, 0.7, 0.1, 0.2]),
10
  'dog': np.array([0.8, 0.8, 0.3, 0.1]),
@@ -20,25 +22,33 @@ dummy_word_vectors = {
20
  'king': np.array([0.9, 0.1, 0.1, 0.8]),
21
  'queen': np.array([0.8, 0.2, 0.2, 0.9]),
22
  'man': np.array([0.9, 0.15, 0.05, 0.7]),
23
- 'woman': np.array([0.85, 0.1, 0.15, 0.85])
 
 
24
  }
25
 
26
  # Normalize vectors (important for cosine similarity)
27
  for word, vec in dummy_word_vectors.items():
28
  dummy_word_vectors[word] = vec / np.linalg.norm(vec)
29
 
30
- # --- Function to find nearest neighbors ---
31
- def find_nearest_neighbors(search_word_input):
32
  search_word = search_word_input.lower()
33
 
34
  if search_word not in dummy_word_vectors:
35
  return (
 
36
  pd.DataFrame([{"Message": f"'{search_word}' not found in our dummy vocabulary. Try one of these: {', '.join(list(dummy_word_vectors.keys()))}"}]),
37
  "Warning: Word not found!"
38
  )
39
 
40
  target_vector = dummy_word_vectors[search_word]
41
  similarities = []
 
 
 
 
 
42
  for word, vector in dummy_word_vectors.items():
43
  if word != search_word: # Don't compare a word to itself
44
  similarity = 1 - cosine(target_vector, vector)
@@ -48,21 +58,100 @@ def find_nearest_neighbors(search_word_input):
48
  by="Cosine Similarity", ascending=False
49
  ).reset_index(drop=True)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Format the DataFrame for better display in Gradio
52
  results_df["Cosine Similarity"] = results_df["Cosine Similarity"].round(4)
53
  results_df.columns = ["Neighbor Word", "Similarity Score"] # Rename for UI clarity
54
 
55
- message = f"Found nearest neighbors for '{search_word}'!"
56
- return results_df, message
 
 
 
57
 
58
  # --- Gradio Interface ---
59
  iface = gr.Interface(
60
- fn=find_nearest_neighbors,
61
  inputs=gr.Textbox(
62
  label="Enter a word to explore its neighbors:",
63
  placeholder="e.g., cat, king, fish"
64
  ),
65
  outputs=[
 
66
  gr.DataFrame(
67
  headers=["Neighbor Word", "Similarity Score"],
68
  row_count=5, # Display up to 5 rows by default
@@ -74,11 +163,13 @@ iface = gr.Interface(
74
  label="Status"
75
  )
76
  ],
77
- title="🚀 Word Vector Explorer (Gradio POC)",
78
  description=(
79
- "Discover the semantic neighbors of words using word embeddings! "
80
- "Type a word, and see its closest companions in the vector space."
81
- "<br>_Note: This POC uses dummy word vectors. In a full version, this would connect to a large pre-trained Word2Vec model!_"
 
 
82
  ),
83
  allow_flagging="never", # Optional: disables the "Flag" button
84
  examples=[
 
2
  import numpy as np
3
  from scipy.spatial.distance import cosine
4
  import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ from sklearn.decomposition import PCA
7
 
8
  # --- Simulate a small pre-trained Word2Vec model ---
9
+ # Dummy word vectors for demonstration (4D for richer visualization)
10
  dummy_word_vectors = {
11
  'cat': np.array([0.9, 0.7, 0.1, 0.2]),
12
  'dog': np.array([0.8, 0.8, 0.3, 0.1]),
 
22
  'king': np.array([0.9, 0.1, 0.1, 0.8]),
23
  'queen': np.array([0.8, 0.2, 0.2, 0.9]),
24
  'man': np.array([0.9, 0.15, 0.05, 0.7]),
25
+ 'woman': np.array([0.85, 0.1, 0.15, 0.85]),
26
+ 'prince': np.array([0.88, 0.12, 0.12, 0.82]),
27
+ 'princess': np.array([0.83, 0.18, 0.18, 0.88])
28
  }
29
 
30
  # Normalize vectors (important for cosine similarity)
31
  for word, vec in dummy_word_vectors.items():
32
  dummy_word_vectors[word] = vec / np.linalg.norm(vec)
33
 
34
+ # --- Function to find nearest neighbors and generate plot ---
35
+ def find_nearest_neighbors_and_plot(search_word_input):
36
  search_word = search_word_input.lower()
37
 
38
  if search_word not in dummy_word_vectors:
39
  return (
40
+ None, # No plot
41
  pd.DataFrame([{"Message": f"'{search_word}' not found in our dummy vocabulary. Try one of these: {', '.join(list(dummy_word_vectors.keys()))}"}]),
42
  "Warning: Word not found!"
43
  )
44
 
45
  target_vector = dummy_word_vectors[search_word]
46
  similarities = []
47
+
48
+ # Collect words and vectors for PCA
49
+ words_to_plot = [search_word]
50
+ vectors_to_plot = [target_vector]
51
+
52
  for word, vector in dummy_word_vectors.items():
53
  if word != search_word: # Don't compare a word to itself
54
  similarity = 1 - cosine(target_vector, vector)
 
58
  by="Cosine Similarity", ascending=False
59
  ).reset_index(drop=True)
60
 
61
+ # Add top N neighbors to plot (e.g., top 5)
62
+ top_n = 5
63
+ for _, row in results_df.head(top_n).iterrows():
64
+ words_to_plot.append(row["Word"])
65
+ vectors_to_plot.append(dummy_word_vectors[row["Word"]])
66
+
67
+ # Convert to numpy array for PCA
68
+ vectors_array = np.array(vectors_to_plot)
69
+
70
+ # Perform PCA to reduce to 2 dimensions for plotting
71
+ pca = PCA(n_components=2)
72
+ # Fit PCA on all dummy vectors first to get a consistent mapping
73
+ # This helps keep the relative positions meaningful across different searches.
74
+ all_vectors_array = np.array(list(dummy_word_vectors.values()))
75
+ pca.fit(all_vectors_array)
76
+
77
+ # Transform only the selected vectors
78
+ transformed_vectors = pca.transform(vectors_array)
79
+
80
+ # Create the plot
81
+ fig, ax = plt.subplots(figsize=(8, 8))
82
+
83
+ # Plot all words in the dummy vocabulary as light grey points
84
+ # to provide some context for the PCA space
85
+ all_transformed_vectors = pca.transform(all_vectors_array)
86
+ all_words = list(dummy_word_vectors.keys())
87
+ for i, word in enumerate(all_words):
88
+ ax.scatter(all_transformed_vectors[i, 0], all_transformed_vectors[i, 1],
89
+ color='lightgray', alpha=0.5, s=50)
90
+ ax.text(all_transformed_vectors[i, 0] + 0.01, all_transformed_vectors[i, 1] + 0.01, word,
91
+ fontsize=8, color='darkgray')
92
+
93
+ # Plot selected words
94
+ for i, word in enumerate(words_to_plot):
95
+ x, y = transformed_vectors[i]
96
+ color = 'red' if word == search_word else 'blue'
97
+ marker = 'D' if word == search_word else 'o' # Diamond for search word
98
+
99
+ ax.scatter(x, y, color=color, label=word, marker=marker, s=150 if word == search_word else 100, edgecolor='black', zorder=5)
100
+ ax.text(x + 0.01, y + 0.01, word, fontsize=10, weight='bold' if word == search_word else 'normal', color=color, zorder=6)
101
+
102
+ # Draw vector from origin to point (simulating conceptual vectors)
103
+ ax.plot([0, x], [0, y], color=color, linestyle='--', linewidth=1, alpha=0.7)
104
+
105
+ # Draw arrows from search word to its neighbors (optional, but good for intuition)
106
+ search_word_x, search_word_y = transformed_vectors[0]
107
+ for i in range(1, len(transformed_vectors)):
108
+ neighbor_x, neighbor_y = transformed_vectors[i]
109
+ # Calculate angle and display for top 1
110
+ if i == 1: # Only for the closest neighbor
111
+ vec1 = transformed_vectors[0] - np.array([0,0]) # Vector from origin to search word
112
+ vec2 = transformed_vectors[i] - np.array([0,0]) # Vector from origin to neighbor
113
+
114
+ # Use original 4D vectors for actual cosine similarity calculation
115
+ original_vec1 = target_vector
116
+ original_vec2 = dummy_word_vectors[words_to_plot[i]]
117
+
118
+ sim_val = 1 - cosine(original_vec1, original_vec2)
119
+ angle_rad = np.arccos(np.clip(sim_val, -1.0, 1.0)) # Clip to handle potential float precision issues
120
+ angle_deg = np.degrees(angle_rad)
121
+ ax.annotate(f"{angle_deg:.1f}°", xy=((vec1[0]+vec2[0])/2, (vec1[1]+vec2[1])/2),
122
+ xytext=(search_word_x + 0.05, search_word_y + 0.05),
123
+ arrowprops=dict(facecolor='black', shrink=0.05, width=0.5, headwidth=5),
124
+ fontsize=9, color='green', weight='bold')
125
+
126
+
127
+ ax.set_title(f"2D Projection of '{search_word}' and its Nearest Neighbors")
128
+ ax.set_xlabel(f"PCA Component 1 (explains {pca.explained_variance_ratio_[0]*100:.1f}%)")
129
+ ax.set_ylabel(f"PCA Component 2 (explains {pca.explained_variance_ratio_[1]*100:.1f}%)")
130
+ ax.grid(True, linestyle=':', alpha=0.6)
131
+ ax.axhline(0, color='gray', linewidth=0.5)
132
+ ax.axvline(0, color='gray', linewidth=0.5)
133
+ ax.set_aspect('equal', adjustable='box')
134
+ plt.tight_layout()
135
+
136
  # Format the DataFrame for better display in Gradio
137
  results_df["Cosine Similarity"] = results_df["Cosine Similarity"].round(4)
138
  results_df.columns = ["Neighbor Word", "Similarity Score"] # Rename for UI clarity
139
 
140
+ message = f"Found nearest neighbors for '{search_word}'! " \
141
+ f"Red diamond is the search word, blue circles are its closest neighbors. " \
142
+ f"The angle annotation shows the angle between the search word and its closest neighbor."
143
+
144
+ return fig, results_df, message
145
 
146
  # --- Gradio Interface ---
147
  iface = gr.Interface(
148
+ fn=find_nearest_neighbors_and_plot,
149
  inputs=gr.Textbox(
150
  label="Enter a word to explore its neighbors:",
151
  placeholder="e.g., cat, king, fish"
152
  ),
153
  outputs=[
154
+ gr.Plot(label="Word Vector Visualization (PCA 2D)"),
155
  gr.DataFrame(
156
  headers=["Neighbor Word", "Similarity Score"],
157
  row_count=5, # Display up to 5 rows by default
 
163
  label="Status"
164
  )
165
  ],
166
+ title="🚀 Word Vector Explorer: Visualize & Understand Cosine Similarity!",
167
  description=(
168
+ "Type a word to see its nearest semantic neighbors in the vector space, along with a 2D visualization! "
169
+ "The angle between vectors on the plot is a visual representation of **Cosine Similarity** "
170
+ "(smaller angle = higher similarity). "
171
+ "<br>_Note: This POC uses dummy 4D word vectors projected to 2D using PCA. "
172
+ "In a full version, this would connect to a large pre-trained Word2Vec model!_"
173
  ),
174
  allow_flagging="never", # Optional: disables the "Flag" button
175
  examples=[