rahul5035 commited on
Commit
2d0a9c0
·
1 Parent(s): 9be5306

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from sklearn.model_selection import train_test_split
4
+ from sklearn.datasets import load_iris
5
+ from sklearn.tree import DecisionTreeClassifier
6
+ from sklearn.metrics import accuracy_score
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+ from os import system
10
+ from graphviz import Source
11
+ import dtreeviz
12
+ import base64
13
+
14
+
15
+ os.environ["PATH"] += os.pathsep + '/home/adminuser/venv/lib/python3.9/site-packages/graphviz/bin/'
16
+ # Load the Iris dataset from scikit-learn
17
+ iris = load_iris()
18
+ X = iris.data[:, :2] # Using only two features for visualization purposes
19
+ y = iris.target
20
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
21
+ st.sidebar.markdown("# Decision Tree Classifier")
22
+ criterion = st.sidebar.selectbox(
23
+ 'Criterion',
24
+ ('gini', 'entropy')
25
+ )
26
+ splitter = st.sidebar.selectbox(
27
+ 'Splitter',
28
+ ('best', 'random')
29
+ )
30
+ max_depth = st.sidebar.slider('Max Depth', 1, 20)
31
+ min_samples_split = st.sidebar.slider('Min Samples Split', 2, 100, 2)
32
+ min_samples_leaf = st.sidebar.slider('Min Samples Leaf', 1, 100, 1)
33
+ max_features = st.sidebar.slider('Max Features', 1, 4, 2)
34
+ max_leaf_nodes = st.sidebar.slider('Max Leaf Nodes', 2, 50)
35
+ min_impurity_decrease = st.sidebar.number_input('Min Impurity Decrease')
36
+ # Rest of your sidebar inputs...
37
+ # Load initial graph
38
+ fig, ax = plt.subplots()
39
+ # Plot initial graph
40
+ ax.scatter(X[:, 0], X[:, 1], c=y, cmap='rainbow')
41
+ orig = st.pyplot(fig)
42
+ if st.sidebar.button('Run Algorithm'):
43
+ orig.empty()
44
+ clf = DecisionTreeClassifier(criterion=criterion, splitter=splitter, max_depth=max_depth, random_state=42,
45
+ min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf,
46
+ max_features=max_features, max_leaf_nodes=max_leaf_nodes,
47
+ min_impurity_decrease=min_impurity_decrease)
48
+ clf.fit(X_train, y_train)
49
+ y_pred = clf.predict(X_test)
50
+ viz_model = dtreeviz.model(clf,
51
+ X_train=X, y_train=y,
52
+ feature_names=iris.feature_names,
53
+ target_name='iris',
54
+ class_names=iris.target_names)
55
+
56
+ v = viz_model.view() # render as SVG into internal object
57
+ # v.show() # pop up window
58
+ # v.save("/tmp/iris.svg")
59
+ def svg_write(svg, center=True):
60
+ """
61
+ Disable center to left-margin align like other objects.
62
+ """
63
+ # Encode as base 64
64
+ b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
65
+
66
+ # Add some CSS on top
67
+ css_justify = "center" if center else "left"
68
+ css = f'<p style="text-align:center; display: flex; justify-content: {css_justify};">'
69
+ html = f'{css}<img src="data:image/svg+xml;base64,{b64}"/>'
70
+
71
+ # Write the HTML
72
+ st.write(html, unsafe_allow_html=True)
73
+ st.write(v)
74
+ svg=v.svg()
75
+ # svg_write(svg)
76
+ st.subheader("Accuracy for Decision Tree: " + str(round(accuracy_score(y_test, y_pred), 2)))