File size: 7,117 Bytes
5ee0a10
 
 
7ccde22
 
017ee94
7ccde22
 
 
 
 
 
 
017ee94
b2bcde5
7ccde22
 
 
b2bcde5
7ccde22
 
 
 
 
 
017ee94
b2bcde5
7ccde22
017ee94
 
 
7ccde22
017ee94
 
 
 
7ccde22
017ee94
a223079
2dec497
 
a223079
7ccde22
2dec497
 
 
a223079
2dec497
a223079
 
 
 
 
 
 
7ccde22
 
 
 
 
a223079
7ccde22
a223079
 
 
 
 
7ccde22
a223079
 
 
 
 
 
 
 
 
 
 
 
 
 
7ccde22
5ee0a10
017ee94
 
5ee0a10
 
b2bcde5
017ee94
b2bcde5
5ee0a10
 
 
017ee94
5ee0a10
 
017ee94
5ee0a10
017ee94
 
 
5ee0a10
017ee94
 
 
 
 
 
 
5ee0a10
 
a223079
 
 
 
 
 
 
 
 
 
017ee94
 
a223079
 
5ee0a10
a223079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b73a811
 
 
 
2dec497
 
b73a811
 
 
2dec497
 
 
 
 
b73a811
2dec497
 
 
 
 
b73a811
 
2dec497
b73a811
2dec497
 
b73a811
 
 
 
2dec497
b73a811
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import numpy as np
import pandas as pd
import faiss
import zipfile
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st

class MetadataManager:
    def __init__(self):
        self.shard_dir = Path("metadata_shards")
        self.shard_map = {}
        self.loaded_shards = {}
        self.total_docs = 0
        self._ensure_unzipped()  # Removed Streamlit elements from here
        self._build_shard_map()

    def _ensure_unzipped(self):
        """Handle ZIP extraction without Streamlit elements"""
        if not self.shard_dir.exists():
            zip_path = Path("metadata_shards.zip")
            if zip_path.exists():
                with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                    zip_ref.extractall(self.shard_dir)
            else:
                raise FileNotFoundError("Metadata ZIP file not found")
                
    def _build_shard_map(self):
        """Create index range to shard mapping"""
        self.total_docs = 0
        for f in sorted(self.shard_dir.glob("*.parquet")):
            parts = f.stem.split("_")
            start = int(parts[1])
            end = int(parts[2])
            self.shard_map[(start, end)] = f.name
            self.total_docs = max(self.total_docs, end + 1)

    def get_metadata(self, global_indices):
        """Retrieve metadata with validation"""
        # Check for empty numpy array properly
        if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
            return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
        
        # Convert numpy array to list for processing
        indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices
        
        # Filter valid indices
        valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs]
        if not valid_indices:
            return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
    
        # Group indices by shard with boundary check
        shard_groups = {}
        for idx in valid_indices:
            found = False
            for (start, end), shard in self.shard_map.items():
                if start <= idx <= end:
                    if shard not in shard_groups:
                        shard_groups[shard] = []
                    shard_groups[shard].append(idx - start)
                    found = True
                    break
            if not found:
                st.warning(f"Index {idx} out of shard range (0-{self.total_docs-1})")
    
        # Load and process shards
        results = []
        for shard, local_indices in shard_groups.items():
            try:
                if shard not in self.loaded_shards:
                    self.loaded_shards[shard] = pd.read_parquet(
                        self.shard_dir / shard,
                        columns=["title", "summary", "source"]
                    )
                
                if local_indices:
                    results.append(self.loaded_shards[shard].iloc[local_indices])
            except Exception as e:
                st.error(f"Error loading shard {shard}: {str(e)}")
                continue
    
        return pd.concat(results).reset_index(drop=True) if results else pd.DataFrame()

class SemanticSearch:
    def __init__(self):
        self.shard_dir = Path("compressed_shards")
        self.model = None
        self.index_shards = []
        self.metadata_mgr = MetadataManager()  # No Streamlit elements in constructor
        self.shard_sizes = []
        
    @st.cache_resource
    def load_model(_self):
        return SentenceTransformer('all-MiniLM-L6-v2')

    def initialize_system(self):
        self.model = self.load_model()
        self._load_faiss_shards()

    def _load_faiss_shards(self):
        """Load all FAISS index shards"""
        self.shard_sizes = []
        for shard_path in sorted(self.shard_dir.glob("*.index")):
            index = faiss.read_index(str(shard_path))
            self.index_shards.append(index)
            self.shard_sizes.append(index.ntotal)

    def _global_index(self, shard_idx, local_idx):
        """Convert local index to global index"""
        return sum(self.shard_sizes[:shard_idx]) + local_idx

    def search(self, query, top_k=5):
        """Search with validation"""
        if not query or not self.index_shards:
            return pd.DataFrame()
        
        try:
            query_embedding = self.model.encode([query], convert_to_numpy=True)
        except Exception as e:
            st.error(f"Query encoding failed: {str(e)}")
            return pd.DataFrame()
    
        all_distances = []
        all_global_indices = []
    
        # Search with index validation
        for shard_idx, index in enumerate(self.index_shards):
            if index.ntotal == 0:
                continue
                
            try:
                distances, indices = index.search(query_embedding, top_k)
                valid_indices = [idx for idx in indices[0] if 0 <= idx < index.ntotal]
                global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
                
                all_distances.extend(distances[0][:len(valid_indices)])
                all_global_indices.extend(global_indices)
            except Exception as e:
                st.error(f"Search failed in shard {shard_idx}: {str(e)}")
                continue
    
        # Ensure equal array lengths
        min_length = min(len(all_distances), len(all_global_indices))
        return self._process_results(
            np.array(all_distances[:min_length]), 
            np.array(all_global_indices[:min_length]), 
            top_k
        )

    def _process_results(self, distances, global_indices, top_k):
        """Process raw search results into formatted DataFrame"""
        # Proper numpy array emptiness checks
        if global_indices.size == 0 or distances.size == 0:
            return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
        
        try:
            # Convert numpy indices to Python list for metadata retrieval
            indices_list = global_indices.tolist()
            
            # Get metadata for matched indices
            results = self.metadata_mgr.get_metadata(indices_list)
            
            # Ensure distances match results length
            if len(results) != len(distances):
                distances = distances[:len(results)]
                
            # Calculate similarity scores
            results['similarity'] = 1 - (distances / 2)
            
            # Deduplicate and sort results
            results = results.drop_duplicates(subset=["title", "source"])
                .sort_values("similarity", ascending=False)
                .head(top_k)
            
            return results.reset_index(drop=True)
            
        except Exception as e:
            st.error(f"Result processing failed: {str(e)}")
            return pd.DataFrame(columns=["title", "summary", "source", "similarity"])