repochat / app.py
jyo01's picture
Update app.py
f3f07d4 verified
raw
history blame
11.1 kB
import re
import json
import base64
import requests
import torch
import uvicorn
import nest_asyncio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer, models
import gradio as gr
############################################
# Configuration
############################################
import os
HF_TOKEN = os.environ.get("HF_TOKEN")
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
############################################
# GitHub API Functions
############################################
def extract_repo_info(github_url: str):
pattern = r"github\.com/([^/]+)/([^/]+)"
match = re.search(pattern, github_url)
if match:
owner = match.group(1)
repo = match.group(2).replace('.git', '')
return owner, repo
else:
raise ValueError("Invalid GitHub URL provided.")
def get_repo_metadata(owner: str, repo: str):
headers = {'Authorization': f'token {GITHUB_TOKEN}'}
repo_url = f"https://api.github.com/repos/{owner}/{repo}"
response = requests.get(repo_url, headers=headers)
return response.json()
def get_repo_tree(owner: str, repo: str, branch: str):
headers = {'Authorization': f'token {GITHUB_TOKEN}'}
tree_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
response = requests.get(tree_url, headers=headers)
return response.json()
def get_file_content(owner: str, repo: str, file_path: str):
headers = {'Authorization': f'token {GITHUB_TOKEN}'}
content_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{file_path}"
response = requests.get(content_url, headers=headers)
data = response.json()
if 'content' in data:
return base64.b64decode(data['content']).decode('utf-8')
else:
return None
############################################
# Embedding Functions
############################################
def preprocess_text(text: str) -> str:
cleaned_text = text.strip()
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
return cleaned_text
def load_embedding_model(model_name: str = 'huggingface/CodeBERTa-small-v1') -> SentenceTransformer:
transformer_model = models.Transformer(model_name)
pooling_model = models.Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True)
model = SentenceTransformer(modules=[transformer_model, pooling_model])
return model
def generate_embedding(text: str, model_name: str = 'huggingface/CodeBERTa-small-v1') -> list:
processed_text = preprocess_text(text)
model = load_embedding_model(model_name)
embedding = model.encode(processed_text)
return embedding
############################################
# LLM Integration Functions
############################################
def is_detailed_query(query: str) -> bool:
keywords = ["detail", "detailed", "thorough", "in depth", "comprehensive", "extensive"]
return any(keyword in query.lower() for keyword in keywords)
def generate_prompt(query: str, context_snippets: list) -> str:
context = "\n\n".join(context_snippets)
if is_detailed_query(query):
instruction = "Provide an extremely detailed and thorough explanation of at least 500 words."
else:
instruction = "Answer concisely."
prompt = (
f"Below is some context from a GitHub repository:\n\n"
f"{context}\n\n"
f"Based on the above, {instruction}\n{query}\n"
f"Answer:"
)
return prompt
# def get_llm_response(prompt: str, model_name: str = "meta-llama/Llama-2-7b-chat-hf", max_new_tokens: int = None) -> str:
# if max_new_tokens is None:
# max_new_tokens = 1024 if is_detailed_query(prompt) else 256
# torch.cuda.empty_cache()
# if not os.path.exists("offload"):
# os.makedirs("offload")
# tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HF_TOKEN)
# model = AutoModelForCausalLM.from_pretrained(
# model_name,
# device_map="auto",
# offload_folder="offload", # Specify the folder where weights will be offloaded
# use_safetensors=False,
# trust_remote_code=True,
# torch_dtype=torch.float16,
# token=HF_TOKEN
# )
# text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
# outputs = text_gen(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7)
# full_response = outputs[0]['generated_text']
# marker = "Answer:"
# if marker in full_response:
# answer = full_response.split(marker, 1)[1].strip()
# else:
# answer = full_response.strip()
# return answer
# def get_llm_response(prompt: str, model_name: str = "EleutherAI/gpt-neo-125M", max_new_tokens: int = None) -> str:
# if max_new_tokens is None:
# max_new_tokens = 256 # You can adjust this value as needed.
# torch.cuda.empty_cache()
# # Load the tokenizer and model for GPT-Neo 125M.
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(
# model_name,
# device_map="auto",
# use_safetensors=False,
# torch_dtype=torch.float32 # Using default precision since model is small.
# )
# text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
# outputs = text_gen(
# prompt,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# temperature=0.9, # Increased temperature
# top_p=0.9, # Using nucleus sampling
# top_k=50 # Limit to top 50 tokens per step
# )
# full_response = outputs[0]['generated_text']
# marker = "Answer:"
# if marker in full_response:
# answer = full_response.split(marker, 1)[1].strip()
# else:
# answer = full_response.strip()
# return answer
def get_gemini_flash_response(prompt: str) -> str:
from google import genai
from google.genai import types
# Create a Gemini client using the API key from the environment.
client = genai.Client(api_key=GEMINI_API_KEY)
# Configure generation settings.
config = types.GenerateContentConfig(
max_output_tokens=500, # Adjust as needed.
temperature=0.1 # Lower temperature for more deterministic responses.
)
# Send the prompt to the Gemini-2.0-flash model.
response = client.models.generate_content(
model="gemini-2.0-flash",
contents=[prompt],
config=config
)
return response.text
############################################
# Gradio Interface Functions
############################################
# For file content retrieval, we now use the file path directly.
def get_file_content_for_choice(github_url: str, file_path: str):
try:
owner, repo = extract_repo_info(github_url)
except Exception as e:
return str(e)
content = get_file_content(owner, repo, file_path)
return content, file_path
def chat_with_file(github_url: str, file_path: str, user_query: str):
# Retrieve file content using the file path directly.
result = get_file_content_for_choice(github_url, file_path)
if isinstance(result, str):
return result # Return error message if occurred.
file_content, selected_file = result
# Preprocess file content and extract context.
preprocessed = preprocess_text(file_content)
context_snippet = preprocessed[:5000] # Use first 1000 characters as context.
# Generate the prompt based on context and user query.
prompt = generate_prompt(user_query, [context_snippet])
# Use Gemini Flash to generate a response.
llm_response = get_gemini_flash_response(prompt)
return f"File: {selected_file}\n\nLLM Response:\n{llm_response}"
def load_repo_contents_backend(github_url: str):
try:
owner, repo = extract_repo_info(github_url)
except Exception as e:
return f"Error: {str(e)}"
repo_data = get_repo_metadata(owner, repo)
default_branch = repo_data.get("default_branch", "main")
tree_data = get_repo_tree(owner, repo, default_branch)
if "tree" not in tree_data:
return "Error: Could not fetch repository tree."
file_list = [item["path"] for item in tree_data["tree"] if item["type"] == "blob"]
return file_list
############################################
# Gradio Interface Setup
############################################
with gr.Blocks() as demo:
gr.Markdown("# RepoChat - Chat with Repository Files")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Repository Information")
github_url_input = gr.Textbox(label="GitHub Repository URL", placeholder="https://github.com/username/repository")
load_repo_btn = gr.Button("Load Repository Contents")
# Dropdown with choices as file paths; default value is empty.
file_dropdown = gr.Dropdown(label="Select a File", interactive=True, value="", choices=[])
repo_content_output = gr.Textbox(label="File Content", interactive=False, lines=10)
with gr.Column(scale=2):
gr.Markdown("### Chat Interface")
chat_query_input = gr.Textbox(label="Your Query", placeholder="Type your query here")
chat_output = gr.Textbox(label="Chatbot Response", interactive=False, lines=10)
chat_btn = gr.Button("Send Query")
# Callback: Update file dropdown choices.
def update_file_dropdown(github_url):
files = load_repo_contents_backend(github_url)
if isinstance(files, str): # Error message
print("Error loading files:", files)
return gr.update(choices=[], value="")
print("Files loaded:", files)
# Do not pre-select any file (empty value)
return gr.update(choices=files, value="")
load_repo_btn.click(fn=update_file_dropdown, inputs=[github_url_input], outputs=[file_dropdown])
# Callback: Update repository content when a file is selected.
def update_repo_content(github_url, file_choice):
if not file_choice:
return "No file selected."
content, _ = get_file_content_for_choice(github_url, file_choice)
return content
file_dropdown.change(fn=update_repo_content, inputs=[github_url_input, file_dropdown], outputs=[repo_content_output])
# Callback: Process chat query.
def process_chat(github_url, file_choice, chat_query):
if not file_choice:
return "Please select a file first."
return chat_with_file(github_url, file_choice, chat_query)
chat_btn.click(fn=process_chat, inputs=[github_url_input, file_dropdown, chat_query_input], outputs=[chat_output])
demo.launch(share=True)