File size: 2,907 Bytes
2d0a9c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838033a
2d0a9c0
 
 
 
 
1c6511e
2d0a9c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import os
from os import system
from graphviz import Source
import dtreeviz
import base64


# Load the Iris dataset from scikit-learn
iris = load_iris()
X = iris.data[:, :2]  # Using only two features for visualization purposes
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
st.sidebar.markdown("# Decision Tree Classifier")
criterion = st.sidebar.selectbox(
   'Criterion',
   ('gini', 'entropy')
)
splitter = st.sidebar.selectbox(
   'Splitter',
   ('best', 'random')
)
max_depth = st.sidebar.slider('Max Depth', 1, 20)
min_samples_split = st.sidebar.slider('Min Samples Split', 2, 100, 2)
min_samples_leaf = st.sidebar.slider('Min Samples Leaf', 1, 100, 1)
max_features = st.sidebar.slider('Max Features', 1, 4, 2)
max_leaf_nodes = st.sidebar.slider('Max Leaf Nodes', 2, 50)
min_impurity_decrease = st.sidebar.number_input('Min Impurity Decrease')
# Rest of your sidebar inputs...
# Load initial graph
fig, ax = plt.subplots()
# Plot initial graph
ax.scatter(X[:, 0], X[:, 1], c=y, cmap='rainbow')
orig = st.pyplot(fig)
if st.sidebar.button('Run Algorithm'):
    orig.empty()
    clf = DecisionTreeClassifier(criterion=criterion, splitter=splitter, max_depth=max_depth, random_state=42,
                                min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf,
                                max_features=max_features, max_leaf_nodes=max_leaf_nodes,
                                min_impurity_decrease=min_impurity_decrease)
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    viz_model = dtreeviz(clf,
                           X_train=X, y_train=y,
                           feature_names=iris.feature_names,
                           target_name='iris',
                           class_names=iris.target_names)

    st.graphviz_chart(viz_model)    # render as SVG into internal object 
    # v.show()                 # pop up window
    # v.save("/tmp/iris.svg")
    def svg_write(svg, center=True):
       """
       Disable center to left-margin align like other objects.
       """
       # Encode as base 64
       b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
   
       # Add some CSS on top
       css_justify = "center" if center else "left"
       css = f'<p style="text-align:center; display: flex; justify-content: {css_justify};">'
       html = f'{css}<img src="data:image/svg+xml;base64,{b64}"/>'
   
       # Write the HTML
       st.write(html, unsafe_allow_html=True)
    st.write(v)
    svg=v.svg()
    # svg_write(svg)
    st.subheader("Accuracy for Decision Tree: " + str(round(accuracy_score(y_test, y_pred), 2)))