File size: 4,157 Bytes
5ee0a10
 
 
7ccde22
 
017ee94
7ccde22
 
 
 
 
 
 
017ee94
7ccde22
 
 
 
017ee94
7ccde22
 
 
 
 
017ee94
7ccde22
017ee94
 
7ccde22
 
017ee94
 
 
7ccde22
017ee94
 
 
 
7ccde22
017ee94
 
7ccde22
 
 
017ee94
 
7ccde22
 
 
 
 
 
 
017ee94
7ccde22
 
 
 
 
 
 
 
 
 
5ee0a10
017ee94
 
5ee0a10
 
7ccde22
017ee94
 
5ee0a10
 
 
017ee94
5ee0a10
 
017ee94
5ee0a10
017ee94
 
 
5ee0a10
017ee94
 
 
 
 
 
 
5ee0a10
 
017ee94
5ee0a10
017ee94
 
 
 
5ee0a10
 
017ee94
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import pandas as pd
import faiss
import zipfile
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st

class MetadataManager:
    def __init__(self):
        self.shard_dir = Path("metadata_shards")
        self.shard_map = {}
        self.loaded_shards = {}
        self.total_docs = 0
        self._ensure_unzipped()
        self._build_shard_map()

    def _ensure_unzipped(self):
        """Handle ZIP extraction automatically"""
        if not self.shard_dir.exists():
            zip_path = Path("metadata_shards.zip")
            if zip_path.exists():
                with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                    zip_ref.extractall(self.shard_dir)
                st.toast("πŸ“¦ Metadata shards extracted successfully!", icon="βœ…")
            else:
                st.error("❌ Missing metadata_shards.zip file!")
                raise FileNotFoundError("Metadata ZIP file not found")

    def _build_shard_map(self):
        """Create index range to shard mapping"""
        self.total_docs = 0
        for f in sorted(self.shard_dir.glob("*.parquet")):
            parts = f.stem.split("_")
            start = int(parts[1])
            end = int(parts[2])
            self.shard_map[(start, end)] = f.name
            self.total_docs = max(self.total_docs, end + 1)

    def get_metadata(self, global_indices):
        """Retrieve metadata for global indices"""
        results = []
        shard_groups = {}
        
        # Organize indices by their respective shards
        for idx in global_indices:
            for (start, end), shard in self.shard_map.items():
                if start <= idx <= end:
                    if shard not in shard_groups:
                        shard_groups[shard] = []
                    shard_groups[shard].append(idx - start)
                    break
        
        # Load and process required shards
        for shard, local_indices in shard_groups.items():
            if shard not in self.loaded_shards:
                self.loaded_shards[shard] = pd.read_parquet(
                    self.shard_dir / shard,
                    columns=["title", "summary", "source"]
                )
            results.append(self.loaded_shards[shard].iloc[local_indices])
        
        return pd.concat(results).reset_index(drop=True)

class SemanticSearch:
    def __init__(self):
        self.shard_dir = Path("compressed_shards")
        self.model = None
        self.index_shards = []
        self.metadata_mgr = MetadataManager()
        self.shard_sizes = []

    @st.cache_resource
    def load_model(_self):
        return SentenceTransformer('all-MiniLM-L6-v2')

    def initialize_system(self):
        self.model = self.load_model()
        self._load_faiss_shards()

    def _load_faiss_shards(self):
        """Load all FAISS index shards"""
        self.shard_sizes = []
        for shard_path in sorted(self.shard_dir.glob("*.index")):
            index = faiss.read_index(str(shard_path))
            self.index_shards.append(index)
            self.shard_sizes.append(index.ntotal)

    def _global_index(self, shard_idx, local_idx):
        """Convert local index to global index"""
        return sum(self.shard_sizes[:shard_idx]) + local_idx

    def search(self, query, top_k=5):
        """Main search functionality"""
        query_embedding = self.model.encode([query], convert_to_numpy=True)
        all_distances = []
        all_global_indices = []

        # Search across all shards
        for shard_idx, index in enumerate(self.index_shards):
            distances, indices = index.search(query_embedding, top_k)
            global_indices = [self._global_index(shard_idx, idx) for idx in indices[0]]
            all_distances.extend(distances[0])
            all_global_indices.extend(global_indices)

        # Process and format results
        results = self.metadata_mgr.get_metadata(all_global_indices)
        results['similarity'] = 1 - (np.array(all_distances) / 2)  # Convert L2 to cosine
        return results.sort_values('similarity', ascending=False).head(top_k)