comrender commited on
Commit
c84d7da
Β·
verified Β·
1 Parent(s): 46b008e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -33,13 +33,9 @@ css = """
33
  }
34
  """
35
 
36
- # Device setup
37
- if torch.cuda.is_available():
38
- power_device = "GPU"
39
- device = "cuda"
40
- else:
41
- power_device = "CPU"
42
- device = "cpu"
43
 
44
  # Get HuggingFace token
45
  huggingface_token = os.getenv("HF_TOKEN")
@@ -54,26 +50,25 @@ model_path = snapshot_download(
54
  token=huggingface_token,
55
  )
56
 
57
- # Load Florence-2 model for image captioning
58
  print("πŸ“₯ Loading Florence-2 model...")
59
  florence_model = AutoModelForCausalLM.from_pretrained(
60
  "microsoft/Florence-2-large",
61
- torch_dtype=torch.float16,
62
  trust_remote_code=True,
63
- attn_implementation="eager" # Fix for SDPA compatibility issue
64
  ).to(device)
65
  florence_processor = AutoProcessor.from_pretrained(
66
  "microsoft/Florence-2-large",
67
  trust_remote_code=True
68
  )
69
 
70
- # Load FLUX Img2Img pipeline
71
  print("πŸ“₯ Loading FLUX Img2Img...")
72
  pipe = FluxImg2ImgPipeline.from_pretrained(
73
  model_path,
74
- torch_dtype=torch.bfloat16
75
  )
76
- pipe.to(device)
77
  pipe.enable_vae_tiling()
78
  pipe.enable_vae_slicing()
79
 
@@ -90,7 +85,6 @@ if USE_ESRGAN:
90
  state_dict = torch.load(esrgan_path)['params_ema']
91
  esrgan_model.load_state_dict(state_dict)
92
  esrgan_model.eval()
93
- esrgan_model.to(device)
94
 
95
  MAX_SEED = 1000000
96
  MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
@@ -227,6 +221,10 @@ def enhance_image(
227
  progress=gr.Progress(track_tqdm=True),
228
  ):
229
  """Main enhancement function"""
 
 
 
 
230
  # Handle image input
231
  if image_input is not None:
232
  input_image = image_input
@@ -253,13 +251,15 @@ def enhance_image(
253
  else:
254
  prompt = custom_prompt if custom_prompt.strip() else ""
255
 
256
- generator = torch.Generator().manual_seed(seed)
257
 
258
  gr.Info("πŸš€ Upscaling image...")
259
 
260
  # Initial upscale
261
  if USE_ESRGAN and upscale_factor == 4:
 
262
  control_image = esrgan_upscale(input_image, upscale_factor)
 
263
  else:
264
  w, h = input_image.size
265
  control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
@@ -284,6 +284,10 @@ def enhance_image(
284
  # Resize input image to match output size for slider alignment
285
  resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
286
 
 
 
 
 
287
  return [resized_input, image]
288
 
289
 
 
33
  }
34
  """
35
 
36
+ # Device setup - Force CPU for startup in ZeroGPU
37
+ power_device = "ZeroGPU"
38
+ device = "cpu"
 
 
 
 
39
 
40
  # Get HuggingFace token
41
  huggingface_token = os.getenv("HF_TOKEN")
 
50
  token=huggingface_token,
51
  )
52
 
53
+ # Load Florence-2 model for image captioning on CPU
54
  print("πŸ“₯ Loading Florence-2 model...")
55
  florence_model = AutoModelForCausalLM.from_pretrained(
56
  "microsoft/Florence-2-large",
57
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
58
  trust_remote_code=True,
59
+ attn_implementation="eager"
60
  ).to(device)
61
  florence_processor = AutoProcessor.from_pretrained(
62
  "microsoft/Florence-2-large",
63
  trust_remote_code=True
64
  )
65
 
66
+ # Load FLUX Img2Img pipeline on CPU
67
  print("πŸ“₯ Loading FLUX Img2Img...")
68
  pipe = FluxImg2ImgPipeline.from_pretrained(
69
  model_path,
70
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
71
  )
 
72
  pipe.enable_vae_tiling()
73
  pipe.enable_vae_slicing()
74
 
 
85
  state_dict = torch.load(esrgan_path)['params_ema']
86
  esrgan_model.load_state_dict(state_dict)
87
  esrgan_model.eval()
 
88
 
89
  MAX_SEED = 1000000
90
  MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
 
221
  progress=gr.Progress(track_tqdm=True),
222
  ):
223
  """Main enhancement function"""
224
+ # Move models to GPU inside the function
225
+ pipe.to("cuda")
226
+ florence_model.to("cuda")
227
+
228
  # Handle image input
229
  if image_input is not None:
230
  input_image = image_input
 
251
  else:
252
  prompt = custom_prompt if custom_prompt.strip() else ""
253
 
254
+ generator = torch.Generator(device="cuda").manual_seed(seed)
255
 
256
  gr.Info("πŸš€ Upscaling image...")
257
 
258
  # Initial upscale
259
  if USE_ESRGAN and upscale_factor == 4:
260
+ esrgan_model.to("cuda")
261
  control_image = esrgan_upscale(input_image, upscale_factor)
262
+ esrgan_model.to("cpu")
263
  else:
264
  w, h = input_image.size
265
  control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
 
284
  # Resize input image to match output size for slider alignment
285
  resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
286
 
287
+ # Move back to CPU to release GPU
288
+ pipe.to("cpu")
289
+ florence_model.to("cpu")
290
+
291
  return [resized_input, image]
292
 
293