Spaces:
Paused
Paused
| from sentence_transformers import CrossEncoder | |
| import json | |
| import math | |
| import numpy as np | |
| from middlewares.search_client import SearchClient | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID") | |
| GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY") | |
| BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY") | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| googleSearchClient = SearchClient( | |
| "google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID | |
| ) | |
| bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None) | |
| def rerank(query, top_k, search_results, chunk_size=512): | |
| chunks = [] | |
| for result in search_results: | |
| text = result["text"] | |
| words = text.split() | |
| num_chunks = math.ceil(len(words) / chunk_size) | |
| for i in range(num_chunks): | |
| start = i * chunk_size | |
| end = (i + 1) * chunk_size | |
| chunk = " ".join(words[start:end]) | |
| chunks.append((result["link"], chunk)) | |
| # Create sentence combinations with the query | |
| sentence_combinations = [[query, chunk[1]] for chunk in chunks] | |
| # Compute similarity scores for these combinations | |
| similarity_scores = reranker.predict(sentence_combinations) | |
| # Sort scores indexes in decreasing order | |
| sim_scores_argsort = reversed(np.argsort(similarity_scores)) | |
| # Rearrange search_results based on the reranked scores | |
| reranked_results = [] | |
| for idx in sim_scores_argsort: | |
| link = chunks[idx][0] | |
| chunk = chunks[idx][1] | |
| reranked_results.append({"link": link, "text": chunk}) | |
| # Return the top K ranks | |
| return reranked_results[:top_k] | |
| def gen_augmented_prompt_via_websearch( | |
| prompt, | |
| search_vendor, | |
| n_crawl, | |
| top_k, | |
| pre_context="", | |
| post_context="", | |
| pre_prompt="", | |
| post_prompt="", | |
| pass_prev=False, | |
| prev_output="", | |
| chunk_size=512, | |
| ): | |
| try: | |
| search_results = [] | |
| reranked_results = [] | |
| if search_vendor == "Google": | |
| search_results = googleSearchClient.search(prompt, n_crawl) | |
| elif search_vendor == "Bing": | |
| print('[Bing search enabled]') | |
| search_results = bingSearchClient.search(prompt, n_crawl) | |
| print(search_results) | |
| print('[Bing search completed]') | |
| if len(search_results) > 0: | |
| reranked_results = rerank(prompt, top_k, search_results, chunk_size) | |
| except Exception as e: | |
| print(e) | |
| links = [] | |
| context = "" | |
| for res in reranked_results: | |
| context += res["text"] + "\n\n" | |
| link = res["link"] | |
| links.append(link) | |
| # remove duplicate links | |
| links = list(set(links)) | |
| prev_output = prev_output if pass_prev else "" | |
| augmented_prompt = f""" | |
| {pre_context} | |
| {context} | |
| {post_context} | |
| {pre_prompt} | |
| {prompt} | |
| {post_prompt} | |
| {prev_output} | |
| """ | |
| print(augmented_prompt) | |
| return augmented_prompt, links | |