zamalali commited on
Commit
0b55d27
·
0 Parent(s):

Push of DeepGit core files

Browse files
Files changed (5) hide show
  1. .gitignore +5 -0
  2. agent.py +125 -0
  3. app.py +309 -0
  4. main.py +380 -0
  5. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .venv/
2
+ .env
3
+ __pycache__/
4
+ *.pyc
5
+ .gradio/
agent.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test.py
2
+
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from github import Github, Auth
6
+
7
+ # LangChain imports
8
+ from langchain_groq import ChatGroq
9
+ from langchain_core.tools import tool
10
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
11
+ from langchain import hub
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
16
+ GITHUB_PAT = os.getenv("GITHUB_API_KEY")
17
+
18
+ if not (GROQ_API_KEY and GITHUB_PAT):
19
+ raise ValueError("Please set GROQ_API_KEY and GITHUB_API_KEY in your .env")
20
+
21
+ # Initialize GitHub client
22
+ _auth = Auth.Token(GITHUB_PAT)
23
+ _gh = Github(auth=_auth)
24
+
25
+ # Define the GitHub tool
26
+ @tool
27
+ def get_repo_info(repo_name: str) -> str:
28
+ """Fetch and summarize metadata about a GitHub repository."""
29
+ try:
30
+ repo = _gh.get_repo(repo_name)
31
+ except Exception as e:
32
+ return f" Error fetching '{repo_name}': {e}"
33
+
34
+ name = repo.full_name
35
+ desc = repo.description or "No description"
36
+ url = repo.html_url
37
+ owner = repo.owner.login
38
+ stars = repo.stargazers_count
39
+ forks = repo.forks_count
40
+ issues = repo.open_issues_count
41
+ created = repo.created_at.isoformat()
42
+ updated = repo.updated_at.isoformat()
43
+ watchers = repo.watchers_count
44
+ default_br = repo.default_branch
45
+ language = repo.language or "None"
46
+
47
+ try:
48
+ license_name = repo.get_license().license.name
49
+ except:
50
+ license_name = "None"
51
+
52
+ topics = repo.get_topics()
53
+ try:
54
+ raw_md = repo.get_readme().decoded_content.decode("utf-8")
55
+ snippet = raw_md[:300].replace("\n", " ") + "..."
56
+ except:
57
+ snippet = "No README found"
58
+
59
+ contribs = repo.get_contributors()[:5]
60
+ contrib_list = ", ".join(f"{c.login}({c.contributions})" for c in contribs)
61
+
62
+ commits = repo.get_commits()[:3]
63
+ commit_list = "; ".join(c.commit.message.split("\n")[0] for c in commits)
64
+
65
+ return f"""
66
+ Repository: {name}
67
+ Description: {desc}
68
+ URL: {url}
69
+ Owner: {owner}
70
+ ⭐ Stars: {stars} 🍴 Forks: {forks} 🐛 Open Issues: {issues}
71
+ 👁️ Watchers: {watchers} Default branch: {default_br}
72
+ ⚙️ Language: {language} License: {license_name}
73
+ 🔍 Topics: {topics}
74
+
75
+ README Snippet: {snippet}
76
+
77
+ 👥 Top Contributors: {contrib_list}
78
+ 🧾 Latest Commits: {commit_list}
79
+ """
80
+
81
+ # Instantiate the Groq LLM
82
+ llm = ChatGroq(
83
+ model="llama-3.1-8b-instant",
84
+ temperature=0.3,
85
+ max_tokens=1024,
86
+ api_key=GROQ_API_KEY,
87
+ )
88
+
89
+ # Define the tools to pass into the agent
90
+ tools = [get_repo_info]
91
+
92
+ # Pull default tool-calling agent prompt from LangChain hub
93
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
94
+
95
+ prompt = ChatPromptTemplate.from_messages([
96
+ ("system",
97
+ "You are GitHub Agent, an expert at analyzing repositories.\n"
98
+ "When a user asks about a repo, call the tool and return a clear, concise summary of the repository based on the tool result.\n"
99
+ "Avoid repeating raw tool output or adding unnecessary disclaimers.\n"
100
+ "Respond in complete sentences, in natural language."
101
+ ),
102
+ MessagesPlaceholder(variable_name="chat_history", optional=True),
103
+ ("human", "{input}"),
104
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
105
+ ])
106
+
107
+ # Create the agent using LangChain's legacy AgentExecutor approach
108
+ agent = create_tool_calling_agent(llm, tools, prompt)
109
+
110
+ # Run the agent
111
+ # At the bottom of test.py
112
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, max_iterations=2)
113
+
114
+ # Export it
115
+ __all__ = ["agent_executor"]
116
+
117
+ # # Quick test
118
+ if __name__ == "__main__":
119
+ import time
120
+ user_input = "Give me details about the repo zamalali/deepgit"
121
+ start_time = time.time()
122
+ result = agent_executor.invoke({"input": user_input})
123
+ end_time = time.time()
124
+ print("\n Final Answer:\n", result["output"])
125
+ print(f"\n Took {end_time - start_time:.2f} seconds")
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import threading
4
+ import logging
5
+ from gradio.themes.utils import sizes
6
+ from main import run_repository_ranking # Import the repository ranking function
7
+ import agent # Import the test.py module for chat agent
8
+
9
+ # ---------------------------
10
+ # Global Logging Buffer Setup
11
+ # ---------------------------
12
+ LOG_BUFFER = []
13
+ LOG_BUFFER_LOCK = threading.Lock()
14
+
15
+ class BufferLogHandler(logging.Handler):
16
+ def emit(self, record):
17
+ log_entry = self.format(record)
18
+ with LOG_BUFFER_LOCK:
19
+ LOG_BUFFER.append(log_entry)
20
+
21
+ root_logger = logging.getLogger()
22
+ if not any(isinstance(h, BufferLogHandler) for h in root_logger.handlers):
23
+ handler = BufferLogHandler()
24
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
25
+ handler.setFormatter(formatter)
26
+ root_logger.addHandler(handler)
27
+
28
+ def filter_logs(logs):
29
+ filtered = []
30
+ last_was_fetching = False
31
+ for log in logs:
32
+ if "HTTP Request:" in log:
33
+ if not last_was_fetching:
34
+ filtered.append("Fetching repositories...")
35
+ last_was_fetching = True
36
+ else:
37
+ filtered.append(log)
38
+ last_was_fetching = False
39
+ return filtered
40
+
41
+ def parse_result_to_html(raw_result: str, num_results: int) -> (str, list):
42
+ """
43
+ Parses the raw string output from run_repository_ranking to an HTML table.
44
+ Only the top N results are displayed.
45
+ Returns (html, repo_names)
46
+ """
47
+ entries = raw_result.strip().split("Final Rank:")
48
+ entries = entries[1:num_results+1] # Use only the first N entries
49
+ if not entries:
50
+ return ("<p>No repositories found for your query.</p>", [])
51
+ html = """
52
+ <table border="1" style="width:80%; margin: auto; border-collapse: collapse;">
53
+ <thead>
54
+ <tr>
55
+ <th>Rank</th>
56
+ <th>Title</th>
57
+ <th>Link</th>
58
+ <th>Combined Score</th>
59
+ </tr>
60
+ </thead>
61
+ <tbody>
62
+ """
63
+ repo_names = []
64
+ for entry in entries:
65
+ lines = entry.strip().split("\n")
66
+ data = {}
67
+ data["Final Rank"] = lines[0].strip() if lines else ""
68
+ for line in lines[1:]:
69
+ if ": " in line:
70
+ key, val = line.split(": ", 1)
71
+ data[key.strip()] = val.strip()
72
+ # Try to extract repo name from the Link (github.com/user/repo)
73
+ link = data.get('Link', '')
74
+ repo_name = ''
75
+ if 'github.com/' in link:
76
+ repo_name = link.split('github.com/')[-1].strip('/ ')
77
+ if repo_name:
78
+ repo_names.append(repo_name)
79
+ html += f"""
80
+ <tr>
81
+ <td>{data.get('Final Rank', '')}</td>
82
+ <td>{data.get('Title', '')}</td>
83
+ <td><a href=\"{data.get('Link', '#')}\" target=\"_blank\">GitHub</a></td>
84
+ <td>{data.get('Combined Score', '')}</td>
85
+ </tr>
86
+ """
87
+ html += "</tbody></table>"
88
+ return html, repo_names
89
+
90
+ # ---------------------------
91
+ # GPU-enabled Wrapper for Repository Ranking
92
+ # ---------------------------
93
+ def gpu_run_repo(topic: str, num_results: int):
94
+ return run_repository_ranking(topic, num_results)
95
+
96
+ def run_lite_workflow(topic, num_results, result_container):
97
+ result = gpu_run_repo(topic, num_results)
98
+ result_container["raw_result"] = result
99
+
100
+ def stream_lite_workflow(topic, num_results):
101
+ logging.info("[UI] User started a new search for topic: %s", topic)
102
+ with LOG_BUFFER_LOCK:
103
+ LOG_BUFFER.clear()
104
+ result_container = {}
105
+ workflow_thread = threading.Thread(target=run_lite_workflow, args=(topic, num_results, result_container))
106
+ workflow_thread.start()
107
+
108
+ last_index = 0
109
+ while workflow_thread.is_alive() or (last_index < len(LOG_BUFFER)):
110
+ with LOG_BUFFER_LOCK:
111
+ new_logs = LOG_BUFFER[last_index:]
112
+ last_index = len(LOG_BUFFER)
113
+ if new_logs:
114
+ filtered_logs = filter_logs(new_logs)
115
+ status_msg = filtered_logs[-1]
116
+ detail_msg = "<br/>".join(filtered_logs)
117
+ yield status_msg, detail_msg, []
118
+ time.sleep(0.5)
119
+
120
+ workflow_thread.join()
121
+ with LOG_BUFFER_LOCK:
122
+ final_logs = LOG_BUFFER[:]
123
+ raw_result = result_container.get("raw_result", "No results returned.")
124
+ html_result, repo_names = parse_result_to_html(raw_result, num_results)
125
+ yield "", html_result, repo_names
126
+
127
+ def lite_runner(topic, num_results):
128
+ logging.info("[UI] Running lite_runner for topic: %s", topic)
129
+ yield "Workflow started", "<p>Processing your request. Please wait...</p>", []
130
+ for status, details, repos in stream_lite_workflow(topic, num_results):
131
+ yield status, details, repos
132
+
133
+ # ---------------------------
134
+ # App UI Setup Using Gradio Soft Theme with Centered Layout
135
+ # ---------------------------
136
+ with gr.Blocks(
137
+ theme=gr.themes.Soft(text_size=sizes.text_md),
138
+ title="DeepGit Lite",
139
+ css="""
140
+ /* Center header and footer */
141
+ #header { text-align: center; margin-bottom: 20px; }
142
+ #main-container { max-width: 800px; margin: auto; }
143
+ #footer { text-align: center; margin-top: 20px; }
144
+ """
145
+ ) as demo:
146
+ gr.Markdown(
147
+ """
148
+ <div style="padding-top: 60px;">
149
+ <div style="display: flex; align-items: center; justify-content: center;">
150
+ <img src="https://img.icons8.com/?size=100&id=118557&format=png&color=000000"
151
+ style="width: 60px; height: 60px; margin-right: 12px;">
152
+ <h1 style="margin: 0; font-size: 2.5em; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;">
153
+ DeepGit Lite
154
+ </h1>
155
+ </div>
156
+ <div style="text-align: center; margin-top: 20px; font-size: 1.1em; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;">
157
+ <p>
158
+ ✨ DeepGit Lite is the lightweight pro version of <strong>DeepGit</strong>.<br>
159
+ It harnesses advanced deep semantic search to explore GitHub repositories and deliver curated results.<br>
160
+ Under the hood, it leverages a hybrid ranking approach combining dense retrieval, BM25 scoring, and cross-encoder re-ranking for optimal discovery.<br>
161
+ If the agent returns no repositories found, it means no chain was invoked due to GPU unavailability. Please duplicate the space and re-run.
162
+ </p>
163
+ <p>
164
+ 🚀 Check out the full DeepGit version on
165
+ <a href="https://github.com/zamalali/DeepGit" target="_blank">GitHub</a> and ⭐
166
+ <strong>Star DeepGit</strong> on GitHub!
167
+ </p>
168
+ </div>
169
+ </div>
170
+ """,
171
+ elem_id="header"
172
+ )
173
+
174
+ # --- Search UI ---
175
+ with gr.Column(elem_id="main-container", visible=True) as search_ui:
176
+ research_input = gr.Textbox(
177
+ label="Research Query",
178
+ placeholder="Enter your research topic here, e.g., Looking for a low code/no code tool to augment images and annotations?",
179
+ lines=3
180
+ )
181
+ num_results_slider = gr.Slider(
182
+ minimum=5, maximum=25, value=10, step=1,
183
+ label="Number of Results to Display",
184
+ info="Choose how many top repositories to show (sorted by score)"
185
+ )
186
+ run_button = gr.Button("Run DeepGit Lite", variant="primary")
187
+ status_display = gr.Markdown(label="Status")
188
+ detail_display = gr.HTML(label="Results")
189
+ repo_state = gr.State([])
190
+ go_to_chat_btn = gr.Button("Go to Chat", visible=False)
191
+
192
+ # --- Chat UI ---
193
+ with gr.Column(visible=False) as chat_ui:
194
+
195
+ repo_choice = gr.Radio(choices=[], label="Select a repository", interactive=True)
196
+ chat_history = gr.Chatbot(label="Chat with GitHub Agent")
197
+ user_input = gr.Textbox(label="Your question", placeholder="Ask about the selected repo...e.g., tell me a bit more and guide me to set this up and running?")
198
+ send_btn = gr.Button("Send")
199
+ chat_state = gr.State([])
200
+ back_btn = gr.Button("Back to Search")
201
+
202
+ def update_chat_button(status, details, repos):
203
+ logging.info("[UI] Search complete. Showing Go to Chat button: %s", bool(repos))
204
+ return gr.update(visible=bool(repos)), repos
205
+
206
+ def show_chat_ui(repos):
207
+ logging.info("[UI] Switching to Chat UI. Repositories available: %s", repos)
208
+ return gr.update(visible=False), gr.update(visible=True), gr.update(choices=repos, value=None), []
209
+
210
+ def back_to_search():
211
+ logging.info("[UI] Switching back to Search UI.")
212
+ return gr.update(visible=True), gr.update(visible=False), gr.update(value=[]), gr.update(value=None), []
213
+
214
+ def chat_with_agent(user_msg, repo, history):
215
+ logging.info("[Chat] User sent message: '%s' for repo: '%s'", user_msg, repo)
216
+ if not user_msg or not user_msg.strip():
217
+ # Block blank messages
218
+ return history + [["", "Please enter a message before sending."]], history
219
+ if not repo:
220
+ return history + [[user_msg, "Please select a repository first."]], history
221
+ full_query = f"[{repo}] {user_msg}"
222
+ try:
223
+ result = agent.agent_executor.invoke({"input": full_query})
224
+ answer = result["output"]
225
+ logging.info("[Chat] Agent response received.")
226
+ except Exception as e:
227
+ answer = f"Error: {e}"
228
+ logging.error("[Chat] Error in agent_executor: %s", e)
229
+ history = history + [[user_msg, answer]]
230
+ return history, history
231
+
232
+ # Disable send button if no repo is selected or message is blank, and show a helpful message
233
+ def can_send(user_msg, repo):
234
+ if not user_msg or not user_msg.strip():
235
+ return gr.update(interactive=False, value="Enter a message to send")
236
+ if not repo:
237
+ return gr.update(interactive=False, value="Select a repository")
238
+ return gr.update(interactive=True, value="Send")
239
+ user_input.change(
240
+ fn=can_send,
241
+ inputs=[user_input, repo_choice],
242
+ outputs=[send_btn],
243
+ show_progress=False
244
+ )
245
+ repo_choice.change(
246
+ fn=can_send,
247
+ inputs=[user_input, repo_choice],
248
+ outputs=[send_btn],
249
+ show_progress=False
250
+ )
251
+
252
+ run_button.click(
253
+ fn=lite_runner,
254
+ inputs=[research_input, num_results_slider],
255
+ outputs=[status_display, detail_display, repo_state],
256
+ api_name="deepgit_lite",
257
+ show_progress=True
258
+ ).then(
259
+ fn=update_chat_button,
260
+ inputs=[status_display, detail_display, repo_state],
261
+ outputs=[go_to_chat_btn, repo_state]
262
+ )
263
+
264
+ research_input.submit(
265
+ fn=lite_runner,
266
+ inputs=[research_input, num_results_slider],
267
+ outputs=[status_display, detail_display, repo_state],
268
+ api_name="deepgit_lite_submit",
269
+ show_progress=True
270
+ ).then(
271
+ fn=update_chat_button,
272
+ inputs=[status_display, detail_display, repo_state],
273
+ outputs=[go_to_chat_btn, repo_state]
274
+ )
275
+
276
+ go_to_chat_btn.click(
277
+ fn=show_chat_ui,
278
+ inputs=[repo_state],
279
+ outputs=[search_ui, chat_ui, repo_choice, chat_state]
280
+ )
281
+
282
+ back_btn.click(
283
+ fn=back_to_search,
284
+ inputs=[],
285
+ outputs=[search_ui, chat_ui, chat_history, repo_choice, chat_state]
286
+ )
287
+
288
+ send_btn.click(
289
+ fn=chat_with_agent,
290
+ inputs=[user_input, repo_choice, chat_state],
291
+ outputs=[chat_history, chat_state],
292
+ queue=False
293
+ )
294
+ user_input.submit(
295
+ fn=chat_with_agent,
296
+ inputs=[user_input, repo_choice, chat_state],
297
+ outputs=[chat_history, chat_state],
298
+ queue=False
299
+ )
300
+
301
+ gr.HTML(
302
+ """
303
+ <div id="footer">
304
+ Made with ❤️ by <b>Zamal</b>
305
+ </div>
306
+ """
307
+ )
308
+
309
+ demo.queue(max_size=10).launch()
main.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import requests
4
+ import numpy as np
5
+ import faiss
6
+ import re
7
+ import logging
8
+ from pathlib import Path
9
+
10
+ # For local development, load environment variables from a .env file.
11
+ # In HuggingFace Spaces, secrets are automatically available as environment variables.
12
+ from dotenv import load_dotenv
13
+ load_dotenv()
14
+
15
+ from sentence_transformers import SentenceTransformer, CrossEncoder
16
+ from langchain_groq import ChatGroq
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+
19
+ # Optionally import BM25 for sparse retrieval.
20
+ try:
21
+ from rank_bm25 import BM25Okapi
22
+ except ImportError:
23
+ BM25Okapi = None
24
+
25
+ # ---------------------------
26
+ # Environment Variables & Setup
27
+ # ---------------------------
28
+ # GitHub API key (required for GitHub API calls)
29
+ GITHUB_API_KEY = os.getenv("GITHUB_API_KEY")
30
+ # GROQ API key (if required by ChatGroq)
31
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
32
+ # HuggingFace token (if you need it to load private models from HuggingFace)
33
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
34
+
35
+ CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
36
+
37
+ # Set up a persistent session for GitHub API requests.
38
+ session = requests.Session()
39
+ session.headers.update({
40
+ "Authorization": f"token {GITHUB_API_KEY}",
41
+ "Accept": "application/vnd.github.v3+json"
42
+ })
43
+
44
+ # ---------------------------
45
+ # Langchain Groq Setup for Search Tag Conversion
46
+ # ---------------------------
47
+ llm = ChatGroq(
48
+ model="deepseek-r1-distill-llama-70b",
49
+ temperature=0.3,
50
+ max_tokens=512,
51
+ max_retries=3,
52
+ api_key=GROQ_API_KEY # Pass GROQ_API_KEY if the ChatGroq library supports it.
53
+ )
54
+
55
+ prompt = ChatPromptTemplate.from_messages([
56
+ ("system",
57
+ """You are a GitHub search optimization expert.
58
+
59
+ Your job is to:
60
+ 1. Read a user's query about tools, research, or tasks.
61
+ 2. Detect if the query mentions a specific programming language other than Python (for example, JavaScript or JS). If so, record that language as the target language.
62
+ 3. Think iteratively and generate your internal chain-of-thought enclosed in <think> ... </think> tags.
63
+ 4. After your internal reasoning, output up to five GitHub-style search tags or library names that maximize repository discovery.
64
+ Use as many tags as necessary based on the query's complexity, but never more than five.
65
+ 5. If you detected a non-Python target language, append an additional tag at the end in the format target-[language] (e.g., target-javascript).
66
+ If no specific language is mentioned, do not include any target tag.
67
+
68
+ Output Format:
69
+ tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]
70
+
71
+ Rules:
72
+ - Use lowercase and hyphenated keywords (e.g., image-augmentation, chain-of-thought).
73
+ - Use terms commonly found in GitHub repo names, topics, or descriptions.
74
+ - Avoid generic terms like "python", "ai", "tool", "project".
75
+ - Do NOT use full phrases or vague words like "no-code", "framework", or "approach".
76
+ - Prefer real tools, popular methods, or dataset names when mentioned.
77
+ - If your output does not strictly match the required format, correct it after your internal reasoning.
78
+ - Choose high-signal keywords to ensure the search yields the most relevant GitHub repositories.
79
+
80
+ Excellent Examples:
81
+
82
+ Input: "No code tool to augment image and annotation"
83
+ Output: image-augmentation:albumentations
84
+
85
+ Input: "Repos around chain of thought prompting mainly for finetuned models"
86
+ Output: chain-of-thought:finetuned-llm
87
+
88
+ Input: "Find repositories implementing data augmentation pipelines in JavaScript"
89
+ Output: data-augmentation:target-javascript
90
+
91
+ Output must be ONLY the search tags separated by colons. Do not include any extra text, bullet points, or explanations.
92
+ """),
93
+ ("human", "{query}")
94
+ ])
95
+ chain = prompt | llm
96
+
97
+ def valid_tags(tags: str) -> bool:
98
+ pattern = r'^[a-z0-9-]+(?::[a-z0-9-]+){1,5}$'
99
+ return re.match(pattern, tags) is not None
100
+
101
+ def parse_search_tags(response: str) -> str:
102
+ # Remove any text inside <think>...</think> blocks.
103
+ cleaned = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
104
+ pattern = r'([a-z0-9-]+(?::[a-z0-9-]+){1,5})'
105
+ match = re.search(pattern, cleaned)
106
+ if match:
107
+ return match.group(1).strip()
108
+ return cleaned.strip()
109
+
110
+ def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str:
111
+ print(f"\n [iterative_convert_to_search_tags] Input Query: {query}")
112
+ refined_query = query
113
+ tags_output = ""
114
+ for iteration in range(max_iterations):
115
+ print(f"\n Iteration {iteration+1}")
116
+ response = chain.invoke({"query": refined_query})
117
+ full_output = response.content.strip()
118
+ tags_output = parse_search_tags(full_output)
119
+ print(f"Output Tags: {tags_output}")
120
+ if valid_tags(tags_output):
121
+ print("Valid tags format detected.")
122
+ return tags_output
123
+ else:
124
+ print(" Invalid tags format. Requesting refinement...")
125
+ refined_query = f"{query}\nPlease refine your answer so that the output strictly matches the format: tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]."
126
+ print("Final output (may be invalid):", tags_output)
127
+ return tags_output
128
+
129
+ # ---------------------------
130
+ # GitHub API Helper Functions
131
+ # ---------------------------
132
+ def fetch_readme_content(repo_full_name: str) -> str:
133
+ readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
134
+ response = session.get(readme_url)
135
+ if response.status_code == 200:
136
+ readme_data = response.json()
137
+ try:
138
+ return base64.b64decode(readme_data.get('content', '')).decode('utf-8', errors='replace')
139
+ except Exception:
140
+ return ""
141
+ return ""
142
+
143
+ def fetch_markdown_contents(repo_full_name: str) -> str:
144
+ url = f"https://api.github.com/repos/{repo_full_name}/contents"
145
+ response = session.get(url)
146
+ contents = ""
147
+ if response.status_code == 200:
148
+ items = response.json()
149
+ for item in items:
150
+ if item.get("type") == "file" and item.get("name", "").lower().endswith(".md"):
151
+ file_url = item.get("download_url")
152
+ if file_url:
153
+ file_resp = requests.get(file_url)
154
+ if file_resp.status_code == 200:
155
+ contents += "\n" + file_resp.text
156
+ return contents
157
+
158
+ def fetch_all_markdown(repo_full_name: str) -> str:
159
+ readme = fetch_readme_content(repo_full_name)
160
+ other_md = fetch_markdown_contents(repo_full_name)
161
+ return readme + "\n" + other_md
162
+
163
+ def fetch_github_repositories(query: str, max_results: int = 10) -> list:
164
+ url = "https://api.github.com/search/repositories"
165
+ params = {
166
+ "q": query,
167
+ "per_page": max_results
168
+ }
169
+ response = session.get(url, params=params)
170
+ if response.status_code != 200:
171
+ print(f"Error {response.status_code}: {response.json().get('message')}")
172
+ return []
173
+ repo_list = []
174
+ for repo in response.json().get('items', []):
175
+ repo_link = repo.get('html_url')
176
+ description = repo.get('description') or ""
177
+ combined_markdown = fetch_all_markdown(repo.get('full_name'))
178
+ combined_text = (description + "\n" + combined_markdown).strip()
179
+ repo_list.append({
180
+ "title": repo.get('name', 'No title available'),
181
+ "link": repo_link,
182
+ "combined_text": combined_text
183
+ })
184
+ return repo_list
185
+
186
+ # ---------------------------
187
+ # Dense Retrieval Model Setup
188
+ # ---------------------------
189
+ try:
190
+ # If using a GPU-enabled model, the HuggingFace token can be used for private models.
191
+ model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
192
+ except Exception as e:
193
+ print("Error initializing GPU for SentenceTransformer; falling back to CPU:", e)
194
+ model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
195
+
196
+ def robust_min_max_norm(scores: np.ndarray) -> np.ndarray:
197
+ min_val = scores.min()
198
+ max_val = scores.max()
199
+ if max_val - min_val < 1e-10:
200
+ return np.ones_like(scores)
201
+ return (scores - min_val) / (max_val - min_val)
202
+
203
+ # ---------------------------
204
+ # Cross-Encoder Re-Ranking Function
205
+ # ---------------------------
206
+ def cross_encoder_rerank_candidates(candidates: list, query: str, model_name: str, top_n: int = 10) -> list:
207
+ try:
208
+ cross_encoder = CrossEncoder(model_name, device='cpu')
209
+ except Exception as e:
210
+ print("Error initializing CrossEncoder on GPU; falling back to CPU:", e)
211
+ cross_encoder = CrossEncoder(model_name, device='cpu')
212
+
213
+ CHUNK_SIZE = 2000
214
+ MAX_DOC_LENGTH = 5000
215
+ MIN_DOC_LENGTH = 200
216
+
217
+ def split_text(text: str, chunk_size: int = CHUNK_SIZE) -> list:
218
+ return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
219
+
220
+ for candidate in candidates:
221
+ doc = candidate.get("combined_text", "")
222
+ if len(doc) > MAX_DOC_LENGTH:
223
+ doc = doc[:MAX_DOC_LENGTH]
224
+ try:
225
+ if len(doc) < MIN_DOC_LENGTH:
226
+ score = cross_encoder.predict([[query, doc]])
227
+ if hasattr(score, '__len__') and len(score) == 1:
228
+ candidate["cross_encoder_score"] = float(score[0])
229
+ else:
230
+ candidate["cross_encoder_score"] = float(score)
231
+ else:
232
+ chunks = split_text(doc)
233
+ pairs = [[query, chunk] for chunk in chunks]
234
+ scores = cross_encoder.predict(pairs)
235
+ scores = np.array(scores)
236
+ max_score = float(np.max(scores)) if scores.size > 0 else 0.0
237
+ avg_score = float(np.mean(scores)) if scores.size > 0 else 0.0
238
+ candidate["cross_encoder_score"] = 0.5 * max_score + 0.5 * avg_score
239
+ except Exception as e:
240
+ logging.debug(f"[cross-encoder] Error scoring candidate {candidate.get('link', 'unknown')}: {e}")
241
+ candidate["cross_encoder_score"] = 0.0
242
+
243
+ all_scores = [candidate["cross_encoder_score"] for candidate in candidates]
244
+ if all_scores:
245
+ min_score = min(all_scores)
246
+ if min_score < 0:
247
+ for candidate in candidates:
248
+ candidate["cross_encoder_score"] += -min_score
249
+
250
+ return candidates
251
+
252
+ # ---------------------------
253
+ # Main Ranking Function with Hybrid Retrieval and Combined Scoring
254
+ # ---------------------------
255
+ def run_repository_ranking(query: str, num_results: int = 10) -> str:
256
+ logging.info("[DeepGit] Step 1: Generate search tags from the query.")
257
+ search_tags = iterative_convert_to_search_tags(query)
258
+ tag_list = [tag.strip() for tag in search_tags.split(":") if tag.strip()]
259
+
260
+ # Step 2: Handle target language extraction.
261
+ logging.info("[DeepGit] Step 2: Handle target language extraction.")
262
+ if any(tag.startswith("target-") for tag in tag_list):
263
+ target_tag = next(tag for tag in tag_list if tag.startswith("target-"))
264
+ lang_query = f"language:{target_tag.replace('target-', '')}"
265
+ tag_list = [tag for tag in tag_list if not tag.startswith("target-")]
266
+ else:
267
+ lang_query = "language:python"
268
+
269
+ # Step 3: Build advanced search qualifiers.
270
+ logging.info("[DeepGit] Step 3: Build advanced search qualifiers and fetch repositories.")
271
+ advanced_qualifier = "in:name,description,readme"
272
+ all_repositories = []
273
+
274
+ for tag in tag_list:
275
+ github_query = f"{tag} {advanced_qualifier} {lang_query}"
276
+ logging.info(f"[DeepGit] GitHub Query: {github_query}")
277
+ repos = fetch_github_repositories(github_query, max_results=15)
278
+ all_repositories.extend(repos)
279
+
280
+ combined_query = " OR ".join(tag_list)
281
+ combined_query = f"({combined_query}) {advanced_qualifier} {lang_query}"
282
+ logging.info(f"[DeepGit] Combined GitHub Query: {combined_query}")
283
+ repos = fetch_github_repositories(combined_query, max_results=15)
284
+ all_repositories.extend(repos)
285
+
286
+ unique_repositories = {}
287
+ for repo in all_repositories:
288
+ if repo["link"] not in unique_repositories:
289
+ unique_repositories[repo["link"]] = repo
290
+ else:
291
+ existing_text = unique_repositories[repo["link"]]["combined_text"]
292
+ unique_repositories[repo["link"]]["combined_text"] = existing_text + "\n" + repo["combined_text"]
293
+ repositories = list(unique_repositories.values())
294
+
295
+ if not repositories:
296
+ return "No repositories found for your query."
297
+
298
+ # Step 4: Prepare documents.
299
+ logging.info("[DeepGit] Step 4: Prepare documents for dense retrieval.")
300
+ docs = [repo.get("combined_text", "") for repo in repositories]
301
+
302
+ # Step 5: Dense retrieval.
303
+ logging.info("[DeepGit] Step 5: Compute dense embeddings and scores.")
304
+ doc_embeddings = model.encode(docs, convert_to_numpy=True, show_progress_bar=True, batch_size=16)
305
+ if doc_embeddings.ndim == 1:
306
+ doc_embeddings = doc_embeddings.reshape(1, -1)
307
+ norms = np.linalg.norm(doc_embeddings, axis=1, keepdims=True)
308
+ norm_doc_embeddings = doc_embeddings / (norms + 1e-10)
309
+
310
+ query_embedding = model.encode(query, convert_to_numpy=True)
311
+ if query_embedding.ndim == 1:
312
+ query_embedding = query_embedding.reshape(1, -1)
313
+ norm_query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
314
+
315
+ dim = norm_doc_embeddings.shape[1]
316
+ index = faiss.IndexFlatIP(dim)
317
+ index.add(norm_doc_embeddings)
318
+ k = norm_doc_embeddings.shape[0]
319
+ D, I = index.search(norm_query_embedding, k)
320
+ dense_scores = D.squeeze()
321
+ norm_dense_scores = robust_min_max_norm(dense_scores)
322
+
323
+ # Step 6: BM25 scoring.
324
+ logging.info("[DeepGit] Step 6: Compute BM25 scores.")
325
+ if BM25Okapi is not None:
326
+ tokenized_docs = [re.findall(r'\w+', doc.lower()) for doc in docs]
327
+ bm25 = BM25Okapi(tokenized_docs)
328
+ query_tokens = re.findall(r'\w+', query.lower())
329
+ bm25_scores = np.array(bm25.get_scores(query_tokens))
330
+ norm_bm25_scores = robust_min_max_norm(bm25_scores)
331
+ else:
332
+ norm_bm25_scores = np.zeros_like(norm_dense_scores)
333
+
334
+ # Step 7: Combine scores (dense score weighted higher).
335
+ logging.info("[DeepGit] Step 7: Combine dense and BM25 scores.")
336
+ alpha = 0.8
337
+ combined_scores = alpha * norm_dense_scores + (1 - alpha) * norm_bm25_scores
338
+ for idx, repo in enumerate(repositories):
339
+ repo["combined_score"] = float(combined_scores[idx])
340
+
341
+ # Step 8: Initial ranking by combined score.
342
+ logging.info("[DeepGit] Step 8: Initial ranking by combined score.")
343
+ ranked_repositories = sorted(repositories, key=lambda x: x.get("combined_score", 0), reverse=True)
344
+
345
+ # Step 9: Compute cross-encoder scores for the top candidates.
346
+ logging.info("[DeepGit] Step 9: Cross-encoder re-ranking.")
347
+ top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories
348
+ cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=len(top_candidates))
349
+
350
+ # Combine both metrics: final_score = w1 * combined_score + w2 * cross_encoder_score.
351
+ logging.info("[DeepGit] Step 10: Final scoring and output formatting.")
352
+ w1 = 0.7
353
+ w2 = 0.3
354
+ for candidate in top_candidates:
355
+ candidate["final_score"] = w1 * candidate.get("combined_score", 0) + w2 * candidate.get("cross_encoder_score", 0)
356
+
357
+ final_ranked = sorted(top_candidates, key=lambda x: x.get("final_score", 0), reverse=True)[:num_results]
358
+
359
+ # Step 11: Format final output with scores as percentages.
360
+ output = "\n=== Ranked Repositories ===\n"
361
+ for rank, repo in enumerate(final_ranked, 1):
362
+ output += f"Final Rank: {rank}\n"
363
+ output += f"Title: {repo['title']}\n"
364
+ output += f"Link: {repo['link']}\n"
365
+ output += f"Combined Score: {repo.get('combined_score', 0) * 100:.2f}%\n"
366
+ output += f"Cross-Encoder Score: {repo.get('cross_encoder_score', 0) * 100:.2f}%\n"
367
+ output += f"Final Score: {repo.get('final_score', 0) * 100:.2f}%\n"
368
+ snippet = repo['combined_text'][:300].replace('\n', ' ')
369
+ output += f"Snippet: {snippet}...\n"
370
+ output += '-' * 80 + "\n"
371
+ output += "\n=== End of Results ==="
372
+ return output
373
+
374
+ # ---------------------------
375
+ # Main Entry Point for Testing
376
+ # ---------------------------
377
+ if __name__ == "__main__":
378
+ test_query = "Chain of thought prompting for reasoning models"
379
+ result = run_repository_ranking(test_query)
380
+ print(result)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ requests==2.32.3
2
+ numpy==1.25.2
3
+ python-dotenv==1.0.1
4
+ sentence-transformers==3.4.1
5
+ faiss-cpu==1.9.0.post1
6
+ langgraph==0.2.62
7
+ langchain_groq==0.2.4
8
+ langchain_core==0.3.47