EduardoPacheco commited on
Commit
f2b8171
·
1 Parent(s): a6272af

Utilities to run app

Browse files
Files changed (1) hide show
  1. utils.py +162 -0
utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ import plotly.graph_objects as go
8
+ from sklearn.ensemble import GradientBoostingRegressor
9
+
10
+ class DataGenerator:
11
+ def __init__(self, formula_str: str, x_range: list, n_samples: int, seed: int) -> None:
12
+ self.formula_str = formula_str
13
+ self.x_range = x_range
14
+ self.n_samples = n_samples
15
+ self.seed = seed
16
+ self.rng = np.random.RandomState(seed)
17
+
18
+ @property
19
+ def X(self) -> np.array:
20
+ self.rng = np.random.RandomState(42)
21
+ X = np.atleast_2d(self.rng.uniform(*self.x_range, size=self.n_samples)).T
22
+ return X
23
+
24
+ @property
25
+ def y_raw(self) -> np.array:
26
+ y_raw = self._eval_formula()
27
+ return y_raw.ravel()
28
+
29
+ @property
30
+ def y(self) -> np.array:
31
+ sigma = 0.5 + self.X.ravel() / 10
32
+ noise = self.rng.lognormal(sigma=sigma) - np.exp(sigma**2 / 2)
33
+ return self.y_raw + noise
34
+
35
+ def _eval_formula(self) -> np.array:
36
+ function_map = {
37
+ 'sin': "np.sin",
38
+ 'cos': "np.cos",
39
+ 'tan': "np.tan",
40
+ 'exp': "np.exp",
41
+ 'log': "np.log",
42
+ 'sqrt': "np.sqrt",
43
+ 'abs': "np.abs",
44
+ }
45
+ # Replace "x" in the formula string with "x_values"
46
+ _formula_str = re.sub(r'\bx\b', '(self.X)', self.formula_str)
47
+ # Replace any function calls in the formula string with the appropriate function object
48
+ _formula_str = re.sub(r'(\w+)\(([^)]*)\)', lambda m: f'{function_map[m.group(1)]}({m.group(2)})', _formula_str)
49
+ # Evaluate the formula using the updated string and return the result
50
+ return eval(_formula_str)
51
+
52
+ class GradientBoostingCoverage:
53
+ def __init__(self, lower: float, upper: float, **kwargs) -> None:
54
+ self.lower = lower
55
+ self.upper = upper
56
+ self.kwargs = kwargs
57
+ self.models = self._build_models()
58
+
59
+ @property
60
+ def expected_coverage(self) -> float:
61
+ return self.upper - self.lower
62
+
63
+ def _build_models(self) -> dict[str, GradientBoostingRegressor]:
64
+ models = {}
65
+ for name, alpha in [("lower", self.lower), ("upper", self.upper)]:
66
+ models[f"{name}"] = GradientBoostingRegressor(loss="quantile", alpha=alpha, **self.kwargs)
67
+ return models
68
+
69
+ def fit(self, X: np.ndarray, y: np.array) -> None:
70
+ for model in self.models.values():
71
+ model.fit(X, y)
72
+
73
+ def predict(self, X: np.ndarray) -> tuple[np.array, np.array]:
74
+ lower = self.models["lower"].predict(X)
75
+ upper = self.models["upper"].predict(X)
76
+ return lower, upper
77
+
78
+ def coverage_fraction(self, X: np.ndarray, y: np.array) -> float:
79
+ y_low, y_high = self.predict(X)
80
+ return np.mean(np.logical_and(y >= y_low, y <= y_high))
81
+
82
+
83
+ def fit_gradientboosting(X, y, **kwargs) -> GradientBoostingRegressor:
84
+ model = GradientBoostingRegressor(**kwargs)
85
+ model.fit(X, y)
86
+ return model
87
+
88
+ def plot_interval(
89
+ xx: np.array,
90
+ X_test: np.array,
91
+ y_test: np.array,
92
+ y_upper: np.array,
93
+ y_lower: np.array,
94
+ y_med: np.array,
95
+ y_mean: np.array,
96
+ formula_str: Optional[str]=None,
97
+ interval: Optional[Union[int, str]]=None,
98
+ ) -> go.Figure:
99
+ # Using plotly to plot an interval
100
+ fig = go.Figure()
101
+
102
+ fig.add_trace(
103
+ go.Scatter(
104
+ x=xx.ravel(),
105
+ y=y_upper,
106
+ fill=None,
107
+ mode="lines",
108
+ line_color="rgba(255,255,0,0)",
109
+ name=""
110
+ )
111
+ )
112
+
113
+ fig.add_trace(
114
+ go.Scatter(
115
+ x=xx.ravel(),
116
+ y=y_lower,
117
+ fill="tonexty",
118
+ mode="lines",
119
+ line_color="rgba(255,255,0,0)",
120
+ name=f"Predicted Interval"
121
+ )
122
+ )
123
+
124
+ fig.add_trace(
125
+ go.Scatter(
126
+ x=xx.ravel(),
127
+ y=y_med,
128
+ mode="lines",
129
+ line_color="red",
130
+ name='Predicted Median',
131
+ )
132
+ )
133
+
134
+ fig.add_trace(
135
+ go.Scatter(
136
+ x=xx.ravel(),
137
+ y=y_mean,
138
+ mode="lines",
139
+ name='Predicted Mean',
140
+ line=dict(color='red', dash='dash')
141
+ )
142
+ )
143
+
144
+ fig.add_trace(
145
+ go.Scatter(
146
+ x=X_test.ravel(),
147
+ y=y_test,
148
+ mode="markers",
149
+ marker_color="blue",
150
+ name="Test Observations",
151
+ marker=dict(size=5, line=dict(width=2, color="DarkSlateGrey"))
152
+ )
153
+ )
154
+
155
+ fig.update_layout(
156
+ title=f"Predicted {interval}% Interval",
157
+ xaxis_title="x",
158
+ yaxis_title="f(x)" if not formula_str else formula_str,
159
+ height=600
160
+ )
161
+
162
+ return fig