yongyeol commited on
Commit
07cf72c
Β·
verified Β·
1 Parent(s): 0836597

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -64,10 +64,8 @@ except ModuleNotFoundError:
64
  # ─────────────────────────────────────────────────────────────
65
  caption_model = VisionEncoderDecoderModel.from_pretrained(
66
  "nlpconnect/vit-gpt2-image-captioning",
67
- use_safetensors=True,
68
- low_cpu_mem_usage=True # κ·ΈλŒ€λ‘œ 두어도 OK
69
- )
70
- caption_model.to("cpu") # β˜… μΆ”κ°€
71
 
72
  feature_extractor = ViTImageProcessor.from_pretrained(
73
  "nlpconnect/vit-gpt2-image-captioning"
@@ -76,6 +74,7 @@ tokenizer = AutoTokenizer.from_pretrained(
76
  "nlpconnect/vit-gpt2-image-captioning"
77
  )
78
 
 
79
  # ─────────────────────────────────────────────────────────────
80
  # 5. MusicGen λͺ¨λΈ
81
  # ─────────────────────────────────────────────────────────────
@@ -86,11 +85,15 @@ musicgen.set_generation_params(duration=10)
86
  # 6. νŒŒμ΄ν”„λΌμΈ ν•¨μˆ˜
87
  # ─────────────────────────────────────────────────────────────
88
  def generate_caption(image: Image.Image) -> str:
89
- pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
90
- caption_model.to(pixel_values.device) # β˜… μ•ˆμ „ 이동
91
- with torch.no_grad():
92
- ids = caption_model.generate(pixel_values, max_length=50)
93
- return tokenizer.decode(ids[0], skip_special_tokens=True)
 
 
 
 
94
 
95
 
96
  def generate_music(prompt: str) -> str:
 
64
  # ─────────────────────────────────────────────────────────────
65
  caption_model = VisionEncoderDecoderModel.from_pretrained(
66
  "nlpconnect/vit-gpt2-image-captioning",
67
+ use_safetensors=True # OK
68
+ ).eval() # 평가 λͺ¨λ“œλ‘œ
 
 
69
 
70
  feature_extractor = ViTImageProcessor.from_pretrained(
71
  "nlpconnect/vit-gpt2-image-captioning"
 
74
  "nlpconnect/vit-gpt2-image-captioning"
75
  )
76
 
77
+
78
  # ─────────────────────────────────────────────────────────────
79
  # 5. MusicGen λͺ¨λΈ
80
  # ─────────────────────────────────────────────────────────────
 
85
  # 6. νŒŒμ΄ν”„λΌμΈ ν•¨μˆ˜
86
  # ─────────────────────────────────────────────────────────────
87
  def generate_caption(image: Image.Image) -> str:
88
+ with torch.no_grad(): # β˜… λ©”λͺ¨λ¦¬ μ ˆμ•½
89
+ pixel_values = feature_extractor(
90
+ images=image, return_tensors="pt"
91
+ ).pixel_values
92
+ output_ids = caption_model.generate(
93
+ pixel_values.to(caption_model.device), # CPU λ””λ°”μ΄μŠ€ 톡일
94
+ max_length=50
95
+ )
96
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
97
 
98
 
99
  def generate_music(prompt: str) -> str: