suzall commited on
Commit
fcf3deb
·
verified ·
1 Parent(s): 210da86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -40
app.py CHANGED
@@ -5,47 +5,46 @@ import base64
5
  from PIL import Image
6
  import os
7
 
8
- # Get Hugging Face API key from environment variable
9
  hf_api_key = os.getenv("HUGGINGFACE_TOKEN")
 
10
 
11
  # API URLs
12
  ANALYSIS_API_URL = "https://api-inference.huggingface.co/models/dandelin/vilt-b32-finetuned-vqa"
13
  GENERATION_API_URL = "https://api-inference.huggingface.co/models/thejagstudio/3d-animation-style-sdxl"
14
 
15
- # Set headers with authorization
16
- headers = {"Authorization": f"Bearer {hf_api_key}"}
17
-
18
  # Function to query analysis (VQA) model
19
  def query_analysis(image_bytes, question):
 
 
 
20
  try:
21
- payload = {
22
- "inputs": {
23
- "question": question,
24
- "image": base64.b64encode(image_bytes).decode('utf-8')
25
- }
26
- }
27
  response = requests.post(ANALYSIS_API_URL, headers=headers, json=payload)
28
- response.raise_for_status() # Raise an exception for HTTP errors
29
- if response.json():
30
- return response.json()[0].get('answer', 'unspecified')
31
  except Exception as e:
32
- st.error(f"Error analyzing image: {e}")
33
  return 'unspecified'
34
 
35
  # Function to query image generation model
36
  def query_generation(prompt, image_bytes):
 
 
 
 
37
  try:
38
- payload = {
39
- "inputs": prompt,
40
- "image": base64.b64encode(image_bytes).decode('utf-8')
41
- }
42
  response = requests.post(GENERATION_API_URL, headers=headers, json=payload)
43
- response.raise_for_status() # Raise an exception for HTTP errors
44
  return response.content
45
  except Exception as e:
46
- st.error(f"Error generating image: {e}")
47
  return None
48
 
 
 
 
 
 
49
  # Streamlit app title
50
  st.title("Image Insight & Generation Studio 👻")
51
 
@@ -53,12 +52,12 @@ st.title("Image Insight & Generation Studio 👻")
53
  uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
54
 
55
  if uploaded_file is not None:
56
- # Display uploaded image
57
- st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
58
-
59
  # Read image bytes
60
  image_bytes = uploaded_file.read()
61
 
 
 
 
62
  # Text input for additional description for image generation
63
  user_prompt = st.text_input("Enter additional description for image generation (optional):")
64
 
@@ -69,32 +68,56 @@ if uploaded_file is not None:
69
  clothing = query_analysis(image_bytes, "What is the person wearing and which color?")
70
  hair_color = query_analysis(image_bytes, "What is the hair color of the person?")
71
  facial_expression = query_analysis(image_bytes, "What is the facial expression of the person?")
 
72
 
73
  # Build generation prompt based on VQA responses and user input
74
  if gender.lower() == "female":
75
- prompt = f"Create a school-going girl with {hair_color} hair, wearing {clothing}, showing a {facial_expression}. {user_prompt}"
76
- elif gender.lower() == "male":
77
- prompt = f"Create a school-going boy with {hair_color} hair, wearing {clothing}, showing a {facial_expression}. {user_prompt}"
78
  else:
79
- prompt = f"Create a school-going child with {hair_color} hair, wearing {clothing}, showing a {facial_expression}. {user_prompt}"
80
-
81
  # Call image generation API
82
  with st.spinner("Generating the image..."):
83
  generated_image_data = query_generation(prompt, image_bytes)
 
 
 
 
84
 
85
- if generated_image_data:
86
- # Display the generated image
87
- generated_image = Image.open(io.BytesIO(generated_image_data))
88
- st.image(generated_image, caption="Generated Image", use_column_width=True)
 
89
 
