Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -124,14 +124,14 @@ class FastDatasetSearcher:
|
|
| 124 |
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
|
| 125 |
|
| 126 |
def quick_search(self, query, df):
|
| 127 |
-
"""Enhanced search with
|
| 128 |
if df.empty or not query.strip():
|
| 129 |
return df
|
| 130 |
|
| 131 |
try:
|
| 132 |
-
# Define
|
| 133 |
-
|
| 134 |
-
|
| 135 |
|
| 136 |
# Get searchable columns
|
| 137 |
searchable_cols = []
|
|
@@ -150,34 +150,55 @@ class FastDatasetSearcher:
|
|
| 150 |
for _, row in df.iterrows():
|
| 151 |
text_parts = []
|
| 152 |
row_matched = False
|
|
|
|
| 153 |
|
| 154 |
-
#
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
val = row[col]
|
| 157 |
if val is not None:
|
| 158 |
val_str = str(val).lower()
|
| 159 |
-
if
|
|
|
|
|
|
|
| 160 |
row_matched = True
|
| 161 |
text_parts.append(str(val))
|
| 162 |
|
| 163 |
text = ' '.join(text_parts)
|
| 164 |
|
| 165 |
if text.strip():
|
| 166 |
-
# Calculate
|
| 167 |
-
|
| 168 |
-
matching_terms = query_terms.intersection(
|
| 169 |
keyword_score = len(matching_terms) / len(query_terms)
|
| 170 |
|
| 171 |
# Calculate semantic score
|
| 172 |
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
|
| 173 |
semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0])
|
| 174 |
|
| 175 |
-
# Weighted
|
| 176 |
-
combined_score = 0.
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
| 181 |
else:
|
| 182 |
combined_score = 0.0
|
| 183 |
row_matched = False
|
|
@@ -460,6 +481,7 @@ def perform_arxiv_lookup(query, vocal_summary=True, titles_summary=True, full_au
|
|
| 460 |
st.audio(audio_file_full)
|
| 461 |
|
| 462 |
def render_result(result):
|
|
|
|
| 463 |
score = result.get('relevance_score', 0)
|
| 464 |
result_filtered = {k: v for k, v in result.items()
|
| 465 |
if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']}
|
|
@@ -469,12 +491,36 @@ def render_result(result):
|
|
| 469 |
|
| 470 |
cols = st.columns([2, 1])
|
| 471 |
with cols[0]:
|
|
|
|
| 472 |
for key, value in result_filtered.items():
|
| 473 |
if isinstance(value, (str, int, float)):
|
| 474 |
st.write(f"**{key}:** {value}")
|
|
|
|
|
|
|
| 475 |
|
| 476 |
with cols[1]:
|
| 477 |
st.metric("Relevance Score", f"{score:.2%}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
|
| 479 |
def main():
|
| 480 |
st.title("π₯ Advanced Video & Dataset Search with Voice")
|
|
|
|
| 124 |
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
|
| 125 |
|
| 126 |
def quick_search(self, query, df):
|
| 127 |
+
"""Enhanced search with strict token matching and semantic relevance"""
|
| 128 |
if df.empty or not query.strip():
|
| 129 |
return df
|
| 130 |
|
| 131 |
try:
|
| 132 |
+
# Define stricter thresholds
|
| 133 |
+
MIN_SEMANTIC_SCORE = 0.5 # Higher semantic threshold
|
| 134 |
+
EXACT_MATCH_BOOST = 2.0 # Boost for exact matches
|
| 135 |
|
| 136 |
# Get searchable columns
|
| 137 |
searchable_cols = []
|
|
|
|
| 150 |
for _, row in df.iterrows():
|
| 151 |
text_parts = []
|
| 152 |
row_matched = False
|
| 153 |
+
exact_match = False
|
| 154 |
|
| 155 |
+
# Prioritize description and matched_text fields
|
| 156 |
+
priority_fields = ['description', 'matched_text']
|
| 157 |
+
other_fields = [col for col in searchable_cols if col not in priority_fields]
|
| 158 |
+
|
| 159 |
+
# First check priority fields for exact matches
|
| 160 |
+
for col in priority_fields:
|
| 161 |
+
if col in row:
|
| 162 |
+
val = row[col]
|
| 163 |
+
if val is not None:
|
| 164 |
+
val_str = str(val).lower()
|
| 165 |
+
# Check for exact token matches
|
| 166 |
+
if query_lower in val_str.split():
|
| 167 |
+
exact_match = True
|
| 168 |
+
if any(term in val_str.split() for term in query_terms):
|
| 169 |
+
row_matched = True
|
| 170 |
+
text_parts.append(str(val))
|
| 171 |
+
|
| 172 |
+
# Then check other fields
|
| 173 |
+
for col in other_fields:
|
| 174 |
val = row[col]
|
| 175 |
if val is not None:
|
| 176 |
val_str = str(val).lower()
|
| 177 |
+
if query_lower in val_str.split():
|
| 178 |
+
exact_match = True
|
| 179 |
+
if any(term in val_str.split() for term in query_terms):
|
| 180 |
row_matched = True
|
| 181 |
text_parts.append(str(val))
|
| 182 |
|
| 183 |
text = ' '.join(text_parts)
|
| 184 |
|
| 185 |
if text.strip():
|
| 186 |
+
# Calculate exact token matches
|
| 187 |
+
text_tokens = set(text.lower().split())
|
| 188 |
+
matching_terms = query_terms.intersection(text_tokens)
|
| 189 |
keyword_score = len(matching_terms) / len(query_terms)
|
| 190 |
|
| 191 |
# Calculate semantic score
|
| 192 |
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
|
| 193 |
semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0])
|
| 194 |
|
| 195 |
+
# Weighted scoring with priority for exact matches
|
| 196 |
+
combined_score = 0.8 * keyword_score + 0.2 * semantic_score
|
| 197 |
|
| 198 |
+
if exact_match:
|
| 199 |
+
combined_score *= EXACT_MATCH_BOOST
|
| 200 |
+
elif row_matched:
|
| 201 |
+
combined_score *= 1.2
|
| 202 |
else:
|
| 203 |
combined_score = 0.0
|
| 204 |
row_matched = False
|
|
|
|
| 481 |
st.audio(audio_file_full)
|
| 482 |
|
| 483 |
def render_result(result):
|
| 484 |
+
"""Render a search result with voice selection and TTS options"""
|
| 485 |
score = result.get('relevance_score', 0)
|
| 486 |
result_filtered = {k: v for k, v in result.items()
|
| 487 |
if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']}
|
|
|
|
| 491 |
|
| 492 |
cols = st.columns([2, 1])
|
| 493 |
with cols[0]:
|
| 494 |
+
text_content = [] # Collect text for TTS
|
| 495 |
for key, value in result_filtered.items():
|
| 496 |
if isinstance(value, (str, int, float)):
|
| 497 |
st.write(f"**{key}:** {value}")
|
| 498 |
+
if isinstance(value, str) and len(value.strip()) > 0:
|
| 499 |
+
text_content.append(f"{key}: {value}")
|
| 500 |
|
| 501 |
with cols[1]:
|
| 502 |
st.metric("Relevance Score", f"{score:.2%}")
|
| 503 |
+
|
| 504 |
+
# Voice selection for TTS
|
| 505 |
+
voices = {
|
| 506 |
+
"Aria (US Female)": "en-US-AriaNeural",
|
| 507 |
+
"Guy (US Male)": "en-US-GuyNeural",
|
| 508 |
+
"Sonia (UK Female)": "en-GB-SoniaNeural",
|
| 509 |
+
"Tony (UK Male)": "en-GB-TonyNeural",
|
| 510 |
+
"Jenny (US Female)": "en-US-JennyNeural"
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
selected_voice = st.selectbox(
|
| 514 |
+
"Select Voice",
|
| 515 |
+
list(voices.keys()),
|
| 516 |
+
key=f"voice_{result.get('video_id', '')}"
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
if st.button("π Read Description", key=f"read_{result.get('video_id', '')}"):
|
| 520 |
+
text_to_read = ". ".join(text_content)
|
| 521 |
+
audio_file = asyncio.run(generate_speech(text_to_read, voices[selected_voice]))
|
| 522 |
+
if audio_file:
|
| 523 |
+
st.audio(audio_file)
|
| 524 |
|
| 525 |
def main():
|
| 526 |
st.title("π₯ Advanced Video & Dataset Search with Voice")
|