EureCA / dspy /primitives /program.py
tonneli's picture
Delete history
f5776d3
import copy
import inspect
from dspy.primitives.module import BaseModule
from dspy.primitives.assertions import *
import re
class ProgramMeta(type):
pass
# def __call__(cls, *args, **kwargs):
# obj = super(ProgramMeta, cls).__call__(*args, **kwargs)
# if issubclass(cls, Program) and not getattr(obj, "_program_init_called", False):
# obj._base_init()
# obj._program_init_called = True
# return obj
class Module(BaseModule, metaclass=ProgramMeta):
def _base_init(self):
self._compiled = False
def __init__(self):
self._compiled = False
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def named_predictors(self):
from dspy.predict.predict import Predict
named_parameters = self.named_parameters()
return [
(name, param)
for name, param in named_parameters
if isinstance(param, Predict)
]
def predictors(self):
return [param for _, param in self.named_predictors()]
def __repr__(self):
s = []
for name, param in self.named_predictors():
s.append(f"{name} = {param}")
return "\n".join(s)
def map_named_predictors(self, func):
"""Applies a function to all named predictors."""
for name, predictor in self.named_predictors():
set_attribute_by_name(self, name, func(predictor))
return self
def activate_assertions(self, handler=backtrack_handler, **handler_args):
"""
Activates assertions for the module.
The default handler is the backtrack_handler.
"""
assert_transform_module(self, handler, **handler_args)
return self
# def __deepcopy__(self, memo):
# # memo is a dict of id's to copies already made during the current call
# # Check if the object is already copied
# if id(self) in memo:
# return memo[id(self)]
# print(f"Deep copying {self.__class__.__name__}...")
# new_copy = copy.copy(self)
# memo[id(self)] = new_copy
# for k, v in self.__dict__.items():
# print(f"Copying attribute {k} of type {type(v)}...")
# setattr(new_copy, k, copy.deepcopy(v, memo))
# print("Done")
# return new_copy
# FIXME(Shangyint): This may cause some problems for nested patterns.
def set_attribute_by_name(obj, name, value):
# Regular expressions for different patterns
module_pattern = re.compile(r"^([^.]+)\.(.+)$")
list_pattern = re.compile(r"^([^\[]+)\[([0-9]+)\]$")
dict_pattern = re.compile(r"^([^\[]+)\['([^']+)'\]$")
# Match for module.attribute pattern
module_match = module_pattern.match(name)
if module_match:
module_name, sub_name = module_match.groups()
sub_obj = getattr(obj, module_name)
set_attribute_by_name(sub_obj, sub_name, value)
return
# Match for list[index] pattern
list_match = list_pattern.match(name)
if list_match:
list_name, index = list_match.groups()
getattr(obj, list_name)[int(index)] = value
return
# Match for dict['key'] pattern
dict_match = dict_pattern.match(name)
if dict_match:
dict_name, key = dict_match.groups()
getattr(obj, dict_name)[key] = value
return
# Default case for simple attributes
setattr(obj, name, value)
Program = Module