Dhan98 commited on
Commit
07e5984
·
verified ·
1 Parent(s): 99b4d49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -28
app.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
  from PIL import Image
8
  import tempfile
9
  import os
 
10
 
11
  @st.cache_resource
12
  def load_models():
@@ -22,44 +23,32 @@ def load_models():
22
 
23
  return pipeline, blip, blip_processor
24
 
25
- def enhance_image(image):
26
- img = np.array(image)
27
- denoised = cv2.fastNlMeansDenoisingColored(img)
28
- lab = cv2.cvtColor(denoised, cv2.COLOR_RGB2LAB)
29
- l, a, b = cv2.split(lab)
30
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
31
- l = clahe.apply(l)
32
- enhanced = cv2.cvtColor(cv2.merge([l,a,b]), cv2.COLOR_LAB2RGB)
33
- return Image.fromarray(enhanced)
34
-
35
- def get_description(image, blip_model, blip_processor):
36
- inputs = blip_processor(images=image, return_tensors="pt")
37
- output = blip_model.generate(**inputs, max_length=50)
38
- return blip_processor.decode(output[0], skip_special_tokens=True)
39
-
40
- def save_video_frames(frames, fps=8):
41
  temp_dir = tempfile.mkdtemp()
42
  temp_path = os.path.join(temp_dir, "output.mp4")
43
 
44
- height, width = frames[0].shape[:2]
45
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
46
- video_writer = cv2.VideoWriter(temp_path, fourcc, fps, (width, height))
47
 
48
- for frame in frames:
49
  video_writer.write(frame)
50
  video_writer.release()
51
 
52
  return temp_path
53
 
54
- def generate_video(pipeline, description):
55
- video_frames = pipeline(
56
- description,
57
- num_inference_steps=50,
58
- num_frames=24
59
- ).frames
60
-
61
- video_path = save_video_frames(video_frames)
62
- return video_path
63
 
64
  def main():
65
  st.title("Video Generator")
@@ -80,6 +69,7 @@ def main():
80
  with st.spinner("Generating video..."):
81
  video_path = generate_video(pipeline, description)
82
  st.video(video_path)
 
83
 
84
  if __name__ == "__main__":
85
  main()
 
7
  from PIL import Image
8
  import tempfile
9
  import os
10
+ import base64
11
 
12
  @st.cache_resource
13
  def load_models():
 
23
 
24
  return pipeline, blip, blip_processor
25
 
26
+ def generate_video(pipeline, description):
27
+ video_frames = pipeline(
28
+ description,
29
+ num_inference_steps=30, # Reduced from 50
30
+ num_frames=16 # Reduced from 24
31
+ ).frames
32
+
 
 
 
 
 
 
 
 
 
33
  temp_dir = tempfile.mkdtemp()
34
  temp_path = os.path.join(temp_dir, "output.mp4")
35
 
36
+ height, width = video_frames[0].shape[:2]
37
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
38
+ video_writer = cv2.VideoWriter(temp_path, fourcc, 8, (width, height))
39
 
40
+ for frame in video_frames:
41
  video_writer.write(frame)
42
  video_writer.release()
43
 
44
  return temp_path
45
 
46
+ def get_binary_file_downloader_html(bin_file, file_label='File'):
47
+ with open(bin_file, 'rb') as f:
48
+ data = f.read()
49
+ bin_str = base64.b64encode(data).decode()
50
+ href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{file_label}">Download {file_label}</a>'
51
+ return href
 
 
 
52
 
53
  def main():
54
  st.title("Video Generator")
 
69
  with st.spinner("Generating video..."):
70
  video_path = generate_video(pipeline, description)
71
  st.video(video_path)
72
+ st.markdown(get_binary_file_downloader_html(video_path, 'video.mp4'), unsafe_allow_html=True)
73
 
74
  if __name__ == "__main__":
75
  main()