Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		Pedro Cuenca
		
	commited on
		
		
					Commit 
							
							·
						
						ffed138
	
1
								Parent(s):
							
							f62b045
								
Simple skeleton for a streamlit app
Browse filesIn order to use it, you need to create a file `.streamlit/secrets.toml`
to define the URL of the BACKEND_SERVER:
```
BACKEND_SERVER="<server url>"
```
Former-commit-id: 4d81cb1c805c903c74b82a5706b3a54ce8a2348b
- app/app.py +22 -180
 
    	
        app/app.py
    CHANGED
    
    | 
         @@ -1,196 +1,38 @@ 
     | 
|
| 1 | 
         
             
            #!/usr/bin/env python
         
     | 
| 2 | 
         
             
            # coding: utf-8
         
     | 
| 3 | 
         | 
| 4 | 
         
            -
            # Uncomment to run on cpu
         
     | 
| 5 | 
         
            -
            #import os
         
     | 
| 6 | 
         
            -
            #os.environ["JAX_PLATFORM_NAME"] = "cpu"
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
             
            import random
         
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
            import  
     | 
| 11 | 
         
            -
            import flax.linen as nn
         
     | 
| 12 | 
         
            -
            from flax.training.common_utils import shard
         
     | 
| 13 | 
         
            -
            from flax.jax_utils import replicate, unreplicate
         
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
            from transformers.models.bart.modeling_flax_bart import *
         
     | 
| 16 | 
         
            -
            from transformers import BartTokenizer, FlaxBartForConditionalGeneration
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
            import requests
         
     | 
| 20 | 
         
            -
            from PIL import Image
         
     | 
| 21 | 
         
            -
            import numpy as np
         
     | 
| 22 | 
         
            -
            import matplotlib.pyplot as plt
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
            from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
         
     | 
| 26 | 
         | 
| 27 | 
         
             
            import streamlit as st
         
     | 
| 28 | 
         | 
| 29 | 
         
            -
            st.write("Loading model...")
         
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
            # TODO: set those args in a config file
         
     | 
| 32 | 
         
            -
            OUTPUT_VOCAB_SIZE = 16384 + 1  # encoded image token space + 1 for bos
         
     | 
| 33 | 
         
            -
            OUTPUT_LENGTH = 256 + 1  # number of encoded tokens + 1 for bos
         
     | 
| 34 | 
         
            -
            BOS_TOKEN_ID = 16384
         
     | 
| 35 | 
         
            -
            BASE_MODEL = 'flax-community/dalle-mini'
         
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
            class CustomFlaxBartModule(FlaxBartModule):
         
     | 
| 38 | 
         
            -
                def setup(self):
         
     | 
| 39 | 
         
            -
                    # we keep shared to easily load pre-trained weights
         
     | 
| 40 | 
         
            -
                    self.shared = nn.Embed(
         
     | 
| 41 | 
         
            -
                        self.config.vocab_size,
         
     | 
| 42 | 
         
            -
                        self.config.d_model,
         
     | 
| 43 | 
         
            -
                        embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
         
     | 
| 44 | 
         
            -
                        dtype=self.dtype,
         
     | 
| 45 | 
         
            -
                    )
         
     | 
| 46 | 
         
            -
                    # a separate embedding is used for the decoder
         
     | 
| 47 | 
         
            -
                    self.decoder_embed = nn.Embed(
         
     | 
| 48 | 
         
            -
                        OUTPUT_VOCAB_SIZE,
         
     | 
| 49 | 
         
            -
                        self.config.d_model,
         
     | 
| 50 | 
         
            -
                        embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
         
     | 
| 51 | 
         
            -
                        dtype=self.dtype,
         
     | 
| 52 | 
         
            -
                    )
         
     | 
| 53 | 
         
            -
                    self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
                    # the decoder has a different config
         
     | 
| 56 | 
         
            -
                    decoder_config = BartConfig(self.config.to_dict())
         
     | 
| 57 | 
         
            -
                    decoder_config.max_position_embeddings = OUTPUT_LENGTH
         
     | 
