Spaces:
Sleeping
Sleeping
File size: 2,907 Bytes
2d0a9c0 838033a 2d0a9c0 1c6511e 2d0a9c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import streamlit as st
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import os
from os import system
from graphviz import Source
import dtreeviz
import base64
# Load the Iris dataset from scikit-learn
iris = load_iris()
X = iris.data[:, :2] # Using only two features for visualization purposes
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
st.sidebar.markdown("# Decision Tree Classifier")
criterion = st.sidebar.selectbox(
'Criterion',
('gini', 'entropy')
)
splitter = st.sidebar.selectbox(
'Splitter',
('best', 'random')
)
max_depth = st.sidebar.slider('Max Depth', 1, 20)
min_samples_split = st.sidebar.slider('Min Samples Split', 2, 100, 2)
min_samples_leaf = st.sidebar.slider('Min Samples Leaf', 1, 100, 1)
max_features = st.sidebar.slider('Max Features', 1, 4, 2)
max_leaf_nodes = st.sidebar.slider('Max Leaf Nodes', 2, 50)
min_impurity_decrease = st.sidebar.number_input('Min Impurity Decrease')
# Rest of your sidebar inputs...
# Load initial graph
fig, ax = plt.subplots()
# Plot initial graph
ax.scatter(X[:, 0], X[:, 1], c=y, cmap='rainbow')
orig = st.pyplot(fig)
if st.sidebar.button('Run Algorithm'):
orig.empty()
clf = DecisionTreeClassifier(criterion=criterion, splitter=splitter, max_depth=max_depth, random_state=42,
min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf,
max_features=max_features, max_leaf_nodes=max_leaf_nodes,
min_impurity_decrease=min_impurity_decrease)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
viz_model = dtreeviz(clf,
X_train=X, y_train=y,
feature_names=iris.feature_names,
target_name='iris',
class_names=iris.target_names)
st.graphviz_chart(viz_model) # render as SVG into internal object
# v.show() # pop up window
# v.save("/tmp/iris.svg")
def svg_write(svg, center=True):
"""
Disable center to left-margin align like other objects.
"""
# Encode as base 64
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
# Add some CSS on top
css_justify = "center" if center else "left"
css = f'<p style="text-align:center; display: flex; justify-content: {css_justify};">'
html = f'{css}<img src="data:image/svg+xml;base64,{b64}"/>'
# Write the HTML
st.write(html, unsafe_allow_html=True)
st.write(v)
svg=v.svg()
# svg_write(svg)
st.subheader("Accuracy for Decision Tree: " + str(round(accuracy_score(y_test, y_pred), 2))) |