cptsubtext commited on
Commit
b8a909c
·
1 Parent(s): 1640c7d

update files with app

Browse files
Files changed (2) hide show
  1. requirements.txt +4 -3
  2. src/streamlit_app.py +141 -38
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- altair
2
- pandas
3
- streamlit
 
 
1
+ streamlit
2
+ Pillow
3
+ transformers
4
+ torch
src/streamlit_app.py CHANGED
@@ -1,40 +1,143 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ import json
4
+ from PIL import Image
5
+ from transformers import LlavaForConditionalGeneration, AutoProcessor
6
+ import torch
7
+ import base64
8
+ from io import BytesIO
9
 
10
+ # Configuration (similar to aim.py, but adapted for Streamlit)
11
+ DEFAULT_KEYWORD_COUNT = 5
12
+ DEFAULT_MODEL = "llava-hf/llava-1.5-7b-hf" # A common Llava model on Hugging Face
13
+ DEFAULT_TONE = "witty,curious"
14
+ DEFAULT_TEMP = 0.5
15
+
16
+ # Function to convert PIL Image to base64 for display
17
+ def convert_to_base64(pil_image):
18
+ buffered = BytesIO()
19
+ pil_image.save(buffered, format="JPEG")
20
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
21
+ return img_str
22
+
23
+ # Function to extract keywords (from aim.py)
24
+ def extract_keywords(keywords_string):
25
+ if keywords_string.startswith("Keywords: "):
26
+ keywords = keywords_string.replace("Keywords: ", "").strip().split(",")
27
+ return [keyword.strip() for keyword in keywords]
28
+ else:
29
+ return []
30
+
31
+ # Function to generate metadata using Transformers Llava model
32
+ @st.cache_resource
33
+ def load_llava_model(model_name):
34
+ processor = AutoProcessor.from_pretrained(model_name)
35
+ model = LlavaForConditionalGeneration.from_pretrained(model_name)
36
+ return processor, model
37
+
38
+ def generate_metadata(image, prompt_template, model_name, temperature):
39
+ processor, model = load_llava_model(model_name)
40
+
41
+ # Prepare image and prompt for the model
42
+ # Llava models typically take a conversation-like prompt
43
+ # Example prompt format for Llava: "USER: <image>\nWhat is this?\nASSISTANT:"
44
+
45
+ # We'll need to adapt the prompt to fit the Llava model's expected input
46
+ # For now, let's keep it simple and pass the image and the direct prompt.
47
+ # A more robust solution might involve a chat template from the processor.
48
+
49
+ inputs = processor(text=prompt_template, images=image, return_tensors="pt")
50
+
51
+ # Generate response
52
+ with torch.no_grad():
53
+ output = model.generate(**inputs, max_new_tokens=100, temperature=temperature, do_sample=True, top_p=0.9)
54
+
55
+ generated_text = processor.decode(output[0], skip_special_tokens=True)
56
+
57
+ # The generated_text will contain the prompt itself and then the model's response.
58
+ # We need to extract only the part that is the model's answer.
59
+ # This might require some string manipulation depending on the exact output format of the model.
60
+ # For now, let's assume the model's response starts after the prompt.
61
+
62
+ # Find the end of the prompt in the generated text
63
+ # This is a simplification and might need adjustment based on actual model output
64
+ if prompt_template in generated_text:
65
+ model_response = generated_text.split(prompt_template)[-1].strip()
66
+ else:
67
+ model_response = generated_text # Fallback if prompt is not found as a prefix
68
+
69
+ return model_response
70
+
71
+ # Streamlit App
72
+ st.set_page_config(layout="wide", page_title="Image Metadata Generator")
73
+
74
+ st.title("📸 AI-Powered Image Metadata Generator")
75
+ st.markdown("Upload an image and let the AI generate a catchy title, description, and keywords!")
76
+
77
+ # Sidebar for configuration
78
+ st.sidebar.header("Configuration")
79
+ selected_model = st.sidebar.selectbox(
80
+ "Choose a Llava Model",
81
+ ["llava-hf/llava-1.5-7b-hf", "llava-hf/baklava-hf"], # Add more Llava models as needed
82
+ index=0
83
+ )
84
+ temperature = st.sidebar.slider("Creativity (Temperature)", 0.0, 1.0, DEFAULT_TEMP, 0.05)
85
+ keyword_count = st.sidebar.number_input("Number of Keywords", 1, 10, DEFAULT_KEYWORD_COUNT)
86
+ tone_input = st.sidebar.text_input("Tone (e.g., witty, curious)", DEFAULT_TONE)
87
+ tone = [t.strip() for t in tone_input.split(',')]
88
+
89
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
90
+
91
+ if uploaded_file is not None:
92
+ image = Image.open(uploaded_file).convert("RGB")
93
+
94
+ st.subheader("Uploaded Image")
95
+ st.image(image, caption="Uploaded Image", use_column_width=True)
96
+
97
+ if st.button("Generate Metadata"):
98
+ with st.spinner("Generating metadata... This might take a moment."):
99
+ prompt_template = f"""
100
+ As a photojournalist analyze the following image and provide it in a {tone[0]} and {tone[1] if len(tone) > 1 else tone[0]} tone:
101
+ - Image Headline: A short, impactful title
102
+ - Image Description: A brief, informative summary
103
+ - {keyword_count} Image Keywords, separated by commas
104
+ - Return the Image Headline, Image Description, and Image Keywords in the following format: Headline: ..., Description: ..., Keywords: ...".
105
+ """
106
+
107
+ # Generate metadata using the selected Llava model
108
+ ollama_response = generate_metadata(image, prompt_template, selected_model, temperature)
109
+
110
+ if ollama_response:
111
+ st.subheader("Generated Metadata")
112
+
113
+ # Parse the response similar to aim.py
114
+ lines = ollama_response.split('\n')
115
+
116
+ headline = ""
117
+ description = ""
118
+ keywords = []
119
+
120
+ for line in lines:
121
+ if line.startswith("Headline:"):
122
+ headline = line.replace("Headline:", "").strip()
123
+ elif line.startswith("Description:"):
124
+ description = line.replace("Description:", "").strip()
125
+ elif line.startswith("Keywords:"):
126
+ keywords = extract_keywords(line)
127
+
128
+ # Remove quotation marks
129
+ headline = headline.strip('"')
130
+ description = description.strip('"')
131
+ lstkeywords = [x.strip('"') for x in keywords]
132
+
133
+ st.info(f"**Headline:** {headline}")
134
+ st.info(f"**Description:** {description}")
135
+ st.info(f"**Keywords:** {', '.join(lstkeywords)}")
136
+ else:
137
+ st.error("Failed to generate metadata. Please try again.")
138
+
139
+ st.markdown("""
140
+ ---
141
+ *This app utilizes Hugging Face's Transformers library and Llava models to generate image metadata.
142
+ The quality of the generated metadata depends on the chosen model and the complexity of the image.*
143
+ """)