jamino30 commited on
Commit
6421088
·
verified ·
1 Parent(s): 28ac920

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
5
 
6
  import spaces
7
  import torch
 
8
  import torchvision.models as models
9
  import numpy as np
10
  import gradio as gr
@@ -53,12 +54,19 @@ for style_name, style_img_path in style_options.items():
53
  cached_style_features[style_name] = style_features
54
 
55
  @spaces.GPU(duration=30)
56
- def run(content_image, style_name, style_strength=10):
57
  yield [None] * 3
58
  content_img, original_size = preprocess_img(content_image, img_size)
59
  content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
60
  content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
61
 
 
 
 
 
 
 
 
62
  print('-'*15)
63
  print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est
64
  print('STYLE:', style_name)
@@ -85,7 +93,8 @@ def run(content_image, style_name, style_strength=10):
85
  content_image_norm=content_img_normalized,
86
  style_features=style_features,
87
  lr=lrs[style_strength-1],
88
- apply_to_background=apply_to_background
 
89
  )
90
 
91
  with ThreadPoolExecutor() as executor:
@@ -124,6 +133,8 @@ with gr.Blocks(css=css) as demo:
124
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
125
  with gr.Group():
126
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=10, info='Higher values add artistic flair, lower values add a realistic feel.')
 
 
127
  submit_button = gr.Button('Submit', variant='primary')
128
 
129
  examples = gr.Examples(
 
5
 
6
  import spaces
7
  import torch
8
+ import torch.optim as optim
9
  import torchvision.models as models
10
  import numpy as np
11
  import gradio as gr
 
54
  cached_style_features[style_name] = style_features
55
 
56
  @spaces.GPU(duration=30)
57
+ def run(content_image, style_name, style_strength=10, optim_name='AdamW'):
58
  yield [None] * 3
59
  content_img, original_size = preprocess_img(content_image, img_size)
60
  content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
61
  content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
62
 
63
+ if optim_name == 'Adam':
64
+ optim_caller = torch.optim.Adam
65
+ elif optim_name == 'AdamW':
66
+ optim_caller = torch.optim.AdamW
67
+ else:
68
+ optim_caller = torch.optim.LBFGS
69
+
70
  print('-'*15)
71
  print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est
72
  print('STYLE:', style_name)
 
93
  content_image_norm=content_img_normalized,
94
  style_features=style_features,
95
  lr=lrs[style_strength-1],
96
+ apply_to_background=apply_to_background,
97
+ optim_caller=optim_caller
98
  )
99
 
100
  with ThreadPoolExecutor() as executor:
 
133
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
134
  with gr.Group():
135
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=10, info='Higher values add artistic flair, lower values add a realistic feel.')
136
+ with gr.Accordion(label='Advanced Options', open=False):
137
+ optim_dropdown = gr.Radio(choices=['Adam', 'AdamW', 'L-BFGS'], label='Optimizer', value='AdamW', type='value')
138
  submit_button = gr.Button('Submit', variant='primary')
139
 
140
  examples = gr.Examples(