mgbam's picture
Update app.py
c436283 verified
raw
history blame
16.8 kB
# ------------------------------
# Imports & Dependencies
# ------------------------------
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langgraph.graph import END, StateGraph
from typing_extensions import TypedDict, Annotated
from typing import Sequence, Dict, List, Optional, Any
from langgraph.graph.message import add_messages # Add this import
import chromadb
import numpy as np
import os
import streamlit as st
import requests
import hashlib
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from sklearn.metrics.pairwise import cosine_similarity
# ------------------------------
# State Schema Definition
# ------------------------------
class AgentState(TypedDict):
messages: Annotated[Sequence[AIMessage | HumanMessage], add_messages]
context: Dict[str, Any]
metadata: Dict[str, Any]
# ------------------------------
# Configuration
# ------------------------------
class ResearchConfig:
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
CHROMA_PATH = "chroma_db"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
MAX_CONCURRENT_REQUESTS = 5
EMBEDDING_DIMENSIONS = 1536
RESEARCH_EMBEDDING = np.random.randn(1536)
DOCUMENT_MAP = {
"CV-Transformer Hybrid Architecture": {
"title": "Research Report: CV-Transformer Model (98% Accuracy)",
"content": """
Hybrid architecture combining CNNs and Transformers achieves 98% image recognition accuracy.
Key equation: $f(x) = \text{Attention}(\text{CNN}(x))$
Validation on ImageNet-1k: Top-1 Accuracy 98.2%, Inference Speed 42ms/img
"""
},
"Transformer Architecture Analysis": {
"title": "Academic Paper: Transformers in NLP",
"content": """
Self-attention mechanism remains core innovation:
$\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$
GLUE Benchmark Score: 92.4%, Training Efficiency: 1.8x vs RNNs
"""
},
"Quantum ML Frontiers": {
"title": "Quantum Machine Learning Review",
"content": """
Quantum gradient descent enables faster optimization:
$\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}(\theta_t)$
100x speedup on optimization tasks, 58% energy reduction
"""
}
}
ANALYSIS_TEMPLATE = """Analyze these technical documents:
{context}
Respond in MARKDOWN with:
1. **Key Technical Contributions** (bullet points with equations)
2. **Novel Methodologies** (algorithms with math notation)
3. **Empirical Results** (comparative metrics)
4. **Applications** (domain-specific implementations)
5. **Limitations** (theoretical/practical boundaries)
Include LaTeX equations where applicable."""
if not ResearchConfig.DEEPSEEK_API_KEY:
st.error("""**Configuration Required**
1. Get DeepSeek API key: [platform.deepseek.com](https://platform.deepseek.com/)
2. Set secret: `DEEPSEEK_API_KEY`
3. Rebuild deployment""")
st.stop()
# ------------------------------
# Document Processing System
# ------------------------------
class QuantumDocumentManager:
def __init__(self):
self.client = chromadb.PersistentClient(path=ResearchConfig.CHROMA_PATH)
self.embeddings = OpenAIEmbeddings(
model="text-embedding-3-large",
dimensions=ResearchConfig.EMBEDDING_DIMENSIONS
)
def create_collection(self, document_map: Dict[str, Dict[str, str]], collection_name: str) -> Chroma:
splitter = RecursiveCharacterTextSplitter(
chunk_size=ResearchConfig.CHUNK_SIZE,
chunk_overlap=ResearchConfig.CHUNK_OVERLAP,
separators=["\n\n", "\n", "|||"]
)
docs = []
for key, data in document_map.items():
chunks = splitter.split_text(data["content"])
for chunk in chunks:
docs.append(Document(
page_content=chunk,
metadata={
"title": data["title"],
"source": collection_name,
"hash": hashlib.sha256(chunk.encode()).hexdigest()[:16]
}
))
return Chroma.from_documents(
documents=docs,
embedding=self.embeddings,
collection_name=collection_name,
ids=[self._document_id(doc.page_content) for doc in docs]
)
def _document_id(self, content: str) -> str:
return f"{hashlib.sha256(content.encode()).hexdigest()[:16]}-{int(time.time())}"
# Initialize document system
qdm = QuantumDocumentManager()
research_docs = qdm.create_collection(ResearchConfig.DOCUMENT_MAP, "research")
# ------------------------------
# Intelligent Retrieval System
# ------------------------------
class ResearchRetriever:
def __init__(self):
self.retriever = research_docs.as_retriever(
search_type="mmr",
search_kwargs={
'k': 4,
'fetch_k': 20,
'lambda_mult': 0.85
}
)
def retrieve(self, query: str) -> List[Document]:
try:
docs = self.retriever.invoke(query)
if not docs:
raise ValueError("No relevant documents found")
return docs
except Exception as e:
st.error(f"Retrieval Error: {str(e)}")
return []
# ------------------------------
# Robust Processing Core
# ------------------------------
class CognitiveProcessor:
def __init__(self):
self.executor = ThreadPoolExecutor(max_workers=ResearchConfig.MAX_CONCURRENT_REQUESTS)
def process_query(self, prompt: str) -> Dict:
futures = [self.executor.submit(self._api_request, prompt) for _ in range(3)]
return self._best_result([f.result() for f in as_completed(futures)])
def _api_request(self, prompt: str) -> Dict:
headers = {
"Authorization": f"Bearer {ResearchConfig.DEEPSEEK_API_KEY}",
"Content-Type": "application/json"
}
try:
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
headers=headers,
json={
"model": "deepseek-chat",
"messages": [{
"role": "user",
"content": f"Respond as Senior AI Researcher:\n{prompt}"
}],
"temperature": 0.7,
"max_tokens": 1500,
"top_p": 0.9
},
timeout=45
)
response.raise_for_status()
return response.json()
except Exception as e:
return {"error": str(e)}
def _best_result(self, results: List[Dict]) -> Dict:
valid = [r for r in results if "error" not in r]
if not valid:
return {"error": "All API requests failed"}
# Select response with most technical content
contents = [r.get('choices', [{}])[0].get('message', {}).get('content', '') for r in valid]
tech_scores = [len(re.findall(r"\$.*?\$", c)) for c in contents]
return valid[np.argmax(tech_scores)]
# ------------------------------
# Validation Workflow Engine
# ------------------------------
class ResearchWorkflow:
def __init__(self):
self.retriever = ResearchRetriever()
self.processor = CognitiveProcessor()
self.workflow = StateGraph(AgentState)
self._build_workflow()
def _build_workflow(self):
self.workflow.add_node("ingest", self.ingest_query)
self.workflow.add_node("retrieve", self.retrieve_documents)
self.workflow.add_node("analyze", self.analyze_content)
self.workflow.add_node("validate", self.validate_output)
self.workflow.add_node("refine", self.refine_results)
self.workflow.set_entry_point("ingest")
self.workflow.add_edge("ingest", "retrieve")
self.workflow.add_edge("retrieve", "analyze")
self.workflow.add_conditional_edges(
"analyze",
self._quality_check,
{"valid": "validate", "invalid": "refine"}
)
self.workflow.add_edge("validate", END)
self.workflow.add_edge("refine", "retrieve")
self.app = self.workflow.compile()
def ingest_query(self, state: AgentState) -> Dict:
try:
query = state["messages"][-1].content
return {
"messages": [AIMessage(content="Query ingested successfully")],
"context": {"raw_query": query},
"metadata": {"timestamp": datetime.now().isoformat()}
}
except Exception as e:
return self._error_state(f"Ingestion Error: {str(e)}")
def retrieve_documents(self, state: AgentState) -> Dict:
try:
docs = self.retriever.retrieve(state["context"]["raw_query"])
if not docs:
return self._error_state("Document correlation failure - no relevant papers found")
return {
"messages": [AIMessage(content=f"Retrieved {len(docs)} documents")],
"context": {"documents": docs}
}
except Exception as e:
return self._error_state(f"Retrieval Error: {str(e)}")
def analyze_content(self, state: AgentState) -> Dict:
try:
docs = state["context"]["documents"]
context = "\n\n".join([f"### {doc.metadata['title']}\n{doc.page_content}" for doc in docs])
prompt = ResearchConfig.ANALYSIS_TEMPLATE.format(context=context)
response = self.processor.process_query(prompt)
if "error" in response:
raise RuntimeError(response["error"])
analysis = response['choices'][0]['message']['content']
self._validate_analysis_structure(analysis)
return {
"messages": [AIMessage(content=analysis)],
"context": {"analysis": analysis}
}
except Exception as e:
return self._error_state(f"Analysis Error: {str(e)}")
def validate_output(self, state: AgentState) -> Dict:
validation_prompt = f"""Validate this technical analysis:
{state["messages"][-1].content}
Check for:
1. Mathematical accuracy
2. Empirical evidence
3. Technical depth
4. Logical consistency
Respond with 'VALID' or 'INVALID'"""
response = self.processor.process_query(validation_prompt)
content = response.get('choices', [{}])[0].get('message', {}).get('content', '')
return {
"messages": [AIMessage(content=f"{state['messages'][-1].content}\n\n## Validation\n{content}")],
"context": {"valid": "VALID" in content}
}
def refine_results(self, state: AgentState) -> Dict:
refinement_prompt = f"""Improve this analysis:
{state["messages"][-1].content}
Focus on:
1. Enhancing mathematical rigor
2. Adding empirical references
3. Strengthening technical arguments"""
response = self.processor.process_query(refinement_prompt)
return {
"messages": [AIMessage(content=response['choices'][0]['message']['content'])],
"context": state["context"]
}
def _quality_check(self, state: AgentState) -> str:
return "valid" if state.get("context", {}).get("valid", False) else "invalid"
def _validate_analysis_structure(self, content: str):
required_sections = [
"Key Technical Contributions",
"Novel Methodologies",
"Empirical Results",
"Applications",
"Limitations"
]
missing = [s for s in required_sections if f"## {s}" not in content]
if missing:
raise ValueError(f"Missing critical sections: {', '.join(missing)}")
if not re.search(r"\$.*?\$", content):
raise ValueError("Analysis lacks required mathematical notation")
def _error_state(self, message: str) -> Dict:
return {
"messages": [AIMessage(content=f"❌ {message}")],
"context": {"error": True},
"metadata": {"status": "error"}
}
# ------------------------------
# Research Interface
# ------------------------------
class ResearchInterface:
def __init__(self):
self.workflow = ResearchWorkflow()
self._initialize_interface()
def _initialize_interface(self):
st.set_page_config(
page_title="NeuroResearch AI",
layout="wide",
initial_sidebar_state="expanded"
)
self._inject_styles()
self._build_sidebar()
self._build_main_interface()
def _inject_styles(self):
st.markdown("""
<style>
:root {
--primary: #2ecc71;
--secondary: #3498db;
--background: #0a0a0a;
--text: #ecf0f1;
}
.stApp {
background: var(--background);
color: var(--text);
font-family: 'Roboto', sans-serif;
}
.stTextArea textarea {
background: #1a1a1a !important;
color: var(--text) !important;
border: 2px solid var(--secondary);
border-radius: 8px;
padding: 1rem;
}
.stButton>button {
background: linear-gradient(135deg, var(--primary), var(--secondary));
border: none;
border-radius: 8px;
padding: 1rem 2rem;
transition: all 0.3s;
}
.stButton>button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(46, 204, 113, 0.3);
}
.stExpander {
background: #1a1a1a;
border: 1px solid #2a2a2a;
border-radius: 8px;
margin: 1rem 0;
}
code {
color: #2ecc71;
background: #002200;
padding: 2px 4px;
border-radius: 4px;
}
</style>
""", unsafe_allow_html=True)
def _build_sidebar(self):
with st.sidebar:
st.title("🔍 Research Database")
for key, data in ResearchConfig.DOCUMENT_MAP.items():
with st.expander(data["title"]):
st.markdown(f"```\n{data['content']}\n```")
st.metric("Embedding Dimensions", ResearchConfig.EMBEDDING_DIMENSIONS)
st.metric("Document Chunks", len(research_docs.get()['ids']))
def _build_main_interface(self):
st.title("🧠 NeuroResearch AI")
query = st.text_area("Research Query:", height=200,
placeholder="Enter technical research question...")
if st.button("Execute Analysis", type="primary"):
self._execute_analysis(query)
def _execute_analysis(self, query: str):
try:
with st.spinner("Performing deep technical analysis..."):
result = self.workflow.app.invoke(
{"messages": [HumanMessage(content=query)]}
)
if result.get("context", {}).get("error"):
self._show_error(result["context"].get("error", "Unknown error"))
else:
self._display_results(result)
except Exception as e:
self._show_error(str(e))
def _display_results(self, result):
content = result["messages"][-1].content
with st.expander("Technical Analysis Report", expanded=True):
st.markdown(content)
with st.expander("Source Documents", expanded=False):
for doc in result["context"].get("documents", []):
st.markdown(f"**{doc.metadata['title']}**")
st.code(doc.page_content, language='latex')
def _show_error(self, message):
st.error(f"""
⚠️ Analysis Failed: {message}
Troubleshooting Steps:
1. Check query specificity
2. Verify document connections
3. Ensure mathematical notation in sources
4. Review API key validity
5. Simplify complex query structures
""")
if __name__ == "__main__":
ResearchInterface()