File size: 3,488 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import copy
import inspect
from dspy.primitives.module import BaseModule
from dspy.primitives.assertions import *
import re
class ProgramMeta(type):
pass
# def __call__(cls, *args, **kwargs):
# obj = super(ProgramMeta, cls).__call__(*args, **kwargs)
# if issubclass(cls, Program) and not getattr(obj, "_program_init_called", False):
# obj._base_init()
# obj._program_init_called = True
# return obj
class Module(BaseModule, metaclass=ProgramMeta):
def _base_init(self):
self._compiled = False
def __init__(self):
self._compiled = False
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def named_predictors(self):
from dspy.predict.predict import Predict
named_parameters = self.named_parameters()
return [
(name, param)
for name, param in named_parameters
if isinstance(param, Predict)
]
def predictors(self):
return [param for _, param in self.named_predictors()]
def __repr__(self):
s = []
for name, param in self.named_predictors():
s.append(f"{name} = {param}")
return "\n".join(s)
def map_named_predictors(self, func):
"""Applies a function to all named predictors."""
for name, predictor in self.named_predictors():
set_attribute_by_name(self, name, func(predictor))
return self
def activate_assertions(self, handler=backtrack_handler, **handler_args):
"""
Activates assertions for the module.
The default handler is the backtrack_handler.
"""
assert_transform_module(self, handler, **handler_args)
return self
# def __deepcopy__(self, memo):
# # memo is a dict of id's to copies already made during the current call
# # Check if the object is already copied
# if id(self) in memo:
# return memo[id(self)]
# print(f"Deep copying {self.__class__.__name__}...")
# new_copy = copy.copy(self)
# memo[id(self)] = new_copy
# for k, v in self.__dict__.items():
# print(f"Copying attribute {k} of type {type(v)}...")
# setattr(new_copy, k, copy.deepcopy(v, memo))
# print("Done")
# return new_copy
# FIXME(Shangyint): This may cause some problems for nested patterns.
def set_attribute_by_name(obj, name, value):
# Regular expressions for different patterns
module_pattern = re.compile(r"^([^.]+)\.(.+)$")
list_pattern = re.compile(r"^([^\[]+)\[([0-9]+)\]$")
dict_pattern = re.compile(r"^([^\[]+)\['([^']+)'\]$")
# Match for module.attribute pattern
module_match = module_pattern.match(name)
if module_match:
module_name, sub_name = module_match.groups()
sub_obj = getattr(obj, module_name)
set_attribute_by_name(sub_obj, sub_name, value)
return
# Match for list[index] pattern
list_match = list_pattern.match(name)
if list_match:
list_name, index = list_match.groups()
getattr(obj, list_name)[int(index)] = value
return
# Match for dict['key'] pattern
dict_match = dict_pattern.match(name)
if dict_match:
dict_name, key = dict_match.groups()
getattr(obj, dict_name)[key] = value
return
# Default case for simple attributes
setattr(obj, name, value)
Program = Module
|