import inspect from typing import Any, Callable import dsp import dspy import logging import uuid #################### Assertion Helpers #################### def setup_logger(): logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) fileHandler = logging.FileHandler("assertion.log") fileHandler.setLevel(logging.DEBUG) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) fileHandler.setFormatter(formatter) logger.addHandler(fileHandler) return logger logger = setup_logger() def _build_error_msg(feedback_msgs): """Build an error message from a list of feedback messages.""" return "\n".join([msg for msg in feedback_msgs]) #################### Assertion Exceptions #################### class DSPyAssertionError(AssertionError): """Custom exception raised when a DSPy `Assert` fails.""" def __init__( self, id: str, msg: str, target_module: Any = None, state: Any = None, is_metric: bool = False, ) -> None: super().__init__(msg) self.id = id self.msg = msg self.target_module = target_module self.state = state self.is_metric = is_metric class DSPySuggestionError(AssertionError): """Custom exception raised when a DSPy `Suggest` fails.""" def __init__( self, id: str, msg: str, target_module: Any = None, state: Any = None, is_metric: bool = False, ) -> None: super().__init__(msg) self.id = id self.msg = msg self.target_module = target_module self.state = state self.is_metric = is_metric #################### Assertion Primitives #################### class Constraint: def __init__( self, result: bool, msg: str = "", target_module=None, is_metric: bool = False ): self.id = str(uuid.uuid4()) self.result = result self.msg = msg self.target_module = target_module self.is_metric = is_metric self.__call__() class Assert(Constraint): """DSPy Assertion""" def __call__(self) -> bool: if isinstance(self.result, bool): if self.result: return True elif dspy.settings.bypass_assert: logger.error(f"AssertionError: {self.msg}") return True else: logger.error(f"AssertionError: {self.msg}") raise DSPyAssertionError( id=self.id, msg=self.msg, target_module=self.target_module, state=dsp.settings.trace, is_metric=self.is_metric, ) else: raise ValueError("Assertion function should always return [bool]") class Suggest(Constraint): """DSPy Suggestion""" def __call__(self) -> Any: if isinstance(self.result, bool): if self.result: return True elif dspy.settings.bypass_suggest: logger.error(f"SuggestionFailed: {self.msg}") return True else: logger.error(f"SuggestionFailed: {self.msg}") raise DSPySuggestionError( id=self.id, msg=self.msg, target_module=self.target_module, state=dsp.settings.trace, is_metric=self.is_metric, ) else: raise ValueError("Suggestion function should always return [bool]") #################### Assertion Handlers #################### def noop_handler(func): """Handler to bypass assertions and suggestions. Now both assertions and suggestions will become noops. """ def wrapper(*args, **kwargs): with dspy.settings.context(bypass_assert=True, bypass_suggest=True): return func(*args, **kwargs) return wrapper def bypass_suggest_handler(func): """Handler to bypass suggest only. If a suggestion fails, it will be logged but not raised. And If an assertion fails, it will be raised. """ def wrapper(*args, **kwargs): with dspy.settings.context(bypass_suggest=True, bypass_assert=False): return func(*args, **kwargs) return wrapper def bypass_assert_handler(func): """Handler to bypass assertion only. If a assertion fails, it will be logged but not raised. And If an assertion fails, it will be raised. """ def wrapper(*args, **kwargs): with dspy.settings.context(bypass_assert=True): return func(*args, **kwargs) return wrapper def assert_no_except_handler(func): """Handler to ignore assertion failure and return None.""" def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except DSPyAssertionError as e: return None return wrapper def backtrack_handler(func, bypass_suggest=True, max_backtracks=2): """Handler for backtracking suggestion and assertion. Re-run the latest predictor up to `max_backtracks` times, with updated signature if an assertion fails. updated signature adds a new input field to the signature, which is the feedback. """ def wrapper(*args, **kwargs): error_msg, result = None, None with dspy.settings.lock: dspy.settings.backtrack_to = None dspy.settings.suggest_failures = 0 dspy.settings.assert_failures = 0 # Predictor -> List[feedback_msg] dspy.settings.predictor_feedbacks = {} current_error = None for i in range(max_backtracks + 1): if i > 0 and dspy.settings.backtrack_to is not None: # generate values for new fields feedback_msg = _build_error_msg( dspy.settings.predictor_feedbacks[dspy.settings.backtrack_to] ) dspy.settings.backtrack_to_args = { "feedback": feedback_msg, "past_outputs": past_outputs, } # if last backtrack: ignore suggestion errors if i == max_backtracks: if isinstance(current_error, DSPyAssertionError): raise current_error dsp.settings.trace.clear() result = ( bypass_suggest_handler(func)(*args, **kwargs) if bypass_suggest else None ) break else: try: dsp.settings.trace.clear() # print("backtrack", dspy.settings.backtrack_to) result = func(*args, **kwargs) break except (DSPySuggestionError, DSPyAssertionError) as e: if not current_error: current_error = e error_id, error_msg, error_target_module, error_state = ( e.id, e.msg, e.target_module, e.state[-1], ) # increment failure count depending on type of error if isinstance(e, DSPySuggestionError) and e.is_metric: dspy.settings.suggest_failures += 1 elif isinstance(e, DSPyAssertionError) and e.is_metric: dspy.settings.assert_failures += 1 if dsp.settings.trace: if error_target_module: for i in range(len(dsp.settings.trace) - 1, -1, -1): trace_element = dsp.settings.trace[i] mod = trace_element[0] if mod.signature == error_target_module: dspy.settings.backtrack_to = mod break else: dspy.settings.backtrack_to = dsp.settings.trace[-1][0] if dspy.settings.backtrack_to is None: logger.error("Specified module not found in trace") # save unique feedback message for predictor if ( error_msg not in dspy.settings.predictor_feedbacks.setdefault( dspy.settings.backtrack_to, [] ) ): dspy.settings.predictor_feedbacks[ dspy.settings.backtrack_to ].append(error_msg) output_fields = vars(error_state[0].signature.signature) past_outputs = {} for field_name, field_obj in output_fields.items(): if isinstance(field_obj, dspy.OutputField): past_outputs[field_name] = getattr( error_state[2], field_name, None ) # save latest failure trace for predictor per suggestion error_ip = error_state[1] error_op = error_state[2].__dict__["_store"] error_op.pop("_assert_feedback", None) error_op.pop("_assert_traces", None) else: logger.error( f"UNREACHABLE: No trace available, this should not happen. Is this run time?" ) return result return wrapper def handle_assert_forward(assertion_handler, **handler_args): def forward(self, *args, **kwargs): args_to_vals = inspect.getcallargs(self._forward, *args, **kwargs) # if user has specified a bypass_assert flag, set it if "bypass_assert" in args_to_vals: dspy.settings.configure(bypass_assert=args_to_vals["bypass_assert"]) wrapped_forward = assertion_handler(self._forward, **handler_args) return wrapped_forward(*args, **kwargs) return forward default_assertion_handler = backtrack_handler def assert_transform_module( module, assertion_handler=default_assertion_handler, **handler_args ): """ Transform a module to handle assertions. """ if not getattr(module, "forward", False): raise ValueError( "Module must have a forward method to have assertions handled." ) if getattr(module, "_forward", False): logger.info( f"Module {module.__class__.__name__} already has a _forward method. Skipping..." ) pass # TODO warning: might be overwriting a previous _forward method module._forward = module.forward module.forward = handle_assert_forward(assertion_handler, **handler_args).__get__( module ) if all( map(lambda p: isinstance(p[1], dspy.retry.Retry), module.named_predictors()) ): pass # we already applied the Retry mapping outside elif all( map(lambda p: not isinstance(p[1], dspy.retry.Retry), module.named_predictors()) ): module.map_named_predictors(dspy.retry.Retry) else: raise RuntimeError("Module has mixed predictors, can't apply Retry mapping.") setattr(module, "_assert_transformed", True) return module