MichaelMM2000 commited on
Commit
6f34772
·
1 Parent(s): f54c316

inital commit

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ disposition.md
2
+ week1/data_with_large_residuals.csv
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import gradio as gr
3
+ from sklearn.ensemble import RandomForestRegressor
4
+ import numpy as np
5
+ import pandas as pd
6
+ import pickle
7
+ # Define model filename
8
+ model_filename = "random_forest_regression_extended.pkl"
9
+
10
+ try:
11
+ # Try to load the model
12
+ with open(model_filename, 'rb') as f:
13
+ model_data = pickle.load(f)
14
+ if isinstance(model_data, dict) and 'model' in model_data and 'feature_names' in model_data:
15
+ random_forest_model = model_data['model']
16
+ feature_names = model_data['feature_names']
17
+ # Check scikit-learn version and handle feature information
18
+ if hasattr(random_forest_model, 'n_features_in_'):
19
+ print('Number of features: ', random_forest_model.n_features_in_)
20
+ else:
21
+ print('Number of features: ', len(feature_names))
22
+ print('Features are: ', feature_names)
23
+ else:
24
+ print("Error: Model file does not contain expected dictionary structure")
25
+ print("Expected keys: 'model' and 'feature_names'")
26
+ print(f"Found keys: {model_data.keys() if isinstance(model_data, dict) else 'not a dictionary'}")
27
+ exit(1)
28
+ except FileNotFoundError:
29
+ print(f"Error: Could not find model file '{model_filename}'")
30
+ print("Please run save_model.py first to create the model file.")
31
+ exit(1)
32
+ except Exception as e:
33
+ print(f"Error loading model: {str(e)}")
34
+ print(f"scikit-learn version: {sklearn.__version__}")
35
+ exit(1)
36
+
37
+ # Load and prepare BFS data
38
+ df_bfs_data = pd.read_csv('bfs_municipality_and_tax_data.csv', sep=',', encoding='utf-8')
39
+ df_bfs_data['tax_income'] = df_bfs_data['tax_income'].str.replace("'", "").astype(float)
40
+ df_bfs_data['proximity_to_public_transportation'] = 500 # Default value in meters
41
+
42
+
43
+ # %%
44
+ locations = {
45
+ "Zürich": 261,
46
+ "Kloten": 62,
47
+ "Uster": 198,
48
+ "Illnau-Effretikon": 296,
49
+ "Feuerthalen": 27,
50
+ "Pfäffikon": 177,
51
+ "Ottenbach": 11,
52
+ "Dübendorf": 191,
53
+ "Richterswil": 138,
54
+ "Maur": 195,
55
+ "Embrach": 56,
56
+ "Bülach": 53,
57
+ "Winterthur": 230,
58
+ "Oetwil am See": 157,
59
+ "Russikon": 178,
60
+ "Obfelden": 10,
61
+ "Wald (ZH)": 120,
62
+ "Niederweningen": 91,
63
+ "Dällikon": 84,
64
+ "Buchs (ZH)": 83,
65
+ "Rüti (ZH)": 118,
66
+ "Hittnau": 173,
67
+ "Bassersdorf": 52,
68
+ "Glattfelden": 58,
69
+ "Opfikon": 66,
70
+ "Hinwil": 117,
71
+ "Regensberg": 95,
72
+ "Langnau am Albis": 136,
73
+ "Dietikon": 243,
74
+ "Erlenbach (ZH)": 151,
75
+ "Kappel am Albis": 6,
76
+ "Stäfa": 158,
77
+ "Zell (ZH)": 231,
78
+ "Turbenthal": 228,
79
+ "Oberglatt": 92,
80
+ "Winkel": 72,
81
+ "Volketswil": 199,
82
+ "Kilchberg (ZH)": 135,
83
+ "Wetzikon (ZH)": 121,
84
+ "Zumikon": 160,
85
+ "Weisslingen": 180,
86
+ "Elsau": 219,
87
+ "Hettlingen": 221,
88
+ "Rüschlikon": 139,
89
+ "Stallikon": 13,
90
+ "Dielsdorf": 86,
91
+ "Wallisellen": 69,
92
+ "Dietlikon": 54,
93
+ "Meilen": 156,
94
+ "Wangen-Brüttisellen": 200,
95
+ "Flaach": 28,
96
+ "Regensdorf": 96,
97
+ "Niederhasli": 90,
98
+ "Bauma": 297,
99
+ "Aesch (ZH)": 241,
100
+ "Schlieren": 247,
101
+ "Dürnten": 113,
102
+ "Unterengstringen": 249,
103
+ "Gossau (ZH)": 115,
104
+ "Oberengstringen": 245,
105
+ "Schleinikon": 98,
106
+ "Aeugst am Albis": 1,
107
+ "Rheinau": 38,
108
+ "Höri": 60,
109
+ "Rickenbach (ZH)": 225,
110
+ "Rafz": 67,
111
+ "Adliswil": 131,
112
+ "Zollikon": 161,
113
+ "Urdorf": 250,
114
+ "Hombrechtikon": 153,
115
+ "Birmensdorf (ZH)": 242,
116
+ "Fehraltorf": 172,
117
+ "Weiach": 102,
118
+ "Männedorf": 155,
119
+ "Küsnacht (ZH)": 154,
120
+ "Hausen am Albis": 4,
121
+ "Hochfelden": 59,
122
+ "Fällanden": 193,
123
+ "Greifensee": 194,
124
+ "Mönchaltorf": 196,
125
+ "Dägerlen": 214,
126
+ "Thalheim an der Thur": 39,
127
+ "Uetikon am See": 159,
128
+ "Seuzach": 227,
129
+ "Uitikon": 248,
130
+ "Affoltern am Albis": 2,
131
+ "Geroldswil": 244,
132
+ "Niederglatt": 89,
133
+ "Thalwil": 141,
134
+ "Rorbas": 68,
135
+ "Pfungen": 224,
136
+ "Weiningen (ZH)": 251,
137
+ "Bubikon": 112,
138
+ "Neftenbach": 223,
139
+ "Mettmenstetten": 9,
140
+ "Otelfingen": 94,
141
+ "Flurlingen": 29,
142
+ "Stadel": 100,
143
+ "Grüningen": 116,
144
+ "Henggart": 31,
145
+ "Dachsen": 25,
146
+ "Bonstetten": 3,
147
+ "Bachenbülach": 51,
148
+ "Horgen": 295
149
+ }
150
+ def predict_apartment(rooms, area, town):
151
+ bfs_number = locations[town]
152
+ df = df_bfs_data[df_bfs_data['bfs_number']==bfs_number].copy()
153
+ df.reset_index(inplace=True)
154
+ df.loc[0, 'rooms'] = rooms
155
+ df.loc[0, 'area'] = area
156
+ df.loc[0, 'proximity_to_public_transportation'] = 500 # Default value
157
+
158
+ if len(df) != 1:
159
+ return -1
160
+
161
+ features = ['rooms', 'area', 'pop', 'pop_dens', 'frg_pct', 'emp', 'tax_income', 'proximity_to_public_transportation']
162
+ X = df[features].values # Convert to numpy array without feature names
163
+ prediction = random_forest_model.predict(X)
164
+ return np.round(prediction[0], 0)
165
+
166
+ # Create the Gradio interface
167
+ iface = gr.Interface(
168
+ fn=predict_apartment,
169
+ inputs=["number", "number", gr.Dropdown(choices=locations.keys(), label="Town", type="value")],
170
+ outputs=[gr.Number()],
171
+ examples=[[4.5, 120, "Dietlikon"], [3.5, 60, "Winterthur"]]
172
+ )
173
+
174
+ iface.launch()
bfs_municipality_and_tax_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
random_forest_regression_extended.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b568b08ae4f77ccc8ca6410a8a29b1c36b0b59e9ffed320d810a6d48322613a7
3
+ size 1432809
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ scikit-learn == 1.6.1
2
+ matplotlib == 3.10.1
3
+ pandas == 2.2.3
4
+ numpy == 2.2.3
5
+ geopy