import os
import base64
import requests
import numpy as np
import faiss
import re
import logging
from pathlib import Path
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer, CrossEncoder
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
# Optionally import BM25 for sparse retrieval.
try:
from rank_bm25 import BM25Okapi
except ImportError:
BM25Okapi = None
# ---------------------------
# Environment Setup
# ---------------------------
load_dotenv()
# Set the cross-encoder model from environment or use a default SOTA model.
CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
# Setup a persistent session for GitHub API requests.
session = requests.Session()
session.headers.update({
"Authorization": f"token {os.getenv('GITHUB_API_KEY')}",
"Accept": "application/vnd.github.v3+json"
})
# ---------------------------
# Langchain Groq Setup for Search Tag Conversion
# ---------------------------
llm = ChatGroq(
model="deepseek-r1-distill-llama-70b",
temperature=0.3,
max_tokens=512,
max_retries=3,
)
prompt = ChatPromptTemplate.from_messages([
("system",
"""You are a GitHub search optimization expert.
Your job is to:
1. Read a user's query about tools, research, or tasks.
2. Detect if the query mentions a specific programming language other than Python (for example, JavaScript or JS). If so, record that language as the target language.
3. Think iteratively and generate your internal chain-of-thought enclosed in ... tags.
4. After your internal reasoning, output up to five GitHub-style search tags or library names that maximize repository discovery.
Use as many tags as necessary based on the query's complexity, but never more than five.
5. If you detected a non-Python target language, append an additional tag at the end in the format target-[language] (e.g., target-javascript).
If no specific language is mentioned, do not include any target tag.
Output Format:
tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]
Rules:
- Use lowercase and hyphenated keywords (e.g., image-augmentation, chain-of-thought).
- Use terms commonly found in GitHub repo names, topics, or descriptions.
- Avoid generic terms like "python", "ai", "tool", "project".
- Do NOT use full phrases or vague words like "no-code", "framework", or "approach".
- Prefer real tools, popular methods, or dataset names when mentioned.
- If your output does not strictly match the required format, correct it after your internal reasoning.
- Choose high-signal keywords to ensure the search yields the most relevant GitHub repositories.
Excellent Examples:
Input: "No code tool to augment image and annotation"
Output: image-augmentation:albumentations
Input: "Repos around chain of thought prompting mainly for finetuned models"
Output: chain-of-thought:finetuned-llm
Input: "Find repositories implementing data augmentation pipelines in JavaScript"
Output: data-augmentation:target-javascript
Output must be ONLY the search tags separated by colons. Do not include any extra text, bullet points, or explanations.
"""),
("human", "{query}")
])
chain = prompt | llm
def valid_tags(tags: str) -> bool:
"""
Validates that the output is one to six colon-separated tokens composed
of lowercase letters, numbers, and hyphens.
"""
pattern = r'^[a-z0-9-]+(?::[a-z0-9-]+){1,5}$'
return re.match(pattern, tags) is not None
def parse_search_tags(response: str) -> str:
"""
Extracts a valid colon-separated tag string from the LLM response.
This function removes any chain-of-thought commentary.
"""
# Remove any text inside ... blocks.
cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL)
# Use regex to find a valid tag pattern.
pattern = r'([a-z0-9-]+(?::[a-z0-9-]+){1,5})'
match = re.search(pattern, cleaned)
if match:
return match.group(1).strip()
return cleaned.strip()
def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str:
print(f"\n🧠 [iterative_convert_to_search_tags] Input Query: {query}")
refined_query = query
tags_output = ""
for iteration in range(max_iterations):
print(f"\n🔄 Iteration {iteration+1}")
response = chain.invoke({"query": refined_query})
full_output = response.content.strip()
tags_output = parse_search_tags(full_output)
print(f"Output Tags: {tags_output}")
if valid_tags(tags_output):
print("✅ Valid tags format detected.")
return tags_output
else:
print("⚠️ Invalid tags format. Requesting refinement...")
refined_query = f"{query}\nPlease refine your answer so that the output strictly matches the format: tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]."
print("Final output (may be invalid):", tags_output)
return tags_output
# ---------------------------
# GitHub API Helper Functions
# ---------------------------
def fetch_readme_content(repo_full_name):
"""Fetch the README content (if available) using the GitHub API."""
readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
response = session.get(readme_url)
if response.status_code == 200:
readme_data = response.json()
try:
return base64.b64decode(readme_data.get('content', '')).decode('utf-8', errors='replace')
except Exception:
return ""
return ""
def fetch_markdown_contents(repo_full_name):
"""
Fetch all markdown files (except the README already fetched) from the root of the repository.
"""
url = f"https://api.github.com/repos/{repo_full_name}/contents"
response = session.get(url)
contents = ""
if response.status_code == 200:
items = response.json()
for item in items:
if item.get("type") == "file" and item.get("name", "").lower().endswith(".md"):
file_url = item.get("download_url")
if file_url:
file_resp = requests.get(file_url)
if file_resp.status_code == 200:
contents += "\n" + file_resp.text
return contents
def fetch_all_markdown(repo_full_name):
"""Combine README with all markdown contents from the repository root."""
readme = fetch_readme_content(repo_full_name)
other_md = fetch_markdown_contents(repo_full_name)
return readme + "\n" + other_md
def fetch_github_repositories(query, max_results=10):
"""
Searches GitHub repositories using the provided query and retrieves key information.
"""
url = "https://api.github.com/search/repositories"
params = {
"q": query,
"per_page": max_results
}
response = session.get(url, params=params)
if response.status_code != 200:
print(f"Error {response.status_code}: {response.json().get('message')}")
return []
repo_list = []
for repo in response.json().get('items', []):
repo_link = repo.get('html_url')
description = repo.get('description') or ""
combined_markdown = fetch_all_markdown(repo.get('full_name'))
combined_text = (description + "\n" + combined_markdown).strip()
repo_list.append({
"title": repo.get('name', 'No title available'),
"link": repo_link,
"combined_text": combined_text
})
return repo_list
# ---------------------------
# Initialize SentenceTransformer Model for Dense Retrieval
# ---------------------------
model = SentenceTransformer('all-mpnet-base-v2')
def robust_min_max_norm(scores):
"""
Performs min-max normalization while avoiding division by zero.
"""
min_val = scores.min()
max_val = scores.max()
if max_val - min_val < 1e-10:
return np.ones_like(scores)
return (scores - min_val) / (max_val - min_val)
# ---------------------------
# Cross-Encoder Re-Ranking Function
# ---------------------------
def cross_encoder_rerank_candidates(candidates, query, model_name, top_n=10):
"""
Re-ranks candidate repositories using a cross-encoder model.
For long documents, the text is split into chunks and scores are aggregated.
"""
cross_encoder = CrossEncoder(model_name)
CHUNK_SIZE = 2000 # characters per chunk
MAX_DOC_LENGTH = 5000 # cap for long docs
MIN_DOC_LENGTH = 200 # threshold for short docs
def split_text(text, chunk_size=CHUNK_SIZE):
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
for candidate in candidates:
doc = candidate.get("combined_text", "")
if len(doc) > MAX_DOC_LENGTH:
doc = doc[:MAX_DOC_LENGTH]
try:
if len(doc) < MIN_DOC_LENGTH:
score = cross_encoder.predict([[query, doc]])
candidate["cross_encoder_score"] = float(score[0])
else:
chunks = split_text(doc)
pairs = [[query, chunk] for chunk in chunks]
scores = cross_encoder.predict(pairs)
max_score = np.max(scores) if len(scores) > 0 else 0.0
avg_score = np.mean(scores) if len(scores) > 0 else 0.0
candidate["cross_encoder_score"] = float(0.5 * max_score + 0.5 * avg_score)
except Exception as e:
logging.error(f"Error scoring candidate {candidate.get('link', 'unknown')}: {e}")
candidate["cross_encoder_score"] = 0.0
all_scores = [candidate["cross_encoder_score"] for candidate in candidates]
if all_scores:
min_score = min(all_scores)
if min_score < 0:
for candidate in candidates:
candidate["cross_encoder_score"] += -min_score
reranked = sorted(candidates, key=lambda x: x["cross_encoder_score"], reverse=True)
return reranked[:top_n]
# ---------------------------
# Main Function: Repository Ranking with Hybrid Retrieval and Cross-Encoder Re-Ranking
# ---------------------------
def run_repository_ranking(query: str) -> str:
"""
Converts the user query into search tags, runs multiple GitHub queries (individual and combined),
deduplicates results, and applies a hybrid ranking strategy:
- Dense embeddings (via SentenceTransformer) combined with BM25 scoring.
- Re-ranks top candidates using a cross-encoder for improved contextual alignment.
"""
# Step 1: Generate search tags from the query.
search_tags = iterative_convert_to_search_tags(query)
tag_list = [tag.strip() for tag in search_tags.split(":") if tag.strip()]
# Step 2: Handle target language extraction.
if any(tag.startswith("target-") for tag in tag_list):
target_tag = next(tag for tag in tag_list if tag.startswith("target-"))
target_lang = target_tag.replace("target-", "")
lang_query = f"language:{target_lang}"
tag_list = [tag for tag in tag_list if not tag.startswith("target-")]
else:
lang_query = "language:python"
# Step 3: Build advanced search qualifiers.
advanced_qualifier = "in:name,description,readme"
all_repositories = []
# Loop over individual tags.
for tag in tag_list:
github_query = f"{tag} {advanced_qualifier} {lang_query}"
print("GitHub Query:", github_query)
repos = fetch_github_repositories(github_query, max_results=15)
all_repositories.extend(repos)
# Combined query using OR logic.
combined_query = " OR ".join(tag_list)
combined_query = f"({combined_query}) {advanced_qualifier} {lang_query}"
print("Combined GitHub Query:", combined_query)
repos = fetch_github_repositories(combined_query, max_results=15)
all_repositories.extend(repos)
# Deduplicate repositories using the repo link.
unique_repositories = {}
for repo in all_repositories:
if repo["link"] not in unique_repositories:
unique_repositories[repo["link"]] = repo
else:
existing_text = unique_repositories[repo["link"]]["combined_text"]
unique_repositories[repo["link"]]["combined_text"] = existing_text + "\n" + repo["combined_text"]
repositories = list(unique_repositories.values())
if not repositories:
return "No repositories found for your query."
# Step 4: Prepare documents.
docs = [repo.get("combined_text", "") for repo in repositories]
# Step 5: Dense retrieval.
doc_embeddings = model.encode(docs, convert_to_numpy=True, show_progress_bar=True, batch_size=16)
if doc_embeddings.ndim == 1:
doc_embeddings = doc_embeddings.reshape(1, -1)
norms = np.linalg.norm(doc_embeddings, axis=1, keepdims=True)
norm_doc_embeddings = doc_embeddings / (norms + 1e-10)
query_embedding = model.encode(query, convert_to_numpy=True)
if query_embedding.ndim == 1:
query_embedding = query_embedding.reshape(1, -1)
norm_query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
dim = norm_doc_embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(norm_doc_embeddings)
k = norm_doc_embeddings.shape[0]
D, I = index.search(norm_query_embedding, k)
dense_scores = D.squeeze()
norm_dense_scores = robust_min_max_norm(dense_scores)
# Step 6: BM25 scoring.
if BM25Okapi is not None:
tokenized_docs = [re.findall(r'\w+', doc.lower()) for doc in docs]
bm25 = BM25Okapi(tokenized_docs)
query_tokens = re.findall(r'\w+', query.lower())
bm25_scores = np.array(bm25.get_scores(query_tokens))
norm_bm25_scores = robust_min_max_norm(bm25_scores)
else:
norm_bm25_scores = np.zeros_like(norm_dense_scores)
# Step 7: Combine scores.
alpha = 0.8
combined_scores = alpha * norm_dense_scores + (1 - alpha) * norm_bm25_scores
for idx, repo in enumerate(repositories):
repo["combined_score"] = float(combined_scores[idx])
# Step 8: Initial ranking.
ranked_repositories = sorted(repositories, key=lambda x: x.get("combined_score", 0), reverse=True)
# Step 9: Cross-Encoder Re-Ranking.
top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories
final_ranked = cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=10)
# Step 10: Format output.
output = "\n=== Ranked Repositories ===\n"
for rank, repo in enumerate(final_ranked, 1):
output += f"Final Rank: {rank}\n"
output += f"Title: {repo['title']}\n"
output += f"Link: {repo['link']}\n"
output += f"Combined Score: {repo.get('combined_score', 0):.4f}\n"
output += f"Cross-Encoder Score: {repo.get('cross_encoder_score', 0):.4f}\n"
snippet = repo['combined_text'][:300].replace('\n', ' ')
output += f"Snippet: {snippet}...\n"
output += '-' * 80 + "\n"
output += "\n=== End of Results ==="
return output
# ---------------------------
# Main Entry Point for Testing
# ---------------------------
if __name__ == "__main__":
test_query = "Chain of thought prompting for reasoning models"
result = run_repository_ranking(test_query)
print(result)