VinitT commited on
Commit
fbe9130
·
verified ·
1 Parent(s): 8563804

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -78
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
3
  from PIL import Image
4
  import torch
5
  import cv2
@@ -9,89 +9,123 @@ import tempfile
9
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
10
  model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
11
 
 
 
 
 
12
  # Check if CUDA is available and set the device accordingly
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
 
15
 
16
  # Streamlit app
17
  st.title("Media Description Generator")
18
 
19
- uploaded_file = st.file_uploader("Choose an image or video...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"])
20
-
21
- if uploaded_file is not None:
22
- file_type = uploaded_file.type.split('/')[0]
23
-
24
- if file_type == 'image':
25
- # Open the image
26
- image = Image.open(uploaded_file)
27
- st.image(image, caption='Uploaded Image.', use_column_width=True)
28
- st.write("Generating description...")
29
-
30
- elif file_type == 'video':
31
- # Save the uploaded video to a temporary file
32
- tfile = tempfile.NamedTemporaryFile(delete=False)
33
- tfile.write(uploaded_file.read())
34
-
35
- # Open the video file
36
- cap = cv2.VideoCapture(tfile.name)
37
-
38
- # Extract the first frame
39
- ret, frame = cap.read()
40
- if not ret:
41
- st.error("Failed to read the video file.")
42
- st.stop()
43
- else:
44
- # Convert the frame to an image
45
- image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
46
- st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
47
- st.write("Generating description...")
48
-
49
- # Release the video capture object
50
- cap.release()
51
-
52
- else:
53
- st.error("Unsupported file type.")
54
- st.stop()
55
 
56
- # Add a text input for the user to ask a question
57
- user_question = st.text_input("Ask a question about the image or video:")
58
 
59
  if user_question:
60
- messages = [
61
- {
62
- "role": "user",
63
- "content": [
64
- {
65
- "type": "image",
66
- "image": image,
67
- },
68
- {"type": "text", "text": user_question},
69
- ],
70
- }
71
- ]
72
-
73
- # Preparation for inference
74
- text = processor.apply_chat_template(
75
- messages, tokenize=False, add_generation_prompt=True
76
- )
77
-
78
- # Pass the image to the processor
79
- inputs = processor(
80
- text=[text],
81
- images=[image],
82
- padding=True,
83
- return_tensors="pt",
84
- )
85
- inputs = inputs.to(device) # Ensure inputs are on the same device as the model
86
-
87
- # Inference: Generation of the output
88
- generated_ids = model.generate(**inputs, max_new_tokens=128)
89
- generated_ids_trimmed = [
90
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
91
- ]
92
- output_text = processor.batch_decode(
93
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
94
- )
95
-
96
- st.write("Description:")
97
- st.write(output_text[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
3
  from PIL import Image
4
  import torch
5
  import cv2
 
9
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
10
  model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
11
 
12
+ # Load Meta-Llama model and tokenizer for story generation
13
+ llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
14
+ llama_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
15
+
16
  # Check if CUDA is available and set the device accordingly
17
+ device = torch.device("cpu")
18
  model.to(device)
19
+ llama_model.to(device)
20
 
21
  # Streamlit app
22
  st.title("Media Description Generator")
23
 
24
+ uploaded_files = st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ if uploaded_files:
27
+ user_question = st.text_input("Ask a question about the images or videos:")
28
 
29
  if user_question:
30
+ all_output_texts = [] # Initialize an empty list to store all output texts
31
+
32
+ for uploaded_file in uploaded_files:
33
+ file_type = uploaded_file.type.split('/')[0]
34
+
35
+ if file_type == 'image':
36
+ # Open the image
37
+ image = Image.open(uploaded_file)
38
+ # Resize image to reduce memory usage
39
+ image = image.resize((512, 512))
40
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
41
+ st.write("Generating description...")
42
+
43
+ elif file_type == 'video':
44
+ # Save the uploaded video to a temporary file
45
+ tfile = tempfile.NamedTemporaryFile(delete=False)
46
+ tfile.write(uploaded_file.read())
47
+
48
+ # Open the video file
49
+ cap = cv2.VideoCapture(tfile.name)
50
+
51
+ # Extract the first frame
52
+ ret, frame = cap.read()
53
+ if not ret:
54
+ st.error("Failed to read the video file.")
55
+ continue
56
+ else:
57
+ # Convert the frame to an image
58
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
59
+ # Resize image to reduce memory usage
60
+ image = image.resize((512, 512))
61
+ st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
62
+ st.write("Generating description...")
63
+
64
+ # Release the video capture object
65
+ cap.release()
66
+
67
+ else:
68
+ st.error("Unsupported file type.")
69
+ continue
70
+
71
+ messages = [
72
+ {
73
+ "role": "user",
74
+ "content": [
75
+ {
76
+ "type": "image",
77
+ "image": image,
78
+ },
79
+ {"type": "text", "text": user_question},
80
+ ],
81
+ }
82
+ ]
83
+
84
+ # Preparation for inference
85
+ text = processor.apply_chat_template(
86
+ messages, tokenize=False, add_generation_prompt=True
87
+ )
88
+
89
+ # Pass the image to the processor
90
+ inputs = processor(
91
+ text=[text],
92
+ images=[image],
93
+ padding=True,
94
+ return_tensors="pt",
95
+ )
96
+ inputs = inputs.to(device) # Ensure inputs are on the same device as the model
97
+
98
+ # Inference: Generation of the output
99
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
100
+ generated_ids_trimmed = [
101
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
102
+ ]
103
+ output_text = processor.batch_decode(
104
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
105
+ )
106
+
107
+ st.write("Description:")
108
+ st.write(output_text[0])
109
+
110
+ # Append the output text to the list
111
+ all_output_texts.append(output_text[0])
112
+
113
+ # Clear memory after processing each file
114
+ del image, inputs, generated_ids, generated_ids_trimmed, output_text
115
+ torch.cuda.empty_cache()
116
+ torch.manual_seed(0) # Reset the seed to ensure reproducibility
117
+
118
+ # Combine all descriptions into a single text
119
+ combined_text = " ".join(all_output_texts)
120
+
121
+ # Create a custom prompt
122
+ custom_prompt = f"Based on the following descriptions, create a short story:\n\n{combined_text}\n\nStory:"
123
+
124
+ # Generate a story using Meta-Llama
125
+ inputs = llama_tokenizer.encode(custom_prompt, return_tensors="pt").to(device)
126
+ story_ids = llama_model.generate(inputs, max_length=500, num_return_sequences=1)
127
+ story = llama_tokenizer.decode(story_ids[0], skip_special_tokens=True)
128
+
129
+ # Display the generated story
130
+ st.write("Generated Story:")
131
+ st.write(story)