BillyZ1129 commited on
Commit
7d6fb10
·
verified ·
1 Parent(s): 901fa0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -4
app.py CHANGED
@@ -7,6 +7,7 @@ from gtts import gTTS
7
  import tempfile
8
  import os
9
  import base64
 
10
 
11
  # Set page config
12
  st.set_page_config(
@@ -39,10 +40,35 @@ def load_story_generator():
39
 
40
  # Function to generate caption from image
41
  def generate_caption(image, processor, model):
42
- inputs = processor(image, return_tensors="pt")
43
- out = model.generate(**inputs, max_length=30)
44
- caption = processor.decode(out[0], skip_special_tokens=True)
45
- return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Function to generate story from caption
48
  def generate_story(caption, generator):
@@ -142,6 +168,8 @@ def main():
142
  # Generate caption and story
143
  with st.spinner("Looking at your picture and thinking of a story..."):
144
  caption = generate_caption(original_image, processor, caption_model)
 
 
145
  story = generate_story(caption, story_generator)
146
 
147
  # Display the story and audio
 
7
  import tempfile
8
  import os
9
  import base64
10
+ import numpy as np
11
 
12
  # Set page config
13
  st.set_page_config(
 
40
 
41
  # Function to generate caption from image
42
  def generate_caption(image, processor, model):
43
+ # 确保图像是RGB格式
44
+ if image.mode != 'RGB':
45
+ image = image.convert('RGB')
46
+
47
+ # 标准预处理:调整大小到BLIP模型期望的输入尺寸
48
+ image = image.resize((384, 384))
49
+
50
+ try:
51
+ # 使用处理器准备图像
52
+ inputs = processor(image, return_tensors="pt", padding=True)
53
+
54
+ # 生成caption
55
+ out = model.generate(**inputs, max_length=30)
56
+ caption = processor.decode(out[0], skip_special_tokens=True)
57
+ return caption
58
+ except Exception as e:
59
+ # 如果有错误,使用一个备用方法
60
+ st.warning(f"Caption generation error: {str(e)}. Using fallback method.")
61
+
62
+ # 转换图像为numpy数组
63
+ img_array = np.array(image)
64
+
65
+ # 手动准备图像为模型输入
66
+ pixel_values = processor.image_processor(images=img_array, return_tensors="pt").pixel_values
67
+
68
+ # 生成caption
69
+ generated_ids = model.generate(pixel_values=pixel_values, max_length=30)
70
+ caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
71
+ return caption
72
 
73
  # Function to generate story from caption
74
  def generate_story(caption, generator):
 
168
  # Generate caption and story
169
  with st.spinner("Looking at your picture and thinking of a story..."):
170
  caption = generate_caption(original_image, processor, caption_model)
171
+ # 打印图片描述
172
+ st.info(f"Image caption: {caption}")
173
  story = generate_story(caption, story_generator)
174
 
175
  # Display the story and audio