Spaces:
Running
Running
from __future__ import annotations | |
import os | |
import sys | |
import base64 | |
import os | |
import json | |
import asyncio | |
from typing import Any, Dict, List, Optional | |
from pathlib import Path | |
from datetime import datetime | |
from anthropic import AsyncAnthropic | |
from anthropic.types import ToolUseBlock | |
from langgraph.graph import END, StateGraph | |
from pydantic import BaseModel, Field | |
from src.agents.prompt import REVIEWER_SYSTEM_PROMPT, EVALUATION_PROMPT_TEMPLATE, TOOLS, TOOL_CHOICE | |
from src.database import db | |
from src.config import config | |
from src.logger import logger | |
class ConversationState(BaseModel): | |
"""State for the conversation graph""" | |
messages: List[Dict[str, Any]] = Field(default_factory=list) | |
response_text: str = "" | |
tool_result: Optional[Dict[str, Any]] = None | |
arxiv_id: Optional[str] = None | |
pdf_path: Optional[str] = None | |
output_file: Optional[str] = None | |
def _load_pdf_as_content(pdf_path: str) -> Dict[str, Any]: | |
if os.path.exists(pdf_path): | |
with open(pdf_path, "rb") as f: | |
data_b64 = base64.b64encode(f.read()).decode("utf-8") | |
return { | |
"type": "document", | |
"source": { | |
"type": "base64", | |
"media_type": "application/pdf", | |
"data": data_b64, | |
}, | |
} | |
if pdf_path.startswith("http"): | |
return { | |
"type": "document", | |
"source": { | |
"type": "url", | |
"url": pdf_path, | |
}, | |
} | |
raise FileNotFoundError(f"PDF not found or invalid path: {pdf_path}") | |
class Evaluator: | |
def __init__(self, api_key: Optional[str] = None): | |
api_key = api_key or os.getenv("ANTHROPIC_API_KEY") | |
if not api_key: | |
raise ValueError("Anthropic API key is required. Please set HF_SECRET_ANTHROPIC_API_KEY in Hugging Face Spaces secrets or ANTHROPIC_API_KEY environment variable.") | |
self.client = AsyncAnthropic(api_key=api_key) | |
self.system_prompt = REVIEWER_SYSTEM_PROMPT | |
self.eval_template = EVALUATION_PROMPT_TEMPLATE | |
async def __call__(self, state: ConversationState) -> ConversationState: | |
"""Evaluate the paper using the conversation state""" | |
# Prepare messages for the API call | |
messages = [] | |
messages.extend(state.messages) | |
# Load PDF content if pdf_path is provided | |
if state.pdf_path: | |
try: | |
pdf_content = _load_pdf_as_content(state.pdf_path) | |
messages.append({ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "Please evaluate this academic paper:"}, | |
pdf_content | |
] | |
}) | |
except Exception as e: | |
state.response_text = f"Error loading PDF: {str(e)}" | |
return state | |
# Add the evaluation prompt | |
messages.append({ | |
"role": "user", | |
"content": self.eval_template | |
}) | |
try: | |
# Call Anthropic API with tools (async) | |
response = await self.client.messages.create( | |
model=config.model_id, | |
max_tokens=10000, | |
system=self.system_prompt, | |
messages=messages, | |
tools=TOOLS, | |
tool_choice=TOOL_CHOICE | |
) | |
# Process the response | |
# Check if response is a tool use or text | |
if response.content and isinstance(response.content[0], ToolUseBlock): | |
# This is a tool use response | |
tool_use = response.content[0] | |
if tool_use: | |
tool_result = tool_use.input | |
# set metadata | |
tool_result['metadata'] = { | |
'assessed_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
'model': config.model_id, | |
'version': config.version, | |
'paper_path': state.pdf_path | |
} | |
state.tool_result = tool_result | |
state.response_text = json.dumps(tool_result, ensure_ascii=False, indent=4) | |
# Add tool use to messages | |
state.messages.append({ | |
"role": "assistant", | |
"content": f"Tool use: {tool_use.name}" | |
}) | |
else: | |
state.response_text = "Error: Tool use response but no tool_use found" | |
else: | |
# This is a text response | |
text_content = response.content[0].text if response.content else "" | |
state.messages.append({ | |
"role": "assistant", | |
"content": text_content | |
}) | |
state.response_text = text_content | |
except Exception as e: | |
state.response_text = f"Error during evaluation: {str(e)}" | |
return state | |
async def save_node(state: ConversationState) -> ConversationState: | |
"""Save the evaluation result to database""" | |
try: | |
if not state.arxiv_id: | |
state.response_text += f"\n\nError: No arxiv_id provided for database save" | |
return state | |
# Parse the evaluation result | |
evaluation_content = state.response_text | |
evaluation_score = None | |
overall_score = None | |
evaluation_tags = None | |
# Try to extract score and tags from tool_result if available | |
if state.tool_result: | |
try: | |
# Extract overall automatability score from scores | |
if 'scores' in state.tool_result and 'overall_automatability' in state.tool_result['scores']: | |
evaluation_score = state.tool_result['scores']['overall_automatability'] | |
# Extract overall score from scores | |
if 'scores' in state.tool_result and 'overall_automatability' in state.tool_result['scores']: | |
overall_score = state.tool_result['scores']['overall_automatability'] | |
# Create tags from key dimensions in scores | |
tags = [] | |
if 'scores' in state.tool_result: | |
scores = state.tool_result['scores'] | |
if 'three_year_feasibility_pct' in scores: | |
tags.append(f"3yr_feasibility:{scores['three_year_feasibility_pct']}%") | |
if 'task_formalization' in scores: | |
tags.append(f"task_formalization:{scores['task_formalization']}/4") | |
if 'data_resource_availability' in scores: | |
tags.append(f"data_availability:{scores['data_resource_availability']}/4") | |
evaluation_tags = ",".join(tags) if tags else None | |
except Exception as e: | |
logger.warning(f"Warning: Could not extract structured data from tool_result: {e}") | |
else: | |
# Try to parse evaluation_content as JSON to extract structured data | |
try: | |
evaluation_json = json.loads(evaluation_content) | |
# Extract overall automatability score from scores | |
if 'scores' in evaluation_json and 'overall_automatability' in evaluation_json['scores']: | |
evaluation_score = evaluation_json['scores']['overall_automatability'] | |
# Extract overall score from scores | |
if 'scores' in evaluation_json and 'overall_automatability' in evaluation_json['scores']: | |
overall_score = evaluation_json['scores']['overall_automatability'] | |
# Create tags from key dimensions in scores | |
tags = [] | |
if 'scores' in evaluation_json: | |
scores = evaluation_json['scores'] | |
if 'three_year_feasibility_pct' in scores: | |
tags.append(f"3yr_feasibility:{scores['three_year_feasibility_pct']}%") | |
if 'task_formalization' in scores: | |
tags.append(f"task_formalization:{scores['task_formalization']}/4") | |
if 'data_resource_availability' in scores: | |
tags.append(f"data_availability:{scores['data_resource_availability']}/4") | |
evaluation_tags = ",".join(tags) if tags else None | |
except Exception as e: | |
logger.warning(f"Warning: Could not parse evaluation_content as JSON: {e}") | |
# Save to database | |
await db.update_paper_evaluation( | |
arxiv_id=state.arxiv_id, | |
evaluation_content=evaluation_content, | |
evaluation_score=evaluation_score, | |
overall_score=overall_score, | |
evaluation_tags=evaluation_tags | |
) | |
state.response_text += f"\n\nEvaluation saved to database for paper: {state.arxiv_id}" | |
except Exception as e: | |
state.response_text += f"\n\nError saving evaluation to database: {str(e)}" | |
return state | |
def build_graph(api_key: Optional[str] = None): | |
"""Build the evaluation graph""" | |
graph = StateGraph(ConversationState) | |
evaluator = Evaluator(api_key=api_key) | |
graph.add_node("evaluate", evaluator) | |
graph.add_node("save", save_node) | |
# Define the flow | |
graph.set_entry_point("evaluate") | |
graph.add_edge("evaluate", "save") | |
graph.add_edge("save", END) | |
return graph.compile() | |
async def run_evaluation(pdf_path: str, arxiv_id: Optional[str] = None, output_file: Optional[str] = None, api_key: Optional[str] = None) -> str: | |
app = build_graph(api_key=api_key) | |
initial = ConversationState(pdf_path=pdf_path, arxiv_id=arxiv_id, output_file=output_file) | |
# Ensure compatibility with LangGraph's dict-based state | |
final_state = await app.ainvoke(initial.model_dump()) | |
if isinstance(final_state, dict): | |
return str(final_state.get("response_text", "")) | |
if isinstance(final_state, ConversationState): | |
return final_state.response_text | |
return str(getattr(final_state, "response_text", "")) | |