| 58 | 
         
            -
                    decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
         
     | 
| 59 | 
         
            -
                    self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
         
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
            class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
         
     | 
| 62 | 
         
            -
                def setup(self):
         
     | 
| 63 | 
         
            -
                    self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
         
     | 
| 64 | 
         
            -
                    self.lm_head = nn.Dense(
         
     | 
| 65 | 
         
            -
                        OUTPUT_VOCAB_SIZE,
         
     | 
| 66 | 
         
            -
                        use_bias=False,
         
     | 
| 67 | 
         
            -
                        dtype=self.dtype,
         
     | 
| 68 | 
         
            -
                        kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
         
     | 
| 69 | 
         
            -
                    )
         
     | 
| 70 | 
         
            -
                    self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
         
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
            -
            class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
         
     | 
| 73 | 
         
            -
                module_class = CustomFlaxBartForConditionalGenerationModule
         
     | 
| 74 | 
         
            -
             
     | 
| 75 | 
         
            -
            # create our model
         
     | 
| 76 | 
         
            -
            # FIXME: Save tokenizer to hub so we can load from there
         
     | 
| 77 | 
         
            -
            tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
         
     | 
| 78 | 
         
            -
            model = CustomFlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)
         
     | 
| 79 | 
         
            -
            model.config.force_bos_token_to_be_generated = False
         
     | 
| 80 | 
         
            -
            model.config.forced_bos_token_id = None
         
     | 
| 81 | 
         
            -
            model.config.forced_eos_token_id = None
         
     | 
| 82 | 
         
            -
             
     | 
| 83 | 
         
            -
            vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
         
     | 
| 84 | 
         
            -
             
     | 
| 85 | 
         
            -
            def custom_to_pil(x):
         
     | 
| 86 | 
         
            -
                x = np.clip(x, 0., 1.)
         
     | 
| 87 | 
         
            -
                x = (255*x).astype(np.uint8)
         
     | 
| 88 | 
         
            -
                x = Image.fromarray(x)
         
     | 
| 89 | 
         
            -
                if not x.mode == "RGB":
         
     | 
| 90 | 
         
            -
                    x = x.convert("RGB")
         
     | 
| 91 | 
         
            -
                return x
         
     | 
| 92 | 
         
            -
             
     | 
| 93 | 
         
            -
            def generate(input, rng, params):
         
     | 
| 94 | 
         
            -
                return model.generate(
         
     | 
| 95 | 
         
            -
                    **input,
         
     | 
| 96 | 
         
            -
                    max_length=257,
         
     | 
| 97 | 
         
            -
                    num_beams=1,
         
     | 
| 98 | 
         
            -
                    do_sample=True,
         
     | 
| 99 | 
         
            -
                    prng_key=rng,
         
     | 
| 100 | 
         
            -
                    eos_token_id=50000,
         
     | 
| 101 | 
         
            -
                    pad_token_id=50000,
         
     | 
| 102 | 
         
            -
                    params=params,
         
     | 
| 103 | 
         
            -
                )
         
     | 
| 104 | 
         
            -
             
     | 
| 105 | 
         
            -
            def get_images(indices, params):
         
     | 
| 106 | 
         
            -
                return vqgan.decode_code(indices, params=params)
         
     | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
            def plot_images(images):
         
     | 
| 109 | 
         
            -
                fig = plt.figure(figsize=(40, 20))
         
     | 
| 110 | 
         
            -
                columns = 4
         
     | 
| 111 | 
         
            -
                rows = 2
         
     | 
| 112 | 
         
            -
                plt.subplots_adjust(hspace=0, wspace=0)
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
                for i in range(1, columns*rows +1):
         
     | 
| 115 | 
         
            -
                    fig.add_subplot(rows, columns, i)
         
     | 
| 116 | 
         
            -
                    plt.imshow(images[i-1])
         
     | 
| 117 | 
         
            -
                plt.gca().axes.get_yaxis().set_visible(False)
         
     | 
| 118 | 
         
            -
                plt.show()
         
     | 
| 119 | 
         
            -
                
         
     | 
