cc-api / response_formatter.py
Severian's picture
Update response_formatter.py
3c1c117 verified
raw
history blame
10.8 kB
from typing import Dict, Optional, Tuple, List, Any, Set, Union
import re
import xml.etree.ElementTree as ET
from datetime import datetime
import json
import logging
from enum import Enum
# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Create console handler if needed
if not logger.handlers:
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
class StreamingFormatter:
def __init__(self):
self.processed_events = set()
self.current_tool_outputs = []
self.current_citations = []
self.current_metadata = {}
self.current_message_id = None
self.current_message_buffer = ""
def reset(self):
"""Reset the formatter state"""
self.processed_events.clear()
self.current_tool_outputs.clear()
self.current_citations.clear()
self.current_metadata.clear()
self.current_message_id = None
self.current_message_buffer = ""
def append_to_buffer(self, text: str):
"""Append text to the current message buffer"""
self.current_message_buffer += text
def get_and_clear_buffer(self) -> str:
"""Get the current buffer content and clear it"""
content = self.current_message_buffer
self.current_message_buffer = ""
return content
class ToolType(Enum):
"""Enum for supported tool types"""
DUCKDUCKGO = "ddgo_search"
REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase"
PUBMED = "pubmed_search"
CENSUS = "get_census_data"
HEATMAP = "heatmap_code"
MERMAID = "mermaid_output"
WISQARS = "wisqars"
WONDER = "wonder"
NCHS = "nchs"
ONESTEP = "onestep"
DQS = "dqs_nhis_adult_summary_health_statistics"
@classmethod
def get_tool_type(cls, tool_name: str) -> Optional['ToolType']:
"""Get enum member from tool name string"""
try:
return cls[tool_name.upper()]
except KeyError:
return None
class ResponseFormatter:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ResponseFormatter, cls).__new__(cls)
cls._instance.streaming_state = StreamingFormatter()
cls._instance.logger = logger
return cls._instance
def format_thought(
self,
thought: str,
observation: str,
citations: List[Dict] = None,
metadata: Dict = None,
tool_outputs: List[Dict] = None,
event_id: str = None,
message_id: str = None
) -> Optional[Tuple[str, str]]:
"""Format agent thought for both terminal and XML output"""
# Skip if already processed in streaming mode
if event_id and event_id in self.streaming_state.processed_events:
return None
# Handle message state
if message_id != self.streaming_state.current_message_id:
self.streaming_state.reset()
self.streaming_state.current_message_id = message_id
# Skip empty thoughts
if not thought and not observation and not tool_outputs:
return None
# Terminal format
terminal_output = {
"type": "agent_thought",
"content": thought,
"metadata": metadata or {}
}
if tool_outputs:
# Deduplicate tool outputs
seen_outputs = set()
unique_outputs = []
for output in tool_outputs:
output_key = f"{output.get('type')}:{output.get('content')}"
if output_key not in seen_outputs:
seen_outputs.add(output_key)
unique_outputs.append(output)
terminal_output["tool_outputs"] = unique_outputs
# XML format
root = ET.Element("agent_response")
if thought:
thought_elem = ET.SubElement(root, "thought")
thought_elem.text = thought
if observation:
obs_elem = ET.SubElement(root, "observation")
obs_elem.text = observation
if tool_outputs:
tools_elem = ET.SubElement(root, "tool_outputs")
for tool_output in unique_outputs:
tool_elem = ET.SubElement(tools_elem, "tool_output")
tool_elem.attrib["type"] = tool_output.get("type", "")
tool_elem.text = tool_output.get("content", "")
if citations:
cites_elem = ET.SubElement(root, "citations")
for citation in citations:
cite_elem = ET.SubElement(cites_elem, "citation")
for key, value in citation.items():
cite_elem.attrib[key] = str(value)
xml_output = ET.tostring(root, encoding='unicode')
# Track processed event
if event_id:
self.streaming_state.processed_events.add(event_id)
return json.dumps(terminal_output), xml_output
def format_message(
self,
message: str,
event_id: str = None,
message_id: str = None
) -> Optional[Tuple[str, str]]:
"""Format agent message for both terminal and XML output"""
# Skip if already processed
if event_id and event_id in self.streaming_state.processed_events:
return None
# Handle message state
if message_id != self.streaming_state.current_message_id:
self.streaming_state.reset()
self.streaming_state.current_message_id = message_id
# Accumulate message content
self.streaming_state.append_to_buffer(message)
# Only output if we have meaningful content
if not self.streaming_state.current_message_buffer.strip():
return None
# Terminal format
terminal_output = self.streaming_state.current_message_buffer.strip()
# XML format
root = ET.Element("agent_response")
msg_elem = ET.SubElement(root, "message")
msg_elem.text = terminal_output
xml_output = ET.tostring(root, encoding='unicode')
# Track processed event
if event_id:
self.streaming_state.processed_events.add(event_id)
return terminal_output, xml_output
def format_error(
self,
error: str,
event_id: str = None,
message_id: str = None
) -> Optional[Tuple[str, str]]:
"""Format error message for both terminal and XML output"""
# Skip if already processed
if event_id and event_id in self.streaming_state.processed_events:
return None
# Handle message state
if message_id != self.streaming_state.current_message_id:
self.streaming_state.reset()
self.streaming_state.current_message_id = message_id
# Skip empty errors
if not error:
return None
# Terminal format
terminal_output = f"Error: {error}"
# XML format
root = ET.Element("agent_response")
error_elem = ET.SubElement(root, "error")
error_elem.text = error
xml_output = ET.tostring(root, encoding='unicode')
# Track processed event
if event_id:
self.streaming_state.processed_events.add(event_id)
return terminal_output, xml_output
def format_tool_output(
self,
tool_type: str,
content: Union[str, Dict],
metadata: Optional[Dict] = None
) -> Dict:
"""Format tool output into standardized structure"""
try:
# Get enum tool type
tool = ToolType.get_tool_type(tool_type)
if not tool:
self.logger.warning(f"Unknown tool type: {tool_type}")
return {
"type": tool_type,
"content": content,
"metadata": metadata or {}
}
# Format based on tool type
if tool == ToolType.MERMAID:
return {
"type": "mermaid",
"content": self._clean_mermaid_content(content),
"metadata": metadata or {}
}
elif tool == ToolType.HEATMAP:
return {
"type": "heatmap",
"content": self._format_heatmap_data(content),
"metadata": metadata or {}
}
else:
# Default formatting for other tools
return {
"type": tool.value,
"content": content,
"metadata": metadata or {}
}
except Exception as e:
self.logger.error(f"Error formatting tool output: {str(e)}")
return {
"type": "error",
"content": str(e),
"metadata": metadata or {}
}
def _clean_mermaid_content(self, content: Union[str, Dict]) -> str:
"""Clean and standardize mermaid diagram content"""
try:
if isinstance(content, dict):
content = content.get("mermaid_diagram", "")
# Remove markdown formatting
content = re.sub(r'```mermaid\s*|\s*```', '', content)
# Clean up whitespace
content = content.strip()
return content
except Exception as e:
self.logger.error(f"Error cleaning mermaid content: {str(e)}")
return str(content)
def _format_heatmap_data(self, content: Union[str, Dict]) -> Dict:
"""Format heatmap data into standardized structure"""
try:
if isinstance(content, str):
content = json.loads(content)
return {
"data": content.get("data", []),
"options": content.get("options", {}),
"metadata": content.get("metadata", {})
}
except Exception as e:
self.logger.error(f"Error formatting heatmap data: {str(e)}")
return {"error": str(e)}
@staticmethod
def _clean_markdown(text: str) -> str:
"""Clean markdown formatting from text"""
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
text = re.sub(r'[*_`#]', '', text)
return re.sub(r'\n{3,}', '\n\n', text.strip())