EduardoPacheco commited on
Commit
c820b57
·
1 Parent(s): 5e03784

App itself

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import plotly.graph_objects as go
4
+ from sklearn.datasets import make_circles
5
+ from sklearn.naive_bayes import BernoulliNB
6
+ from sklearn.decomposition import TruncatedSVD
7
+ from sklearn.ensemble import RandomTreesEmbedding, ExtraTreesClassifier
8
+
9
+
10
+ def plot_scatter(X, y, title):
11
+ fig = go.Figure()
12
+
13
+ fig.add_trace(
14
+ go.Scatter(
15
+ x=X[:, 0],
16
+ y=X[:, 1],
17
+ mode="markers",
18
+ marker=dict(color=y, size=10, colorscale="Viridis", line=dict(width=1)),
19
+ )
20
+ )
21
+
22
+ fig.update_layout(
23
+ title=title,
24
+ xaxis=dict(showticklabels=False),
25
+ yaxis=dict(showticklabels=False)
26
+ )
27
+
28
+ return fig
29
+
30
+ def plot_decision_boundary(X, y, model, data_preprocess=None, title=None):
31
+ # Creating Grid
32
+ h = 0.01
33
+ x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
34
+ y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
35
+ xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
36
+ grid = np.c_[xx.ravel(), yy.ravel()]
37
+
38
+ # Creating Contour
39
+ if data_preprocess:
40
+ grid = data_preprocess.transform(grid)
41
+ y_grid_pred = model.predict_proba(grid)[:, 1]
42
+
43
+ # Plotting
44
+ fig = go.Figure()
45
+ fig.add_trace(
46
+ go.Heatmap(
47
+ x=np.arange(x_min, x_max, h),
48
+ y=np.arange(y_min, y_max, h),
49
+ z=y_grid_pred.reshape(xx.shape),
50
+ colorscale="Viridis",
51
+ opacity=0.8,
52
+ showscale=False
53
+ )
54
+ )
55
+
56
+ fig.add_trace(
57
+ go.Scatter(
58
+ x=X[:, 0],
59
+ y=X[:, 1],
60
+ mode="markers",
61
+ marker=dict(color=y, size=10, colorscale="Viridis", line=dict(width=1)),
62
+ )
63
+ )
64
+
65
+ fig.update_layout(
66
+ title=title if title else "Decision Boundary",
67
+ xaxis=dict(showticklabels=False),
68
+ yaxis=dict(showticklabels=False)
69
+ )
70
+
71
+ return fig
72
+
73
+
74
+
75
+ def app_fn(
76
+ factor: float,
77
+ random_state: int,
78
+ noise:float,
79
+ n_estimators: int,
80
+ max_depth: int
81
+ ):
82
+ # make a synthetic dataset
83
+ X, y = make_circles(factor=factor, random_state=random_state, noise=noise)
84
+
85
+ # use RandomTreesEmbedding to transform data
86
+ hasher = RandomTreesEmbedding(n_estimators=n_estimators, random_state=random_state, max_depth=max_depth)
87
+ X_transformed = hasher.fit_transform(X)
88
+
89
+ # Visualize result after dimensionality reduction using truncated SVD
90
+ svd = TruncatedSVD(n_components=2)
91
+ X_reduced = svd.fit_transform(X_transformed)
92
+
93
+ # Learn a Naive Bayes classifier on the transformed data
94
+ nb = BernoulliNB()
95
+ nb.fit(X_transformed, y)
96
+
97
+ # Learn an ExtraTreesClassifier for comparison
98
+ trees = ExtraTreesClassifier(max_depth=max_depth, n_estimators=n_estimators, random_state=random_state)
99
+ trees.fit(X, y)
100
+
101
+ # Plotting Original Data
102
+ fig1 = plot_scatter(X, y, "Original Data")
103
+ fig2 = plot_scatter(X_reduced, y, f"Truncated SVD Reduction (2D) of Transformed Data ({X_transformed.shape[1]})")
104
+ fig3 = plot_decision_boundary(X, y, nb, hasher, "Naive Bayes Decision Boundary")
105
+ fig4 = plot_decision_boundary(X, y, trees, title="Extra Trees Decision Boundary")
106
+
107
+ return fig1, fig2, fig3, fig4
108
+
109
+ title = "Hashing Feature Transformation using Totally Random Trees"
110
+ with gr.Blocks() as demo:
111
+ gr.Markdown(f"# {title}")
112
+ gr.Markdown(
113
+ """
114
+ ### RandomTreesEmbedding provides a way to map data to a very high-dimensional, \
115
+ sparse representation, which might be beneficial for classification. \
116
+ The mapping is completely unsupervised and very efficient.
117
+
118
+ ### This example visualizes the partitions given by several trees and shows how \
119
+ the transformation can also be used for non-linear dimensionality reduction \
120
+ or non-linear classification.
121
+
122
+ ### Points that are neighboring often share the same leaf of a \
123
+ tree and therefore share large parts of their hashed representation. \
124
+ This allows to separate two concentric circles simply based on \
125
+ the principal components of the transformed data with truncated SVD.
126
+
127
+ ### In high-dimensional spaces, linear classifiers often achieve excellent \
128
+ accuracy. For sparse binary data, BernoulliNB is particularly well-suited. \
129
+ The bottom row compares the decision boundary obtained by BernoulliNB in the \
130
+ transformed space with an ExtraTreesClassifier forests learned on the original data.
131
+
132
+ [Original Example](https://scikit-learn.org/stable/auto_examples/ensemble/plot_random_forest_embedding.html#sphx-glr-auto-examples-ensemble-plot-random-forest-embedding-py)
133
+ """
134
+ )
135
+ with gr.Row():
136
+ factor = gr.inputs.Slider(minimum=0.05, maximum=1.0, step=0.01, default=0.5, label="Factor")
137
+ noise = gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.05, label="Noise")
138
+ n_estimators = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=10, label="Number of Estimators")
139
+ max_depth = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=3, label="Max Depth")
140
+ random_state = gr.inputs.Slider(minimum=0, maximum=100, step=1, default=0, label="Random State")
141
+ btn = gr.Button(label="Run")
142
+ with gr.Row():
143
+ plot1 = gr.Plot(label="Origianl Data")
144
+ plot2 = gr.Plot(label="Truncated Date")
145
+ with gr.Row():
146
+ plot3 = gr.Plot(label="Naive Bayes Decision Boundary")
147
+ plot4 = gr.Plot(label="Extra Trees Decision Boundary")
148
+
149
+ btn.click(app_fn, outputs=[plot1, plot2, plot3, plot4], inputs=[factor, random_state, noise, n_estimators, max_depth])
150
+ demo.load(app_fn, inputs=[factor, random_state, noise, n_estimators, max_depth], outputs=[plot1, plot2, plot3, plot4])
151
+
152
+ demo.launch()