import streamlit as st import numpy as np from sklearn.model_selection import train_test_split from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import accuracy_score import matplotlib.pyplot as plt import os from os import system from graphviz import Source import dtreeviz import base64 # Load the Iris dataset from scikit-learn iris = load_iris() X = iris.data[:, :2] # Using only two features for visualization purposes y = iris.target X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) st.sidebar.markdown("# Decision Tree Classifier") criterion = st.sidebar.selectbox( 'Criterion', ('gini', 'entropy') ) splitter = st.sidebar.selectbox( 'Splitter', ('best', 'random') ) max_depth = st.sidebar.slider('Max Depth', 1, 20) min_samples_split = st.sidebar.slider('Min Samples Split', 2, 100, 2) min_samples_leaf = st.sidebar.slider('Min Samples Leaf', 1, 100, 1) max_features = st.sidebar.slider('Max Features', 1, 4, 2) max_leaf_nodes = st.sidebar.slider('Max Leaf Nodes', 2, 50) min_impurity_decrease = st.sidebar.number_input('Min Impurity Decrease') # Rest of your sidebar inputs... # Load initial graph fig, ax = plt.subplots() # Plot initial graph ax.scatter(X[:, 0], X[:, 1], c=y, cmap='rainbow') orig = st.pyplot(fig) if st.sidebar.button('Run Algorithm'): orig.empty() clf = DecisionTreeClassifier(criterion=criterion, splitter=splitter, max_depth=max_depth, random_state=42, min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, max_features=max_features, max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=min_impurity_decrease) clf.fit(X_train, y_train) y_pred = clf.predict(X_test) viz_model = dtreeviz(clf, X_train=X, y_train=y, feature_names=iris.feature_names, target_name='iris', class_names=iris.target_names) st.graphviz_chart(viz_model) # render as SVG into internal object # v.show() # pop up window # v.save("/tmp/iris.svg") def svg_write(svg, center=True): """ Disable center to left-margin align like other objects. """ # Encode as base 64 b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") # Add some CSS on top css_justify = "center" if center else "left" css = f'

' html = f'{css}' # Write the HTML st.write(html, unsafe_allow_html=True) st.write(v) svg=v.svg() # svg_write(svg) st.subheader("Accuracy for Decision Tree: " + str(round(accuracy_score(y_test, y_pred), 2)))