| 120 | 
         
            -
            def stack_reconstructions(images):
         
     | 
| 121 | 
         
            -
                w, h = images[0].size[0], images[0].size[1]
         
     | 
| 122 | 
         
            -
                img = Image.new("RGB", (len(images)*w, h))
         
     | 
| 123 | 
         
            -
                for i, img_ in enumerate(images):
         
     | 
| 124 | 
         
            -
                    img.paste(img_, (i*w,0))
         
     | 
| 125 | 
         
            -
                return img
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
            p_generate = jax.pmap(generate, "batch")
         
     | 
| 128 | 
         
            -
            p_get_images = jax.pmap(get_images, "batch")
         
     | 
| 129 | 
         
            -
             
     | 
| 130 | 
         
            -
            bart_params = replicate(model.params)
         
     | 
| 131 | 
         
            -
            vqgan_params = replicate(vqgan.params)
         
     | 
| 132 | 
         
            -
             
     | 
| 133 | 
         
            -
            # ## CLIP Scoring
         
     | 
| 134 | 
         
            -
            from transformers import CLIPProcessor, FlaxCLIPModel
         
     | 
| 135 | 
         
            -
             
     | 
| 136 | 
         
            -
            clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
         
     | 
| 137 | 
         
            -
            # st.write("FlaxCLIPModel")
         
     | 
| 138 | 
         
            -
            # print("Initialize FlaxCLIPModel")
         
     | 
| 139 | 
         
            -
            processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
         
     | 
| 140 | 
         
            -
            # st.write("CLIPProcessor")
         
     | 
| 141 | 
         
            -
            # print("Initialize CLIPProcessor")
         
     | 
| 142 | 
         
            -
             
     | 
| 143 | 
         
            -
            def hallucinate(prompt, num_images=64):
         
     | 
| 144 | 
         
            -
                prompt = [prompt] * jax.device_count()
         
     | 
| 145 | 
         
            -
                inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
         
     | 
| 146 | 
         
            -
                inputs = shard(inputs)
         
     | 
| 147 | 
         
            -
             
     | 
| 148 | 
         
            -
                all_images = []
         
     | 
