|
import copy |
|
import inspect |
|
|
|
from dspy.primitives.module import BaseModule |
|
from dspy.primitives.assertions import * |
|
import re |
|
|
|
|
|
class ProgramMeta(type): |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Module(BaseModule, metaclass=ProgramMeta): |
|
|
|
def _base_init(self): |
|
self._compiled = False |
|
|
|
def __init__(self): |
|
self._compiled = False |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.forward(*args, **kwargs) |
|
|
|
def named_predictors(self): |
|
from dspy.predict.predict import Predict |
|
|
|
named_parameters = self.named_parameters() |
|
return [ |
|
(name, param) |
|
for name, param in named_parameters |
|
if isinstance(param, Predict) |
|
] |
|
|
|
def predictors(self): |
|
return [param for _, param in self.named_predictors()] |
|
|
|
def __repr__(self): |
|
s = [] |
|
|
|
for name, param in self.named_predictors(): |
|
s.append(f"{name} = {param}") |
|
|
|
return "\n".join(s) |
|
|
|
def map_named_predictors(self, func): |
|
"""Applies a function to all named predictors.""" |
|
for name, predictor in self.named_predictors(): |
|
set_attribute_by_name(self, name, func(predictor)) |
|
return self |
|
|
|
def activate_assertions(self, handler=backtrack_handler, **handler_args): |
|
""" |
|
Activates assertions for the module. |
|
The default handler is the backtrack_handler. |
|
""" |
|
assert_transform_module(self, handler, **handler_args) |
|
return self |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_attribute_by_name(obj, name, value): |
|
|
|
module_pattern = re.compile(r"^([^.]+)\.(.+)$") |
|
list_pattern = re.compile(r"^([^\[]+)\[([0-9]+)\]$") |
|
dict_pattern = re.compile(r"^([^\[]+)\['([^']+)'\]$") |
|
|
|
|
|
module_match = module_pattern.match(name) |
|
if module_match: |
|
module_name, sub_name = module_match.groups() |
|
sub_obj = getattr(obj, module_name) |
|
set_attribute_by_name(sub_obj, sub_name, value) |
|
return |
|
|
|
|
|
list_match = list_pattern.match(name) |
|
if list_match: |
|
list_name, index = list_match.groups() |
|
getattr(obj, list_name)[int(index)] = value |
|
return |
|
|
|
|
|
dict_match = dict_pattern.match(name) |
|
if dict_match: |
|
dict_name, key = dict_match.groups() |
|
getattr(obj, dict_name)[key] = value |
|
return |
|
|
|
|
|
setattr(obj, name, value) |
|
|
|
|
|
Program = Module |
|
|