suzall commited on
Commit
07da576
·
verified ·
1 Parent(s): 3e5509d

final commit

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import io
4
+ import base64
5
+ from PIL import Image
6
+ import os
7
+
8
+ hf_api_key = os.getenv("HUGGINGFACE_TOKEN")
9
+
10
+ headers = {"Authorization": f"Bearer {hf_api_key}"}
11
+ ANALYSIS_API_URL = "https://api-inference.huggingface.co/models/dandelin/vilt-b32-finetuned-vqa"
12
+ GENERATION_API_URL = "https://api-inference.huggingface.co/models/thejagstudio/3d-animation-style-sdxl"
13
+
14
+ def query_analysis(image_bytes, question):
15
+ payload = {
16
+ "inputs": {
17
+ "question": question,
18
+ "image": base64.b64encode(image_bytes).decode('utf-8')
19
+ }
20
+ }
21
+ response = requests.post(ANALYSIS_API_URL, headers=headers, json=payload)
22
+ if response.status_code == 200 and response.json():
23
+ return response.json()[0].get('answer', 'unspecified')
24
+ return 'unspecified'
25
+
26
+ def query_generation(prompt, image_bytes):
27
+ payload = {
28
+ "inputs": prompt,
29
+ "image": base64.b64encode(image_bytes).decode('utf-8')
30
+ }
31
+ response = requests.post(GENERATION_API_URL, headers=headers, json=payload)
32
+ return response.content
33
+
34
+ st.title("Image Insight & Generation Studio👻")
35
+
36
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
37
+
38
+ if uploaded_file is not None:
39
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
40
+
41
+ image_bytes = uploaded_file.read()
42
+
43
+ user_prompt = st.text_input("Enter additional description for image generation:")
44
+
45
+ if st.button("Generate Image"):
46
+ gender = query_analysis(image_bytes, "What is the gender of the person in the image?")
47
+ clothing = query_analysis(image_bytes, "What is the person wearing and which color?")
48
+ hair_color = query_analysis(image_bytes, "What is the hair color of the person?")
49
+ facial_expression = query_analysis(image_bytes, "What is the facial expression of the person?")
50
+
51
+ if gender.lower() == "female":
52
+ prompt = f"create one school-going girl with {hair_color}, wearing {clothing}, showing a {facial_expression}. {user_prompt}"
53
+ elif gender.lower() == "male":
54
+ prompt = f"create one school-going boy with {hair_color}, wearing {clothing}, showing a {facial_expression}. {user_prompt}"
55
+ else:
56
+ prompt = f"create one school-going child with {hair_color}, wearing {clothing}, showing a {facial_expression}. {user_prompt}"
57
+
58
+ generated_image_data = query_generation(prompt, image_bytes)
59
+ generated_image = Image.open(io.BytesIO(generated_image_data))
60
+
61
+ st.image(generated_image, caption="Generated Image", use_column_width=True)
62
+
63
+ buffered = io.BytesIO()
64
+ generated_image.save(buffered, format="PNG")
65
+ st.download_button(label="Download Generated Image", data=buffered.getvalue(), file_name="generated_image.png", mime="image/png")