14likhit commited on
Commit
f5940c8
·
verified ·
1 Parent(s): 6945f53

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
- from diffusers import DiffusionPipeline
3
  import torch
 
 
4
 
5
  # Set page config
6
  st.set_page_config(
@@ -30,10 +31,13 @@ prompt = st.text_area(
30
  value="Masterpiece portrait of a beautiful young woman with flowing hair, detailed face, photorealistic, 8k, professional photography"
31
  )
32
 
33
- # Generate button
34
  if st.button("Generate Portrait", type="primary"):
35
  with st.spinner("Loading model and generating portrait..."):
36
  try:
 
 
 
37
  # Set up the model pipeline
38
  pipeline = DiffusionPipeline.from_pretrained(
39
  "Shakker-Labs/AWPortraitCN2",
@@ -63,9 +67,6 @@ if st.button("Generate Portrait", type="primary"):
63
  st.image(image, caption="Generated Portrait", use_column_width=True)
64
 
65
  # Option to download
66
- # Convert the PIL image to bytes
67
- import io
68
- from PIL import Image
69
  buf = io.BytesIO()
70
  image.save(buf, format="PNG")
71
  byte_im = buf.getvalue()
 
1
  import streamlit as st
 
2
  import torch
3
+ from PIL import Image
4
+ import io
5
 
6
  # Set page config
7
  st.set_page_config(
 
31
  value="Masterpiece portrait of a beautiful young woman with flowing hair, detailed face, photorealistic, 8k, professional photography"
32
  )
33
 
34
+ # Load the model when the user clicks the button
35
  if st.button("Generate Portrait", type="primary"):
36
  with st.spinner("Loading model and generating portrait..."):
37
  try:
38
+ # Import here to avoid the cached_download issue until the model is actually needed
39
+ from diffusers import DiffusionPipeline
40
+
41
  # Set up the model pipeline
42
  pipeline = DiffusionPipeline.from_pretrained(
43
  "Shakker-Labs/AWPortraitCN2",
 
67
  st.image(image, caption="Generated Portrait", use_column_width=True)
68
 
69
  # Option to download
 
 
 
70
  buf = io.BytesIO()
71
  image.save(buf, format="PNG")
72
  byte_im = buf.getvalue()