EduardoPacheco commited on
Commit
0864129
·
1 Parent(s): aceec91

Gradio app to run example

Browse files
Files changed (1) hide show
  1. app.py +78 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import plotly.graph_objects as go
4
+ from sklearn.datasets import load_diabetes
5
+ from sklearn.ensemble import GradientBoostingRegressor
6
+ from sklearn.ensemble import RandomForestRegressor
7
+ from sklearn.linear_model import LinearRegression
8
+ from sklearn.ensemble import VotingRegressor
9
+
10
+
11
+ def plot_votes(preds: list[tuple[str, np.array]], markers: list[str]=None) -> go.Figure:
12
+ fig = go.Figure()
13
+
14
+ for idx, (name, pred) in enumerate(preds):
15
+ if not markers:
16
+ symbol = "diamond"
17
+ else:
18
+ symbol = markers[idx]
19
+ fig.add_trace(
20
+ go.Scatter(
21
+ y=pred,
22
+ mode="markers",
23
+ name=name,
24
+ marker=dict(symbol=symbol, size=10, line=dict(width=2, color="DarkSlateGrey"))
25
+ )
26
+ )
27
+ fig.update_layout(
28
+ title="Regressor predictions and their average",
29
+ yaxis_title="Predicted",
30
+ xaxis_title="Training Samples",
31
+ height=500,
32
+ width=1000,
33
+ xaxis=dict(showticklabels=False),
34
+ hovermode="x unified"
35
+ )
36
+
37
+ return fig
38
+
39
+
40
+ def app_fn(n: int) -> go.Figure:
41
+ X, y = load_diabetes(return_X_y=True)
42
+
43
+ # Train classifiers
44
+ reg1 = GradientBoostingRegressor(random_state=1)
45
+ reg2 = RandomForestRegressor(random_state=1)
46
+ reg3 = LinearRegression()
47
+
48
+ reg1.fit(X, y)
49
+ reg2.fit(X, y)
50
+ reg3.fit(X, y)
51
+
52
+ ereg = VotingRegressor([("gb", reg1), ("rf", reg2), ("lr", reg3)])
53
+ ereg.fit(X, y)
54
+
55
+ xt = X[:n]
56
+
57
+ pred1 = reg1.predict(xt)
58
+ pred2 = reg2.predict(xt)
59
+ pred3 = reg3.predict(xt)
60
+ pred4 = ereg.predict(xt)
61
+
62
+ preds = [
63
+ ("Gradient Boosting", pred1),
64
+ ("Random Forest", pred2),
65
+ ("Linear Regression", pred3),
66
+ ("Voting Regressor", pred4)
67
+ ]
68
+ markers = ["diamond-tall", "triangle-up", "square", "star"]
69
+ fig = plot_votes(preds, markers)
70
+
71
+ return fig
72
+
73
+ with gr.Blocks() as demo:
74
+ n = gr.inputs.Slider(10, 30, 5, 20, "Number of training samples")
75
+ plot = gr.Plot(label="Individual & Voting Predictions")
76
+ button = gr.Button(label="Update Plot")
77
+ button.click(fn=app_fn, inputs=[n], outputs=[plot])
78
+ demo.load(fn=app_fn, inputs=[n], outputs=[plot])