yongyeol commited on
Commit
d7b41a8
Β·
verified Β·
1 Parent(s): faca888

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -88,17 +88,13 @@ musicgen.set_generation_params(duration=10)
88
  # 6. νŒŒμ΄ν”„λΌμΈ ν•¨μˆ˜
89
  # ─────────────────────────────────────────────────────────────
90
  def generate_caption(image: Image.Image) -> str:
91
- with torch.no_grad(): # β˜… λ©”λͺ¨λ¦¬ μ ˆμ•½
92
- pixel_values = feature_extractor(
93
- images=image, return_tensors="pt"
94
- ).pixel_values
95
- output_ids = caption_model.generate(
96
- pixel_values.to(caption_model.device), # CPU λ””λ°”μ΄μŠ€ 톡일
97
- max_length=50
98
- )
99
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
100
 
101
 
 
102
  def generate_music(prompt: str) -> str:
103
  wav = musicgen.generate([prompt]) # batch size = 1
104
  tmpdir = tempfile.mkdtemp()
 
88
  # 6. νŒŒμ΄ν”„λΌμΈ ν•¨μˆ˜
89
  # ─────────────────────────────────────────────────────────────
90
  def generate_caption(image: Image.Image) -> str:
91
+ with torch.no_grad():
92
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
93
+ output_ids = caption_model.generate(pixel_values, max_length=50)
 
 
 
 
 
94
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
95
 
96
 
97
+
98
  def generate_music(prompt: str) -> str:
99
  wav = musicgen.generate([prompt]) # batch size = 1
100
  tmpdir = tempfile.mkdtemp()