import os import base64 import requests import numpy as np import faiss import re import logging from pathlib import Path from dotenv import load_dotenv from sentence_transformers import SentenceTransformer, CrossEncoder from langchain_groq import ChatGroq from langchain_core.prompts import ChatPromptTemplate # Optionally import BM25 for sparse retrieval. try: from rank_bm25 import BM25Okapi except ImportError: BM25Okapi = None # --------------------------- # Environment Setup # --------------------------- load_dotenv() # Set the cross-encoder model from environment or use a default SOTA model. CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2") # Setup a persistent session for GitHub API requests. session = requests.Session() session.headers.update({ "Authorization": f"token {os.getenv('GITHUB_API_KEY')}", "Accept": "application/vnd.github.v3+json" }) # --------------------------- # Langchain Groq Setup for Search Tag Conversion # --------------------------- llm = ChatGroq( model="deepseek-r1-distill-llama-70b", temperature=0.3, max_tokens=512, max_retries=3, ) prompt = ChatPromptTemplate.from_messages([ ("system", """You are a GitHub search optimization expert. Your job is to: 1. Read a user's query about tools, research, or tasks. 2. Detect if the query mentions a specific programming language other than Python (for example, JavaScript or JS). If so, record that language as the target language. 3. Think iteratively and generate your internal chain-of-thought enclosed in ... tags. 4. After your internal reasoning, output up to five GitHub-style search tags or library names that maximize repository discovery. Use as many tags as necessary based on the query's complexity, but never more than five. 5. If you detected a non-Python target language, append an additional tag at the end in the format target-[language] (e.g., target-javascript). If no specific language is mentioned, do not include any target tag. Output Format: tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]] Rules: - Use lowercase and hyphenated keywords (e.g., image-augmentation, chain-of-thought). - Use terms commonly found in GitHub repo names, topics, or descriptions. - Avoid generic terms like "python", "ai", "tool", "project". - Do NOT use full phrases or vague words like "no-code", "framework", or "approach". - Prefer real tools, popular methods, or dataset names when mentioned. - If your output does not strictly match the required format, correct it after your internal reasoning. - Choose high-signal keywords to ensure the search yields the most relevant GitHub repositories. Excellent Examples: Input: "No code tool to augment image and annotation" Output: image-augmentation:albumentations Input: "Repos around chain of thought prompting mainly for finetuned models" Output: chain-of-thought:finetuned-llm Input: "Find repositories implementing data augmentation pipelines in JavaScript" Output: data-augmentation:target-javascript Output must be ONLY the search tags separated by colons. Do not include any extra text, bullet points, or explanations. """), ("human", "{query}") ]) chain = prompt | llm def valid_tags(tags: str) -> bool: """ Validates that the output is one to six colon-separated tokens composed of lowercase letters, numbers, and hyphens. """ pattern = r'^[a-z0-9-]+(?::[a-z0-9-]+){1,5}$' return re.match(pattern, tags) is not None def parse_search_tags(response: str) -> str: """ Extracts a valid colon-separated tag string from the LLM response. This function removes any chain-of-thought commentary. """ # Remove any text inside ... blocks. cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL) # Use regex to find a valid tag pattern. pattern = r'([a-z0-9-]+(?::[a-z0-9-]+){1,5})' match = re.search(pattern, cleaned) if match: return match.group(1).strip() return cleaned.strip() def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str: print(f"\n🧠 [iterative_convert_to_search_tags] Input Query: {query}") refined_query = query tags_output = "" for iteration in range(max_iterations): print(f"\n🔄 Iteration {iteration+1}") response = chain.invoke({"query": refined_query}) full_output = response.content.strip() tags_output = parse_search_tags(full_output) print(f"Output Tags: {tags_output}") if valid_tags(tags_output): print("✅ Valid tags format detected.") return tags_output else: print("⚠️ Invalid tags format. Requesting refinement...") refined_query = f"{query}\nPlease refine your answer so that the output strictly matches the format: tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]." print("Final output (may be invalid):", tags_output) return tags_output # --------------------------- # GitHub API Helper Functions # --------------------------- def fetch_readme_content(repo_full_name): """Fetch the README content (if available) using the GitHub API.""" readme_url = f"https://api.github.com/repos/{repo_full_name}/readme" response = session.get(readme_url) if response.status_code == 200: readme_data = response.json() try: return base64.b64decode(readme_data.get('content', '')).decode('utf-8', errors='replace') except Exception: return "" return "" def fetch_markdown_contents(repo_full_name): """ Fetch all markdown files (except the README already fetched) from the root of the repository. """ url = f"https://api.github.com/repos/{repo_full_name}/contents" response = session.get(url) contents = "" if response.status_code == 200: items = response.json() for item in items: if item.get("type") == "file" and item.get("name", "").lower().endswith(".md"): file_url = item.get("download_url") if file_url: file_resp = requests.get(file_url) if file_resp.status_code == 200: contents += "\n" + file_resp.text return contents def fetch_all_markdown(repo_full_name): """Combine README with all markdown contents from the repository root.""" readme = fetch_readme_content(repo_full_name) other_md = fetch_markdown_contents(repo_full_name) return readme + "\n" + other_md def fetch_github_repositories(query, max_results=10): """ Searches GitHub repositories using the provided query and retrieves key information. """ url = "https://api.github.com/search/repositories" params = { "q": query, "per_page": max_results } response = session.get(url, params=params) if response.status_code != 200: print(f"Error {response.status_code}: {response.json().get('message')}") return [] repo_list = [] for repo in response.json().get('items', []): repo_link = repo.get('html_url') description = repo.get('description') or "" combined_markdown = fetch_all_markdown(repo.get('full_name')) combined_text = (description + "\n" + combined_markdown).strip() repo_list.append({ "title": repo.get('name', 'No title available'), "link": repo_link, "combined_text": combined_text }) return repo_list # --------------------------- # Initialize SentenceTransformer Model for Dense Retrieval # --------------------------- model = SentenceTransformer('all-mpnet-base-v2') def robust_min_max_norm(scores): """ Performs min-max normalization while avoiding division by zero. """ min_val = scores.min() max_val = scores.max() if max_val - min_val < 1e-10: return np.ones_like(scores) return (scores - min_val) / (max_val - min_val) # --------------------------- # Cross-Encoder Re-Ranking Function # --------------------------- def cross_encoder_rerank_candidates(candidates, query, model_name, top_n=10): """ Re-ranks candidate repositories using a cross-encoder model. For long documents, the text is split into chunks and scores are aggregated. """ cross_encoder = CrossEncoder(model_name) CHUNK_SIZE = 2000 # characters per chunk MAX_DOC_LENGTH = 5000 # cap for long docs MIN_DOC_LENGTH = 200 # threshold for short docs def split_text(text, chunk_size=CHUNK_SIZE): return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] for candidate in candidates: doc = candidate.get("combined_text", "") if len(doc) > MAX_DOC_LENGTH: doc = doc[:MAX_DOC_LENGTH] try: if len(doc) < MIN_DOC_LENGTH: score = cross_encoder.predict([[query, doc]]) candidate["cross_encoder_score"] = float(score[0]) else: chunks = split_text(doc) pairs = [[query, chunk] for chunk in chunks] scores = cross_encoder.predict(pairs) max_score = np.max(scores) if len(scores) > 0 else 0.0 avg_score = np.mean(scores) if len(scores) > 0 else 0.0 candidate["cross_encoder_score"] = float(0.5 * max_score + 0.5 * avg_score) except Exception as e: logging.error(f"Error scoring candidate {candidate.get('link', 'unknown')}: {e}") candidate["cross_encoder_score"] = 0.0 all_scores = [candidate["cross_encoder_score"] for candidate in candidates] if all_scores: min_score = min(all_scores) if min_score < 0: for candidate in candidates: candidate["cross_encoder_score"] += -min_score reranked = sorted(candidates, key=lambda x: x["cross_encoder_score"], reverse=True) return reranked[:top_n] # --------------------------- # Main Function: Repository Ranking with Hybrid Retrieval and Cross-Encoder Re-Ranking # --------------------------- def run_repository_ranking(query: str) -> str: """ Converts the user query into search tags, runs multiple GitHub queries (individual and combined), deduplicates results, and applies a hybrid ranking strategy: - Dense embeddings (via SentenceTransformer) combined with BM25 scoring. - Re-ranks top candidates using a cross-encoder for improved contextual alignment. """ # Step 1: Generate search tags from the query. search_tags = iterative_convert_to_search_tags(query) tag_list = [tag.strip() for tag in search_tags.split(":") if tag.strip()] # Step 2: Handle target language extraction. if any(tag.startswith("target-") for tag in tag_list): target_tag = next(tag for tag in tag_list if tag.startswith("target-")) target_lang = target_tag.replace("target-", "") lang_query = f"language:{target_lang}" tag_list = [tag for tag in tag_list if not tag.startswith("target-")] else: lang_query = "language:python" # Step 3: Build advanced search qualifiers. advanced_qualifier = "in:name,description,readme" all_repositories = [] # Loop over individual tags. for tag in tag_list: github_query = f"{tag} {advanced_qualifier} {lang_query}" print("GitHub Query:", github_query) repos = fetch_github_repositories(github_query, max_results=15) all_repositories.extend(repos) # Combined query using OR logic. combined_query = " OR ".join(tag_list) combined_query = f"({combined_query}) {advanced_qualifier} {lang_query}" print("Combined GitHub Query:", combined_query) repos = fetch_github_repositories(combined_query, max_results=15) all_repositories.extend(repos) # Deduplicate repositories using the repo link. unique_repositories = {} for repo in all_repositories: if repo["link"] not in unique_repositories: unique_repositories[repo["link"]] = repo else: existing_text = unique_repositories[repo["link"]]["combined_text"] unique_repositories[repo["link"]]["combined_text"] = existing_text + "\n" + repo["combined_text"] repositories = list(unique_repositories.values()) if not repositories: return "No repositories found for your query." # Step 4: Prepare documents. docs = [repo.get("combined_text", "") for repo in repositories] # Step 5: Dense retrieval. doc_embeddings = model.encode(docs, convert_to_numpy=True, show_progress_bar=True, batch_size=16) if doc_embeddings.ndim == 1: doc_embeddings = doc_embeddings.reshape(1, -1) norms = np.linalg.norm(doc_embeddings, axis=1, keepdims=True) norm_doc_embeddings = doc_embeddings / (norms + 1e-10) query_embedding = model.encode(query, convert_to_numpy=True) if query_embedding.ndim == 1: query_embedding = query_embedding.reshape(1, -1) norm_query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-10) dim = norm_doc_embeddings.shape[1] index = faiss.IndexFlatIP(dim) index.add(norm_doc_embeddings) k = norm_doc_embeddings.shape[0] D, I = index.search(norm_query_embedding, k) dense_scores = D.squeeze() norm_dense_scores = robust_min_max_norm(dense_scores) # Step 6: BM25 scoring. if BM25Okapi is not None: tokenized_docs = [re.findall(r'\w+', doc.lower()) for doc in docs] bm25 = BM25Okapi(tokenized_docs) query_tokens = re.findall(r'\w+', query.lower()) bm25_scores = np.array(bm25.get_scores(query_tokens)) norm_bm25_scores = robust_min_max_norm(bm25_scores) else: norm_bm25_scores = np.zeros_like(norm_dense_scores) # Step 7: Combine scores. alpha = 0.8 combined_scores = alpha * norm_dense_scores + (1 - alpha) * norm_bm25_scores for idx, repo in enumerate(repositories): repo["combined_score"] = float(combined_scores[idx]) # Step 8: Initial ranking. ranked_repositories = sorted(repositories, key=lambda x: x.get("combined_score", 0), reverse=True) # Step 9: Cross-Encoder Re-Ranking. top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories final_ranked = cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=10) # Step 10: Format output. output = "\n=== Ranked Repositories ===\n" for rank, repo in enumerate(final_ranked, 1): output += f"Final Rank: {rank}\n" output += f"Title: {repo['title']}\n" output += f"Link: {repo['link']}\n" output += f"Combined Score: {repo.get('combined_score', 0):.4f}\n" output += f"Cross-Encoder Score: {repo.get('cross_encoder_score', 0):.4f}\n" snippet = repo['combined_text'][:300].replace('\n', ' ') output += f"Snippet: {snippet}...\n" output += '-' * 80 + "\n" output += "\n=== End of Results ===" return output # --------------------------- # Main Entry Point for Testing # --------------------------- if __name__ == "__main__": test_query = "Chain of thought prompting for reasoning models" result = run_repository_ranking(test_query) print(result)