90
- # Provide a download option for the generated image
91
- buffered = io.BytesIO()
92
- generated_image.save(buffered, format="PNG")
93
- st.download_button(label="Download Generated Image", data=buffered.getvalue(), file_name="generated_image.png", mime="image/png")
94
- else:
95
- st.error("Failed to generate image. Please try again.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Footer for better UI experience
98
  st.markdown("---")
99
  st.markdown("❤️‍🔥 *Made by Sujal Tamrakar*")
100
- st.markdown("💡 *Powered by Hugging Face and Streamlit*")
 
5
  from PIL import Image
6
  import os
7
 
8
+ # Hugging Face API Key
9
  hf_api_key = os.getenv("HUGGINGFACE_TOKEN")
10
+ headers = {"Authorization": f"Bearer {hf_api_key}"}
11
 
12
  # API URLs
13
  ANALYSIS_API_URL = "https://api-inference.huggingface.co/models/dandelin/vilt-b32-finetuned-vqa"
14
  GENERATION_API_URL = "https://api-inference.huggingface.co/models/thejagstudio/3d-animation-style-sdxl"
15
 
 
 
 
16
  # Function to query analysis (VQA) model
17
  def query_analysis(image_bytes, question):
18
+ payload = {
19
+ "inputs": {"question": question, "image": base64.b64encode(image_bytes).decode('utf-8')}
20
+ }
21
  try:
 
 
 
 
 
 
22
  response = requests.post(ANALYSIS_API_URL, headers=headers, json=payload)
23
+ response.raise_for_status()
24
+ return response.json()[0].get('answer', 'unspecified')
 
25
  except Exception as e:
26
+ st.error(f"Error: {e}")
27
  return 'unspecified'
28
 
29
  # Function to query image generation model
30
  def query_generation(prompt, image_bytes):
31
+ payload = {
32
+ "inputs": prompt,
33
+ "image": base64.b64encode(image_bytes).decode('utf-8')
34
+ }
35
  try:
 
 
 
 
36
  response = requests.post(GENERATION_API_URL, headers=headers, json=payload)
37
+ response.raise_for_status()
38
  return response.content
39
  except Exception as e:
40
+ st.error(f"Error: {e}")
41
  return None
42
 
43
+ # Function to save feedback to a file
44
+ def save_feedback(name, feedback, rating):
45
+ with open("feedback.txt", "a") as f:
46
+ f.write(f"Name: {name}\nFeedback: {feedback}\nRating: {rating}/5\n\n")
47
+
48
  # Streamlit app title
49
  st.title("Image Insight & Generation Studio 👻")
50
 
 
52
  uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
53
 
54
  if uploaded_file is not None:
 
 
 
55
  # Read image bytes
56
  image_bytes = uploaded_file.read()
57
 
58
+ # Display the uploaded image separately before generating the new image
59
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
60
+
61
  # Text input for additional description for image generation
62
  user_prompt = st.text_input("Enter additional description for image generation (optional):")
63
 
 
68
  clothing = query_analysis(image_bytes, "What is the person wearing and which color?")
69
  hair_color = query_analysis(image_bytes, "What is the hair color of the person?")
70
  facial_expression = query_analysis(image_bytes, "What is the facial expression of the person?")
71
+ age = query_analysis(image_bytes, "What is the estimated age of the person?")
72
 
73
  # Build generation prompt based on VQA responses and user input
74
  if gender.lower() == "female":
75
+ prompt = f"Create a {age}-year-old girl with {hair_color} hair, wearing {clothing}, showing a {facial_expression}. {user_prompt}"
 
 
76
  else:
77
+ prompt = f"Create a {age}-year-old person with {hair_color} hair, wearing {clothing}, showing a {facial_expression}. {user_prompt}"
78
+
79
  # Call image generation API
80
  with st.spinner("Generating the image..."):
81
  generated_image_data = query_generation(prompt, image_bytes)
82
+ if generated_image_data:
83
+ # Store the generated image in session state
84
+ st.session_state.generated_image_data = generated_image_data
85
+ st.success("Image generated successfully!")
86
 
87
+ # Display the generated image if available
88
+ if 'generated_image_data' in st.session_state:
89
+ st.markdown("### Generated Image")
90
+ generated_image = Image.open(io.BytesIO(st.session_state.generated_image_data))
91
+ st.image(generated_image, caption="Generated Image", use_column_width=True)
92
 
93
+ # Provide download option for the generated image
94
+ buffered = io.BytesIO()
95
+ generated_image.save(buffered, format="PNG")
96
+ st.download_button(
97
+ label="Download Generated Image",
98
+ data=buffered.getvalue(),
99
+ file_name="generated_image.png",
100
+ mime="image/png"
101
+ )
102
+
103
+ # Ask for feedback after the image is generated
104
+ with st.form(key='feedback_form'):
105
+ name = st.text_input("Your Name")
106
+ feedback = st.text_area("Please leave your feedback")
107
+ rating = st.slider("Rate the image quality", 1, 5)
108
+ submit_button = st.form_submit_button(label='Submit Feedback')
109
+
110
+ if submit_button:
111
+ save_feedback(name, feedback, rating)
112
+ st.success("Thank you for your feedback!")
113
+
114
+ # Ensure that the generated image does not disappear after feedback or download
115
+ if 'generated_image_data' in st.session_state:
116
+ # st.markdown("### Generated Image (Persistent)")
117
+ # st.image(Image.open(io.BytesIO(st.session_state.generated_image_data)), caption="Generated Image", use_column_width=True)
118
+ pass
119
 
120
+ # Footer for a better UI experience
121
  st.markdown("---")
122
  st.markdown("❤️‍🔥 *Made by Sujal Tamrakar*")
123
+ st.markdown("💡 *Powered by Hugging Face and Streamlit*")