Spaces:
Sleeping
Sleeping
# %% | |
import gradio as gr | |
from sklearn.ensemble import RandomForestRegressor | |
import numpy as np | |
import pandas as pd | |
import pickle | |
# Define model filename | |
model_filename = "random_forest_regression_extended.pkl" | |
try: | |
# Try to load the model | |
with open(model_filename, 'rb') as f: | |
model_data = pickle.load(f) | |
if isinstance(model_data, dict) and 'model' in model_data and 'feature_names' in model_data: | |
random_forest_model = model_data['model'] | |
feature_names = model_data['feature_names'] | |
# Check scikit-learn version and handle feature information | |
if hasattr(random_forest_model, 'n_features_in_'): | |
print('Number of features: ', random_forest_model.n_features_in_) | |
else: | |
print('Number of features: ', len(feature_names)) | |
print('Features are: ', feature_names) | |
else: | |
print("Error: Model file does not contain expected dictionary structure") | |
print("Expected keys: 'model' and 'feature_names'") | |
print(f"Found keys: {model_data.keys() if isinstance(model_data, dict) else 'not a dictionary'}") | |
exit(1) | |
except FileNotFoundError: | |
print(f"Error: Could not find model file '{model_filename}'") | |
print("Please run save_model.py first to create the model file.") | |
exit(1) | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
print(f"scikit-learn version: {sklearn.__version__}") | |
exit(1) | |
# Load and prepare BFS data | |
df_bfs_data = pd.read_csv('bfs_municipality_and_tax_data.csv', sep=',', encoding='utf-8') | |
df_bfs_data['tax_income'] = df_bfs_data['tax_income'].str.replace("'", "").astype(float) | |
df_bfs_data['proximity_to_public_transportation'] = 500 # Default value in meters | |
# %% | |
locations = { | |
"Zürich": 261, | |
"Kloten": 62, | |
"Uster": 198, | |
"Illnau-Effretikon": 296, | |
"Feuerthalen": 27, | |
"Pfäffikon": 177, | |
"Ottenbach": 11, | |
"Dübendorf": 191, | |
"Richterswil": 138, | |
"Maur": 195, | |
"Embrach": 56, | |
"Bülach": 53, | |
"Winterthur": 230, | |
"Oetwil am See": 157, | |
"Russikon": 178, | |
"Obfelden": 10, | |
"Wald (ZH)": 120, | |
"Niederweningen": 91, | |
"Dällikon": 84, | |
"Buchs (ZH)": 83, | |
"Rüti (ZH)": 118, | |
"Hittnau": 173, | |
"Bassersdorf": 52, | |
"Glattfelden": 58, | |
"Opfikon": 66, | |
"Hinwil": 117, | |
"Regensberg": 95, | |
"Langnau am Albis": 136, | |
"Dietikon": 243, | |
"Erlenbach (ZH)": 151, | |
"Kappel am Albis": 6, | |
"Stäfa": 158, | |
"Zell (ZH)": 231, | |
"Turbenthal": 228, | |
"Oberglatt": 92, | |
"Winkel": 72, | |
"Volketswil": 199, | |
"Kilchberg (ZH)": 135, | |
"Wetzikon (ZH)": 121, | |
"Zumikon": 160, | |
"Weisslingen": 180, | |
"Elsau": 219, | |
"Hettlingen": 221, | |
"Rüschlikon": 139, | |
"Stallikon": 13, | |
"Dielsdorf": 86, | |
"Wallisellen": 69, | |
"Dietlikon": 54, | |
"Meilen": 156, | |
"Wangen-Brüttisellen": 200, | |
"Flaach": 28, | |
"Regensdorf": 96, | |
"Niederhasli": 90, | |
"Bauma": 297, | |
"Aesch (ZH)": 241, | |
"Schlieren": 247, | |
"Dürnten": 113, | |
"Unterengstringen": 249, | |
"Gossau (ZH)": 115, | |
"Oberengstringen": 245, | |
"Schleinikon": 98, | |
"Aeugst am Albis": 1, | |
"Rheinau": 38, | |
"Höri": 60, | |
"Rickenbach (ZH)": 225, | |
"Rafz": 67, | |
"Adliswil": 131, | |
"Zollikon": 161, | |
"Urdorf": 250, | |
"Hombrechtikon": 153, | |
"Birmensdorf (ZH)": 242, | |
"Fehraltorf": 172, | |
"Weiach": 102, | |
"Männedorf": 155, | |
"Küsnacht (ZH)": 154, | |
"Hausen am Albis": 4, | |
"Hochfelden": 59, | |
"Fällanden": 193, | |
"Greifensee": 194, | |
"Mönchaltorf": 196, | |
"Dägerlen": 214, | |
"Thalheim an der Thur": 39, | |
"Uetikon am See": 159, | |
"Seuzach": 227, | |
"Uitikon": 248, | |
"Affoltern am Albis": 2, | |
"Geroldswil": 244, | |
"Niederglatt": 89, | |
"Thalwil": 141, | |
"Rorbas": 68, | |
"Pfungen": 224, | |
"Weiningen (ZH)": 251, | |
"Bubikon": 112, | |
"Neftenbach": 223, | |
"Mettmenstetten": 9, | |
"Otelfingen": 94, | |
"Flurlingen": 29, | |
"Stadel": 100, | |
"Grüningen": 116, | |
"Henggart": 31, | |
"Dachsen": 25, | |
"Bonstetten": 3, | |
"Bachenbülach": 51, | |
"Horgen": 295 | |
} | |
def predict_apartment(rooms, area, proximity, town): | |
bfs_number = locations[town] | |
df = df_bfs_data[df_bfs_data['bfs_number']==bfs_number].copy() | |
df.reset_index(inplace=True) | |
df.loc[0, 'rooms'] = rooms | |
df.loc[0, 'area'] = area | |
df.loc[0, 'proximity_to_public_transportation'] = proximity # Use user input instead of default | |
if len(df) != 1: | |
return -1 | |
features = ['rooms', 'area', 'pop', 'pop_dens', 'frg_pct', 'emp', 'tax_income', 'proximity_to_public_transportation'] | |
X = df[features].values | |
prediction = random_forest_model.predict(X) | |
return np.round(prediction[0], 0) | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=predict_apartment, | |
inputs=[ | |
gr.Number(label="Number of Rooms"), | |
gr.Number(label="Area"), | |
gr.Slider(minimum=0, maximum=2000, value=500, step=50, | |
label="Distance to Public Transportation (meters)"), | |
gr.Dropdown(choices=locations.keys(), label="Town", type="value") | |
], | |
outputs=[gr.Number(label="Predicted Price (CHF)")], | |
examples=[ | |
[4.5, 120, 500, "Dietlikon"], | |
[3.5, 60, 250, "Winterthur"] | |
], | |
description="Predict apartment prices in Zürich based on rooms, area, proximity to public transportation, and location." | |
) | |
iface.launch() |