Spaces:
Sleeping
Sleeping
import streamlit as st | |
import difflib | |
import os | |
import re | |
import hashlib | |
from groq import Groq | |
# --- Page config --- | |
st.set_page_config(page_title="π AI Assistant with Workflow + Semantic Search", layout="wide") | |
# --- Groq API Setup --- | |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
if not GROQ_API_KEY: | |
st.error("β Please set your GROQ_API_KEY environment variable.") | |
st.stop() | |
client = Groq(api_key=GROQ_API_KEY) | |
# --- Cache for embeddings --- | |
embedding_cache = {} | |
def get_embedding(text): | |
key = hashlib.sha256(text.encode()).hexdigest() | |
if key in embedding_cache: | |
return embedding_cache[key] | |
embedding = [ord(c) % 100 / 100 for c in text[:512]] | |
embedding_cache[key] = embedding | |
return embedding | |
def cosine_similarity(vec1, vec2): | |
dot = sum(a*b for a,b in zip(vec1, vec2)) | |
norm1 = sum(a*a for a in vec1) ** 0.5 | |
norm2 = sum(b*b for b in vec2) ** 0.5 | |
return dot / (norm1 * norm2 + 1e-8) | |
def split_code_into_chunks(code, lang): | |
if lang.lower() == "python": | |
pattern = r'(def\s+\w+\(.*?\):|class\s+\w+\(?.*?\)?:)' | |
splits = re.split(pattern, code) | |
chunks = [] | |
for i in range(1, len(splits), 2): | |
header = splits[i] | |
body = splits[i+1] if (i+1) < len(splits) else "" | |
chunks.append(header + body) | |
return chunks if chunks else [code] | |
else: | |
return [code] | |
def groq_call(prompt): | |
resp = client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt}], | |
model="llama3-70b-8192", | |
) | |
return resp.choices[0].message.content | |
def semantic_search_improved(code, question, lang, skill, role, explain_lang): | |
chunks = split_code_into_chunks(code, lang) | |
question_emb = get_embedding(question) | |
scored_chunks = [] | |
for chunk in chunks: | |
emb = get_embedding(chunk) | |
score = cosine_similarity(question_emb, emb) | |
scored_chunks.append((score, chunk)) | |
scored_chunks.sort(key=lambda x: x[0], reverse=True) | |
top_chunks = [c for _, c in scored_chunks[:3]] | |
combined_code = "\n\n".join(top_chunks) | |
prompt = ( | |
f"You are a friendly and insightful {lang} expert helping a {skill} {role}.\n" | |
f"Based on these relevant code snippets:\n{combined_code}\n" | |
f"Answer this question in {explain_lang}:\n{question}\n" | |
f"Explain which parts handle the question and how to modify them if needed." | |
) | |
return groq_call(prompt) | |
def error_detection_and_fixes(refactored_code, lang, skill, role, explain_lang): | |
prompt = ( | |
f"You are a senior {lang} developer. Analyze this code for bugs, security flaws, " | |
f"and performance issues. Suggest fixes with explanations in {explain_lang}:\n\n{refactored_code}" | |
) | |
return groq_call(prompt) | |
def agentic_workflow(code, skill_level, programming_language, explanation_language, user_role): | |
timeline = [] | |
suggestions = [] | |
# Explanation | |
explain_prompt = ( | |
f"You are a friendly and insightful {programming_language} expert helping a {skill_level} {user_role}. " | |
f"Explain this code in {explanation_language} with clear examples, analogies, and why each part matters:\n\n{code}" | |
) | |
explanation = groq_call(explain_prompt) | |
timeline.append({"step": "Explain", "description": "Detailed explanation", "output": explanation, "code": code}) | |
suggestions.append("Consider refactoring your code to improve readability and performance.") | |
# Refactor | |
refactor_prompt = ( | |
f"Refactor this {programming_language} code. Explain the changes like a mentor helping a {skill_level} {user_role}. " | |
f"Include best practices and improvements:\n\n{code}" | |
) | |
refactor_response = groq_call(refactor_prompt) | |
if "```" in refactor_response: | |
parts = refactor_response.split("```") | |
refactored_code = "" | |
for part in parts: | |
if part.strip().startswith(programming_language.lower()): | |
refactored_code = part.strip().split('\n', 1)[1] if '\n' in part else "" | |
break | |
if not refactored_code: | |
refactored_code = refactor_response | |
else: | |
refactored_code = refactor_response | |
timeline.append({"step": "Refactor", "description": "Refactored code with improvements", "output": refactored_code, "code": refactored_code}) | |
suggestions.append("Review the refactored code and adapt it to your style or project needs.") | |
# Review | |
review_prompt = ( | |
f"As a senior {programming_language} developer, review the refactored code. " | |
f"Give constructive feedback on strengths, weaknesses, performance, security, and improvements in {explanation_language}:\n\n{refactored_code}" | |
) | |
review = groq_call(review_prompt) | |
timeline.append({"step": "Review", "description": "Code review and suggestions", "output": review, "code": refactored_code}) | |
suggestions.append("Incorporate review feedback for cleaner, robust code.") | |
# Error detection & fixes | |
errors = error_detection_and_fixes(refactored_code, programming_language, skill_level, user_role, explanation_language) | |
timeline.append({"step": "Error Detection", "description": "Bugs, security, performance suggestions", "output": errors, "code": refactored_code}) | |
suggestions.append("Apply fixes to improve code safety and performance.") | |
# Test generation | |
test_prompt = ( | |
f"Write clear, effective unit tests for this {programming_language} code. " | |
f"Explain what each test does in {explanation_language}, for a {skill_level} {user_role}:\n\n{refactored_code}" | |
) | |
tests = groq_call(test_prompt) | |
timeline.append({"step": "Test Generation", "description": "Generated unit tests", "output": tests, "code": tests}) | |
suggestions.append("Run generated tests locally to validate changes.") | |
return timeline, suggestions | |
def get_inline_diff_html(original, modified): | |
differ = difflib.HtmlDiff(tabsize=4, wrapcolumn=80) | |
html = differ.make_table( | |
original.splitlines(), modified.splitlines(), | |
"Original", "Refactored", context=True, numlines=2 | |
) | |
return f'<div style="overflow-x:auto; max-height:400px;">{html}</div>' | |
def detect_code_type(code, programming_language): | |
backend_keywords = [ | |
'flask', 'django', 'express', 'fastapi', 'spring', 'controller', 'api', 'server', 'database', 'sql', 'mongoose' | |
] | |
frontend_keywords = [ | |
'react', 'vue', 'angular', 'component', 'html', 'css', 'document.getelementbyid', 'window.', 'render', 'jsx', | |
'<html', '<body', '<script', '<div', 'getelementbyid', 'queryselector', 'addeventlistener', 'innerhtml' | |
] | |
data_science_keywords = [ | |
'pandas', 'numpy', 'sklearn', 'matplotlib', 'seaborn', 'plt', 'train_test_split', 'randomforestclassifier', 'classification_report' | |
] | |
code_lower = code.lower() | |
if any(word in code_lower for word in data_science_keywords): | |
return 'data_science' | |
if any(word in code_lower for word in frontend_keywords): | |
return 'frontend' | |
if programming_language.lower() in ['python', 'java', 'c#']: | |
if any(word in code_lower for word in backend_keywords): | |
return 'backend' | |
if programming_language.lower() in ['javascript', 'typescript', 'java', 'c#']: | |
if any(word in code_lower for word in frontend_keywords): | |
return 'frontend' | |
if programming_language.lower() in ['python', 'java', 'c#']: | |
return 'backend' | |
if programming_language.lower() in ['javascript', 'typescript']: | |
return 'frontend' | |
return 'unknown' | |
def code_complexity(code): | |
lines = code.count('\n') + 1 | |
functions = code.count('def ') | |
classes = code.count('class ') | |
comments = code.count('#') | |
return f"Lines: {lines}, Functions: {functions}, Classes: {classes}, Comments: {comments}" | |
def code_matches_language(code: str, language: str) -> bool: | |
"""Strictly check whether code matches key patterns of the selected language.""" | |
code_lower = code.strip().lower() | |
language = language.lower() | |
patterns = { | |
"python": [ | |
"def ", "class ", "import ", "from ", "try:", "except", "raise", "lambda", | |
"with ", "yield", "async ", "await", "print(", "self.", "__init__", "__name__", | |
"if __name__ == '__main__':", "#!", # shebang for executable scripts | |
], | |
"c++": [ | |
"#include", "int main(", "std::", "::", "cout <<", "cin >>", "new ", "delete ", | |
"try {", "catch(", "template<", "using namespace", "class ", "struct ", "#define", | |
], | |
"java": [ | |
"package ", "import java.", "public class", "private ", "protected ", "public static void main", | |
"System.out.println", "try {", "catch(", "throw new ", "implements ", "extends ", | |
"@Override", "interface ", "enum ", "synchronized ", "final ", | |
], | |
"c#": [ | |
"using System", "namespace ", "class ", "interface ", "public static void Main", | |
"Console.WriteLine", "try {", "catch(", "throw ", "async ", "await ", "get;", "set;", | |
"List<", "Dictionary<", "[Serializable]", "[Obsolete]", | |
], | |
"javascript": [ | |
"function ", "const ", "let ", "var ", "document.", "window.", "console.log", | |
"if(", "for(", "while(", "switch(", "try {", "catch(", "export ", "import ", "async ", | |
"await ", "=>", "this.", "class ", "prototype", "new ", "$(", | |
], | |
"typescript": [ | |
"function ", "const ", "let ", "interface ", "type ", ": string", ": number", ": boolean", | |
"implements ", "extends ", "enum ", "public ", "private ", "protected ", "readonly ", | |
"import ", "export ", "console.log", "async ", "await ", "=>", "this.", | |
], | |
"html": [ | |
"<!doctype html", "<html", "<head>", "<body>", "<script", "<style", "<meta ", "<link ", | |
"<title>", "<div", "<span", "<p>", "<h1>", "<ul>", "<li>", "<form", "<input", "<button", | |
"<table", "<footer", "<header", "<section", "<article", "<nav", "<img", "<a ", "</html>", | |
], | |
} | |
match_patterns = patterns.get(language, []) | |
match_count = sum(1 for pattern in match_patterns if pattern in code_lower) | |
# Require at least one pattern to match for validation to succeed | |
return match_count >= 1 | |
# --- Sidebar --- | |
st.sidebar.title("π§ Configuration") | |
lang = st.sidebar.selectbox("Programming Language", ["Python", "JavaScript", "C++", "Java", "C#", "TypeScript"]) | |
skill = st.sidebar.selectbox("Skill Level", ["Beginner", "Intermediate", "Expert"]) | |
role = st.sidebar.selectbox("Your Role", ["Student", "Frontend Developer", "Backend Developer", "Data Scientist"]) | |
explain_lang = st.sidebar.selectbox("Explanation Language", ["English", "Spanish", "Chinese", "Urdu"]) | |
st.sidebar.markdown("---") | |
st.sidebar.markdown("<span style='color:#fff;'>Powered by <b>BLACKBOX.AI</b></span>", unsafe_allow_html=True) | |
tabs = st.tabs(["π§ Full AI Workflow", "π Semantic Search"]) | |
# --- Tab 1: Full AI Workflow --- | |
with tabs[0]: | |
st.title("π§ Full AI Workflow") | |
file_types = { | |
"Python": ["py"], | |
"JavaScript": ["js"], | |
"C++": ["cpp", "h", "hpp"], | |
"Java": ["java"], | |
"C#": ["cs"], | |
"TypeScript": ["ts"], | |
} | |
uploaded_file = st.file_uploader( | |
f"Upload {', '.join(file_types.get(lang, []))} file(s)", | |
type=file_types.get(lang, None) | |
) | |
if uploaded_file: | |
code_input = uploaded_file.read().decode("utf-8") | |
else: | |
code_input = st.text_area("Your Code", height=300, placeholder="Paste your code here...") | |
if code_input: | |
st.markdown(f"<b>Complexity:</b> {code_complexity(code_input)}", unsafe_allow_html=True) | |
if st.button("Run AI Workflow"): | |
if not code_input.strip(): | |
st.warning("Please paste or upload your code.") | |
elif not code_matches_language(code_input, lang): | |
st.error(f"The pasted code doesnβt look like valid {lang} code. Please check your code or select the correct language.") | |
else: | |
code_type = detect_code_type(code_input, lang) | |
if code_type == "data_science" and role != "Data Scientist": | |
st.error("Data science code detected. Please select 'Data Scientist' role.") | |
elif code_type == "frontend" and role != "Frontend Developer": | |
st.error("Frontend code detected. Please select 'Frontend Developer' role.") | |
elif code_type == "backend" and role != "Backend Developer": | |
st.error("Backend code detected. Please select 'Backend Developer' role.") | |
else: | |
with st.spinner("Running agentic workflow..."): | |
timeline, suggestions = agentic_workflow(code_input, skill, lang, explain_lang, role) | |
# Show each step in an expander | |
for step in timeline: | |
with st.expander(f"β {step['step']} - {step['description']}"): | |
if step['step'] == "Refactor": | |
diff_html = get_inline_diff_html(code_input, step['code']) | |
st.markdown(diff_html, unsafe_allow_html=True) | |
st.code(step['output'], language=lang.lower()) | |
else: | |
st.markdown(step['output']) | |
st.markdown("#### Agent Suggestions") | |
for s in suggestions: | |
st.markdown(f"- {s}") | |
# Download buttons after suggestions | |
st.markdown("---") | |
st.markdown("### π₯ Download Results") | |
report_text = "" | |
for step in timeline: | |
report_text += f"## {step['step']}\n{step['description']}\n\n{step['output']}\n\n" | |
st.download_button( | |
label="π Download Full Workflow Report", | |
data=report_text, | |
file_name="ai_workflow_report.txt", | |
mime="text/plain", | |
) | |
# --- Tab 2: Semantic Search --- | |
with tabs[1]: | |
st.title("π Semantic Search") | |
sem_code = st.text_area("Your Code", height=300, placeholder="Paste your code...") | |
sem_q = st.text_input("Your Question", placeholder="E.g., What does this function do?") | |
if st.button("Run Semantic Search"): | |
if not sem_code.strip() or not sem_q.strip(): | |
st.warning("Code and question required.") | |
else: | |
with st.spinner("Running semantic search..."): | |
answer = semantic_search_improved(sem_code, sem_q, lang, skill, role, explain_lang) | |
st.markdown("### π Answer") | |
st.markdown(answer) | |
st.markdown("---") | |