GAIA_Agent / agents /verifier_agent.py
Aedelon's picture
agent enhancement (#3)
b8f6b7f verified
raw
history blame
13.9 kB
import os
import logging
import re
from typing import List
from llama_index.core.agent.workflow import FunctionAgent, ReActAgent
from llama_index.core.tools import FunctionTool
from llama_index.llms.google_genai import GoogleGenAI
# Setup logging
logger = logging.getLogger(__name__)
class VerificationError(Exception):
"""Custom exception for verification failures."""
pass
class Verifier:
"""
Cross-check extracted facts, identify contradictions using LLM,
and assign a confidence score to each fact.
"""
def __init__(self):
"""Initializes the Verifier, loading configuration from environment variables."""
logger.info("Initializing Verifier...")
self.threshold = float(os.getenv("VERIFIER_CONFIDENCE_THRESHOLD", 0.7))
self.verifier_llm_model = os.getenv("VERIFIER_LLM_MODEL", "models/gemini-2.0-flash") # For scoring
self.agent_llm_model = os.getenv("VERIFIER_AGENT_LLM_MODEL", "models/gemini-1.5-pro") # For agent logic & contradiction
self.gemini_api_key = os.getenv("GEMINI_API_KEY")
if not self.gemini_api_key:
logger.error("GEMINI_API_KEY not found in environment variables.")
raise ValueError("GEMINI_API_KEY must be set")
try:
self.verifier_llm = GoogleGenAI(
api_key=self.gemini_api_key,
model=self.verifier_llm_model,
)
self.agent_llm = GoogleGenAI(
api_key=self.gemini_api_key,
model=self.agent_llm_model,
)
logger.info(f"Verifier initialized with threshold {self.threshold}, verifier LLM {self.verifier_llm_model}, agent LLM {self.agent_llm_model}")
except Exception as e:
logger.error(f"Error initializing Verifier LLMs: {e}", exc_info=True)
raise
def verify_facts(self, facts: List[str]) -> List[str]:
"""
Assign a confidence score via LLM to each fact and return formatted strings.
Args:
facts (List[str]): Facts to verify.
Returns:
List[str]: Each item is "fact: score" with score ∈ [threshold, 1.0].
Raises:
VerificationError: If LLM call fails.
"""
logger.info(f"Verifying {len(facts)} facts...")
results: List[str] = []
for fact in facts:
prompt = (
"You are a fact verifier. "
"On a scale from 0.00 to 1.00, where any value below "
f"{self.threshold:.2f} indicates low confidence, rate the following statement’s trustworthiness. "
"Respond with **only** a decimal number rounded to two digits (e.g., 0.82) and no extra text.\n\n"
f"Statement: \"{fact}\""
)
try:
response = self.verifier_llm.complete(prompt)
score_text = response.text.strip()
# Try direct conversion first
try:
score = float(score_text)
except ValueError:
# Fallback: extract first float if model returns extra text
match = re.search(r"0?\.\d+|1(?:\.0+)?", score_text)
if match:
score = float(match.group(0))
logger.warning(f"Extracted score {score} from noisy LLM response: {score_text}")
else:
logger.error(f"Could not parse score from LLM response: {score_text}. Using threshold {self.threshold}.")
score = self.threshold # Fallback to threshold if parsing fails completely
# Enforce threshold floor
if score < self.threshold:
logger.info(f"Score {score:.2f} for fact {fact} below threshold {self.threshold}, raising to threshold.")
score = self.threshold
results.append(f"{fact}: {score:.2f}")
except Exception as e:
logger.error(f"LLM call failed during fact verification for {fact}: {e}", exc_info=True)
# Option 1: Raise an error
# raise VerificationError(f"LLM call failed for fact: {fact}") from e
# Option 2: Append an error message (current approach)
results.append(f"{fact}: ERROR - Verification failed")
# Option 3: Assign lowest score
# results.append(f"{fact}: {self.threshold:.2f} (Verification Error)")
logger.info(f"Fact verification complete. {len(results)} results generated.")
return results
def find_contradictions_llm(self, facts: List[str]) -> List[str]:
"""
Identify contradictions among a list of facts using an LLM.
Args:
facts (List[str]): List of fact strings.
Returns:
List[str]: Pairs of facts detected as contradictory, joined by " <> ".
Raises:
VerificationError: If LLM call fails.
"""
logger.info(f"Finding contradictions in {len(facts)} facts using LLM...")
if len(facts) < 2:
logger.info("Not enough facts to find contradictions.")
return []
facts_numbered = "\n".join([f"{i+1}. {fact}" for i, fact in enumerate(facts)])
prompt = (
"You are a logical reasoning assistant. Analyze the following numbered list of statements. "
"Identify any pairs of statements that directly contradict each other. "
"List *only* the numbers of the contradicting pairs, one pair per line, formatted as 'X, Y'. "
"If no contradictions are found, respond with 'None'. Do not include any other text or explanation.\n\n"
f"Statements:\n{facts_numbered}"
)
try:
response = self.agent_llm.complete(prompt) # Use the more powerful agent LLM
response_text = response.text.strip()
logger.info(f"LLM response for contradictions: {response_text}")
if response_text.lower() == 'none':
logger.info("LLM reported no contradictions.")
return []
contradiction_pairs = []
lines = response_text.split("\n")
for line in lines:
line = line.strip()
if not line:
continue
try:
# Expect format like "1, 5"
parts = line.split(',')
if len(parts) == 2:
idx1 = int(parts[0].strip()) - 1
idx2 = int(parts[1].strip()) - 1
# Validate indices
if 0 <= idx1 < len(facts) and 0 <= idx2 < len(facts) and idx1 != idx2:
# Ensure pair order doesn't matter and avoid duplicates
pair = tuple(sorted((idx1, idx2)))
fact1 = facts[pair[0]]
fact2 = facts[pair[1]]
contradiction_str = f"{fact1} <> {fact2}"
if contradiction_str not in contradiction_pairs:
contradiction_pairs.append(contradiction_str)
logger.info(f"Identified contradiction: {contradiction_str}")
else:
logger.warning(f"Invalid index pair found in LLM contradiction response: {line}")
else:
logger.warning(f"Could not parse contradiction pair from LLM response line: {line}")
except ValueError:
logger.warning(f"Non-integer index found in LLM contradiction response line: {line}")
except Exception as parse_err:
logger.warning(f"Error parsing LLM contradiction response line {line}: {parse_err}")
logger.info(f"Contradiction check complete. Found {len(contradiction_pairs)} pairs.")
return contradiction_pairs
except Exception as e:
logger.error(f"LLM call failed during contradiction detection: {e}", exc_info=True)
# Option 1: Raise an error
raise VerificationError("LLM call failed during contradiction detection") from e
# Option 2: Return empty list (fail silently)
# return []
# --- Tool Definitions ---
# Tools need to be created within the initialization function to bind to the instance
# --- Agent Initialization ---
# Store the initializer instance globally to ensure singleton behavior
_verifier_initializer_instance = None
class VerifierInitializer:
def __init__(self):
self.verifier = Verifier() # Initialize the Verifier class
self._create_tools()
def _create_tools(self):
self.verify_facts_tool = FunctionTool.from_defaults(
fn=self.verifier.verify_facts, # Bind to instance method
name="verify_facts",
description=(
"Assigns a numerical confidence score (based on plausibility and internal consistency) to each factual assertion in a list. "
"Input: List[str] of statements. Output: List[str] of 'statement: score' pairs."
),
)
self.find_contradictions_tool = FunctionTool.from_defaults(
fn=self.verifier.find_contradictions_llm, # Bind to instance method (using LLM version)
name="find_contradictions",
description=(
"Uses an LLM to detect logical contradictions among a list of statements. "
"Input: List[str] of factual assertions. "
"Output: List[str] where each entry is a conflicting pair in the format 'statement1 <> statement2'. Returns empty list if none found."
)
)
def get_agent(self) -> FunctionAgent:
"""Initializes and returns the Verifier Agent."""
logger.info("Creating VerifierAgent FunctionAgent instance...")
# System prompt (consider loading from file)
system_prompt = """\
You are VerifierAgent, a fact verification assistant. Given a list of factual statements, you must:
1. **Verify Facts**: Call `verify_facts` to assign a confidence score to each statement.
2. **Detect Contradictions**: Call `find_contradictions` to identify logical conflicts between the statements using an LLM.
3. **Present Results**: Output clear bullet points listing each fact with its confidence score, followed by a list of any detected contradictions.
4. **Hand-Off**: If significant contradictions or low-confidence facts are found that require deeper analysis, hand off to **reasoning_agent**. Otherwise, pass the verified facts and contradiction summary to **planner_agent** for integration.
"""
agent = FunctionAgent(
name="verifier_agent",
description=(
"Evaluates factual statements by assigning confidence scores (`verify_facts`) "
"and detecting logical contradictions using an LLM (`find_contradictions`). "
"Hands off to reasoning_agent for complex issues or planner_agent for synthesis."
),
tools=[
self.verify_facts_tool,
self.find_contradictions_tool,
],
llm=self.verifier.agent_llm, # Use the agent LLM from the Verifier instance
system_prompt=system_prompt,
can_handoff_to=["reasoning_agent", "planner_agent", "advanced_validation_agent"],
)
logger.info("VerifierAgent FunctionAgent instance created.")
return agent
def get_verifier_initializer():
"""Gets the singleton instance of VerifierInitializer."""
global _verifier_initializer_instance
if _verifier_initializer_instance is None:
logger.info("Instantiating VerifierInitializer for the first time.")
_verifier_initializer_instance = VerifierInitializer()
return _verifier_initializer_instance
def initialize_verifier_agent() -> FunctionAgent:
"""Initializes and returns the Verifier Agent using a singleton initializer."""
logger.info("initialize_verifier_agent called.")
initializer = get_verifier_initializer()
return initializer.get_agent()
# Example usage (for testing if run directly)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger.info("Running verifier_agent.py directly for testing...")
# Ensure API key is set for testing
if not os.getenv("GEMINI_API_KEY"):
print("Error: GEMINI_API_KEY environment variable not set. Cannot run test.")
else:
try:
test_agent = initialize_verifier_agent()
print("Verifier Agent initialized successfully for testing.")
# Test contradiction detection
initializer = get_verifier_initializer()
test_facts = [
"The sky is blue.",
"Water boils at 100 degrees Celsius.",
"The sky is not blue.",
"Paris is the capital of France."
]
print(f"\nTesting contradiction detection on: {test_facts}")
contradictions = initializer.verifier.find_contradictions_llm(test_facts)
print(f"Detected contradictions: {contradictions}")
# Test fact verification
print(f"\nTesting fact verification on: {test_facts}")
verified = initializer.verifier.verify_facts(test_facts)
print(f"Verified facts: {verified}")
except Exception as e:
print(f"Error during testing: {e}")