File size: 2,521 Bytes
5ee0a10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import pandas as pd
import faiss
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st

class SemanticSearch:
    def __init__(self, shard_dir="compressed_shards"):
        self.shard_dir = Path(shard_dir)
        self.shard_dir.mkdir(exist_ok=True, parents=True)
        self.model = None
        self.index_shards = []
        
    @st.cache_resource
    def load_model(_self):
        return SentenceTransformer('all-MiniLM-L6-v2')
        
    def initialize_system(self):
        self.model = self.load_model()
        self._load_index_shards()

    def _load_index_shards(self):
        """Load FAISS shards directly from local directory"""
        for shard_path in sorted(self.shard_dir.glob("*.index")):
            self.index_shards.append(faiss.read_index(str(shard_path)))

    def search(self, query, top_k=5):
        """Search across all shards"""
        query_embedding = self.model.encode([query], convert_to_numpy=True)
        all_scores = []
        all_indices = []
        
        for shard_idx, index in enumerate(self.index_shards):
            distances, indices = index.search(query_embedding, top_k)
            # Convert local indices to global shard offsets
            global_indices = [
                self._calculate_global_index(shard_idx, idx)
                for idx in indices[0]
            ]
            all_scores.extend(distances[0])
            all_indices.extend(global_indices)
            
        return self._process_results(np.array(all_scores), np.array(all_indices), top_k)

    def _calculate_global_index(self, shard_idx, local_idx):
        """Convert shard-local index to global index"""
        # Implement your specific shard indexing logic here
        # Example: return f"{shard_idx}-{local_idx}"
        return local_idx  # Simple version if using unique IDs

    def _process_results(self, distances, indices, top_k):
        """Format search results"""
        results = pd.DataFrame({
            'global_index': indices,
            'similarity': 1 - (distances / 2)  # L2 to cosine approximation
        })
        return results.sort_values('similarity', ascending=False).head(top_k)

    def search_with_threshold(self, query, top_k=5, similarity_threshold=0.6):
        """Threshold-filtered search"""
        results = self.search(query, top_k*2)
        filtered = results[results['similarity'] > similarity_threshold].head(top_k)
        return filtered.reset_index(drop=True)