samaonline commited on
Commit
56dc12e
·
1 Parent(s): a2e060c

add gpu annotation

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -4,12 +4,14 @@ import sys
4
 
5
  import gradio as gr
6
  import numpy as np
 
7
 
8
- # import spaces
9
  # from huggingface_hub import hf_hub_download
10
  from huggingface_hub import snapshot_download
11
  from PIL import Image, ImageDraw, ImageFont
12
 
 
 
13
  # Set the working directory to the root directory
14
  # root_dir = os.path.abspath("..")
15
  # os.chdir(root_dir)
@@ -137,6 +139,7 @@ def viz(x, cmap="gray", vmin=0, vmax=1):
137
  return img_array
138
 
139
 
 
140
  def forward(model, idx, rate):
141
  if rate == 4:
142
  dataset = val_dataset_4x
 
4
 
5
  import gradio as gr
6
  import numpy as np
7
+ import spaces
8
 
 
9
  # from huggingface_hub import hf_hub_download
10
  from huggingface_hub import snapshot_download
11
  from PIL import Image, ImageDraw, ImageFont
12
 
13
+ zero = torch.Tensor([0]).cuda()
14
+
15
  # Set the working directory to the root directory
16
  # root_dir = os.path.abspath("..")
17
  # os.chdir(root_dir)
 
139
  return img_array
140
 
141
 
142
+ @spaces.GPU
143
  def forward(model, idx, rate):
144
  if rate == 4:
145
  dataset = val_dataset_4x