File size: 10,214 Bytes
c899329
bd5a98c
 
 
c899329
 
 
 
8853856
c899329
8853856
c899329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d091eda
f7f0991
c899329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f0991
c899329
f7f0991
c899329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f0991
 
c899329
 
f7f0991
c899329
 
f7f0991
 
c899329
f7f0991
 
 
 
c899329
f7f0991
 
 
c899329
f7f0991
 
 
c899329
 
f7f0991
c899329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a7a05f
 
 
c899329
4a7a05f
 
 
 
c899329
4a7a05f
 
 
 
 
 
c899329
4a7a05f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import os
# # import sys
# # src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "backend"))
# # sys.path.append(src_directory)
from pinecone import Pinecone, ServerlessSpec
import time
from tqdm import tqdm
from dotenv import load_dotenv
from backend.utils import logger
import pandas as pd
from backend.services.embedding_service import get_text_embedding
from sentence_transformers import CrossEncoder

reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

load_dotenv()
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
logger = logger.get_logger()
NAMESPACE = "health-care-dataset"
INDEX_NAME = "health-care-index"
PINECONE = Pinecone(api_key=PINECONE_API_KEY)

def rerank_results(query, results, score_threshold=0.5):
    pairs = [(query, result["metadata"]["question"]) for result in results]
    scores = reranker.predict(pairs)
    
    # Filter based on score threshold
    filtered_results = [
        result for score, result in zip(scores, results) if score >= score_threshold
    ]
    
    # Sort remaining results by score in descending order
    return sorted(filtered_results, key=lambda x: x['score'], reverse=True)

def initialize_pinecone_index(pinecone, index_name, dimension=384, metric="cosine", cloud="aws", region="us-east-1"):
    """
    Retrieves an existing Pinecone index or creates a new one if it does not exist.

    This method checks for the presence of the specified index. If the index does not exist,
    it initiates the creation process, waits until the index is ready, and then returns the index.

    Args:
        pinecone (Pinecone): Pinecone client instance.
        index_name (str): Name of the index to retrieve or create.
        dimension (int, optional): Vector dimension for the index. Default is 384.
        metric (str, optional): Distance metric for the index. Default is "cosine".
        cloud (str, optional): Cloud provider for hosting the index. Default is "aws".
        region (str, optional): Region where the index will be hosted. Default is "us-east-1".

    Returns:
        pinecone.Index: The Pinecone index instance.

    Raises:
        Exception: If an error occurs during index creation or retrieval.

    Example:
        >>> index = get_or_create_index(pinecone, "sample_index")
        Logs: "Index 'sample_index' is ready and accessible."
    """
    try:
        logger.info(f"Checking if the index '{index_name}' exists...")

        # Check if index already exists
        if not pinecone.has_index(index_name):
            logger.info(f"Index '{index_name}' does not exist. Creating a new index...")

            # Create a new index
            pinecone.create_index(
                name=index_name,
                dimension=dimension,
                metric=metric,
                spec=ServerlessSpec(cloud=cloud, region=region)
            )
            logger.info(f"Index '{index_name}' creation initiated. Waiting for it to be ready...")

            # Wait until index is ready
            while True:
                index_status = pinecone.describe_index(index_name)
                if index_status.status.get("ready", False):
                    index = pinecone.Index(index_name)
                    logger.info(f"Index '{index_name}' is ready and accessible.")
                    return index
                else:
                    logger.debug(f"Index '{index_name}' is not ready yet. Checking again in 1 second.")
                    time.sleep(1)
        else:
            # Return the existing index
            index = pinecone.Index(index_name)
            logger.info(f"Index '{index_name}' already exists. Returning the existing index.")
            return index

    except Exception as e:
        logger.error(f"Error occurred while getting or creating the Pinecone index: {str(e)}", exc_info=True)
        return None
    
index = initialize_pinecone_index(PINECONE, INDEX_NAME)
    
def delete_records_by_ids(ids_to_delete):
    """
    Deletes specified IDs from the database index.

    This method interacts with the index to delete entries based on the provided list of IDs.
    It logs a success message if the deletion is successful or returns an error message if it fails.

    Args:
        ids_to_delete (list): 
            A list of unique identifiers (IDs) to be deleted from the database.

    Returns:
        str: A success message is logged upon successful deletion.
             If an error occurs, a string describing the failure is returned.

    Raises:
        Exception: Logs an error if the deletion process encounters an issue.

    Example:
        >>> remove_ids_from_database(['id_123', 'id_456'])
        Logs: "IDs deleted successfully."

    Notes:
        - The method assumes `get_index()` initializes the index object.
        - Deletion occurs within the specified `NAMESPACE`.
    """
    try:
        index.delete(ids=ids_to_delete, namespace=NAMESPACE)
        logger.info("IDs deleted successfully.")
    except Exception as e:
        return f"Failed to delete the IDs: {e}"
    

