Pavan2k4 commited on
Commit
c11e6fb
·
verified ·
1 Parent(s): 82666ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -34,19 +34,19 @@ if not HF_TOKEN:
34
  REPO_ID = "Pavan2k4/Building_area"
35
  REPO_TYPE = "space"
36
 
37
- # Define subdirectories using relative paths
38
- UPLOAD_DIR = "uploaded_images"
39
- MASK_DIR = "generated_masks"
40
- PATCHES_DIR = "patches"
41
- PRED_PATCHES_DIR = "pred_patches"
42
- CSV_LOG_PATH = "image_log.csv"
43
-
44
- # Create directories
45
- for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]:
46
  os.makedirs(directory, exist_ok=True)
47
 
48
- # Load model
49
- @st.cache_resource
50
  def load_model():
51
  model = reunet_cbam()
52
  model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
@@ -111,7 +111,8 @@ def split(image_path, patch_size=512):
111
  patch_filename = f"patch_{i}_{j}.png"
112
  patch_path = os.path.join(PATCHES_DIR, patch_filename)
113
  patch.save(patch_path)
114
- st.write(f"Saved patch: {patch_path}") # Debug output
 
115
 
116
  def upload_page():
117
  if 'file_uploaded' not in st.session_state:
 
34
  REPO_ID = "Pavan2k4/Building_area"
35
  REPO_TYPE = "space"
36
 
37
+ # Define temporary directories outside of tracked directories
38
+ TMP_DIR = "/tmp/space_files"
39
+ PATCHES_DIR = os.path.join(TMP_DIR, "patches")
40
+ PRED_PATCHES_DIR = os.path.join(TMP_DIR, "pred_patches")
41
+ MASK_DIR = os.path.join(TMP_DIR, "generated_masks")
42
+
43
+ # Create temporary directories
44
+ os.makedirs(TMP_DIR, exist_ok=True)
45
+ for directory in [PATCHES_DIR, PRED_PATCHES_DIR, MASK_DIR]:
46
  os.makedirs(directory, exist_ok=True)
47
 
48
+ # Load model and cache it to avoid unnecessary reloading
49
+ @st.cache_resource(show_spinner=False)
50
  def load_model():
51
  model = reunet_cbam()
52
  model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
 
111
  patch_filename = f"patch_{i}_{j}.png"
112
  patch_path = os.path.join(PATCHES_DIR, patch_filename)
113
  patch.save(patch_path)
114
+ st.write(f"Saved patch: {patch_path}")
115
+
116
 
117
  def upload_page():
118
  if 'file_uploaded' not in st.session_state: