zamalali
commited on
Commit
·
0b55d27
0
Parent(s):
Push of DeepGit core files
Browse files- .gitignore +5 -0
- agent.py +125 -0
- app.py +309 -0
- main.py +380 -0
- requirements.txt +8 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.venv/
|
2 |
+
.env
|
3 |
+
__pycache__/
|
4 |
+
*.pyc
|
5 |
+
.gradio/
|
agent.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# test.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from github import Github, Auth
|
6 |
+
|
7 |
+
# LangChain imports
|
8 |
+
from langchain_groq import ChatGroq
|
9 |
+
from langchain_core.tools import tool
|
10 |
+
from langchain.agents import create_tool_calling_agent, AgentExecutor
|
11 |
+
from langchain import hub
|
12 |
+
|
13 |
+
# Load environment variables
|
14 |
+
load_dotenv()
|
15 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
16 |
+
GITHUB_PAT = os.getenv("GITHUB_API_KEY")
|
17 |
+
|
18 |
+
if not (GROQ_API_KEY and GITHUB_PAT):
|
19 |
+
raise ValueError("Please set GROQ_API_KEY and GITHUB_API_KEY in your .env")
|
20 |
+
|
21 |
+
# Initialize GitHub client
|
22 |
+
_auth = Auth.Token(GITHUB_PAT)
|
23 |
+
_gh = Github(auth=_auth)
|
24 |
+
|
25 |
+
# Define the GitHub tool
|
26 |
+
@tool
|
27 |
+
def get_repo_info(repo_name: str) -> str:
|
28 |
+
"""Fetch and summarize metadata about a GitHub repository."""
|
29 |
+
try:
|
30 |
+
repo = _gh.get_repo(repo_name)
|
31 |
+
except Exception as e:
|
32 |
+
return f" Error fetching '{repo_name}': {e}"
|
33 |
+
|
34 |
+
name = repo.full_name
|
35 |
+
desc = repo.description or "No description"
|
36 |
+
url = repo.html_url
|
37 |
+
owner = repo.owner.login
|
38 |
+
stars = repo.stargazers_count
|
39 |
+
forks = repo.forks_count
|
40 |
+
issues = repo.open_issues_count
|
41 |
+
created = repo.created_at.isoformat()
|
42 |
+
updated = repo.updated_at.isoformat()
|
43 |
+
watchers = repo.watchers_count
|
44 |
+
default_br = repo.default_branch
|
45 |
+
language = repo.language or "None"
|
46 |
+
|
47 |
+
try:
|
48 |
+
license_name = repo.get_license().license.name
|
49 |
+
except:
|
50 |
+
license_name = "None"
|
51 |
+
|
52 |
+
topics = repo.get_topics()
|
53 |
+
try:
|
54 |
+
raw_md = repo.get_readme().decoded_content.decode("utf-8")
|
55 |
+
snippet = raw_md[:300].replace("\n", " ") + "..."
|
56 |
+
except:
|
57 |
+
snippet = "No README found"
|
58 |
+
|
59 |
+
contribs = repo.get_contributors()[:5]
|
60 |
+
contrib_list = ", ".join(f"{c.login}({c.contributions})" for c in contribs)
|
61 |
+
|
62 |
+
commits = repo.get_commits()[:3]
|
63 |
+
commit_list = "; ".join(c.commit.message.split("\n")[0] for c in commits)
|
64 |
+
|
65 |
+
return f"""
|
66 |
+
Repository: {name}
|
67 |
+
Description: {desc}
|
68 |
+
URL: {url}
|
69 |
+
Owner: {owner}
|
70 |
+
⭐ Stars: {stars} 🍴 Forks: {forks} 🐛 Open Issues: {issues}
|
71 |
+
👁️ Watchers: {watchers} Default branch: {default_br}
|
72 |
+
⚙️ Language: {language} License: {license_name}
|
73 |
+
🔍 Topics: {topics}
|
74 |
+
|
75 |
+
README Snippet: {snippet}
|
76 |
+
|
77 |
+
👥 Top Contributors: {contrib_list}
|
78 |
+
🧾 Latest Commits: {commit_list}
|
79 |
+
"""
|
80 |
+
|
81 |
+
# Instantiate the Groq LLM
|
82 |
+
llm = ChatGroq(
|
83 |
+
model="llama-3.1-8b-instant",
|
84 |
+
temperature=0.3,
|
85 |
+
max_tokens=1024,
|
86 |
+
api_key=GROQ_API_KEY,
|
87 |
+
)
|
88 |
+
|
89 |
+
# Define the tools to pass into the agent
|
90 |
+
tools = [get_repo_info]
|
91 |
+
|
92 |
+
# Pull default tool-calling agent prompt from LangChain hub
|
93 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
94 |
+
|
95 |
+
prompt = ChatPromptTemplate.from_messages([
|
96 |
+
("system",
|
97 |
+
"You are GitHub Agent, an expert at analyzing repositories.\n"
|
98 |
+
"When a user asks about a repo, call the tool and return a clear, concise summary of the repository based on the tool result.\n"
|
99 |
+
"Avoid repeating raw tool output or adding unnecessary disclaimers.\n"
|
100 |
+
"Respond in complete sentences, in natural language."
|
101 |
+
),
|
102 |
+
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
103 |
+
("human", "{input}"),
|
104 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
105 |
+
])
|
106 |
+
|
107 |
+
# Create the agent using LangChain's legacy AgentExecutor approach
|
108 |
+
agent = create_tool_calling_agent(llm, tools, prompt)
|
109 |
+
|
110 |
+
# Run the agent
|
111 |
+
# At the bottom of test.py
|
112 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, max_iterations=2)
|
113 |
+
|
114 |
+
# Export it
|
115 |
+
__all__ = ["agent_executor"]
|
116 |
+
|
117 |
+
# # Quick test
|
118 |
+
if __name__ == "__main__":
|
119 |
+
import time
|
120 |
+
user_input = "Give me details about the repo zamalali/deepgit"
|
121 |
+
start_time = time.time()
|
122 |
+
result = agent_executor.invoke({"input": user_input})
|
123 |
+
end_time = time.time()
|
124 |
+
print("\n Final Answer:\n", result["output"])
|
125 |
+
print(f"\n Took {end_time - start_time:.2f} seconds")
|
app.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import time
|
3 |
+
import threading
|
4 |
+
import logging
|
5 |
+
from gradio.themes.utils import sizes
|
6 |
+
from main import run_repository_ranking # Import the repository ranking function
|
7 |
+
import agent # Import the test.py module for chat agent
|
8 |
+
|
9 |
+
# ---------------------------
|
10 |
+
# Global Logging Buffer Setup
|
11 |
+
# ---------------------------
|
12 |
+
LOG_BUFFER = []
|
13 |
+
LOG_BUFFER_LOCK = threading.Lock()
|
14 |
+
|
15 |
+
class BufferLogHandler(logging.Handler):
|
16 |
+
def emit(self, record):
|
17 |
+
log_entry = self.format(record)
|
18 |
+
with LOG_BUFFER_LOCK:
|
19 |
+
LOG_BUFFER.append(log_entry)
|
20 |
+
|
21 |
+
root_logger = logging.getLogger()
|
22 |
+
if not any(isinstance(h, BufferLogHandler) for h in root_logger.handlers):
|
23 |
+
handler = BufferLogHandler()
|
24 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
25 |
+
handler.setFormatter(formatter)
|
26 |
+
root_logger.addHandler(handler)
|
27 |
+
|
28 |
+
def filter_logs(logs):
|
29 |
+
filtered = []
|
30 |
+
last_was_fetching = False
|
31 |
+
for log in logs:
|
32 |
+
if "HTTP Request:" in log:
|
33 |
+
if not last_was_fetching:
|
34 |
+
filtered.append("Fetching repositories...")
|
35 |
+
last_was_fetching = True
|
36 |
+
else:
|
37 |
+
filtered.append(log)
|
38 |
+
last_was_fetching = False
|
39 |
+
return filtered
|
40 |
+
|
41 |
+
def parse_result_to_html(raw_result: str, num_results: int) -> (str, list):
|
42 |
+
"""
|
43 |
+
Parses the raw string output from run_repository_ranking to an HTML table.
|
44 |
+
Only the top N results are displayed.
|
45 |
+
Returns (html, repo_names)
|
46 |
+
"""
|
47 |
+
entries = raw_result.strip().split("Final Rank:")
|
48 |
+
entries = entries[1:num_results+1] # Use only the first N entries
|
49 |
+
if not entries:
|
50 |
+
return ("<p>No repositories found for your query.</p>", [])
|
51 |
+
html = """
|
52 |
+
<table border="1" style="width:80%; margin: auto; border-collapse: collapse;">
|
53 |
+
<thead>
|
54 |
+
<tr>
|
55 |
+
<th>Rank</th>
|
56 |
+
<th>Title</th>
|
57 |
+
<th>Link</th>
|
58 |
+
<th>Combined Score</th>
|
59 |
+
</tr>
|
60 |
+
</thead>
|
61 |
+
<tbody>
|
62 |
+
"""
|
63 |
+
repo_names = []
|
64 |
+
for entry in entries:
|
65 |
+
lines = entry.strip().split("\n")
|
66 |
+
data = {}
|
67 |
+
data["Final Rank"] = lines[0].strip() if lines else ""
|
68 |
+
for line in lines[1:]:
|
69 |
+
if ": " in line:
|
70 |
+
key, val = line.split(": ", 1)
|
71 |
+
data[key.strip()] = val.strip()
|
72 |
+
# Try to extract repo name from the Link (github.com/user/repo)
|
73 |
+
link = data.get('Link', '')
|
74 |
+
repo_name = ''
|
75 |
+
if 'github.com/' in link:
|
76 |
+
repo_name = link.split('github.com/')[-1].strip('/ ')
|
77 |
+
if repo_name:
|
78 |
+
repo_names.append(repo_name)
|
79 |
+
html += f"""
|
80 |
+
<tr>
|
81 |
+
<td>{data.get('Final Rank', '')}</td>
|
82 |
+
<td>{data.get('Title', '')}</td>
|
83 |
+
<td><a href=\"{data.get('Link', '#')}\" target=\"_blank\">GitHub</a></td>
|
84 |
+
<td>{data.get('Combined Score', '')}</td>
|
85 |
+
</tr>
|
86 |
+
"""
|
87 |
+
html += "</tbody></table>"
|
88 |
+
return html, repo_names
|
89 |
+
|
90 |
+
# ---------------------------
|
91 |
+
# GPU-enabled Wrapper for Repository Ranking
|
92 |
+
# ---------------------------
|
93 |
+
def gpu_run_repo(topic: str, num_results: int):
|
94 |
+
return run_repository_ranking(topic, num_results)
|
95 |
+
|
96 |
+
def run_lite_workflow(topic, num_results, result_container):
|
97 |
+
result = gpu_run_repo(topic, num_results)
|
98 |
+
result_container["raw_result"] = result
|
99 |
+
|
100 |
+
def stream_lite_workflow(topic, num_results):
|
101 |
+
logging.info("[UI] User started a new search for topic: %s", topic)
|
102 |
+
with LOG_BUFFER_LOCK:
|
103 |
+
LOG_BUFFER.clear()
|
104 |
+
result_container = {}
|
105 |
+
workflow_thread = threading.Thread(target=run_lite_workflow, args=(topic, num_results, result_container))
|
106 |
+
workflow_thread.start()
|
107 |
+
|
108 |
+
last_index = 0
|
109 |
+
while workflow_thread.is_alive() or (last_index < len(LOG_BUFFER)):
|
110 |
+
with LOG_BUFFER_LOCK:
|
111 |
+
new_logs = LOG_BUFFER[last_index:]
|
112 |
+
last_index = len(LOG_BUFFER)
|
113 |
+
if new_logs:
|
114 |
+
filtered_logs = filter_logs(new_logs)
|
115 |
+
status_msg = filtered_logs[-1]
|
116 |
+
detail_msg = "<br/>".join(filtered_logs)
|
117 |
+
yield status_msg, detail_msg, []
|
118 |
+
time.sleep(0.5)
|
119 |
+
|
120 |
+
workflow_thread.join()
|
121 |
+
with LOG_BUFFER_LOCK:
|
122 |
+
final_logs = LOG_BUFFER[:]
|
123 |
+
raw_result = result_container.get("raw_result", "No results returned.")
|
124 |
+
html_result, repo_names = parse_result_to_html(raw_result, num_results)
|
125 |
+
yield "", html_result, repo_names
|
126 |
+
|
127 |
+
def lite_runner(topic, num_results):
|
128 |
+
logging.info("[UI] Running lite_runner for topic: %s", topic)
|
129 |
+
yield "Workflow started", "<p>Processing your request. Please wait...</p>", []
|
130 |
+
for status, details, repos in stream_lite_workflow(topic, num_results):
|
131 |
+
yield status, details, repos
|
132 |
+
|
133 |
+
# ---------------------------
|
134 |
+
# App UI Setup Using Gradio Soft Theme with Centered Layout
|
135 |
+
# ---------------------------
|
136 |
+
with gr.Blocks(
|
137 |
+
theme=gr.themes.Soft(text_size=sizes.text_md),
|
138 |
+
title="DeepGit Lite",
|
139 |
+
css="""
|
140 |
+
/* Center header and footer */
|
141 |
+
#header { text-align: center; margin-bottom: 20px; }
|
142 |
+
#main-container { max-width: 800px; margin: auto; }
|
143 |
+
#footer { text-align: center; margin-top: 20px; }
|
144 |
+
"""
|
145 |
+
) as demo:
|
146 |
+
gr.Markdown(
|
147 |
+
"""
|
148 |
+
<div style="padding-top: 60px;">
|
149 |
+
<div style="display: flex; align-items: center; justify-content: center;">
|
150 |
+
<img src="https://img.icons8.com/?size=100&id=118557&format=png&color=000000"
|
151 |
+
style="width: 60px; height: 60px; margin-right: 12px;">
|
152 |
+
<h1 style="margin: 0; font-size: 2.5em; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;">
|
153 |
+
DeepGit Lite
|
154 |
+
</h1>
|
155 |
+
</div>
|
156 |
+
<div style="text-align: center; margin-top: 20px; font-size: 1.1em; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;">
|
157 |
+
<p>
|
158 |
+
✨ DeepGit Lite is the lightweight pro version of <strong>DeepGit</strong>.<br>
|
159 |
+
It harnesses advanced deep semantic search to explore GitHub repositories and deliver curated results.<br>
|
160 |
+
Under the hood, it leverages a hybrid ranking approach combining dense retrieval, BM25 scoring, and cross-encoder re-ranking for optimal discovery.<br>
|
161 |
+
If the agent returns no repositories found, it means no chain was invoked due to GPU unavailability. Please duplicate the space and re-run.
|
162 |
+
</p>
|
163 |
+
<p>
|
164 |
+
🚀 Check out the full DeepGit version on
|
165 |
+
<a href="https://github.com/zamalali/DeepGit" target="_blank">GitHub</a> and ⭐
|
166 |
+
<strong>Star DeepGit</strong> on GitHub!
|
167 |
+
</p>
|
168 |
+
</div>
|
169 |
+
</div>
|
170 |
+
""",
|
171 |
+
elem_id="header"
|
172 |
+
)
|
173 |
+
|
174 |
+
# --- Search UI ---
|
175 |
+
with gr.Column(elem_id="main-container", visible=True) as search_ui:
|
176 |
+
research_input = gr.Textbox(
|
177 |
+
label="Research Query",
|
178 |
+
placeholder="Enter your research topic here, e.g., Looking for a low code/no code tool to augment images and annotations?",
|
179 |
+
lines=3
|
180 |
+
)
|
181 |
+
num_results_slider = gr.Slider(
|
182 |
+
minimum=5, maximum=25, value=10, step=1,
|
183 |
+
label="Number of Results to Display",
|
184 |
+
info="Choose how many top repositories to show (sorted by score)"
|
185 |
+
)
|
186 |
+
run_button = gr.Button("Run DeepGit Lite", variant="primary")
|
187 |
+
status_display = gr.Markdown(label="Status")
|
188 |
+
detail_display = gr.HTML(label="Results")
|
189 |
+
repo_state = gr.State([])
|
190 |
+
go_to_chat_btn = gr.Button("Go to Chat", visible=False)
|
191 |
+
|
192 |
+
# --- Chat UI ---
|
193 |
+
with gr.Column(visible=False) as chat_ui:
|
194 |
+
|
195 |
+
repo_choice = gr.Radio(choices=[], label="Select a repository", interactive=True)
|
196 |
+
chat_history = gr.Chatbot(label="Chat with GitHub Agent")
|
197 |
+
user_input = gr.Textbox(label="Your question", placeholder="Ask about the selected repo...e.g., tell me a bit more and guide me to set this up and running?")
|
198 |
+
send_btn = gr.Button("Send")
|
199 |
+
chat_state = gr.State([])
|
200 |
+
back_btn = gr.Button("Back to Search")
|
201 |
+
|
202 |
+
def update_chat_button(status, details, repos):
|
203 |
+
logging.info("[UI] Search complete. Showing Go to Chat button: %s", bool(repos))
|
204 |
+
return gr.update(visible=bool(repos)), repos
|
205 |
+
|
206 |
+
def show_chat_ui(repos):
|
207 |
+
logging.info("[UI] Switching to Chat UI. Repositories available: %s", repos)
|
208 |
+
return gr.update(visible=False), gr.update(visible=True), gr.update(choices=repos, value=None), []
|
209 |
+
|
210 |
+
def back_to_search():
|
211 |
+
logging.info("[UI] Switching back to Search UI.")
|
212 |
+
return gr.update(visible=True), gr.update(visible=False), gr.update(value=[]), gr.update(value=None), []
|
213 |
+
|
214 |
+
def chat_with_agent(user_msg, repo, history):
|
215 |
+
logging.info("[Chat] User sent message: '%s' for repo: '%s'", user_msg, repo)
|
216 |
+
if not user_msg or not user_msg.strip():
|
217 |
+
# Block blank messages
|
218 |
+
return history + [["", "Please enter a message before sending."]], history
|
219 |
+
if not repo:
|
220 |
+
return history + [[user_msg, "Please select a repository first."]], history
|
221 |
+
full_query = f"[{repo}] {user_msg}"
|
222 |
+
try:
|
223 |
+
result = agent.agent_executor.invoke({"input": full_query})
|
224 |
+
answer = result["output"]
|
225 |
+
logging.info("[Chat] Agent response received.")
|
226 |
+
except Exception as e:
|
227 |
+
answer = f"Error: {e}"
|
228 |
+
logging.error("[Chat] Error in agent_executor: %s", e)
|
229 |
+
history = history + [[user_msg, answer]]
|
230 |
+
return history, history
|
231 |
+
|
232 |
+
# Disable send button if no repo is selected or message is blank, and show a helpful message
|
233 |
+
def can_send(user_msg, repo):
|
234 |
+
if not user_msg or not user_msg.strip():
|
235 |
+
return gr.update(interactive=False, value="Enter a message to send")
|
236 |
+
if not repo:
|
237 |
+
return gr.update(interactive=False, value="Select a repository")
|
238 |
+
return gr.update(interactive=True, value="Send")
|
239 |
+
user_input.change(
|
240 |
+
fn=can_send,
|
241 |
+
inputs=[user_input, repo_choice],
|
242 |
+
outputs=[send_btn],
|
243 |
+
show_progress=False
|
244 |
+
)
|
245 |
+
repo_choice.change(
|
246 |
+
fn=can_send,
|
247 |
+
inputs=[user_input, repo_choice],
|
248 |
+
outputs=[send_btn],
|
249 |
+
show_progress=False
|
250 |
+
)
|
251 |
+
|
252 |
+
run_button.click(
|
253 |
+
fn=lite_runner,
|
254 |
+
inputs=[research_input, num_results_slider],
|
255 |
+
outputs=[status_display, detail_display, repo_state],
|
256 |
+
api_name="deepgit_lite",
|
257 |
+
show_progress=True
|
258 |
+
).then(
|
259 |
+
fn=update_chat_button,
|
260 |
+
inputs=[status_display, detail_display, repo_state],
|
261 |
+
outputs=[go_to_chat_btn, repo_state]
|
262 |
+
)
|
263 |
+
|
264 |
+
research_input.submit(
|
265 |
+
fn=lite_runner,
|
266 |
+
inputs=[research_input, num_results_slider],
|
267 |
+
outputs=[status_display, detail_display, repo_state],
|
268 |
+
api_name="deepgit_lite_submit",
|
269 |
+
show_progress=True
|
270 |
+
).then(
|
271 |
+
fn=update_chat_button,
|
272 |
+
inputs=[status_display, detail_display, repo_state],
|
273 |
+
outputs=[go_to_chat_btn, repo_state]
|
274 |
+
)
|
275 |
+
|
276 |
+
go_to_chat_btn.click(
|
277 |
+
fn=show_chat_ui,
|
278 |
+
inputs=[repo_state],
|
279 |
+
outputs=[search_ui, chat_ui, repo_choice, chat_state]
|
280 |
+
)
|
281 |
+
|
282 |
+
back_btn.click(
|
283 |
+
fn=back_to_search,
|
284 |
+
inputs=[],
|
285 |
+
outputs=[search_ui, chat_ui, chat_history, repo_choice, chat_state]
|
286 |
+
)
|
287 |
+
|
288 |
+
send_btn.click(
|
289 |
+
fn=chat_with_agent,
|
290 |
+
inputs=[user_input, repo_choice, chat_state],
|
291 |
+
outputs=[chat_history, chat_state],
|
292 |
+
queue=False
|
293 |
+
)
|
294 |
+
user_input.submit(
|
295 |
+
fn=chat_with_agent,
|
296 |
+
inputs=[user_input, repo_choice, chat_state],
|
297 |
+
outputs=[chat_history, chat_state],
|
298 |
+
queue=False
|
299 |
+
)
|
300 |
+
|
301 |
+
gr.HTML(
|
302 |
+
"""
|
303 |
+
<div id="footer">
|
304 |
+
Made with ❤️ by <b>Zamal</b>
|
305 |
+
</div>
|
306 |
+
"""
|
307 |
+
)
|
308 |
+
|
309 |
+
demo.queue(max_size=10).launch()
|
main.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import base64
|
3 |
+
import requests
|
4 |
+
import numpy as np
|
5 |
+
import faiss
|
6 |
+
import re
|
7 |
+
import logging
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
# For local development, load environment variables from a .env file.
|
11 |
+
# In HuggingFace Spaces, secrets are automatically available as environment variables.
|
12 |
+
from dotenv import load_dotenv
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
16 |
+
from langchain_groq import ChatGroq
|
17 |
+
from langchain_core.prompts import ChatPromptTemplate
|
18 |
+
|
19 |
+
# Optionally import BM25 for sparse retrieval.
|
20 |
+
try:
|
21 |
+
from rank_bm25 import BM25Okapi
|
22 |
+
except ImportError:
|
23 |
+
BM25Okapi = None
|
24 |
+
|
25 |
+
# ---------------------------
|
26 |
+
# Environment Variables & Setup
|
27 |
+
# ---------------------------
|
28 |
+
# GitHub API key (required for GitHub API calls)
|
29 |
+
GITHUB_API_KEY = os.getenv("GITHUB_API_KEY")
|
30 |
+
# GROQ API key (if required by ChatGroq)
|
31 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
32 |
+
# HuggingFace token (if you need it to load private models from HuggingFace)
|
33 |
+
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
34 |
+
|
35 |
+
CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
|
36 |
+
|
37 |
+
# Set up a persistent session for GitHub API requests.
|
38 |
+
session = requests.Session()
|
39 |
+
session.headers.update({
|
40 |
+
"Authorization": f"token {GITHUB_API_KEY}",
|
41 |
+
"Accept": "application/vnd.github.v3+json"
|
42 |
+
})
|
43 |
+
|
44 |
+
# ---------------------------
|
45 |
+
# Langchain Groq Setup for Search Tag Conversion
|
46 |
+
# ---------------------------
|
47 |
+
llm = ChatGroq(
|
48 |
+
model="deepseek-r1-distill-llama-70b",
|
49 |
+
temperature=0.3,
|
50 |
+
max_tokens=512,
|
51 |
+
max_retries=3,
|
52 |
+
api_key=GROQ_API_KEY # Pass GROQ_API_KEY if the ChatGroq library supports it.
|
53 |
+
)
|
54 |
+
|
55 |
+
prompt = ChatPromptTemplate.from_messages([
|
56 |
+
("system",
|
57 |
+
"""You are a GitHub search optimization expert.
|
58 |
+
|
59 |
+
Your job is to:
|
60 |
+
1. Read a user's query about tools, research, or tasks.
|
61 |
+
2. Detect if the query mentions a specific programming language other than Python (for example, JavaScript or JS). If so, record that language as the target language.
|
62 |
+
3. Think iteratively and generate your internal chain-of-thought enclosed in <think> ... </think> tags.
|
63 |
+
4. After your internal reasoning, output up to five GitHub-style search tags or library names that maximize repository discovery.
|
64 |
+
Use as many tags as necessary based on the query's complexity, but never more than five.
|
65 |
+
5. If you detected a non-Python target language, append an additional tag at the end in the format target-[language] (e.g., target-javascript).
|
66 |
+
If no specific language is mentioned, do not include any target tag.
|
67 |
+
|
68 |
+
Output Format:
|
69 |
+
tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]
|
70 |
+
|
71 |
+
Rules:
|
72 |
+
- Use lowercase and hyphenated keywords (e.g., image-augmentation, chain-of-thought).
|
73 |
+
- Use terms commonly found in GitHub repo names, topics, or descriptions.
|
74 |
+
- Avoid generic terms like "python", "ai", "tool", "project".
|
75 |
+
- Do NOT use full phrases or vague words like "no-code", "framework", or "approach".
|
76 |
+
- Prefer real tools, popular methods, or dataset names when mentioned.
|
77 |
+
- If your output does not strictly match the required format, correct it after your internal reasoning.
|
78 |
+
- Choose high-signal keywords to ensure the search yields the most relevant GitHub repositories.
|
79 |
+
|
80 |
+
Excellent Examples:
|
81 |
+
|
82 |
+
Input: "No code tool to augment image and annotation"
|
83 |
+
Output: image-augmentation:albumentations
|
84 |
+
|
85 |
+
Input: "Repos around chain of thought prompting mainly for finetuned models"
|
86 |
+
Output: chain-of-thought:finetuned-llm
|
87 |
+
|
88 |
+
Input: "Find repositories implementing data augmentation pipelines in JavaScript"
|
89 |
+
Output: data-augmentation:target-javascript
|
90 |
+
|
91 |
+
Output must be ONLY the search tags separated by colons. Do not include any extra text, bullet points, or explanations.
|
92 |
+
"""),
|
93 |
+
("human", "{query}")
|
94 |
+
])
|
95 |
+
chain = prompt | llm
|
96 |
+
|
97 |
+
def valid_tags(tags: str) -> bool:
|
98 |
+
pattern = r'^[a-z0-9-]+(?::[a-z0-9-]+){1,5}$'
|
99 |
+
return re.match(pattern, tags) is not None
|
100 |
+
|
101 |
+
def parse_search_tags(response: str) -> str:
|
102 |
+
# Remove any text inside <think>...</think> blocks.
|
103 |
+
cleaned = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
|
104 |
+
pattern = r'([a-z0-9-]+(?::[a-z0-9-]+){1,5})'
|
105 |
+
match = re.search(pattern, cleaned)
|
106 |
+
if match:
|
107 |
+
return match.group(1).strip()
|
108 |
+
return cleaned.strip()
|
109 |
+
|
110 |
+
def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str:
|
111 |
+
print(f"\n [iterative_convert_to_search_tags] Input Query: {query}")
|
112 |
+
refined_query = query
|
113 |
+
tags_output = ""
|
114 |
+
for iteration in range(max_iterations):
|
115 |
+
print(f"\n Iteration {iteration+1}")
|
116 |
+
response = chain.invoke({"query": refined_query})
|
117 |
+
full_output = response.content.strip()
|
118 |
+
tags_output = parse_search_tags(full_output)
|
119 |
+
print(f"Output Tags: {tags_output}")
|
120 |
+
if valid_tags(tags_output):
|
121 |
+
print("Valid tags format detected.")
|
122 |
+
return tags_output
|
123 |
+
else:
|
124 |
+
print(" Invalid tags format. Requesting refinement...")
|
125 |
+
refined_query = f"{query}\nPlease refine your answer so that the output strictly matches the format: tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]."
|
126 |
+
print("Final output (may be invalid):", tags_output)
|
127 |
+
return tags_output
|
128 |
+
|
129 |
+
# ---------------------------
|
130 |
+
# GitHub API Helper Functions
|
131 |
+
# ---------------------------
|
132 |
+
def fetch_readme_content(repo_full_name: str) -> str:
|
133 |
+
readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
|
134 |
+
response = session.get(readme_url)
|
135 |
+
if response.status_code == 200:
|
136 |
+
readme_data = response.json()
|
137 |
+
try:
|
138 |
+
return base64.b64decode(readme_data.get('content', '')).decode('utf-8', errors='replace')
|
139 |
+
except Exception:
|
140 |
+
return ""
|
141 |
+
return ""
|
142 |
+
|
143 |
+
def fetch_markdown_contents(repo_full_name: str) -> str:
|
144 |
+
url = f"https://api.github.com/repos/{repo_full_name}/contents"
|
145 |
+
response = session.get(url)
|
146 |
+
contents = ""
|
147 |
+
if response.status_code == 200:
|
148 |
+
items = response.json()
|
149 |
+
for item in items:
|
150 |
+
if item.get("type") == "file" and item.get("name", "").lower().endswith(".md"):
|
151 |
+
file_url = item.get("download_url")
|
152 |
+
if file_url:
|
153 |
+
file_resp = requests.get(file_url)
|
154 |
+
if file_resp.status_code == 200:
|
155 |
+
contents += "\n" + file_resp.text
|
156 |
+
return contents
|
157 |
+
|
158 |
+
def fetch_all_markdown(repo_full_name: str) -> str:
|
159 |
+
readme = fetch_readme_content(repo_full_name)
|
160 |
+
other_md = fetch_markdown_contents(repo_full_name)
|
161 |
+
return readme + "\n" + other_md
|
162 |
+
|
163 |
+
def fetch_github_repositories(query: str, max_results: int = 10) -> list:
|
164 |
+
url = "https://api.github.com/search/repositories"
|
165 |
+
params = {
|
166 |
+
"q": query,
|
167 |
+
"per_page": max_results
|
168 |
+
}
|
169 |
+
response = session.get(url, params=params)
|
170 |
+
if response.status_code != 200:
|
171 |
+
print(f"Error {response.status_code}: {response.json().get('message')}")
|
172 |
+
return []
|
173 |
+
repo_list = []
|
174 |
+
for repo in response.json().get('items', []):
|
175 |
+
repo_link = repo.get('html_url')
|
176 |
+
description = repo.get('description') or ""
|
177 |
+
combined_markdown = fetch_all_markdown(repo.get('full_name'))
|
178 |
+
combined_text = (description + "\n" + combined_markdown).strip()
|
179 |
+
repo_list.append({
|
180 |
+
"title": repo.get('name', 'No title available'),
|
181 |
+
"link": repo_link,
|
182 |
+
"combined_text": combined_text
|
183 |
+
})
|
184 |
+
return repo_list
|
185 |
+
|
186 |
+
# ---------------------------
|
187 |
+
# Dense Retrieval Model Setup
|
188 |
+
# ---------------------------
|
189 |
+
try:
|
190 |
+
# If using a GPU-enabled model, the HuggingFace token can be used for private models.
|
191 |
+
model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
|
192 |
+
except Exception as e:
|
193 |
+
print("Error initializing GPU for SentenceTransformer; falling back to CPU:", e)
|
194 |
+
model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
|
195 |
+
|
196 |
+
def robust_min_max_norm(scores: np.ndarray) -> np.ndarray:
|
197 |
+
min_val = scores.min()
|
198 |
+
max_val = scores.max()
|
199 |
+
if max_val - min_val < 1e-10:
|
200 |
+
return np.ones_like(scores)
|
201 |
+
return (scores - min_val) / (max_val - min_val)
|
202 |
+
|
203 |
+
# ---------------------------
|
204 |
+
# Cross-Encoder Re-Ranking Function
|
205 |
+
# ---------------------------
|
206 |
+
def cross_encoder_rerank_candidates(candidates: list, query: str, model_name: str, top_n: int = 10) -> list:
|
207 |
+
try:
|
208 |
+
cross_encoder = CrossEncoder(model_name, device='cpu')
|
209 |
+
except Exception as e:
|
210 |
+
print("Error initializing CrossEncoder on GPU; falling back to CPU:", e)
|
211 |
+
cross_encoder = CrossEncoder(model_name, device='cpu')
|
212 |
+
|
213 |
+
CHUNK_SIZE = 2000
|
214 |
+
MAX_DOC_LENGTH = 5000
|
215 |
+
MIN_DOC_LENGTH = 200
|
216 |
+
|
217 |
+
def split_text(text: str, chunk_size: int = CHUNK_SIZE) -> list:
|
218 |
+
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
|
219 |
+
|
220 |
+
for candidate in candidates:
|
221 |
+
doc = candidate.get("combined_text", "")
|
222 |
+
if len(doc) > MAX_DOC_LENGTH:
|
223 |
+
doc = doc[:MAX_DOC_LENGTH]
|
224 |
+
try:
|
225 |
+
if len(doc) < MIN_DOC_LENGTH:
|
226 |
+
score = cross_encoder.predict([[query, doc]])
|
227 |
+
if hasattr(score, '__len__') and len(score) == 1:
|
228 |
+
candidate["cross_encoder_score"] = float(score[0])
|
229 |
+
else:
|
230 |
+
candidate["cross_encoder_score"] = float(score)
|
231 |
+
else:
|
232 |
+
chunks = split_text(doc)
|
233 |
+
pairs = [[query, chunk] for chunk in chunks]
|
234 |
+
scores = cross_encoder.predict(pairs)
|
235 |
+
scores = np.array(scores)
|
236 |
+
max_score = float(np.max(scores)) if scores.size > 0 else 0.0
|
237 |
+
avg_score = float(np.mean(scores)) if scores.size > 0 else 0.0
|
238 |
+
candidate["cross_encoder_score"] = 0.5 * max_score + 0.5 * avg_score
|
239 |
+
except Exception as e:
|
240 |
+
logging.debug(f"[cross-encoder] Error scoring candidate {candidate.get('link', 'unknown')}: {e}")
|
241 |
+
candidate["cross_encoder_score"] = 0.0
|
242 |
+
|
243 |
+
all_scores = [candidate["cross_encoder_score"] for candidate in candidates]
|
244 |
+
if all_scores:
|
245 |
+
min_score = min(all_scores)
|
246 |
+
if min_score < 0:
|
247 |
+
for candidate in candidates:
|
248 |
+
candidate["cross_encoder_score"] += -min_score
|
249 |
+
|
250 |
+
return candidates
|
251 |
+
|
252 |
+
# ---------------------------
|
253 |
+
# Main Ranking Function with Hybrid Retrieval and Combined Scoring
|
254 |
+
# ---------------------------
|
255 |
+
def run_repository_ranking(query: str, num_results: int = 10) -> str:
|
256 |
+
logging.info("[DeepGit] Step 1: Generate search tags from the query.")
|
257 |
+
search_tags = iterative_convert_to_search_tags(query)
|
258 |
+
tag_list = [tag.strip() for tag in search_tags.split(":") if tag.strip()]
|
259 |
+
|
260 |
+
# Step 2: Handle target language extraction.
|
261 |
+
logging.info("[DeepGit] Step 2: Handle target language extraction.")
|
262 |
+
if any(tag.startswith("target-") for tag in tag_list):
|
263 |
+
target_tag = next(tag for tag in tag_list if tag.startswith("target-"))
|
264 |
+
lang_query = f"language:{target_tag.replace('target-', '')}"
|
265 |
+
tag_list = [tag for tag in tag_list if not tag.startswith("target-")]
|
266 |
+
else:
|
267 |
+
lang_query = "language:python"
|
268 |
+
|
269 |
+
# Step 3: Build advanced search qualifiers.
|
270 |
+
logging.info("[DeepGit] Step 3: Build advanced search qualifiers and fetch repositories.")
|
271 |
+
advanced_qualifier = "in:name,description,readme"
|
272 |
+
all_repositories = []
|
273 |
+
|
274 |
+
for tag in tag_list:
|
275 |
+
github_query = f"{tag} {advanced_qualifier} {lang_query}"
|
276 |
+
logging.info(f"[DeepGit] GitHub Query: {github_query}")
|
277 |
+
repos = fetch_github_repositories(github_query, max_results=15)
|
278 |
+
all_repositories.extend(repos)
|
279 |
+
|
280 |
+
combined_query = " OR ".join(tag_list)
|
281 |
+
combined_query = f"({combined_query}) {advanced_qualifier} {lang_query}"
|
282 |
+
logging.info(f"[DeepGit] Combined GitHub Query: {combined_query}")
|
283 |
+
repos = fetch_github_repositories(combined_query, max_results=15)
|
284 |
+
all_repositories.extend(repos)
|
285 |
+
|
286 |
+
unique_repositories = {}
|
287 |
+
for repo in all_repositories:
|
288 |
+
if repo["link"] not in unique_repositories:
|
289 |
+
unique_repositories[repo["link"]] = repo
|
290 |
+
else:
|
291 |
+
existing_text = unique_repositories[repo["link"]]["combined_text"]
|
292 |
+
unique_repositories[repo["link"]]["combined_text"] = existing_text + "\n" + repo["combined_text"]
|
293 |
+
repositories = list(unique_repositories.values())
|
294 |
+
|
295 |
+
if not repositories:
|
296 |
+
return "No repositories found for your query."
|
297 |
+
|
298 |
+
# Step 4: Prepare documents.
|
299 |
+
logging.info("[DeepGit] Step 4: Prepare documents for dense retrieval.")
|
300 |
+
docs = [repo.get("combined_text", "") for repo in repositories]
|
301 |
+
|
302 |
+
# Step 5: Dense retrieval.
|
303 |
+
logging.info("[DeepGit] Step 5: Compute dense embeddings and scores.")
|
304 |
+
doc_embeddings = model.encode(docs, convert_to_numpy=True, show_progress_bar=True, batch_size=16)
|
305 |
+
if doc_embeddings.ndim == 1:
|
306 |
+
doc_embeddings = doc_embeddings.reshape(1, -1)
|
307 |
+
norms = np.linalg.norm(doc_embeddings, axis=1, keepdims=True)
|
308 |
+
norm_doc_embeddings = doc_embeddings / (norms + 1e-10)
|
309 |
+
|
310 |
+
query_embedding = model.encode(query, convert_to_numpy=True)
|
311 |
+
if query_embedding.ndim == 1:
|
312 |
+
query_embedding = query_embedding.reshape(1, -1)
|
313 |
+
norm_query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
|
314 |
+
|
315 |
+
dim = norm_doc_embeddings.shape[1]
|
316 |
+
index = faiss.IndexFlatIP(dim)
|
317 |
+
index.add(norm_doc_embeddings)
|
318 |
+
k = norm_doc_embeddings.shape[0]
|
319 |
+
D, I = index.search(norm_query_embedding, k)
|
320 |
+
dense_scores = D.squeeze()
|
321 |
+
norm_dense_scores = robust_min_max_norm(dense_scores)
|
322 |
+
|
323 |
+
# Step 6: BM25 scoring.
|
324 |
+
logging.info("[DeepGit] Step 6: Compute BM25 scores.")
|
325 |
+
if BM25Okapi is not None:
|
326 |
+
tokenized_docs = [re.findall(r'\w+', doc.lower()) for doc in docs]
|
327 |
+
bm25 = BM25Okapi(tokenized_docs)
|
328 |
+
query_tokens = re.findall(r'\w+', query.lower())
|
329 |
+
bm25_scores = np.array(bm25.get_scores(query_tokens))
|
330 |
+
norm_bm25_scores = robust_min_max_norm(bm25_scores)
|
331 |
+
else:
|
332 |
+
norm_bm25_scores = np.zeros_like(norm_dense_scores)
|
333 |
+
|
334 |
+
# Step 7: Combine scores (dense score weighted higher).
|
335 |
+
logging.info("[DeepGit] Step 7: Combine dense and BM25 scores.")
|
336 |
+
alpha = 0.8
|
337 |
+
combined_scores = alpha * norm_dense_scores + (1 - alpha) * norm_bm25_scores
|
338 |
+
for idx, repo in enumerate(repositories):
|
339 |
+
repo["combined_score"] = float(combined_scores[idx])
|
340 |
+
|
341 |
+
# Step 8: Initial ranking by combined score.
|
342 |
+
logging.info("[DeepGit] Step 8: Initial ranking by combined score.")
|
343 |
+
ranked_repositories = sorted(repositories, key=lambda x: x.get("combined_score", 0), reverse=True)
|
344 |
+
|
345 |
+
# Step 9: Compute cross-encoder scores for the top candidates.
|
346 |
+
logging.info("[DeepGit] Step 9: Cross-encoder re-ranking.")
|
347 |
+
top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories
|
348 |
+
cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=len(top_candidates))
|
349 |
+
|
350 |
+
# Combine both metrics: final_score = w1 * combined_score + w2 * cross_encoder_score.
|
351 |
+
logging.info("[DeepGit] Step 10: Final scoring and output formatting.")
|
352 |
+
w1 = 0.7
|
353 |
+
w2 = 0.3
|
354 |
+
for candidate in top_candidates:
|
355 |
+
candidate["final_score"] = w1 * candidate.get("combined_score", 0) + w2 * candidate.get("cross_encoder_score", 0)
|
356 |
+
|
357 |
+
final_ranked = sorted(top_candidates, key=lambda x: x.get("final_score", 0), reverse=True)[:num_results]
|
358 |
+
|
359 |
+
# Step 11: Format final output with scores as percentages.
|
360 |
+
output = "\n=== Ranked Repositories ===\n"
|
361 |
+
for rank, repo in enumerate(final_ranked, 1):
|
362 |
+
output += f"Final Rank: {rank}\n"
|
363 |
+
output += f"Title: {repo['title']}\n"
|
364 |
+
output += f"Link: {repo['link']}\n"
|
365 |
+
output += f"Combined Score: {repo.get('combined_score', 0) * 100:.2f}%\n"
|
366 |
+
output += f"Cross-Encoder Score: {repo.get('cross_encoder_score', 0) * 100:.2f}%\n"
|
367 |
+
output += f"Final Score: {repo.get('final_score', 0) * 100:.2f}%\n"
|
368 |
+
snippet = repo['combined_text'][:300].replace('\n', ' ')
|
369 |
+
output += f"Snippet: {snippet}...\n"
|
370 |
+
output += '-' * 80 + "\n"
|
371 |
+
output += "\n=== End of Results ==="
|
372 |
+
return output
|
373 |
+
|
374 |
+
# ---------------------------
|
375 |
+
# Main Entry Point for Testing
|
376 |
+
# ---------------------------
|
377 |
+
if __name__ == "__main__":
|
378 |
+
test_query = "Chain of thought prompting for reasoning models"
|
379 |
+
result = run_repository_ranking(test_query)
|
380 |
+
print(result)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
requests==2.32.3
|
2 |
+
numpy==1.25.2
|
3 |
+
python-dotenv==1.0.1
|
4 |
+
sentence-transformers==3.4.1
|
5 |
+
faiss-cpu==1.9.0.post1
|
6 |
+
langgraph==0.2.62
|
7 |
+
langchain_groq==0.2.4
|
8 |
+
langchain_core==0.3.47
|