def retrieve_relevant_metadata(embedding, prompt, n_result=3, score_threshold=0.47):
    """
    Retrieves and reranks relevant context data based on a given prompt.
    """
    try:
        response = index.query(
            top_k=n_result,
            vector=embedding,
            namespace=NAMESPACE,
            include_metadata=True
        )

        # Extract metadata and filter by score threshold
        filtered_results = [
            {
                "question": entry.get('metadata', {}).get('question', 'N/A'),
                "answer": entry.get('metadata', {}).get('answer', 'N/A'),
                "instruction": entry.get('metadata', {}).get('instruction', 'N/A'),
                "score": str(entry.get('score', 0)),
                "id": str(entry.get('id', 'N/A'))
            }
            for entry in response.get('matches', [])
            if float(entry.get('score', 0)) >= score_threshold
        ]

        logger.info(f"Retrieved filtered data: {filtered_results}")
        return filtered_results if filtered_results else [{"response": "No relevant data found."}]

        # # Rerank the filtered results using a reranker model
        # if filtered_results:
        #     pairs = [(prompt, item["question"]) for item in filtered_results]
        #     scores = reranker.predict(pairs)  # Predict relevance scores

        #     # Attach reranker scores and sort by relevance
        #     for item, score in zip(filtered_results, scores):
        #         item["reranker_score"] = score

        #     filtered_results.sort(key=lambda x: x["reranker_score"], reverse=True)

        # return filtered_results if filtered_results else [{"response": "No relevant data found."}]

    except Exception as e:
        logger.error(f"Failed to fetch context for prompt: '{prompt}'. Error: {e}")
        return [{"response": "Failed to fetch data due to an error."}]


def upsert_vector_data(df: pd.DataFrame):

    """
    Generates embeddings for the given DataFrame and uploads data to Pinecone in batches.
    
    Parameters:
    - df (pd.DataFrame): DataFrame containing 'input', 'question', and 'answer' columns.
    
    Returns:
    - None
    """

    try:
        df["embedding"] = [
            get_text_embedding([q])[0] 
            for q in tqdm(df["input"], desc="Generating Embeddings")
        ]
    except Exception as e:
        logger.error(f"Error generating embeddings: {e}")
        return

    # # Upload data to Pinecone in batches
    BATCH_SIZE = 500

    for i in tqdm(range(0, len(df), BATCH_SIZE), desc="Uploading Data to Pinecone"):
        batch = df.iloc[i : i + BATCH_SIZE]
    
        vectors = []
        for idx, (embedding, (_, row_data)) in enumerate(zip(batch["embedding"], batch.iterrows())):
            question = row_data.get("input")
            vector_id = f"{question[:50]}:{i + idx}"  # Ensures IDs remain unique across 
            metadata = {
                "question": row_data.get("input"),
                "answer": row_data.get("output"),
                "instruction": row_data.get("instruction"),
            }
            vectors.append((vector_id, embedding, metadata))

        try:
            index.upsert(vectors=vectors,namespace=NAMESPACE)
        except Exception as e:
            logger.error(f"Error uploading batch starting at index {i}: {e}")

    logger.info("All question-answer pairs stored successfully!")

def retrieve_context_from_pinecone(embedding, n_result=3, score_threshold=0.4):
    """
    Retrieves relevant context from Pinecone using vector embeddings.

    Args:
    - embedding (list): Embedding vector for query.
    - n_result (int): Number of top results to retrieve.
    - score_threshold (float): Minimum score threshold for relevance.

    Returns:
    - str: Combined context or fallback message.
    """
    if not embedding or not isinstance(embedding, list):
        logger.warning("Invalid embedding received.")
        return "No relevant context found."

    try:
        response = index.query(
            top_k=n_result,
            vector=embedding,
            namespace=NAMESPACE,
            include_metadata=True
        )

        # Validate response structure
        if 'matches' not in response:
            logger.error("Unexpected response structure from Pinecone.")
            return "No relevant context found."

        # Filter and extract metadata
        filtered_results = []
        for entry in response.get('matches', []):
            score = entry.get('score', 0)
            answer = entry.get('metadata', {}).get('answer', 'N/A')

            if score >= score_threshold:
                filtered_results.append(f"{answer} (Score: {score:.2f})")
            else:
                logger.info(f"Entry skipped due to low score: {score:.2f}")

        # Combine results
        context = "\n".join(filtered_results) if filtered_results else "No relevant context found."
        
        return context

    except Exception as e:
        logger.error(f"Unexpected error in Pinecone retrieval: {e}", exc_info=True)
        return "Error retrieving context. Please try again later."