Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -288,13 +288,19 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
288 |
torch.cuda.empty_cache()
|
289 |
|
290 |
# Load model
|
291 |
-
|
292 |
-
|
293 |
-
model_path = os.path.join(MODEL_DIR, MODEL_NAME)
|
294 |
print("Loading model...")
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
|
|
298 |
# Gradio UI
|
299 |
with gr.Blocks(theme=gr.themes.Soft(
|
300 |
primary_hue="violet",
|
|
|
288 |
torch.cuda.empty_cache()
|
289 |
|
290 |
# Load model
|
291 |
+
MODEL_NAME = "model_weights.pth" # Updated to look in root folder
|
292 |
+
model_path = MODEL_NAME
|
|
|
293 |
print("Loading model...")
|
294 |
+
try:
|
295 |
+
loaded_model = load_model(model_path, device)
|
296 |
+
print("Model loaded successfully!")
|
297 |
+
except Exception as e:
|
298 |
+
print(f"Failed to load model: {e}")
|
299 |
+
# Create a dummy model for demo purposes
|
300 |
+
print("Creating dummy model for demonstration")
|
301 |
+
loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
|
302 |
|
303 |
+
|
304 |
# Gradio UI
|
305 |
with gr.Blocks(theme=gr.themes.Soft(
|
306 |
primary_hue="violet",
|