| 149 | 
         
            -
                for i in range(num_images // jax.device_count()):
         
     | 
| 150 | 
         
            -
                    key = random.randint(0, 1e7)
         
     | 
| 151 | 
         
            -
                    rng = jax.random.PRNGKey(key)
         
     | 
| 152 | 
         
            -
                    rngs = jax.random.split(rng, jax.local_device_count())
         
     | 
| 153 | 
         
            -
                    indices = p_generate(inputs, rngs, bart_params).sequences
         
     | 
| 154 | 
         
            -
                    indices = indices[:, :, 1:]
         
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
                    images = p_get_images(indices, vqgan_params)
         
     | 
| 157 | 
         
            -
                    images = np.squeeze(np.asarray(images), 1)
         
     | 
| 158 | 
         
            -
                    for image in images:
         
     | 
| 159 | 
         
            -
                        all_images.append(custom_to_pil(image))
         
     | 
| 160 | 
         
            -
                return all_images
         
     | 
| 161 | 
         
            -
             
     | 
| 162 | 
         
            -
            def clip_top_k(prompt, images, k=8):
         
     | 
| 163 | 
         
            -
                inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
         
     | 
| 164 | 
         
            -
                outputs = clip(**inputs)
         
     | 
| 165 | 
         
            -
                logits = outputs.logits_per_text
         
     | 
| 166 | 
         
            -
                scores = np.array(logits[0]).argsort()[-k:][::-1]
         
     | 
| 167 | 
         
            -
                return [images[score] for score in scores]
         
     | 
| 168 | 
         
            -
             
     | 
| 169 | 
         
            -
            def captioned_strip(images, caption):
         
     | 
| 170 | 
         
            -
                increased_h = 0 if caption is None else 48
         
     | 
| 171 | 
         
            -
                w, h = images[0].size[0], images[0].size[1]
         
     | 
| 172 | 
         
            -
                img = Image.new("RGB", (len(images)*w, h + increased_h))
         
     | 
| 173 | 
         
            -
                for i, img_ in enumerate(images):
         
     | 
| 174 | 
         
            -
                    img.paste(img_, (i*w, increased_h))
         
     | 
| 175 | 
         
            -
             
     | 
| 176 | 
         
            -
                if caption is not None:
         
     | 
| 177 | 
         
            -
                    draw = ImageDraw.Draw(img)
         
     | 
| 178 | 
         
            -
                    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
         
     | 
| 179 | 
         
            -
                    draw.text((20, 3), caption, (255,255,255), font=font)
         
     | 
| 180 | 
         
            -
                return img
         
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
             
            # Controls
         
     | 
| 183 | 
         | 
| 184 | 
         
            -
            num_images = st.sidebar.slider("Candidates to generate", 1, 64, 8, 1)
         
     | 
| 185 | 
         
            -
            num_preds = st.sidebar.slider("Best predictions to show", 1, 8, 1, 1)
         
     | 
| 186 | 
         | 
| 
         | 
|
| 187 | 
         | 
| 188 | 
         
             
            prompt = st.text_input("What do you want to see?")
         
     | 
| 189 | 
         | 
| 190 | 
         
             
            if prompt != "":
         
     | 
| 191 | 
         
             
                st.write(f"Generating candidates for: {prompt}")
         
     | 
| 192 | 
         
            -
                images = hallucinate(prompt, num_images=num_images)
         
     | 
| 193 | 
         
            -
                images = clip_top_k(prompt, images, k=num_preds)
         
     | 
| 194 | 
         
            -
                predictions_strip = captioned_strip(images, None)
         
     | 
| 195 | 
         | 
| 196 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
             
            #!/usr/bin/env python
         
     | 
| 2 | 
         
             
            # coding: utf-8
         
     | 
| 3 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 4 | 
         
             
            import random
         
     | 
| 5 | 
         
            +
            from dalle_mini.backend import ServiceError, get_images_from_backend
         
     | 
| 6 | 
         
            +
            from dalle_mini.helpers import captioned_strip
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 7 | 
         | 
| 8 | 
         
             
            import streamlit as st
         
     | 
| 9 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 10 | 
         
             
            # Controls
         
     | 
| 11 | 
         | 
| 12 | 
         
            +
            # num_images = st.sidebar.slider("Candidates to generate", 1, 64, 8, 1)
         
     | 
| 13 | 
         
            +
            # num_preds = st.sidebar.slider("Best predictions to show", 1, 8, 1, 1)
         
     | 
| 14 | 
         | 
| 15 | 
         
            +
            st.sidebar.markdown('Visit [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)')
         
     | 
| 16 | 
         | 
| 17 | 
         
             
            prompt = st.text_input("What do you want to see?")
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            if prompt != "":
         
     | 
| 20 | 
         
             
                st.write(f"Generating candidates for: {prompt}")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 21 | 
         | 
| 22 | 
         
            +
                try:
         
     | 
| 23 | 
         
            +
                    backend_url = st.secrets["BACKEND_SERVER"]
         
     | 
| 24 | 
         
            +
                    print(f"Getting selections: {prompt}")
         
     | 
| 25 | 
         
            +
                    selected = get_images_from_backend(prompt, backend_url)
         
     | 
| 26 | 
         
            +
                    preds = captioned_strip(selected, prompt)
         
     | 
| 27 | 
         
            +
                    st.image(preds)
         
     | 
| 28 | 
         
            +
                except ServiceError as error:
         
     | 
| 29 | 
         
            +
                    st.write(f"Service unavailable, status: {error.status_code}")
         
     | 
| 30 | 
         
            +
                except KeyError:
         
     | 
| 31 | 
         
            +
                    st.write("""
         
     | 
| 32 | 
         
            +
                    **Error: BACKEND_SERVER unset**
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
         
     | 
| 35 | 
         
            +
                    ```
         
     | 
| 36 | 
         
            +
                    BACKEND_SERVER="<server url>"
         
     | 
| 37 | 
         
            +
                    ```
         
     | 
| 38 | 
         
            +
                    """)
         
     |