Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import os | |
from datetime import datetime | |
from datasets import load_dataset | |
# Initialize session state | |
if 'search_history' not in st.session_state: | |
st.session_state['search_history'] = [] | |
if 'search_columns' not in st.session_state: | |
st.session_state['search_columns'] = [] | |
if 'dataset_loaded' not in st.session_state: | |
st.session_state['dataset_loaded'] = False | |
if 'current_page' not in st.session_state: | |
st.session_state['current_page'] = 0 | |
if 'data_cache' not in st.session_state: | |
st.session_state['data_cache'] = None | |
ROWS_PER_PAGE = 100 # Number of rows to load at a time | |
def get_model(): | |
return SentenceTransformer('all-MiniLM-L6-v2') | |
class FastDatasetSearcher: | |
def __init__(self, dataset_id="tomg-group-umd/cinepile"): | |
self.dataset_id = dataset_id | |
self.text_model = get_model() | |
self.token = os.environ.get('DATASET_KEY') | |
if not self.token: | |
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") | |
st.stop() | |
self.load_dataset_info() | |
def load_dataset_info(self): | |
"""Load dataset metadata only""" | |
try: | |
dataset = load_dataset( | |
self.dataset_id, | |
token=self.token, | |
streaming=True | |
) | |
self.dataset_info = dataset['train'].info | |
return True | |
except Exception as e: | |
st.error(f"Error loading dataset: {str(e)}") | |
return False | |
def load_page(self, page=0): | |
"""Load a specific page of data""" | |
if st.session_state['data_cache'] is not None and st.session_state['current_page'] == page: | |
return st.session_state['data_cache'] | |
try: | |
dataset = load_dataset( | |
self.dataset_id, | |
token=self.token, | |
streaming=False, | |
split=f'train[{page*ROWS_PER_PAGE}:{(page+1)*ROWS_PER_PAGE}]' | |
) | |
df = pd.DataFrame(dataset) | |
st.session_state['data_cache'] = df | |
st.session_state['current_page'] = page | |
return df | |
except Exception as e: | |
st.error(f"Error loading page {page}: {str(e)}") | |
return pd.DataFrame() | |
def quick_search(self, query, df): | |
"""Fast search on current page""" | |
scores = [] | |
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] | |
for _, row in df.iterrows(): | |
# Combine all searchable text fields | |
text = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float))) | |
# Quick keyword match | |
keyword_score = text.lower().count(query.lower()) / len(text.split()) | |
# Semantic search on combined text | |
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] | |
semantic_score = cosine_similarity([query_embedding], [text_embedding])[0][0] | |
# Combine scores | |
combined_score = 0.5 * semantic_score + 0.5 * keyword_score | |
scores.append(combined_score) | |
# Get top results | |
df['score'] = scores | |
return df.sort_values('score', ascending=False) | |
def main(): | |
st.title("π₯ Fast Video Dataset Search") | |
# Initialize search class | |
searcher = FastDatasetSearcher() | |
# Page navigation | |
page = st.number_input("Page", min_value=0, value=st.session_state['current_page']) | |
# Load current page | |
with st.spinner(f"Loading page {page}..."): | |
df = searcher.load_page(page) | |
if df.empty: | |
st.warning("No data available for this page.") | |
return | |
# Search interface | |
query = st.text_input("Search in current page:", help="Searches within currently loaded data") | |
if query: | |
with st.spinner("Searching..."): | |
results = searcher.quick_search(query, df) | |
# Display results | |
st.write(f"Found {len(results)} results on this page:") | |
for i, (_, result) in enumerate(results.iterrows(), 1): | |
score = result.pop('score') | |
with st.expander(f"Result {i} (Score: {score:.2%})", expanded=i==1): | |
# Display video if available | |
if 'youtube_id' in result: | |
st.video( | |
f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}" | |
) | |
# Display other fields | |
for key, value in result.items(): | |
if isinstance(value, (str, int, float)): | |
st.write(f"**{key}:** {value}") | |
# Show raw data | |
st.subheader("Raw Data") | |
st.dataframe(df) | |
# Navigation buttons | |
cols = st.columns(2) | |
with cols[0]: | |
if st.button("Previous Page") and page > 0: | |
st.session_state['current_page'] -= 1 | |
st.rerun() | |
with cols[1]: | |
if st.button("Next Page"): | |
st.session_state['current_page'] += 1 | |
st.rerun() | |
if __name__ == "__main__": | |
main() |