openfree commited on
Commit
3e75e0e
ยท
verified ยท
1 Parent(s): 0d45b9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -3
app.py CHANGED
@@ -1,12 +1,41 @@
1
  import os
 
 
2
 
3
- # Set this environment variable to disable torch.compiler features
4
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
5
  os.environ["TRANSFORMERS_COMPILER_DISABLED"] = "1"
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import yaml
8
  import torch
9
- import sys
10
  sys.path.append(os.path.abspath('./'))
11
  from inference.utils import *
12
  from train import WurstCoreB
@@ -18,7 +47,6 @@ import argparse
18
  import gradio as gr
19
  import spaces
20
  from huggingface_hub import hf_hub_url
21
- import subprocess
22
  from huggingface_hub import hf_hub_download
23
  from transformers import pipeline
24
 
 
1
  import os
2
+ import subprocess
3
+ import sys
4
 
5
+ # ํ•„์š”ํ•œ ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
6
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
7
  os.environ["TRANSFORMERS_COMPILER_DISABLED"] = "1"
8
 
9
+ # ํ•„์š”ํ•œ ๋ชจ๋“ˆ ์„ค์น˜ ํ•จ์ˆ˜
10
+ def install_required_packages():
11
+ required_packages = [
12
+ "warmup_scheduler",
13
+ "cosine_annealing_warmup_restarts"
14
+ ]
15
+
16
+ for package in required_packages:
17
+ try:
18
+ __import__(package)
19
+ print(f"{package} is already installed")
20
+ except ImportError:
21
+ print(f"Installing {package}...")
22
+ try:
23
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
24
+ print(f"{package} installed successfully")
25
+ except subprocess.CalledProcessError:
26
+ # ์ผ๋ถ€ ํŒจํ‚ค์ง€๋Š” PyPI์— ์—†์„ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ GitHub์—์„œ ์ง์ ‘ ์„ค์น˜
27
+ if package == "warmup_scheduler":
28
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git"])
29
+ print(f"{package} installed from GitHub successfully")
30
+ else:
31
+ print(f"Failed to install {package}")
32
+
33
+ # ํ•„์š”ํ•œ ๋ชจ๋“ˆ ์„ค์น˜
34
+ install_required_packages()
35
+
36
+ # ๊ทธ ํ›„ ๋‚˜๋จธ์ง€ imports
37
  import yaml
38
  import torch
 
39
  sys.path.append(os.path.abspath('./'))
40
  from inference.utils import *
41
  from train import WurstCoreB
 
47
  import gradio as gr
48
  import spaces
49
  from huggingface_hub import hf_hub_url
 
50
  from huggingface_hub import hf_hub_download
51
  from transformers import pipeline
52