DeepGit-lite / main.py
zamalali
Added changes to enhance the metrics
8d67bd2
raw
history blame
15.4 kB
import os
import base64
import requests
import numpy as np
import faiss
import re
import logging
from pathlib import Path
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer, CrossEncoder
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
# Optionally import BM25 for sparse retrieval.
try:
from rank_bm25 import BM25Okapi
except ImportError:
BM25Okapi = None
# ---------------------------
# Environment Setup
# ---------------------------
load_dotenv()
# Set the cross-encoder model from environment or use a default SOTA model.
CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
# Setup a persistent session for GitHub API requests.
session = requests.Session()
session.headers.update({
"Authorization": f"token {os.getenv('GITHUB_API_KEY')}",
"Accept": "application/vnd.github.v3+json"
})
# ---------------------------
# Langchain Groq Setup for Search Tag Conversion
# ---------------------------
llm = ChatGroq(
model="deepseek-r1-distill-llama-70b",
temperature=0.3,
max_tokens=512,
max_retries=3,
)
prompt = ChatPromptTemplate.from_messages([
("system",
"""You are a GitHub search optimization expert.
Your job is to:
1. Read a user's query about tools, research, or tasks.
2. Detect if the query mentions a specific programming language other than Python (for example, JavaScript or JS). If so, record that language as the target language.
3. Think iteratively and generate your internal chain-of-thought enclosed in <think> ... </think> tags.
4. After your internal reasoning, output up to five GitHub-style search tags or library names that maximize repository discovery.
Use as many tags as necessary based on the query's complexity, but never more than five.
5. If you detected a non-Python target language, append an additional tag at the end in the format target-[language] (e.g., target-javascript).
If no specific language is mentioned, do not include any target tag.
Output Format:
tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]
Rules:
- Use lowercase and hyphenated keywords (e.g., image-augmentation, chain-of-thought).
- Use terms commonly found in GitHub repo names, topics, or descriptions.
- Avoid generic terms like "python", "ai", "tool", "project".
- Do NOT use full phrases or vague words like "no-code", "framework", or "approach".
- Prefer real tools, popular methods, or dataset names when mentioned.
- If your output does not strictly match the required format, correct it after your internal reasoning.
- Choose high-signal keywords to ensure the search yields the most relevant GitHub repositories.
Excellent Examples:
Input: "No code tool to augment image and annotation"
Output: image-augmentation:albumentations
Input: "Repos around chain of thought prompting mainly for finetuned models"
Output: chain-of-thought:finetuned-llm
Input: "Find repositories implementing data augmentation pipelines in JavaScript"
Output: data-augmentation:target-javascript
Output must be ONLY the search tags separated by colons. Do not include any extra text, bullet points, or explanations.
"""),
("human", "{query}")
])
chain = prompt | llm
def valid_tags(tags: str) -> bool:
"""
Validates that the output is one to six colon-separated tokens composed
of lowercase letters, numbers, and hyphens.
"""
pattern = r'^[a-z0-9-]+(?::[a-z0-9-]+){1,5}$'
return re.match(pattern, tags) is not None
def parse_search_tags(response: str) -> str:
"""
Extracts a valid colon-separated tag string from the LLM response.
This function removes any chain-of-thought commentary.
"""
# Remove any text inside <think>...</think> blocks.
cleaned = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
# Use regex to find a valid tag pattern.
pattern = r'([a-z0-9-]+(?::[a-z0-9-]+){1,5})'
match = re.search(pattern, cleaned)
if match:
return match.group(1).strip()
return cleaned.strip()
def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str:
print(f"\n🧠 [iterative_convert_to_search_tags] Input Query: {query}")
refined_query = query
tags_output = ""
for iteration in range(max_iterations):
print(f"\n🔄 Iteration {iteration+1}")
response = chain.invoke({"query": refined_query})
full_output = response.content.strip()
tags_output = parse_search_tags(full_output)
print(f"Output Tags: {tags_output}")
if valid_tags(tags_output):
print("✅ Valid tags format detected.")
return tags_output
else:
print("⚠️ Invalid tags format. Requesting refinement...")
refined_query = f"{query}\nPlease refine your answer so that the output strictly matches the format: tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]."
print("Final output (may be invalid):", tags_output)
return tags_output
# ---------------------------
# GitHub API Helper Functions
# ---------------------------
def fetch_readme_content(repo_full_name):
"""Fetch the README content (if available) using the GitHub API."""
readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
response = session.get(readme_url)
if response.status_code == 200:
readme_data = response.json()
try:
return base64.b64decode(readme_data.get('content', '')).decode('utf-8', errors='replace')
except Exception:
return ""
return ""
def fetch_markdown_contents(repo_full_name):
"""
Fetch all markdown files (except the README already fetched) from the root of the repository.
"""
url = f"https://api.github.com/repos/{repo_full_name}/contents"
response = session.get(url)
contents = ""
if response.status_code == 200:
items = response.json()
for item in items:
if item.get("type") == "file" and item.get("name", "").lower().endswith(".md"):
file_url = item.get("download_url")
if file_url:
file_resp = requests.get(file_url)
if file_resp.status_code == 200:
contents += "\n" + file_resp.text
return contents
def fetch_all_markdown(repo_full_name):
"""Combine README with all markdown contents from the repository root."""
readme = fetch_readme_content(repo_full_name)
other_md = fetch_markdown_contents(repo_full_name)
return readme + "\n" + other_md
def fetch_github_repositories(query, max_results=10):
"""
Searches GitHub repositories using the provided query and retrieves key information.
"""
url = "https://api.github.com/search/repositories"
params = {
"q": query,
"per_page": max_results
}
response = session.get(url, params=params)
if response.status_code != 200:
print(f"Error {response.status_code}: {response.json().get('message')}")
return []
repo_list = []
for repo in response.json().get('items', []):
repo_link = repo.get('html_url')
description = repo.get('description') or ""
combined_markdown = fetch_all_markdown(repo.get('full_name'))
combined_text = (description + "\n" + combined_markdown).strip()
repo_list.append({
"title": repo.get('name', 'No title available'),
"link": repo_link,
"combined_text": combined_text
})
return repo_list
# ---------------------------
# Initialize SentenceTransformer Model for Dense Retrieval
# ---------------------------
model = SentenceTransformer('all-mpnet-base-v2')
def robust_min_max_norm(scores):
"""
Performs min-max normalization while avoiding division by zero.
"""
min_val = scores.min()
max_val = scores.max()
if max_val - min_val < 1e-10:
return np.ones_like(scores)
return (scores - min_val) / (max_val - min_val)
# ---------------------------
# Cross-Encoder Re-Ranking Function
# ---------------------------
def cross_encoder_rerank_candidates(candidates, query, model_name, top_n=10):
"""
Re-ranks candidate repositories using a cross-encoder model.
For long documents, the text is split into chunks and scores are aggregated.
"""
cross_encoder = CrossEncoder(model_name)
CHUNK_SIZE = 2000 # characters per chunk
MAX_DOC_LENGTH = 5000 # cap for long docs
MIN_DOC_LENGTH = 200 # threshold for short docs
def split_text(text, chunk_size=CHUNK_SIZE):
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
for candidate in candidates:
doc = candidate.get("combined_text", "")
if len(doc) > MAX_DOC_LENGTH:
doc = doc[:MAX_DOC_LENGTH]
try:
if len(doc) < MIN_DOC_LENGTH:
score = cross_encoder.predict([[query, doc]])
candidate["cross_encoder_score"] = float(score[0])
else:
chunks = split_text(doc)
pairs = [[query, chunk] for chunk in chunks]
scores = cross_encoder.predict(pairs)
max_score = np.max(scores) if len(scores) > 0 else 0.0
avg_score = np.mean(scores) if len(scores) > 0 else 0.0
candidate["cross_encoder_score"] = float(0.5 * max_score + 0.5 * avg_score)
except Exception as e:
logging.error(f"Error scoring candidate {candidate.get('link', 'unknown')}: {e}")
candidate["cross_encoder_score"] = 0.0
all_scores = [candidate["cross_encoder_score"] for candidate in candidates]
if all_scores:
min_score = min(all_scores)
if min_score < 0:
for candidate in candidates:
candidate["cross_encoder_score"] += -min_score
reranked = sorted(candidates, key=lambda x: x["cross_encoder_score"], reverse=True)
return reranked[:top_n]
# ---------------------------
# Main Function: Repository Ranking with Hybrid Retrieval and Cross-Encoder Re-Ranking
# ---------------------------
def run_repository_ranking(query: str) -> str:
"""
Converts the user query into search tags, runs multiple GitHub queries (individual and combined),
deduplicates results, and applies a hybrid ranking strategy:
- Dense embeddings (via SentenceTransformer) combined with BM25 scoring.
- Re-ranks top candidates using a cross-encoder for improved contextual alignment.
"""
# Step 1: Generate search tags from the query.
search_tags = iterative_convert_to_search_tags(query)
tag_list = [tag.strip() for tag in search_tags.split(":") if tag.strip()]
# Step 2: Handle target language extraction.
if any(tag.startswith("target-") for tag in tag_list):
target_tag = next(tag for tag in tag_list if tag.startswith("target-"))
target_lang = target_tag.replace("target-", "")
lang_query = f"language:{target_lang}"
tag_list = [tag for tag in tag_list if not tag.startswith("target-")]
else:
lang_query = "language:python"
# Step 3: Build advanced search qualifiers.
advanced_qualifier = "in:name,description,readme"
all_repositories = []
# Loop over individual tags.
for tag in tag_list:
github_query = f"{tag} {advanced_qualifier} {lang_query}"
print("GitHub Query:", github_query)
repos = fetch_github_repositories(github_query, max_results=15)
all_repositories.extend(repos)
# Combined query using OR logic.
combined_query = " OR ".join(tag_list)
combined_query = f"({combined_query}) {advanced_qualifier} {lang_query}"
print("Combined GitHub Query:", combined_query)
repos = fetch_github_repositories(combined_query, max_results=15)
all_repositories.extend(repos)
# Deduplicate repositories using the repo link.
unique_repositories = {}
for repo in all_repositories:
if repo["link"] not in unique_repositories:
unique_repositories[repo["link"]] = repo
else:
existing_text = unique_repositories[repo["link"]]["combined_text"]
unique_repositories[repo["link"]]["combined_text"] = existing_text + "\n" + repo["combined_text"]
repositories = list(unique_repositories.values())
if not repositories:
return "No repositories found for your query."
# Step 4: Prepare documents.
docs = [repo.get("combined_text", "") for repo in repositories]
# Step 5: Dense retrieval.
doc_embeddings = model.encode(docs, convert_to_numpy=True, show_progress_bar=True, batch_size=16)
if doc_embeddings.ndim == 1:
doc_embeddings = doc_embeddings.reshape(1, -1)
norms = np.linalg.norm(doc_embeddings, axis=1, keepdims=True)
norm_doc_embeddings = doc_embeddings / (norms + 1e-10)
query_embedding = model.encode(query, convert_to_numpy=True)
if query_embedding.ndim == 1:
query_embedding = query_embedding.reshape(1, -1)
norm_query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
dim = norm_doc_embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(norm_doc_embeddings)
k = norm_doc_embeddings.shape[0]
D, I = index.search(norm_query_embedding, k)
dense_scores = D.squeeze()
norm_dense_scores = robust_min_max_norm(dense_scores)
# Step 6: BM25 scoring.
if BM25Okapi is not None:
tokenized_docs = [re.findall(r'\w+', doc.lower()) for doc in docs]
bm25 = BM25Okapi(tokenized_docs)
query_tokens = re.findall(r'\w+', query.lower())
bm25_scores = np.array(bm25.get_scores(query_tokens))
norm_bm25_scores = robust_min_max_norm(bm25_scores)
else:
norm_bm25_scores = np.zeros_like(norm_dense_scores)
# Step 7: Combine scores.
alpha = 0.8
combined_scores = alpha * norm_dense_scores + (1 - alpha) * norm_bm25_scores
for idx, repo in enumerate(repositories):
repo["combined_score"] = float(combined_scores[idx])
# Step 8: Initial ranking.
ranked_repositories = sorted(repositories, key=lambda x: x.get("combined_score", 0), reverse=True)
# Step 9: Cross-Encoder Re-Ranking.
top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories
final_ranked = cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=10)
# Step 10: Format output.
output = "\n=== Ranked Repositories ===\n"
for rank, repo in enumerate(final_ranked, 1):
output += f"Final Rank: {rank}\n"
output += f"Title: {repo['title']}\n"
output += f"Link: {repo['link']}\n"
output += f"Combined Score: {repo.get('combined_score', 0):.4f}\n"
output += f"Cross-Encoder Score: {repo.get('cross_encoder_score', 0):.4f}\n"
snippet = repo['combined_text'][:300].replace('\n', ' ')
output += f"Snippet: {snippet}...\n"
output += '-' * 80 + "\n"
output += "\n=== End of Results ==="
return output
# ---------------------------
# Main Entry Point for Testing
# ---------------------------
if __name__ == "__main__":
test_query = "Chain of thought prompting for reasoning models"
result = run_repository_ranking(test_query)
print(result)