Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Merge branch 'main' of https://github.com/borisdayma/dalle-mini into fix-opt_state
Browse files- .github/workflows/sync_to_hub_debug.yml +17 -0
- CITATION.cff +44 -0
- README.md +76 -7
- app/app.py +22 -4
- app/dalle_mini +0 -1
- app/gradio/dalle_mini +0 -1
- app/img/loading.gif +0 -0
- dalle_mini/text.py +272 -0
- dev/README.md +122 -0
- dev/encoding/vqgan-jax-encoding-streaming.ipynb +562 -0
- dev/encoding/vqgan-jax-encoding-webdataset.ipynb +461 -0
- dev/inference/dalle_mini +0 -1
- dev/inference/inference_pipeline.ipynb +51 -28
- dev/requirements.txt +3 -5
- requirements.txt +0 -2
- setup.cfg +16 -0
- setup.py +4 -0
    	
        .github/workflows/sync_to_hub_debug.yml
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            name: Deploy to debug app
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            on:
         | 
| 4 | 
            +
              # to run this workflow manually from the Actions tab
         | 
| 5 | 
            +
              workflow_dispatch:
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            jobs:
         | 
| 8 | 
            +
              sync-to-hub-debug:
         | 
| 9 | 
            +
                runs-on: ubuntu-latest
         | 
| 10 | 
            +
                steps:
         | 
| 11 | 
            +
                  - uses: actions/checkout@v2
         | 
| 12 | 
            +
                    with:
         | 
| 13 | 
            +
                      fetch-depth: 0
         | 
| 14 | 
            +
                  - name: Push to hub
         | 
| 15 | 
            +
                    env:
         | 
| 16 | 
            +
                      HF_TOKEN: ${{ secrets.HF_TOKEN }}
         | 
| 17 | 
            +
                    run: git push --force https://boris:[email protected]/spaces/flax-community/dalle-mini-debug +HEAD:main
         | 
    	
        CITATION.cff
    ADDED
    
    | @@ -0,0 +1,44 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # YAML 1.2
         | 
| 2 | 
            +
            ---
         | 
| 3 | 
            +
            abstract: "DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardware resources. By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models, we were able to create a model that is 27 times smaller than the original DALL·E and train it on a single TPU v3-8 for only 3 days. DALL·E mini achieves impressive results, albeit of a lower quality than the original system. It can be used for exploration and further experimentation on commodity hardware."
         | 
| 4 | 
            +
            authors: 
         | 
| 5 | 
            +
              -
         | 
| 6 | 
            +
                family-names: Dayma
         | 
| 7 | 
            +
                given-names: Boris
         | 
| 8 | 
            +
              -
         | 
| 9 | 
            +
                family-names: Patil
         | 
| 10 | 
            +
                given-names: Suraj
         | 
| 11 | 
            +
              -
         | 
| 12 | 
            +
                family-names: Cuenca
         | 
| 13 | 
            +
                given-names: Pedro
         | 
| 14 | 
            +
              -
         | 
| 15 | 
            +
                family-names: Saifullah
         | 
| 16 | 
            +
                given-names: Khalid
         | 
| 17 | 
            +
              -
         | 
| 18 | 
            +
                family-names: Abraham
         | 
| 19 | 
            +
                given-names: Tanishq
         | 
| 20 | 
            +
              -
         | 
| 21 | 
            +
                family-names: "Lê Khắc"
         | 
| 22 | 
            +
                given-names: "Phúc"
         | 
| 23 | 
            +
              -
         | 
| 24 | 
            +
                family-names: Melas
         | 
| 25 | 
            +
                given-names: Luke
         | 
| 26 | 
            +
              -
         | 
| 27 | 
            +
                family-names: Ghosh
         | 
| 28 | 
            +
                given-names: Ritobrata
         | 
| 29 | 
            +
            cff-version: "1.1.0"
         | 
| 30 | 
            +
            date-released: 2021-07-29
         | 
| 31 | 
            +
            identifiers: 
         | 
| 32 | 
            +
            keywords: 
         | 
| 33 | 
            +
              - dalle
         | 
| 34 | 
            +
              - "text-to-image generation"
         | 
| 35 | 
            +
              - transformer
         | 
| 36 | 
            +
              - "zero-shot"
         | 
| 37 | 
            +
              - JAX
         | 
| 38 | 
            +
            license: "Apache-2.0"
         | 
| 39 | 
            +
            doi: 10.5281/zenodo.5146400
         | 
| 40 | 
            +
            message: "If you use this project, please cite it using these metadata."
         | 
| 41 | 
            +
            repository-code: "https://github.com/borisdayma/dalle-mini"
         | 
| 42 | 
            +
            title: "DALL·E Mini"
         | 
| 43 | 
            +
            version: "v0.1-alpha"
         | 
| 44 | 
            +
            ...
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,8 +1,8 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            title: DALL·E mini
         | 
| 3 | 
             
            emoji: 🥑
         | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
             
            sdk: streamlit
         | 
| 7 | 
             
            app_file: app/app.py
         | 
| 8 | 
             
            pinned: false
         | 
| @@ -16,7 +16,7 @@ _Generate images from a text prompt_ | |
| 16 |  | 
| 17 | 
             
            Our logo was generated with DALL·E mini using the prompt "logo of an armchair in the shape of an avocado".
         | 
| 18 |  | 
| 19 | 
            -
            You can create your own pictures with [the demo](https://huggingface.co/spaces/flax-community/dalle-mini) | 
| 20 |  | 
| 21 | 
             
            ## How does it work?
         | 
| 22 |  | 
| @@ -26,8 +26,6 @@ Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini | |
| 26 |  | 
| 27 | 
             
            ### Dependencies Installation
         | 
| 28 |  | 
| 29 | 
            -
            The root folder and associated [`requirements.txt`](./requirements.txt) is only for the app.
         | 
| 30 | 
            -
             | 
| 31 | 
             
            For development, use [`dev/requirements.txt`](dev/requirements.txt) or [`dev/environment.yaml`](dev/environment.yaml).
         | 
| 32 |  | 
| 33 | 
             
            ### Training of VQGAN
         | 
| @@ -52,7 +50,16 @@ To generate sample predictions and understand the inference pipeline step by ste | |
| 52 |  | 
| 53 | 
             
            [](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/dev/inference/inference_pipeline.ipynb)
         | 
| 54 |  | 
| 55 | 
            -
            ##  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 56 |  | 
| 57 | 
             
            The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
         | 
| 58 |  | 
| @@ -70,4 +77,66 @@ The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL | |
| 70 | 
             
            ## Acknowledgements
         | 
| 71 |  | 
| 72 | 
             
            - 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
         | 
| 73 | 
            -
            - Google Cloud  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            title: DALL·E mini
         | 
| 3 | 
             
            emoji: 🥑
         | 
| 4 | 
            +
            colorFrom: yellow
         | 
| 5 | 
            +
            colorTo: green
         | 
| 6 | 
             
            sdk: streamlit
         | 
| 7 | 
             
            app_file: app/app.py
         | 
| 8 | 
             
            pinned: false
         | 
|  | |
| 16 |  | 
| 17 | 
             
            Our logo was generated with DALL·E mini using the prompt "logo of an armchair in the shape of an avocado".
         | 
| 18 |  | 
| 19 | 
            +
            You can create your own pictures with [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).
         | 
| 20 |  | 
| 21 | 
             
            ## How does it work?
         | 
| 22 |  | 
|  | |
| 26 |  | 
| 27 | 
             
            ### Dependencies Installation
         | 
| 28 |  | 
|  | |
|  | |
| 29 | 
             
            For development, use [`dev/requirements.txt`](dev/requirements.txt) or [`dev/environment.yaml`](dev/environment.yaml).
         | 
| 30 |  | 
| 31 | 
             
            ### Training of VQGAN
         | 
|  | |
| 50 |  | 
| 51 | 
             
            [](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/dev/inference/inference_pipeline.ipynb)
         | 
| 52 |  | 
| 53 | 
            +
            ## FAQ
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            ### Where to find the latest models?
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            Trained models are on 🤗 Model Hub:
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            - [VQGAN-f16-16384](https://huggingface.co/flax-community/vqgan_f16_16384) for encoding/decoding images
         | 
| 60 | 
            +
            - [DALL·E mini](https://huggingface.co/flax-community/dalle-mini) for generating images from a text prompt
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            ### Where does the logo come from?
         | 
| 63 |  | 
| 64 | 
             
            The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
         | 
| 65 |  | 
|  | |
| 77 | 
             
            ## Acknowledgements
         | 
| 78 |  | 
| 79 | 
             
            - 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
         | 
| 80 | 
            +
            - Google [TPU Research Cloud (TRC) program](https://sites.research.google/trc/) for providing computing resources
         | 
| 81 | 
            +
            - [Weights & Biases](https://wandb.com/) for providing the infrastructure for experiment tracking and model management
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            ## Citing DALL·E mini
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            ```
         | 
| 88 | 
            +
            @misc{Dayma_DALL·E_Mini_2021,
         | 
| 89 | 
            +
            author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
         | 
| 90 | 
            +
            doi = {10.5281/zenodo.5146400},
         | 
| 91 | 
            +
            month = {7},
         | 
| 92 | 
            +
            title = {DALL·E Mini},
         | 
| 93 | 
            +
            url = {https://github.com/borisdayma/dalle-mini},
         | 
| 94 | 
            +
            year = {2021}
         | 
| 95 | 
            +
            }
         | 
| 96 | 
            +
            ```
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            ## References
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            ```
         | 
| 101 | 
            +
            @misc{ramesh2021zeroshot,
         | 
| 102 | 
            +
                  title={Zero-Shot Text-to-Image Generation}, 
         | 
| 103 | 
            +
                  author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
         | 
| 104 | 
            +
                  year={2021},
         | 
| 105 | 
            +
                  eprint={2102.12092},
         | 
| 106 | 
            +
                  archivePrefix={arXiv},
         | 
| 107 | 
            +
                  primaryClass={cs.CV}
         | 
| 108 | 
            +
            }
         | 
| 109 | 
            +
            ```
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            ```
         | 
| 112 | 
            +
            @misc{esser2021taming,
         | 
| 113 | 
            +
                  title={Taming Transformers for High-Resolution Image Synthesis}, 
         | 
| 114 | 
            +
                  author={Patrick Esser and Robin Rombach and Björn Ommer},
         | 
| 115 | 
            +
                  year={2021},
         | 
| 116 | 
            +
                  eprint={2012.09841},
         | 
| 117 | 
            +
                  archivePrefix={arXiv},
         | 
| 118 | 
            +
                  primaryClass={cs.CV}
         | 
| 119 | 
            +
            }
         | 
| 120 | 
            +
            ```
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            ```
         | 
| 123 | 
            +
            @misc{lewis2019bart,
         | 
| 124 | 
            +
                  title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension}, 
         | 
| 125 | 
            +
                  author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
         | 
| 126 | 
            +
                  year={2019},
         | 
| 127 | 
            +
                  eprint={1910.13461},
         | 
| 128 | 
            +
                  archivePrefix={arXiv},
         | 
| 129 | 
            +
                  primaryClass={cs.CL}
         | 
| 130 | 
            +
            }
         | 
| 131 | 
            +
            ```
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            ```
         | 
| 134 | 
            +
            @misc{radford2021learning,
         | 
| 135 | 
            +
                  title={Learning Transferable Visual Models From Natural Language Supervision}, 
         | 
| 136 | 
            +
                  author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
         | 
| 137 | 
            +
                  year={2021},
         | 
| 138 | 
            +
                  eprint={2103.00020},
         | 
| 139 | 
            +
                  archivePrefix={arXiv},
         | 
| 140 | 
            +
                  primaryClass={cs.CV}
         | 
| 141 | 
            +
            }
         | 
| 142 | 
            +
            ```
         | 
    	
        app/app.py
    CHANGED
    
    | @@ -1,7 +1,6 @@ | |
| 1 | 
             
            #!/usr/bin/env python
         | 
| 2 | 
             
            # coding: utf-8
         | 
| 3 |  | 
| 4 | 
            -
            import random
         | 
| 5 | 
             
            from dalle_mini.backend import ServiceError, get_images_from_backend
         | 
| 6 |  | 
| 7 | 
             
            import streamlit as st
         | 
| @@ -55,12 +54,31 @@ st.subheader('Generate images from text') | |
| 55 |  | 
| 56 | 
             
            prompt = st.text_input("What do you want to see?")
         | 
| 57 |  | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
             
            DEBUG = False
         | 
| 61 | 
             
            if prompt != "" or (should_run_again and prompt != ""):
         | 
| 62 | 
             
                container = st.empty()
         | 
| 63 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 64 |  | 
| 65 | 
             
                try:
         | 
| 66 | 
             
                    backend_url = st.secrets["BACKEND_SERVER"]
         | 
|  | |
| 1 | 
             
            #!/usr/bin/env python
         | 
| 2 | 
             
            # coding: utf-8
         | 
| 3 |  | 
|  | |
| 4 | 
             
            from dalle_mini.backend import ServiceError, get_images_from_backend
         | 
| 5 |  | 
| 6 | 
             
            import streamlit as st
         | 
|  | |
| 54 |  | 
| 55 | 
             
            prompt = st.text_input("What do you want to see?")
         | 
| 56 |  | 
| 57 | 
            +
            test = st.empty()
         | 
|  | |
| 58 | 
             
            DEBUG = False
         | 
| 59 | 
             
            if prompt != "" or (should_run_again and prompt != ""):
         | 
| 60 | 
             
                container = st.empty()
         | 
| 61 | 
            +
                # The following mimics `streamlit.info()`.
         | 
| 62 | 
            +
                # I tried to get the secondary background color using `components.streamlit.config.get_options_for_section("theme")["secondaryBackgroundColor"]`
         | 
| 63 | 
            +
                # but it returns None.
         | 
| 64 | 
            +
                container.markdown(f"""
         | 
| 65 | 
            +
                    <style> p {{ margin:0 }} div {{ margin:0 }} </style>
         | 
| 66 | 
            +
                    <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
         | 
| 67 | 
            +
                    <div class="stAlert">
         | 
| 68 | 
            +
                    <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
         | 
| 69 | 
            +
                    <div class="st-b7">
         | 
| 70 | 
            +
                    <div class="css-whx05o e13vu3m50">
         | 
| 71 | 
            +
                    <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
         | 
| 72 | 
            +
                            <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/img/loading.gif" width="30"/>
         | 
| 73 | 
            +
                            Generating predictions for: <b>{prompt}</b>
         | 
| 74 | 
            +
                    </div>
         | 
| 75 | 
            +
                    </div>
         | 
| 76 | 
            +
                    </div>
         | 
| 77 | 
            +
                    </div>
         | 
| 78 | 
            +
                    </div>
         | 
| 79 | 
            +
                    </div>
         | 
| 80 | 
            +
                    <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
         | 
| 81 | 
            +
                """, unsafe_allow_html=True)
         | 
| 82 |  | 
| 83 | 
             
                try:
         | 
| 84 | 
             
                    backend_url = st.secrets["BACKEND_SERVER"]
         | 
    	
        app/dalle_mini
    DELETED
    
    | @@ -1 +0,0 @@ | |
| 1 | 
            -
            ../dalle_mini/
         | 
|  | |
|  | 
    	
        app/gradio/dalle_mini
    DELETED
    
    | @@ -1 +0,0 @@ | |
| 1 | 
            -
            ../../dalle_mini/
         | 
|  | |
|  | 
    	
        app/img/loading.gif
    ADDED
    
    |   | 
    	
        dalle_mini/text.py
    ADDED
    
    | @@ -0,0 +1,272 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Utilities for processing text.
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import requests
         | 
| 6 | 
            +
            from pathlib import Path
         | 
| 7 | 
            +
            from unidecode import unidecode
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import re, math, random, html
         | 
| 10 | 
            +
            import ftfy
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            WIKI_STATS_URL = "https://github.com/borisdayma/wikipedia-word-frequency/raw/feat-update/results/enwiki-20210820-words-frequency.txt"
         | 
| 13 | 
            +
            WIKI_STATS_LOCAL = Path(WIKI_STATS_URL).parts[-1]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # based on wiki word occurence
         | 
| 16 | 
            +
            person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
         | 
| 17 | 
            +
            temp_token = "xtokx"  # avoid repeating chars
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def get_wiki_file():
         | 
| 21 | 
            +
                if not Path(WIKI_STATS_LOCAL).exists():
         | 
| 22 | 
            +
                    r = requests.get(WIKI_STATS_URL, stream=True)
         | 
| 23 | 
            +
                    with open(WIKI_STATS_LOCAL, "wb") as fd:
         | 
| 24 | 
            +
                        for chunk in r.iter_content(chunk_size=128):
         | 
| 25 | 
            +
                            fd.write(chunk)
         | 
| 26 | 
            +
                return WIKI_STATS_LOCAL
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class HashtagProcessor:
         | 
| 30 | 
            +
                # Adapted from wordninja library
         | 
| 31 | 
            +
                # We use our wikipedia word count + a good heuristic to make it work
         | 
| 32 | 
            +
                def __init__(self):
         | 
| 33 | 
            +
                    self._word_cost = (
         | 
| 34 | 
            +
                        l.split()[0] for l in Path(get_wiki_file()).read_text().splitlines()
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
                    self._word_cost = {
         | 
| 37 | 
            +
                        str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
         | 
| 38 | 
            +
                    }
         | 
| 39 | 
            +
                    self._max_word = max(len(x) for x in self._word_cost.keys())
         | 
| 40 | 
            +
                    self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def __call__(self, s):
         | 
| 43 | 
            +
                    """Uses dynamic programming to infer the location of spaces in a string without spaces."""
         | 
| 44 | 
            +
                    l = [self._split(x) for x in self._SPLIT_RE.split(s)]
         | 
| 45 | 
            +
                    return " ".join([item for sublist in l for item in sublist])
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def _split(self, s):
         | 
| 48 | 
            +
                    # Find the best match for the i first characters, assuming cost has
         | 
| 49 | 
            +
                    # been built for the i-1 first characters.
         | 
| 50 | 
            +
                    # Returns a pair (match_cost, match_length).
         | 
| 51 | 
            +
                    def best_match(i):
         | 
| 52 | 
            +
                        candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
         | 
| 53 | 
            +
                        return min(
         | 
| 54 | 
            +
                            (c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
         | 
| 55 | 
            +
                            for k, c in candidates
         | 
| 56 | 
            +
                        )
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # Build the cost array
         | 
| 59 | 
            +
                    cost = [0]
         | 
| 60 | 
            +
                    for i in range(1, len(s) + 1):
         | 
| 61 | 
            +
                        c, k = best_match(i)
         | 
| 62 | 
            +
                        cost.append(c)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # Backtrack to recover the minimal-cost string.
         | 
| 65 | 
            +
                    out = []
         | 
| 66 | 
            +
                    i = len(s)
         | 
| 67 | 
            +
                    while i > 0:
         | 
| 68 | 
            +
                        c, k = best_match(i)
         | 
| 69 | 
            +
                        assert c == cost[i]
         | 
| 70 | 
            +
                        newToken = True
         | 
| 71 | 
            +
                        if not s[i - k : i] == "'":  # ignore a lone apostrophe
         | 
| 72 | 
            +
                            if len(out) > 0:
         | 
| 73 | 
            +
                                # re-attach split 's and split digits
         | 
| 74 | 
            +
                                if out[-1] == "'s" or (
         | 
| 75 | 
            +
                                    s[i - 1].isdigit() and out[-1][0].isdigit()
         | 
| 76 | 
            +
                                ):  # digit followed by digit
         | 
| 77 | 
            +
                                    out[-1] = (
         | 
| 78 | 
            +
                                        s[i - k : i] + out[-1]
         | 
| 79 | 
            +
                                    )  # combine current token with previous token
         | 
| 80 | 
            +
                                    newToken = False
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                        if newToken:
         | 
| 83 | 
            +
                            out.append(s[i - k : i])
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                        i -= k
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    return reversed(out)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def replace_person_token(t):
         | 
| 91 | 
            +
                "Used for CC12M"
         | 
| 92 | 
            +
                t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
         | 
| 93 | 
            +
                while "<person>" in t:
         | 
| 94 | 
            +
                    t = t.replace(
         | 
| 95 | 
            +
                        "<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
                return t
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            def fix_html(t):
         | 
| 101 | 
            +
                "Adapted from fastai"
         | 
| 102 | 
            +
                t = (
         | 
| 103 | 
            +
                    t.replace("#39;", "'")
         | 
| 104 | 
            +
                    .replace("&", "&")
         | 
| 105 | 
            +
                    .replace("amp;", "&")
         | 
| 106 | 
            +
                    .replace("#146;", "'")
         | 
| 107 | 
            +
                    .replace("nbsp;", " ")
         | 
| 108 | 
            +
                    .replace("#36;", "$")
         | 
| 109 | 
            +
                    .replace("\\n", "\n")
         | 
| 110 | 
            +
                    .replace("quot;", "'")
         | 
| 111 | 
            +
                    .replace("<br />", "\n")
         | 
| 112 | 
            +
                    .replace('\\"', '"')
         | 
| 113 | 
            +
                    .replace("<unk>", " ")
         | 
| 114 | 
            +
                    .replace(" @.@ ", ".")
         | 
| 115 | 
            +
                    .replace(" @-@ ", "-")
         | 
| 116 | 
            +
                )
         | 
| 117 | 
            +
                return html.unescape(t)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            def replace_punctuation_with_commas(t):
         | 
| 121 | 
            +
                return re.sub("""([()[\].,|:;?!=+~\-])""", ",", t)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            def simplify_quotes(t):
         | 
| 125 | 
            +
                return re.sub("""['"`]""", ' " ', t)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def merge_quotes(t):
         | 
| 129 | 
            +
                return re.sub('(\s*"+\s*)+', ' " ', t)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            def remove_comma_numbers(t):
         | 
| 133 | 
            +
                def _f(t):
         | 
| 134 | 
            +
                    return re.sub("(\d),(\d{3})", r"\1\2", t)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                return _f(_f(t))
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            def pre_process_dot_numbers(t):
         | 
| 140 | 
            +
                return re.sub("(\d)\.(\d)", fr"\1{temp_token}dot{temp_token}\2", t)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def post_process_dot_numbers(t):
         | 
| 144 | 
            +
                return re.sub(f"{temp_token}dot{temp_token}", ".", t)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            def pre_process_quotes(t):
         | 
| 148 | 
            +
                # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
         | 
| 149 | 
            +
                return re.sub(
         | 
| 150 | 
            +
                    r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", fr"{temp_token}quote{temp_token}", t
         | 
| 151 | 
            +
                )
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            def post_process_quotes(t):
         | 
| 155 | 
            +
                return re.sub(f"{temp_token}quote{temp_token}", "'", t)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def merge_commas(t):
         | 
| 159 | 
            +
                return re.sub("(\s*,+\s*)+", ", ", t)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            def add_space_after_commas(t):
         | 
| 163 | 
            +
                return re.sub(",", ", ", t)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            def handle_special_chars(t):
         | 
| 167 | 
            +
                "Handle special characters"
         | 
| 168 | 
            +
                # replace "-" with a space when between words without space
         | 
| 169 | 
            +
                t = re.sub("([a-zA-Z])-([a-zA-Z])", r"\1 \2", t)
         | 
| 170 | 
            +
                # always add space around &
         | 
| 171 | 
            +
                return re.sub("&", " & ", t)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            def expand_hashtags(t, hashtag_processor):
         | 
| 175 | 
            +
                "Remove # and try to split words"
         | 
| 176 | 
            +
                return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            _re_ignore_chars = """[_#\/\\%]"""
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            def ignore_chars(t):
         | 
| 183 | 
            +
                "Ignore useless characters"
         | 
| 184 | 
            +
                return re.sub(_re_ignore_chars, " ", t)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            def remove_extra_spaces(t):
         | 
| 188 | 
            +
                "Remove extra spaces (including \t and \n)"
         | 
| 189 | 
            +
                return re.sub("\s+", " ", t)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
             | 
| 192 | 
            +
            def remove_repeating_chars(t):
         | 
| 193 | 
            +
                "If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
         | 
| 194 | 
            +
                return re.sub(r"(\D)(\1{3,})", r"\1", t)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            def remove_urls(t):
         | 
| 198 | 
            +
                return re.sub(r"http\S+", "", t)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
            +
            def remove_html_tags(t):
         | 
| 202 | 
            +
                return re.sub("<[^<]+?>", "", t)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
             | 
| 205 | 
            +
            def remove_first_last_commas(t):
         | 
| 206 | 
            +
                t = t.strip()
         | 
| 207 | 
            +
                t = t[:-1] if t and t[-1] == "," else t
         | 
| 208 | 
            +
                t = t[1:] if t and t[0] == "," else t
         | 
| 209 | 
            +
                return t.strip()
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            def remove_wiki_ref(t):
         | 
| 213 | 
            +
                t = re.sub(r"\A\s*\[\d+\]", "", t)
         | 
| 214 | 
            +
                return re.sub(r"\[\d+\]\s*\Z", "", t)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
             | 
| 217 | 
            +
            class TextNormalizer:
         | 
| 218 | 
            +
                "Normalize text"
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                def __init__(self):
         | 
| 221 | 
            +
                    self._hashtag_processor = HashtagProcessor()
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                def __call__(self, t, clip=False):
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    # fix some characters
         | 
| 226 | 
            +
                    t = ftfy.fix_text(t)
         | 
| 227 | 
            +
                    # fix html
         | 
| 228 | 
            +
                    t = fix_html(t)
         | 
| 229 | 
            +
                    if not clip:
         | 
| 230 | 
            +
                        # decode and simplify text: see unidecode library
         | 
| 231 | 
            +
                        t = unidecode(t)
         | 
| 232 | 
            +
                    # lower case
         | 
| 233 | 
            +
                    t = t.lower()
         | 
| 234 | 
            +
                    # replace <PERSON> (for CC12M)
         | 
| 235 | 
            +
                    t = replace_person_token(t)
         | 
| 236 | 
            +
                    # remove wiki reference (for WIT)
         | 
| 237 | 
            +
                    t = remove_wiki_ref(t)
         | 
| 238 | 
            +
                    # remove html tags
         | 
| 239 | 
            +
                    t = remove_html_tags(t)
         | 
| 240 | 
            +
                    # remove urls
         | 
| 241 | 
            +
                    t = remove_urls(t)
         | 
| 242 | 
            +
                    # remove commas in numbers
         | 
| 243 | 
            +
                    t = remove_comma_numbers(t)
         | 
| 244 | 
            +
                    if not clip:
         | 
| 245 | 
            +
                        # handle dots in numbers and quotes - Part 1
         | 
| 246 | 
            +
                        t = pre_process_dot_numbers(t)
         | 
| 247 | 
            +
                        t = pre_process_quotes(t)
         | 
| 248 | 
            +
                        # handle special characters
         | 
| 249 | 
            +
                        t = handle_special_chars(t)
         | 
| 250 | 
            +
                        # handle hashtags
         | 
| 251 | 
            +
                        t = expand_hashtags(t, self._hashtag_processor)
         | 
| 252 | 
            +
                        # ignore useless characters
         | 
| 253 | 
            +
                        t = ignore_chars(t)
         | 
| 254 | 
            +
                        # simplify quotes
         | 
| 255 | 
            +
                        t = simplify_quotes(t)
         | 
| 256 | 
            +
                        # all punctuation becomes commas
         | 
| 257 | 
            +
                        t = replace_punctuation_with_commas(t)
         | 
| 258 | 
            +
                        # handle dots in numbers and quotes - Part 2
         | 
| 259 | 
            +
                        t = post_process_dot_numbers(t)
         | 
| 260 | 
            +
                        t = post_process_quotes(t)
         | 
| 261 | 
            +
                        # handle repeating characters
         | 
| 262 | 
            +
                        t = remove_repeating_chars(t)
         | 
| 263 | 
            +
                        # merge commas
         | 
| 264 | 
            +
                        t = merge_commas(t)
         | 
| 265 | 
            +
                        # merge quotes
         | 
| 266 | 
            +
                        t = merge_quotes(t)
         | 
| 267 | 
            +
                    # remove multiple spaces
         | 
| 268 | 
            +
                    t = remove_extra_spaces(t)
         | 
| 269 | 
            +
                    # remove first and last comma
         | 
| 270 | 
            +
                    t = remove_first_last_commas(t)
         | 
| 271 | 
            +
                    # always start with a space
         | 
| 272 | 
            +
                    return f" {t}" if not clip else t
         | 
    	
        dev/README.md
    ADDED
    
    | @@ -0,0 +1,122 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Development Instructions for TPU
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            ## Setup
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            - Apply to the [TRC program](https://sites.research.google/trc/) for free TPU credits if you're elligible.
         | 
| 6 | 
            +
            - Follow the [Cloud TPU VM User's Guide](https://cloud.google.com/tpu/docs/users-guide-tpu-vm) to set up gcloud.
         | 
| 7 | 
            +
            - Verify `gcloud config list`, in particular account, project & zone.
         | 
| 8 | 
            +
            - Create a TPU VM per the guide and connect to it.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            When needing a larger disk:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            - Create a balanced persistent disk (SSD, so pricier than default HDD but much faster): `gcloud compute disks create DISK_NAME --size SIZE_IN_GB --type pd-balanced`
         | 
| 13 | 
            +
            - Attach the disk to your instance by adding `--data-disk source=REF` per ["Adding a persistent disk to a TPU VM" guide](https://cloud.google.com/tpu/docs/setup-persistent-disk), eg `gcloud alpha compute tpus tpu-vm create INSTANCE_NAME --accelerator-type=v3-8 --version=v2-alpha --data-disk source=projects/tpu-toys/zones/europe-west4-a/disks/DISK_NAME`
         | 
| 14 | 
            +
            - Format the partition as described in the guide.
         | 
| 15 | 
            +
            - Make sure to set up automatic remount of disk at restart.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            ## Connect VS Code
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            - Find external IP in the UI or with `gcloud alpha compute tpus tpu-vm describe INSTANCE_NAME`
         | 
| 20 | 
            +
            - Verify you can connect in terminal with `ssh EXTERNAL_IP -i ~/.ssh/google_compute_engine`
         | 
| 21 | 
            +
            - Add the same command as ssh host in VS Code.
         | 
| 22 | 
            +
            - Check config file
         | 
| 23 | 
            +
             | 
| 24 | 
            +
              ```
         | 
| 25 | 
            +
              Host INSTANCE_NAME
         | 
| 26 | 
            +
                HostName EXTERNAL_IP
         | 
| 27 | 
            +
                IdentityFile ~/.ssh/google_compute_engine
         | 
| 28 | 
            +
              ```
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            ## Environment configuration
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            ### Use virtual environments (optional)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            We recommend using virtual environments (such as conda, venv or pyenv-virtualenv).
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            If you want to use `pyenv` and `pyenv-virtualenv`:
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            - Installation
         | 
| 39 | 
            +
             | 
| 40 | 
            +
              - [Set up build environment](https://github.com/pyenv/pyenv/wiki#suggested-build-environment)
         | 
| 41 | 
            +
              - Use [pyenv-installer](https://github.com/pyenv/pyenv-installer): `curl https://pyenv.run | bash`
         | 
| 42 | 
            +
              - bash set-up:
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                ```bash
         | 
| 45 | 
            +
                echo '\n'\
         | 
| 46 | 
            +
                    '# pyenv setup \n'\
         | 
| 47 | 
            +
                    'export PYENV_ROOT="$HOME/.pyenv" \n'\
         | 
| 48 | 
            +
                    'export PATH="$PYENV_ROOT/bin:$PATH" \n'\
         | 
| 49 | 
            +
                    'eval "$(pyenv init --path)" \n'\
         | 
| 50 | 
            +
                    'eval "$(pyenv init -)" \n'\
         | 
| 51 | 
            +
                    'eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc
         | 
| 52 | 
            +
                ```
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            - Usage
         | 
| 55 | 
            +
             | 
| 56 | 
            +
              - Install a python version: `pyenv install X.X.X`
         | 
| 57 | 
            +
              - Create a virtual environment: `pyenv virtualenv 3.9.6 dalle_env`
         | 
| 58 | 
            +
              - Activate: `pyenv activate dalle_env`
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                Note: you can auto-activate your environment at a location with `echo dalle_env >> .python-version`
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            ### Tools
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            - Git
         | 
| 65 | 
            +
             | 
| 66 | 
            +
              - `git config --global user.email "[email protected]"
         | 
| 67 | 
            +
              - `git config --global user.name "First Last"
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            - Github CLI
         | 
| 70 | 
            +
             | 
| 71 | 
            +
              - See [installation instructions](https://github.com/cli/cli/blob/trunk/docs/install_linux.md)
         | 
| 72 | 
            +
              - `gh auth login`
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            - Direnv
         | 
| 75 | 
            +
             | 
| 76 | 
            +
              - Install direnv: `sudo apt-get update && sudo apt-get install direnv`
         | 
| 77 | 
            +
              - bash set-up:
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                ```bash
         | 
| 80 | 
            +
                echo -e '\n'\
         | 
| 81 | 
            +
                    '# direnv setup \n'\
         | 
| 82 | 
            +
                    'eval "$(direnv hook bash)" \n' >> ~/.bashrc
         | 
| 83 | 
            +
                ```
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            ### Set up repo
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            - Clone repo: `gh repo clone borisdayma/dalle-mini`
         | 
| 88 | 
            +
            - If using `pyenv-virtualenv`, auto-activate env: `echo dalle_env >> .python-version`
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            ## Environment
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            - Install the following (use it later to update our dev requirements.txt)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            ```
         | 
| 95 | 
            +
            requests
         | 
| 96 | 
            +
            pillow
         | 
| 97 | 
            +
            jupyterlab
         | 
| 98 | 
            +
            ipywidgets
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            -e ../datasets[streaming]
         | 
| 101 | 
            +
            -e ../transformers
         | 
| 102 | 
            +
            -e ../webdataset
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            # JAX
         | 
| 105 | 
            +
            --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
         | 
| 106 | 
            +
            jax[tpu]>=0.2.16
         | 
| 107 | 
            +
            flax
         | 
| 108 | 
            +
            ```
         | 
| 109 | 
            +
             | 
| 110 | 
            +
            - `transformers-cli login`
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            ---
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            - set `HF_HOME="/mnt/disks/persist/cache/huggingface"` in `/etc/environment` and ensure you have required permissions, then restart.
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            ## Working with datasets or models
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            - Install [Git LFS](https://github.com/git-lfs/git-lfs/wiki/Installation)
         | 
| 119 | 
            +
            - Clone a dataset without large files: `GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/.../...`
         | 
| 120 | 
            +
            - Use a local [credential store](https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage) for caching credentials
         | 
| 121 | 
            +
            - Track specific extentions: `git lfs track "*.ext"`
         | 
| 122 | 
            +
            - See files tracked with LFS with `git lfs ls-files`
         | 
    	
        dev/encoding/vqgan-jax-encoding-streaming.ipynb
    ADDED
    
    | @@ -0,0 +1,562 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "markdown",
         | 
| 5 | 
            +
               "id": "d0b72877",
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "source": [
         | 
| 8 | 
            +
                "# VQGAN JAX Encoding for 🤗 Datasets in streaming mode"
         | 
| 9 | 
            +
               ]
         | 
| 10 | 
            +
              },
         | 
| 11 | 
            +
              {
         | 
| 12 | 
            +
               "cell_type": "markdown",
         | 
| 13 | 
            +
               "id": "ba7b31e6",
         | 
| 14 | 
            +
               "metadata": {},
         | 
| 15 | 
            +
               "source": [
         | 
| 16 | 
            +
                "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and 🤗 Datasets in streaming mode.\n",
         | 
| 17 | 
            +
                "\n",
         | 
| 18 | 
            +
                "This example uses our YFCC100M dataset, but it should be easy to adapt to any other image/caption dataset in the huggingface hub."
         | 
| 19 | 
            +
               ]
         | 
| 20 | 
            +
              },
         | 
| 21 | 
            +
              {
         | 
| 22 | 
            +
               "cell_type": "code",
         | 
| 23 | 
            +
               "execution_count": null,
         | 
| 24 | 
            +
               "id": "3b59489e",
         | 
| 25 | 
            +
               "metadata": {},
         | 
| 26 | 
            +
               "outputs": [],
         | 
| 27 | 
            +
               "source": [
         | 
| 28 | 
            +
                "import io\n",
         | 
| 29 | 
            +
                "\n",
         | 
| 30 | 
            +
                "import requests\n",
         | 
| 31 | 
            +
                "from PIL import Image\n",
         | 
| 32 | 
            +
                "import numpy as np\n",
         | 
| 33 | 
            +
                "from tqdm import tqdm\n",
         | 
| 34 | 
            +
                "\n",
         | 
| 35 | 
            +
                "import torch\n",
         | 
| 36 | 
            +
                "import torchvision.transforms as T\n",
         | 
| 37 | 
            +
                "import torchvision.transforms.functional as TF\n",
         | 
| 38 | 
            +
                "from torchvision.transforms import InterpolationMode\n",
         | 
| 39 | 
            +
                "import os\n",
         | 
| 40 | 
            +
                "\n",
         | 
| 41 | 
            +
                "import jax\n",
         | 
| 42 | 
            +
                "from jax import pmap"
         | 
| 43 | 
            +
               ]
         | 
| 44 | 
            +
              },
         | 
| 45 | 
            +
              {
         | 
| 46 | 
            +
               "cell_type": "markdown",
         | 
| 47 | 
            +
               "id": "c7c4c1e6",
         | 
| 48 | 
            +
               "metadata": {},
         | 
| 49 | 
            +
               "source": [
         | 
| 50 | 
            +
                "## Dataset and Parameters"
         | 
| 51 | 
            +
               ]
         | 
| 52 | 
            +
              },
         | 
| 53 | 
            +
              {
         | 
| 54 | 
            +
               "cell_type": "code",
         | 
| 55 | 
            +
               "execution_count": null,
         | 
| 56 | 
            +
               "id": "d45a289e",
         | 
| 57 | 
            +
               "metadata": {},
         | 
| 58 | 
            +
               "outputs": [],
         | 
| 59 | 
            +
               "source": [
         | 
| 60 | 
            +
                "import datasets\n",
         | 
| 61 | 
            +
                "from datasets import Dataset, load_dataset"
         | 
| 62 | 
            +
               ]
         | 
| 63 | 
            +
              },
         | 
| 64 | 
            +
              {
         | 
| 65 | 
            +
               "cell_type": "markdown",
         | 
| 66 | 
            +
               "id": "f26e4f18",
         | 
| 67 | 
            +
               "metadata": {},
         | 
| 68 | 
            +
               "source": [
         | 
| 69 | 
            +
                "We'll use the `validation` set for testing. Adjust accordingly."
         | 
| 70 | 
            +
               ]
         | 
| 71 | 
            +
              },
         | 
| 72 | 
            +
              {
         | 
| 73 | 
            +
               "cell_type": "code",
         | 
| 74 | 
            +
               "execution_count": null,
         | 
| 75 | 
            +
               "id": "28893c3e",
         | 
| 76 | 
            +
               "metadata": {},
         | 
| 77 | 
            +
               "outputs": [],
         | 
| 78 | 
            +
               "source": [
         | 
| 79 | 
            +
                "dataset = load_dataset('dalle-mini/YFCC100M_OpenAI_subset', use_auth_token=True, streaming=True, split='validation')"
         | 
| 80 | 
            +
               ]
         | 
| 81 | 
            +
              },
         | 
| 82 | 
            +
              {
         | 
| 83 | 
            +
               "cell_type": "code",
         | 
| 84 | 
            +
               "execution_count": null,
         | 
| 85 | 
            +
               "id": "33861477",
         | 
| 86 | 
            +
               "metadata": {},
         | 
| 87 | 
            +
               "outputs": [],
         | 
| 88 | 
            +
               "source": [
         | 
| 89 | 
            +
                "from pathlib import Path\n",
         | 
| 90 | 
            +
                "\n",
         | 
| 91 | 
            +
                "yfcc100m = Path.home()/'data'/'YFCC100M_OpenAI_subset'\n",
         | 
| 92 | 
            +
                "yfcc100m_output = yfcc100m/'encoded'      # Output directory for encoded files"
         | 
| 93 | 
            +
               ]
         | 
| 94 | 
            +
              },
         | 
| 95 | 
            +
              {
         | 
| 96 | 
            +
               "cell_type": "code",
         | 
| 97 | 
            +
               "execution_count": null,
         | 
| 98 | 
            +
               "id": "6e7b71c4",
         | 
| 99 | 
            +
               "metadata": {},
         | 
| 100 | 
            +
               "outputs": [],
         | 
| 101 | 
            +
               "source": [
         | 
| 102 | 
            +
                "batch_size = 128     # Per device\n",
         | 
| 103 | 
            +
                "num_workers = 16     # Unused in streaming mode"
         | 
| 104 | 
            +
               ]
         | 
| 105 | 
            +
              },
         | 
| 106 | 
            +
              {
         | 
| 107 | 
            +
               "cell_type": "markdown",
         | 
| 108 | 
            +
               "id": "0793c26a",
         | 
| 109 | 
            +
               "metadata": {},
         | 
| 110 | 
            +
               "source": [
         | 
| 111 | 
            +
                "### Data preparation"
         | 
| 112 | 
            +
               ]
         | 
| 113 | 
            +
              },
         | 
| 114 | 
            +
              {
         | 
| 115 | 
            +
               "cell_type": "markdown",
         | 
| 116 | 
            +
               "id": "86415769",
         | 
| 117 | 
            +
               "metadata": {},
         | 
| 118 | 
            +
               "source": [
         | 
| 119 | 
            +
                "* Images: we transform them so they are center-cropped and square, all of the same size so we can build batches for TPU/GPU processing.\n",
         | 
| 120 | 
            +
                "* Captions: we extract a single `caption` column from the source data, by concatenating the cleaned title and description.\n",
         | 
| 121 | 
            +
                "\n",
         | 
| 122 | 
            +
                "These transformations are done using the Datasets `map` function. In the case of streaming datasets, transformations will run as needed instead of pre-processing the dataset at once."
         | 
| 123 | 
            +
               ]
         | 
| 124 | 
            +
              },
         | 
| 125 | 
            +
              {
         | 
| 126 | 
            +
               "cell_type": "markdown",
         | 
| 127 | 
            +
               "id": "0fdf1851",
         | 
| 128 | 
            +
               "metadata": {},
         | 
| 129 | 
            +
               "source": [
         | 
| 130 | 
            +
                "This helper function is used to decode images from the bytes retrieved in `streaming` mode."
         | 
| 131 | 
            +
               ]
         | 
| 132 | 
            +
              },
         | 
| 133 | 
            +
              {
         | 
| 134 | 
            +
               "cell_type": "code",
         | 
| 135 | 
            +
               "execution_count": null,
         | 
| 136 | 
            +
               "id": "5bbca804",
         | 
| 137 | 
            +
               "metadata": {},
         | 
| 138 | 
            +
               "outputs": [],
         | 
| 139 | 
            +
               "source": [
         | 
| 140 | 
            +
                "from PIL import Image\n",
         | 
| 141 | 
            +
                "import io\n",
         | 
| 142 | 
            +
                "\n",
         | 
| 143 | 
            +
                "def get_image(byte_stream):\n",
         | 
| 144 | 
            +
                "    image = Image.open(io.BytesIO(byte_stream))\n",
         | 
| 145 | 
            +
                "    return image.convert('RGB')"
         | 
| 146 | 
            +
               ]
         | 
| 147 | 
            +
              },
         | 
| 148 | 
            +
              {
         | 
| 149 | 
            +
               "cell_type": "markdown",
         | 
| 150 | 
            +
               "id": "b435290b",
         | 
| 151 | 
            +
               "metadata": {},
         | 
| 152 | 
            +
               "source": [
         | 
| 153 | 
            +
                "Image processing"
         | 
| 154 | 
            +
               ]
         | 
| 155 | 
            +
              },
         | 
| 156 | 
            +
              {
         | 
| 157 | 
            +
               "cell_type": "code",
         | 
| 158 | 
            +
               "execution_count": null,
         | 
| 159 | 
            +
               "id": "7e73dfa3",
         | 
| 160 | 
            +
               "metadata": {},
         | 
| 161 | 
            +
               "outputs": [],
         | 
| 162 | 
            +
               "source": [
         | 
| 163 | 
            +
                "def center_crop(image, max_size=256):\n",
         | 
| 164 | 
            +
                "    # Note: we allow upscaling too. We should exclude small images.    \n",
         | 
| 165 | 
            +
                "    image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
         | 
| 166 | 
            +
                "    image = TF.center_crop(image, output_size=2 * [max_size])\n",
         | 
| 167 | 
            +
                "    return image\n",
         | 
| 168 | 
            +
                "\n",
         | 
| 169 | 
            +
                "preprocess_image = T.Compose([\n",
         | 
| 170 | 
            +
                "    get_image,\n",
         | 
| 171 | 
            +
                "    center_crop,\n",
         | 
| 172 | 
            +
                "    T.ToTensor(),\n",
         | 
| 173 | 
            +
                "    lambda t: t.permute(1, 2, 0)   # Reorder, we need dimensions last\n",
         | 
| 174 | 
            +
                "])"
         | 
| 175 | 
            +
               ]
         | 
| 176 | 
            +
              },
         | 
| 177 | 
            +
              {
         | 
| 178 | 
            +
               "cell_type": "markdown",
         | 
| 179 | 
            +
               "id": "1e3ac8de",
         | 
| 180 | 
            +
               "metadata": {},
         | 
| 181 | 
            +
               "source": [
         | 
| 182 | 
            +
                "Caption preparation"
         | 
| 183 | 
            +
               ]
         | 
| 184 | 
            +
              },
         | 
| 185 | 
            +
              {
         | 
| 186 | 
            +
               "cell_type": "code",
         | 
| 187 | 
            +
               "execution_count": null,
         | 
| 188 | 
            +
               "id": "aadb4d23",
         | 
| 189 | 
            +
               "metadata": {},
         | 
| 190 | 
            +
               "outputs": [],
         | 
| 191 | 
            +
               "source": [
         | 
| 192 | 
            +
                "import string\n",
         | 
| 193 | 
            +
                "\n",
         | 
| 194 | 
            +
                "def create_caption(title, description):\n",
         | 
| 195 | 
            +
                "    title = title.strip()\n",
         | 
| 196 | 
            +
                "    description = description.strip()\n",
         | 
| 197 | 
            +
                "    if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
         | 
| 198 | 
            +
                "    return f'{title} {description}'"
         | 
| 199 | 
            +
               ]
         | 
| 200 | 
            +
              },
         | 
| 201 | 
            +
              {
         | 
| 202 | 
            +
               "cell_type": "markdown",
         | 
| 203 | 
            +
               "id": "3c4522b9",
         | 
| 204 | 
            +
               "metadata": {},
         | 
| 205 | 
            +
               "source": [
         | 
| 206 | 
            +
                "And this is the basic transformation function to use in `map`. We don't really need the `key`, but we'll keep it for reference. Since we are returning a new dictionary (as opposed to adding entries to the input), this also removes any metadata columns we don't need."
         | 
| 207 | 
            +
               ]
         | 
| 208 | 
            +
              },
         | 
| 209 | 
            +
              {
         | 
| 210 | 
            +
               "cell_type": "code",
         | 
| 211 | 
            +
               "execution_count": null,
         | 
| 212 | 
            +
               "id": "2566ff68",
         | 
| 213 | 
            +
               "metadata": {},
         | 
| 214 | 
            +
               "outputs": [],
         | 
| 215 | 
            +
               "source": [
         | 
| 216 | 
            +
                "def prepare_item(item):\n",
         | 
| 217 | 
            +
                "    return {\n",
         | 
| 218 | 
            +
                "        'key': item['key'],\n",
         | 
| 219 | 
            +
                "        'caption': create_caption(item['title_clean'], item['description_clean']),\n",
         | 
| 220 | 
            +
                "        'image': preprocess_image(item['img'])\n",
         | 
| 221 | 
            +
                "    }"
         | 
| 222 | 
            +
               ]
         | 
| 223 | 
            +
              },
         | 
| 224 | 
            +
              {
         | 
| 225 | 
            +
               "cell_type": "markdown",
         | 
| 226 | 
            +
               "id": "e519e475",
         | 
| 227 | 
            +
               "metadata": {},
         | 
| 228 | 
            +
               "source": [
         | 
| 229 | 
            +
                "Unlike when using non-streaming datasets, the following operation completes immediately in streaming mode. In streaming mode, `num_proc` is not supported."
         | 
| 230 | 
            +
               ]
         | 
| 231 | 
            +
              },
         | 
| 232 | 
            +
              {
         | 
| 233 | 
            +
               "cell_type": "code",
         | 
| 234 | 
            +
               "execution_count": null,
         | 
| 235 | 
            +
               "id": "10d7750e",
         | 
| 236 | 
            +
               "metadata": {},
         | 
| 237 | 
            +
               "outputs": [],
         | 
| 238 | 
            +
               "source": [
         | 
| 239 | 
            +
                "prepared_dataset = dataset.map(prepare_item, batched=False)"
         | 
| 240 | 
            +
               ]
         | 
| 241 | 
            +
              },
         | 
| 242 | 
            +
              {
         | 
| 243 | 
            +
               "cell_type": "code",
         | 
| 244 | 
            +
               "execution_count": null,
         | 
| 245 | 
            +
               "id": "a8595539",
         | 
| 246 | 
            +
               "metadata": {},
         | 
| 247 | 
            +
               "outputs": [],
         | 
| 248 | 
            +
               "source": [
         | 
| 249 | 
            +
                "%%time\n",
         | 
| 250 | 
            +
                "item = next(iter(prepared_dataset))"
         | 
| 251 | 
            +
               ]
         | 
| 252 | 
            +
              },
         | 
| 253 | 
            +
              {
         | 
| 254 | 
            +
               "cell_type": "code",
         | 
| 255 | 
            +
               "execution_count": null,
         | 
| 256 | 
            +
               "id": "04a6eeb4",
         | 
| 257 | 
            +
               "metadata": {},
         | 
| 258 | 
            +
               "outputs": [],
         | 
| 259 | 
            +
               "source": [
         | 
| 260 | 
            +
                "assert(list(item.keys()) == ['key', 'caption', 'image'])"
         | 
| 261 | 
            +
               ]
         | 
| 262 | 
            +
              },
         | 
| 263 | 
            +
              {
         | 
| 264 | 
            +
               "cell_type": "code",
         | 
| 265 | 
            +
               "execution_count": null,
         | 
| 266 | 
            +
               "id": "40d3115f",
         | 
| 267 | 
            +
               "metadata": {},
         | 
| 268 | 
            +
               "outputs": [],
         | 
| 269 | 
            +
               "source": [
         | 
| 270 | 
            +
                "item['image'].shape"
         | 
| 271 | 
            +
               ]
         | 
| 272 | 
            +
              },
         | 
| 273 | 
            +
              {
         | 
| 274 | 
            +
               "cell_type": "code",
         | 
| 275 | 
            +
               "execution_count": null,
         | 
| 276 | 
            +
               "id": "dd844e1c",
         | 
| 277 | 
            +
               "metadata": {},
         | 
| 278 | 
            +
               "outputs": [],
         | 
| 279 | 
            +
               "source": [
         | 
| 280 | 
            +
                "T.ToPILImage()(item['image'].permute(2, 0, 1))"
         | 
| 281 | 
            +
               ]
         | 
| 282 | 
            +
              },
         | 
| 283 | 
            +
              {
         | 
| 284 | 
            +
               "cell_type": "markdown",
         | 
| 285 | 
            +
               "id": "44d50a51",
         | 
| 286 | 
            +
               "metadata": {},
         | 
| 287 | 
            +
               "source": [
         | 
| 288 | 
            +
                "### Torch DataLoader"
         | 
| 289 | 
            +
               ]
         | 
| 290 | 
            +
              },
         | 
| 291 | 
            +
              {
         | 
| 292 | 
            +
               "cell_type": "markdown",
         | 
| 293 | 
            +
               "id": "17a4bbc6",
         | 
| 294 | 
            +
               "metadata": {},
         | 
| 295 | 
            +
               "source": [
         | 
| 296 | 
            +
                "We'll create a PyTorch DataLoader for convenience. This allows us to easily take batches of our desired size.\n",
         | 
| 297 | 
            +
                "\n",
         | 
| 298 | 
            +
                "We won't be using parallel processing of the DataLoader for now, as the items will be retrieved on the fly. We could attempt to do it using these recommendations: https://pytorch.org/docs/stable/data.html#multi-process-data-loading. For performance considerations, please refer to this thread: https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13"
         | 
| 299 | 
            +
               ]
         | 
| 300 | 
            +
              },
         | 
| 301 | 
            +
              {
         | 
| 302 | 
            +
               "cell_type": "code",
         | 
| 303 | 
            +
               "execution_count": null,
         | 
| 304 | 
            +
               "id": "e1c08b7e",
         | 
| 305 | 
            +
               "metadata": {},
         | 
| 306 | 
            +
               "outputs": [],
         | 
| 307 | 
            +
               "source": [
         | 
| 308 | 
            +
                "import torch\n",
         | 
| 309 | 
            +
                "from torch.utils.data import DataLoader"
         | 
| 310 | 
            +
               ]
         | 
| 311 | 
            +
              },
         | 
| 312 | 
            +
              {
         | 
| 313 | 
            +
               "cell_type": "code",
         | 
| 314 | 
            +
               "execution_count": null,
         | 
| 315 | 
            +
               "id": "6a296677",
         | 
| 316 | 
            +
               "metadata": {},
         | 
| 317 | 
            +
               "outputs": [],
         | 
| 318 | 
            +
               "source": [
         | 
| 319 | 
            +
                "torch_dataset = prepared_dataset.with_format(\"torch\")"
         | 
| 320 | 
            +
               ]
         | 
| 321 | 
            +
              },
         | 
| 322 | 
            +
              {
         | 
| 323 | 
            +
               "cell_type": "markdown",
         | 
| 324 | 
            +
               "id": "29ab13bc",
         | 
| 325 | 
            +
               "metadata": {},
         | 
| 326 | 
            +
               "source": [
         | 
| 327 | 
            +
                "**Note**: according to my tests, `num_workers` is not compatible with Datasets in streaming mode. Processes deadlock and there's no progress."
         | 
| 328 | 
            +
               ]
         | 
| 329 | 
            +
              },
         | 
| 330 | 
            +
              {
         | 
| 331 | 
            +
               "cell_type": "code",
         | 
| 332 | 
            +
               "execution_count": null,
         | 
| 333 | 
            +
               "id": "e2df5e13",
         | 
| 334 | 
            +
               "metadata": {},
         | 
| 335 | 
            +
               "outputs": [],
         | 
| 336 | 
            +
               "source": [
         | 
| 337 | 
            +
                "dataloader = DataLoader(torch_dataset, batch_size=batch_size * jax.device_count())"
         | 
| 338 | 
            +
               ]
         | 
| 339 | 
            +
              },
         | 
| 340 | 
            +
              {
         | 
| 341 | 
            +
               "cell_type": "code",
         | 
| 342 | 
            +
               "execution_count": null,
         | 
| 343 | 
            +
               "id": "c15e3783",
         | 
| 344 | 
            +
               "metadata": {},
         | 
| 345 | 
            +
               "outputs": [],
         | 
| 346 | 
            +
               "source": [
         | 
| 347 | 
            +
                "batch = next(iter(dataloader))"
         | 
| 348 | 
            +
               ]
         | 
| 349 | 
            +
              },
         | 
| 350 | 
            +
              {
         | 
| 351 | 
            +
               "cell_type": "code",
         | 
| 352 | 
            +
               "execution_count": null,
         | 
| 353 | 
            +
               "id": "71d027fe",
         | 
| 354 | 
            +
               "metadata": {},
         | 
| 355 | 
            +
               "outputs": [],
         | 
| 356 | 
            +
               "source": [
         | 
| 357 | 
            +
                "batch['image'].shape"
         | 
| 358 | 
            +
               ]
         | 
| 359 | 
            +
              },
         | 
| 360 | 
            +
              {
         | 
| 361 | 
            +
               "cell_type": "markdown",
         | 
| 362 | 
            +
               "id": "a354472b",
         | 
| 363 | 
            +
               "metadata": {},
         | 
| 364 | 
            +
               "source": [
         | 
| 365 | 
            +
                "## VQGAN-JAX model"
         | 
| 366 | 
            +
               ]
         | 
| 367 | 
            +
              },
         | 
| 368 | 
            +
              {
         | 
| 369 | 
            +
               "cell_type": "code",
         | 
| 370 | 
            +
               "execution_count": null,
         | 
| 371 | 
            +
               "id": "2fcf01d7",
         | 
| 372 | 
            +
               "metadata": {},
         | 
| 373 | 
            +
               "outputs": [],
         | 
| 374 | 
            +
               "source": [
         | 
| 375 | 
            +
                "from vqgan_jax.modeling_flax_vqgan import VQModel"
         | 
| 376 | 
            +
               ]
         | 
| 377 | 
            +
              },
         | 
| 378 | 
            +
              {
         | 
| 379 | 
            +
               "cell_type": "markdown",
         | 
| 380 | 
            +
               "id": "9daa636d",
         | 
| 381 | 
            +
               "metadata": {},
         | 
| 382 | 
            +
               "source": [
         | 
| 383 | 
            +
                "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
         | 
| 384 | 
            +
               ]
         | 
| 385 | 
            +
              },
         | 
| 386 | 
            +
              {
         | 
| 387 | 
            +
               "cell_type": "code",
         | 
| 388 | 
            +
               "execution_count": null,
         | 
| 389 | 
            +
               "id": "47a8b818",
         | 
| 390 | 
            +
               "metadata": {
         | 
| 391 | 
            +
                "scrolled": true
         | 
| 392 | 
            +
               },
         | 
| 393 | 
            +
               "outputs": [],
         | 
| 394 | 
            +
               "source": [
         | 
| 395 | 
            +
                "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
         | 
| 396 | 
            +
               ]
         | 
| 397 | 
            +
              },
         | 
| 398 | 
            +
              {
         | 
| 399 | 
            +
               "cell_type": "markdown",
         | 
| 400 | 
            +
               "id": "62ad01c3",
         | 
| 401 | 
            +
               "metadata": {},
         | 
| 402 | 
            +
               "source": [
         | 
| 403 | 
            +
                "## Encoding"
         | 
| 404 | 
            +
               ]
         | 
| 405 | 
            +
              },
         | 
| 406 | 
            +
              {
         | 
| 407 | 
            +
               "cell_type": "markdown",
         | 
| 408 | 
            +
               "id": "20357f74",
         | 
| 409 | 
            +
               "metadata": {},
         | 
| 410 | 
            +
               "source": [
         | 
| 411 | 
            +
                "Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
         | 
| 412 | 
            +
               ]
         | 
| 413 | 
            +
              },
         | 
| 414 | 
            +
              {
         | 
| 415 | 
            +
               "cell_type": "code",
         | 
| 416 | 
            +
               "execution_count": null,
         | 
| 417 | 
            +
               "id": "6686b004",
         | 
| 418 | 
            +
               "metadata": {},
         | 
| 419 | 
            +
               "outputs": [],
         | 
| 420 | 
            +
               "source": [
         | 
| 421 | 
            +
                "from flax.training.common_utils import shard\n",
         | 
| 422 | 
            +
                "from functools import partial"
         | 
| 423 | 
            +
               ]
         | 
| 424 | 
            +
              },
         | 
| 425 | 
            +
              {
         | 
| 426 | 
            +
               "cell_type": "code",
         | 
| 427 | 
            +
               "execution_count": null,
         | 
| 428 | 
            +
               "id": "322a4619",
         | 
| 429 | 
            +
               "metadata": {},
         | 
| 430 | 
            +
               "outputs": [],
         | 
| 431 | 
            +
               "source": [
         | 
| 432 | 
            +
                "@partial(jax.pmap, axis_name=\"batch\")\n",
         | 
| 433 | 
            +
                "def encode(batch):\n",
         | 
| 434 | 
            +
                "    # Not sure if we should `replicate` params, does not seem to have any effect\n",
         | 
| 435 | 
            +
                "    _, indices = model.encode(batch)\n",
         | 
| 436 | 
            +
                "    return indices"
         | 
| 437 | 
            +
               ]
         | 
| 438 | 
            +
              },
         | 
| 439 | 
            +
              {
         | 
| 440 | 
            +
               "cell_type": "markdown",
         | 
| 441 | 
            +
               "id": "14375a41",
         | 
| 442 | 
            +
               "metadata": {},
         | 
| 443 | 
            +
               "source": [
         | 
| 444 | 
            +
                "### Encoding loop"
         | 
| 445 | 
            +
               ]
         | 
| 446 | 
            +
              },
         | 
| 447 | 
            +
              {
         | 
| 448 | 
            +
               "cell_type": "code",
         | 
| 449 | 
            +
               "execution_count": null,
         | 
| 450 | 
            +
               "id": "ff6c10d4",
         | 
| 451 | 
            +
               "metadata": {},
         | 
| 452 | 
            +
               "outputs": [],
         | 
| 453 | 
            +
               "source": [
         | 
| 454 | 
            +
                "import os\n",
         | 
| 455 | 
            +
                "import pandas as pd\n",
         | 
| 456 | 
            +
                "\n",
         | 
| 457 | 
            +
                "def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
         | 
| 458 | 
            +
                "    output_dir.mkdir(parents=True, exist_ok=True)\n",
         | 
| 459 | 
            +
                "        \n",
         | 
| 460 | 
            +
                "    # Saving strategy:\n",
         | 
| 461 | 
            +
                "    # - Create a new file every so often to prevent excessive file seeking.\n",
         | 
| 462 | 
            +
                "    # - Save each batch after processing.\n",
         | 
| 463 | 
            +
                "    # - Keep the file open until we are done with it.\n",
         | 
| 464 | 
            +
                "    file = None        \n",
         | 
| 465 | 
            +
                "    for n, batch in enumerate(tqdm(iter(dataloader))):\n",
         | 
| 466 | 
            +
                "        if (n % save_every == 0):\n",
         | 
| 467 | 
            +
                "            if file is not None:\n",
         | 
| 468 | 
            +
                "                file.close()\n",
         | 
| 469 | 
            +
                "            split_num = n // save_every\n",
         | 
| 470 | 
            +
                "            file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
         | 
| 471 | 
            +
                "\n",
         | 
| 472 | 
            +
                "        images = batch[\"image\"].numpy()\n",
         | 
| 473 | 
            +
                "        images = shard(images.squeeze())\n",
         | 
| 474 | 
            +
                "        encoded = encode(images)\n",
         | 
| 475 | 
            +
                "        encoded = encoded.reshape(-1, encoded.shape[-1])\n",
         | 
| 476 | 
            +
                "\n",
         | 
| 477 | 
            +
                "        keys = batch[\"key\"]\n",
         | 
| 478 | 
            +
                "        captions = batch[\"caption\"]\n",
         | 
| 479 | 
            +
                "\n",
         | 
| 480 | 
            +
                "        encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
         | 
| 481 | 
            +
                "        batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
         | 
| 482 | 
            +
                "        batch_df.to_json(file, orient='records', lines=True)"
         | 
| 483 | 
            +
               ]
         | 
| 484 | 
            +
              },
         | 
| 485 | 
            +
              {
         | 
| 486 | 
            +
               "cell_type": "markdown",
         | 
| 487 | 
            +
               "id": "09ff75a3",
         | 
| 488 | 
            +
               "metadata": {},
         | 
| 489 | 
            +
               "source": [
         | 
| 490 | 
            +
                "Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
         | 
| 491 | 
            +
               ]
         | 
| 492 | 
            +
              },
         | 
| 493 | 
            +
              {
         | 
| 494 | 
            +
               "cell_type": "code",
         | 
| 495 | 
            +
               "execution_count": null,
         | 
| 496 | 
            +
               "id": "96222bb4",
         | 
| 497 | 
            +
               "metadata": {},
         | 
| 498 | 
            +
               "outputs": [],
         | 
| 499 | 
            +
               "source": [
         | 
| 500 | 
            +
                "save_every = 318"
         | 
| 501 | 
            +
               ]
         | 
| 502 | 
            +
              },
         | 
| 503 | 
            +
              {
         | 
| 504 | 
            +
               "cell_type": "code",
         | 
| 505 | 
            +
               "execution_count": null,
         | 
| 506 | 
            +
               "id": "7704863d",
         | 
| 507 | 
            +
               "metadata": {},
         | 
| 508 | 
            +
               "outputs": [
         | 
| 509 | 
            +
                {
         | 
| 510 | 
            +
                 "name": "stderr",
         | 
| 511 | 
            +
                 "output_type": "stream",
         | 
| 512 | 
            +
                 "text": [
         | 
| 513 | 
            +
                  "28it [01:17,  1.60s/it]"
         | 
| 514 | 
            +
                 ]
         | 
| 515 | 
            +
                }
         | 
| 516 | 
            +
               ],
         | 
| 517 | 
            +
               "source": [
         | 
| 518 | 
            +
                "encode_captioned_dataset(dataloader, yfcc100m_output, save_every=save_every)"
         | 
| 519 | 
            +
               ]
         | 
| 520 | 
            +
              },
         | 
| 521 | 
            +
              {
         | 
| 522 | 
            +
               "cell_type": "markdown",
         | 
| 523 | 
            +
               "id": "e266a70a",
         | 
| 524 | 
            +
               "metadata": {},
         | 
| 525 | 
            +
               "source": [
         | 
| 526 | 
            +
                "This is ~10-15 slower than local encoding from an SSD. For performance considerations, see the discussion at https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13."
         | 
| 527 | 
            +
               ]
         | 
| 528 | 
            +
              },
         | 
| 529 | 
            +
              {
         | 
| 530 | 
            +
               "cell_type": "markdown",
         | 
| 531 | 
            +
               "id": "8953dd84",
         | 
| 532 | 
            +
               "metadata": {},
         | 
| 533 | 
            +
               "source": [
         | 
| 534 | 
            +
                "----"
         | 
| 535 | 
            +
               ]
         | 
| 536 | 
            +
              }
         | 
| 537 | 
            +
             ],
         | 
| 538 | 
            +
             "metadata": {
         | 
| 539 | 
            +
              "interpreter": {
         | 
| 540 | 
            +
               "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
         | 
| 541 | 
            +
              },
         | 
| 542 | 
            +
              "kernelspec": {
         | 
| 543 | 
            +
               "display_name": "Python 3 (ipykernel)",
         | 
| 544 | 
            +
               "language": "python",
         | 
| 545 | 
            +
               "name": "python3"
         | 
| 546 | 
            +
              },
         | 
| 547 | 
            +
              "language_info": {
         | 
| 548 | 
            +
               "codemirror_mode": {
         | 
| 549 | 
            +
                "name": "ipython",
         | 
| 550 | 
            +
                "version": 3
         | 
| 551 | 
            +
               },
         | 
| 552 | 
            +
               "file_extension": ".py",
         | 
| 553 | 
            +
               "mimetype": "text/x-python",
         | 
| 554 | 
            +
               "name": "python",
         | 
| 555 | 
            +
               "nbconvert_exporter": "python",
         | 
| 556 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 557 | 
            +
               "version": "3.8.10"
         | 
| 558 | 
            +
              }
         | 
| 559 | 
            +
             },
         | 
| 560 | 
            +
             "nbformat": 4,
         | 
| 561 | 
            +
             "nbformat_minor": 5
         | 
| 562 | 
            +
            }
         | 
    	
        dev/encoding/vqgan-jax-encoding-webdataset.ipynb
    ADDED
    
    | @@ -0,0 +1,461 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "markdown",
         | 
| 5 | 
            +
               "id": "d0b72877",
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "source": [
         | 
| 8 | 
            +
                "# VQGAN JAX Encoding for `webdataset`"
         | 
| 9 | 
            +
               ]
         | 
| 10 | 
            +
              },
         | 
| 11 | 
            +
              {
         | 
| 12 | 
            +
               "cell_type": "markdown",
         | 
| 13 | 
            +
               "id": "ba7b31e6",
         | 
| 14 | 
            +
               "metadata": {},
         | 
| 15 | 
            +
               "source": [
         | 
| 16 | 
            +
                "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
         | 
| 17 | 
            +
                "\n",
         | 
| 18 | 
            +
                "This example uses a small subset of YFCC100M we created for testing, but it should be easy to adapt to any other image/caption dataset in the `webdataset` format."
         | 
| 19 | 
            +
               ]
         | 
| 20 | 
            +
              },
         | 
| 21 | 
            +
              {
         | 
| 22 | 
            +
               "cell_type": "code",
         | 
| 23 | 
            +
               "execution_count": null,
         | 
| 24 | 
            +
               "id": "3b59489e",
         | 
| 25 | 
            +
               "metadata": {},
         | 
| 26 | 
            +
               "outputs": [],
         | 
| 27 | 
            +
               "source": [
         | 
| 28 | 
            +
                "import numpy as np\n",
         | 
| 29 | 
            +
                "from tqdm import tqdm\n",
         | 
| 30 | 
            +
                "\n",
         | 
| 31 | 
            +
                "import torch\n",
         | 
| 32 | 
            +
                "import torchvision.transforms as T\n",
         | 
| 33 | 
            +
                "import torchvision.transforms.functional as TF\n",
         | 
| 34 | 
            +
                "from torchvision.transforms import InterpolationMode\n",
         | 
| 35 | 
            +
                "import math\n",
         | 
| 36 | 
            +
                "\n",
         | 
| 37 | 
            +
                "import webdataset as wds\n",
         | 
| 38 | 
            +
                "\n",
         | 
| 39 | 
            +
                "import jax\n",
         | 
| 40 | 
            +
                "from jax import pmap"
         | 
| 41 | 
            +
               ]
         | 
| 42 | 
            +
              },
         | 
| 43 | 
            +
              {
         | 
| 44 | 
            +
               "cell_type": "markdown",
         | 
| 45 | 
            +
               "id": "c7c4c1e6",
         | 
| 46 | 
            +
               "metadata": {},
         | 
| 47 | 
            +
               "source": [
         | 
| 48 | 
            +
                "## Dataset and Parameters"
         | 
| 49 | 
            +
               ]
         | 
| 50 | 
            +
              },
         | 
| 51 | 
            +
              {
         | 
| 52 | 
            +
               "cell_type": "markdown",
         | 
| 53 | 
            +
               "id": "9822850f",
         | 
| 54 | 
            +
               "metadata": {},
         | 
| 55 | 
            +
               "source": [
         | 
| 56 | 
            +
                "The following is the list of shards we'll process. We hardcode the length of data so that we can see nice progress bars using `tqdm`."
         | 
| 57 | 
            +
               ]
         | 
| 58 | 
            +
              },
         | 
| 59 | 
            +
              {
         | 
| 60 | 
            +
               "cell_type": "code",
         | 
| 61 | 
            +
               "execution_count": null,
         | 
| 62 | 
            +
               "id": "1265dbfe",
         | 
| 63 | 
            +
               "metadata": {},
         | 
| 64 | 
            +
               "outputs": [],
         | 
| 65 | 
            +
               "source": [
         | 
| 66 | 
            +
                "shards = 'https://huggingface.co/datasets/dalle-mini/YFCC100M_OpenAI_subset/resolve/main/data/shard-{0000..0008}.tar'\n",
         | 
| 67 | 
            +
                "length = 8320"
         | 
| 68 | 
            +
               ]
         | 
| 69 | 
            +
              },
         | 
| 70 | 
            +
              {
         | 
| 71 | 
            +
               "cell_type": "markdown",
         | 
| 72 | 
            +
               "id": "7e38fa14",
         | 
| 73 | 
            +
               "metadata": {},
         | 
| 74 | 
            +
               "source": [
         | 
| 75 | 
            +
                "If we are extra cautious or our server is unreliable, we can enable retries by providing a custom `curl` retrieval command:"
         | 
| 76 | 
            +
               ]
         | 
| 77 | 
            +
              },
         | 
| 78 | 
            +
              {
         | 
| 79 | 
            +
               "cell_type": "code",
         | 
| 80 | 
            +
               "execution_count": null,
         | 
| 81 | 
            +
               "id": "4c8c5960",
         | 
| 82 | 
            +
               "metadata": {},
         | 
| 83 | 
            +
               "outputs": [],
         | 
| 84 | 
            +
               "source": [
         | 
| 85 | 
            +
                "# Enable curl retries to try to work around temporary network / server errors.\n",
         | 
| 86 | 
            +
                "# This shouldn't be necessary when using reliable servers.\n",
         | 
| 87 | 
            +
                "# shards = f'pipe:curl -s --retry 5 --retry-delay 5 -L {shards} || true'"
         | 
| 88 | 
            +
               ]
         | 
| 89 | 
            +
              },
         | 
| 90 | 
            +
              {
         | 
| 91 | 
            +
               "cell_type": "code",
         | 
| 92 | 
            +
               "execution_count": null,
         | 
| 93 | 
            +
               "id": "13c6631b",
         | 
| 94 | 
            +
               "metadata": {},
         | 
| 95 | 
            +
               "outputs": [],
         | 
| 96 | 
            +
               "source": [
         | 
| 97 | 
            +
                "from pathlib import Path\n",
         | 
| 98 | 
            +
                "\n",
         | 
| 99 | 
            +
                "# Output directory for encoded files\n",
         | 
| 100 | 
            +
                "encoded_output = Path.home()/'data'/'wds'/'encoded'\n",
         | 
| 101 | 
            +
                "\n",
         | 
| 102 | 
            +
                "batch_size = 128           # Per device\n",
         | 
| 103 | 
            +
                "num_workers = 8            # For parallel processing"
         | 
| 104 | 
            +
               ]
         | 
| 105 | 
            +
              },
         | 
| 106 | 
            +
              {
         | 
| 107 | 
            +
               "cell_type": "code",
         | 
| 108 | 
            +
               "execution_count": null,
         | 
| 109 | 
            +
               "id": "3435fb85",
         | 
| 110 | 
            +
               "metadata": {},
         | 
| 111 | 
            +
               "outputs": [],
         | 
| 112 | 
            +
               "source": [
         | 
| 113 | 
            +
                "bs = batch_size * jax.device_count()    # You can use a smaller size while testing\n",
         | 
| 114 | 
            +
                "batches = math.ceil(length / bs)"
         | 
| 115 | 
            +
               ]
         | 
| 116 | 
            +
              },
         | 
| 117 | 
            +
              {
         | 
| 118 | 
            +
               "cell_type": "markdown",
         | 
| 119 | 
            +
               "id": "88598e4b",
         | 
| 120 | 
            +
               "metadata": {},
         | 
| 121 | 
            +
               "source": [
         | 
| 122 | 
            +
                "Image processing"
         | 
| 123 | 
            +
               ]
         | 
| 124 | 
            +
              },
         | 
| 125 | 
            +
              {
         | 
| 126 | 
            +
               "cell_type": "code",
         | 
| 127 | 
            +
               "execution_count": null,
         | 
| 128 | 
            +
               "id": "669b35df",
         | 
| 129 | 
            +
               "metadata": {},
         | 
| 130 | 
            +
               "outputs": [],
         | 
| 131 | 
            +
               "source": [
         | 
| 132 | 
            +
                "def center_crop(image, max_size=256):\n",
         | 
| 133 | 
            +
                "    # Note: we allow upscaling too. We should exclude small images.    \n",
         | 
| 134 | 
            +
                "    image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
         | 
| 135 | 
            +
                "    image = TF.center_crop(image, output_size=2 * [max_size])\n",
         | 
| 136 | 
            +
                "    return image\n",
         | 
| 137 | 
            +
                "\n",
         | 
| 138 | 
            +
                "preprocess_image = T.Compose([\n",
         | 
| 139 | 
            +
                "    center_crop,\n",
         | 
| 140 | 
            +
                "    T.ToTensor(),\n",
         | 
| 141 | 
            +
                "    lambda t: t.permute(1, 2, 0)   # Reorder, we need dimensions last\n",
         | 
| 142 | 
            +
                "])"
         | 
| 143 | 
            +
               ]
         | 
| 144 | 
            +
              },
         | 
| 145 | 
            +
              {
         | 
| 146 | 
            +
               "cell_type": "markdown",
         | 
| 147 | 
            +
               "id": "a185e90c",
         | 
| 148 | 
            +
               "metadata": {},
         | 
| 149 | 
            +
               "source": [
         | 
| 150 | 
            +
                "Caption preparation.\n",
         | 
| 151 | 
            +
                "\n",
         | 
| 152 | 
            +
                "Note that we receive the contents of the `json` structure, which will be replaced by the string we return.\n",
         | 
| 153 | 
            +
                "If we want to keep other fields inside `json`, we can add `caption` as a new field."
         | 
| 154 | 
            +
               ]
         | 
| 155 | 
            +
              },
         | 
| 156 | 
            +
              {
         | 
| 157 | 
            +
               "cell_type": "code",
         | 
| 158 | 
            +
               "execution_count": null,
         | 
| 159 | 
            +
               "id": "423ee10e",
         | 
| 160 | 
            +
               "metadata": {},
         | 
| 161 | 
            +
               "outputs": [],
         | 
| 162 | 
            +
               "source": [
         | 
| 163 | 
            +
                "def create_caption(item):\n",
         | 
| 164 | 
            +
                "    title = item['title_clean'].strip()\n",
         | 
| 165 | 
            +
                "    description = item['description_clean'].strip()\n",
         | 
| 166 | 
            +
                "    if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
         | 
| 167 | 
            +
                "    return f'{title} {description}'"
         | 
| 168 | 
            +
               ]
         | 
| 169 | 
            +
              },
         | 
| 170 | 
            +
              {
         | 
| 171 | 
            +
               "cell_type": "markdown",
         | 
| 172 | 
            +
               "id": "8d3a95db",
         | 
| 173 | 
            +
               "metadata": {},
         | 
| 174 | 
            +
               "source": [
         | 
| 175 | 
            +
                "When an error occurs (a download is disconnected, an image cannot be decoded, etc) the process stops with an exception. We can use one of the exception handlers provided by the `webdataset` library, such as `wds.warn_and_continue` or `wds.ignore_and_continue` to ignore the offending entry and keep iterating.\n",
         | 
| 176 | 
            +
                "\n",
         | 
| 177 | 
            +
                "**IMPORTANT WARNING:** Do not use error handlers to ignore exceptions until you have tested that your processing pipeline works fine. Otherwise, the process will continue trying to find a valid entry, and it will consume your whole dataset without doing any work.\n",
         | 
| 178 | 
            +
                "\n",
         | 
| 179 | 
            +
                "We can also create our custom exception handler as demonstrated here:"
         | 
| 180 | 
            +
               ]
         | 
| 181 | 
            +
              },
         | 
| 182 | 
            +
              {
         | 
| 183 | 
            +
               "cell_type": "code",
         | 
| 184 | 
            +
               "execution_count": null,
         | 
| 185 | 
            +
               "id": "369d9719",
         | 
| 186 | 
            +
               "metadata": {},
         | 
| 187 | 
            +
               "outputs": [],
         | 
| 188 | 
            +
               "source": [
         | 
| 189 | 
            +
                "# UNUSED - Log exceptions to a file\n",
         | 
| 190 | 
            +
                "def ignore_and_log(exn):\n",
         | 
| 191 | 
            +
                "    with open('errors.txt', 'a') as f:\n",
         | 
| 192 | 
            +
                "        f.write(f'{repr(exn)}\\n')\n",
         | 
| 193 | 
            +
                "    return True"
         | 
| 194 | 
            +
               ]
         | 
| 195 | 
            +
              },
         | 
| 196 | 
            +
              {
         | 
| 197 | 
            +
               "cell_type": "code",
         | 
| 198 | 
            +
               "execution_count": null,
         | 
| 199 | 
            +
               "id": "27de1414",
         | 
| 200 | 
            +
               "metadata": {},
         | 
| 201 | 
            +
               "outputs": [],
         | 
| 202 | 
            +
               "source": [
         | 
| 203 | 
            +
                "# Or simply use `wds.ignore_and_continue`\n",
         | 
| 204 | 
            +
                "exception_handler = wds.warn_and_continue"
         | 
| 205 | 
            +
               ]
         | 
| 206 | 
            +
              },
         | 
| 207 | 
            +
              {
         | 
| 208 | 
            +
               "cell_type": "code",
         | 
| 209 | 
            +
               "execution_count": null,
         | 
| 210 | 
            +
               "id": "5149b6d5",
         | 
| 211 | 
            +
               "metadata": {},
         | 
| 212 | 
            +
               "outputs": [],
         | 
| 213 | 
            +
               "source": [
         | 
| 214 | 
            +
                "dataset = wds.WebDataset(shards,\n",
         | 
| 215 | 
            +
                "                         length=batches,              # Hint so `len` is implemented\n",
         | 
| 216 | 
            +
                "                         shardshuffle=False,          # Keep same order for encoded files for easier bookkeeping. Set to `True` for training.\n",
         | 
| 217 | 
            +
                "                         handler=exception_handler,   # Ignore read errors instead of failing.\n",
         | 
| 218 | 
            +
                ")\n",
         | 
| 219 | 
            +
                "\n",
         | 
| 220 | 
            +
                "dataset = (dataset           \n",
         | 
| 221 | 
            +
                "      .decode('pil')                     # decode image with PIL\n",
         | 
| 222 | 
            +
                "#       .map_dict(jpg=preprocess_image, json=create_caption, handler=exception_handler)    # Process fields with functions defined above\n",
         | 
| 223 | 
            +
                "      .map_dict(jpg=preprocess_image, json=create_caption)    # Process fields with functions defined above\n",
         | 
| 224 | 
            +
                "      .to_tuple('__key__', 'jpg', 'json') # filter to keep only key (for reference), image, caption.\n",
         | 
| 225 | 
            +
                "      .batched(bs))                      # better to batch in the dataset (but we could also do it in the dataloader) - this arg does not affect speed and we could remove it"
         | 
| 226 | 
            +
               ]
         | 
| 227 | 
            +
              },
         | 
| 228 | 
            +
              {
         | 
| 229 | 
            +
               "cell_type": "code",
         | 
| 230 | 
            +
               "execution_count": null,
         | 
| 231 | 
            +
               "id": "8cac98cb",
         | 
| 232 | 
            +
               "metadata": {
         | 
| 233 | 
            +
                "scrolled": true
         | 
| 234 | 
            +
               },
         | 
| 235 | 
            +
               "outputs": [],
         | 
| 236 | 
            +
               "source": [
         | 
| 237 | 
            +
                "%%time\n",
         | 
| 238 | 
            +
                "keys, images, captions = next(iter(dataset))"
         | 
| 239 | 
            +
               ]
         | 
| 240 | 
            +
              },
         | 
| 241 | 
            +
              {
         | 
| 242 | 
            +
               "cell_type": "code",
         | 
| 243 | 
            +
               "execution_count": null,
         | 
| 244 | 
            +
               "id": "cd268fbf",
         | 
| 245 | 
            +
               "metadata": {},
         | 
| 246 | 
            +
               "outputs": [],
         | 
| 247 | 
            +
               "source": [
         | 
| 248 | 
            +
                "images.shape"
         | 
| 249 | 
            +
               ]
         | 
| 250 | 
            +
              },
         | 
| 251 | 
            +
              {
         | 
| 252 | 
            +
               "cell_type": "code",
         | 
| 253 | 
            +
               "execution_count": null,
         | 
| 254 | 
            +
               "id": "c24693c0",
         | 
| 255 | 
            +
               "metadata": {},
         | 
| 256 | 
            +
               "outputs": [],
         | 
| 257 | 
            +
               "source": [
         | 
| 258 | 
            +
                "T.ToPILImage()(images[0].permute(2, 0, 1))"
         | 
| 259 | 
            +
               ]
         | 
| 260 | 
            +
              },
         | 
| 261 | 
            +
              {
         | 
| 262 | 
            +
               "cell_type": "markdown",
         | 
| 263 | 
            +
               "id": "44d50a51",
         | 
| 264 | 
            +
               "metadata": {},
         | 
| 265 | 
            +
               "source": [
         | 
| 266 | 
            +
                "### Torch DataLoader"
         | 
| 267 | 
            +
               ]
         | 
| 268 | 
            +
              },
         | 
| 269 | 
            +
              {
         | 
| 270 | 
            +
               "cell_type": "code",
         | 
| 271 | 
            +
               "execution_count": null,
         | 
| 272 | 
            +
               "id": "e2df5e13",
         | 
| 273 | 
            +
               "metadata": {},
         | 
| 274 | 
            +
               "outputs": [],
         | 
| 275 | 
            +
               "source": [
         | 
| 276 | 
            +
                "dl = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=num_workers)"
         | 
| 277 | 
            +
               ]
         | 
| 278 | 
            +
              },
         | 
| 279 | 
            +
              {
         | 
| 280 | 
            +
               "cell_type": "markdown",
         | 
| 281 | 
            +
               "id": "a354472b",
         | 
| 282 | 
            +
               "metadata": {},
         | 
| 283 | 
            +
               "source": [
         | 
| 284 | 
            +
                "## VQGAN-JAX model"
         | 
| 285 | 
            +
               ]
         | 
| 286 | 
            +
              },
         | 
| 287 | 
            +
              {
         | 
| 288 | 
            +
               "cell_type": "code",
         | 
| 289 | 
            +
               "execution_count": null,
         | 
| 290 | 
            +
               "id": "2fcf01d7",
         | 
| 291 | 
            +
               "metadata": {},
         | 
| 292 | 
            +
               "outputs": [],
         | 
| 293 | 
            +
               "source": [
         | 
| 294 | 
            +
                "from vqgan_jax.modeling_flax_vqgan import VQModel"
         | 
| 295 | 
            +
               ]
         | 
| 296 | 
            +
              },
         | 
| 297 | 
            +
              {
         | 
| 298 | 
            +
               "cell_type": "markdown",
         | 
| 299 | 
            +
               "id": "9daa636d",
         | 
| 300 | 
            +
               "metadata": {},
         | 
| 301 | 
            +
               "source": [
         | 
| 302 | 
            +
                "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
         | 
| 303 | 
            +
               ]
         | 
| 304 | 
            +
              },
         | 
| 305 | 
            +
              {
         | 
| 306 | 
            +
               "cell_type": "code",
         | 
| 307 | 
            +
               "execution_count": null,
         | 
| 308 | 
            +
               "id": "47a8b818",
         | 
| 309 | 
            +
               "metadata": {
         | 
| 310 | 
            +
                "scrolled": true
         | 
| 311 | 
            +
               },
         | 
| 312 | 
            +
               "outputs": [],
         | 
| 313 | 
            +
               "source": [
         | 
| 314 | 
            +
                "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
         | 
| 315 | 
            +
               ]
         | 
| 316 | 
            +
              },
         | 
| 317 | 
            +
              {
         | 
| 318 | 
            +
               "cell_type": "markdown",
         | 
| 319 | 
            +
               "id": "62ad01c3",
         | 
| 320 | 
            +
               "metadata": {},
         | 
| 321 | 
            +
               "source": [
         | 
| 322 | 
            +
                "## Encoding"
         | 
| 323 | 
            +
               ]
         | 
| 324 | 
            +
              },
         | 
| 325 | 
            +
              {
         | 
| 326 | 
            +
               "cell_type": "markdown",
         | 
| 327 | 
            +
               "id": "20357f74",
         | 
| 328 | 
            +
               "metadata": {},
         | 
| 329 | 
            +
               "source": [
         | 
| 330 | 
            +
                "Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
         | 
| 331 | 
            +
               ]
         | 
| 332 | 
            +
              },
         | 
| 333 | 
            +
              {
         | 
| 334 | 
            +
               "cell_type": "code",
         | 
| 335 | 
            +
               "execution_count": null,
         | 
| 336 | 
            +
               "id": "6686b004",
         | 
| 337 | 
            +
               "metadata": {},
         | 
| 338 | 
            +
               "outputs": [],
         | 
| 339 | 
            +
               "source": [
         | 
| 340 | 
            +
                "from flax.training.common_utils import shard\n",
         | 
| 341 | 
            +
                "from functools import partial"
         | 
| 342 | 
            +
               ]
         | 
| 343 | 
            +
              },
         | 
| 344 | 
            +
              {
         | 
| 345 | 
            +
               "cell_type": "code",
         | 
| 346 | 
            +
               "execution_count": null,
         | 
| 347 | 
            +
               "id": "322a4619",
         | 
| 348 | 
            +
               "metadata": {},
         | 
| 349 | 
            +
               "outputs": [],
         | 
| 350 | 
            +
               "source": [
         | 
| 351 | 
            +
                "@partial(jax.pmap, axis_name=\"batch\")\n",
         | 
| 352 | 
            +
                "def encode(batch):\n",
         | 
| 353 | 
            +
                "    # Not sure if we should `replicate` params, does not seem to have any effect\n",
         | 
| 354 | 
            +
                "    _, indices = model.encode(batch)\n",
         | 
| 355 | 
            +
                "    return indices"
         | 
| 356 | 
            +
               ]
         | 
| 357 | 
            +
              },
         | 
| 358 | 
            +
              {
         | 
| 359 | 
            +
               "cell_type": "markdown",
         | 
| 360 | 
            +
               "id": "14375a41",
         | 
| 361 | 
            +
               "metadata": {},
         | 
| 362 | 
            +
               "source": [
         | 
| 363 | 
            +
                "### Encoding loop"
         | 
| 364 | 
            +
               ]
         | 
| 365 | 
            +
              },
         | 
| 366 | 
            +
              {
         | 
| 367 | 
            +
               "cell_type": "code",
         | 
| 368 | 
            +
               "execution_count": null,
         | 
| 369 | 
            +
               "id": "ff6c10d4",
         | 
| 370 | 
            +
               "metadata": {},
         | 
| 371 | 
            +
               "outputs": [],
         | 
| 372 | 
            +
               "source": [
         | 
| 373 | 
            +
                "import os\n",
         | 
| 374 | 
            +
                "import pandas as pd\n",
         | 
| 375 | 
            +
                "\n",
         | 
| 376 | 
            +
                "def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
         | 
| 377 | 
            +
                "    output_dir.mkdir(parents=True, exist_ok=True)\n",
         | 
| 378 | 
            +
                "\n",
         | 
| 379 | 
            +
                "    # Saving strategy:\n",
         | 
| 380 | 
            +
                "    # - Create a new file every so often to prevent excessive file seeking.\n",
         | 
| 381 | 
            +
                "    # - Save each batch after processing.\n",
         | 
| 382 | 
            +
                "    # - Keep the file open until we are done with it.\n",
         | 
| 383 | 
            +
                "    file = None        \n",
         | 
| 384 | 
            +
                "    for n, (keys, images, captions) in enumerate(tqdm(dataloader)):\n",
         | 
| 385 | 
            +
                "        if (n % save_every == 0):\n",
         | 
| 386 | 
            +
                "            if file is not None:\n",
         | 
| 387 | 
            +
                "                file.close()\n",
         | 
| 388 | 
            +
                "            split_num = n // save_every\n",
         | 
| 389 | 
            +
                "            file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
         | 
| 390 | 
            +
                "\n",
         | 
| 391 | 
            +
                "        images = shard(images.numpy().squeeze())\n",
         | 
| 392 | 
            +
                "        encoded = encode(images)\n",
         | 
| 393 | 
            +
                "        encoded = encoded.reshape(-1, encoded.shape[-1])\n",
         | 
| 394 | 
            +
                "\n",
         | 
| 395 | 
            +
                "        encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
         | 
| 396 | 
            +
                "        batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
         | 
| 397 | 
            +
                "        batch_df.to_json(file, orient='records', lines=True)"
         | 
| 398 | 
            +
               ]
         | 
| 399 | 
            +
              },
         | 
| 400 | 
            +
              {
         | 
| 401 | 
            +
               "cell_type": "markdown",
         | 
| 402 | 
            +
               "id": "09ff75a3",
         | 
| 403 | 
            +
               "metadata": {},
         | 
| 404 | 
            +
               "source": [
         | 
| 405 | 
            +
                "Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
         | 
| 406 | 
            +
               ]
         | 
| 407 | 
            +
              },
         | 
| 408 | 
            +
              {
         | 
| 409 | 
            +
               "cell_type": "code",
         | 
| 410 | 
            +
               "execution_count": null,
         | 
| 411 | 
            +
               "id": "96222bb4",
         | 
| 412 | 
            +
               "metadata": {},
         | 
| 413 | 
            +
               "outputs": [],
         | 
| 414 | 
            +
               "source": [
         | 
| 415 | 
            +
                "save_every = 318"
         | 
| 416 | 
            +
               ]
         | 
| 417 | 
            +
              },
         | 
| 418 | 
            +
              {
         | 
| 419 | 
            +
               "cell_type": "code",
         | 
| 420 | 
            +
               "execution_count": null,
         | 
| 421 | 
            +
               "id": "7704863d",
         | 
| 422 | 
            +
               "metadata": {},
         | 
| 423 | 
            +
               "outputs": [],
         | 
| 424 | 
            +
               "source": [
         | 
| 425 | 
            +
                "encode_captioned_dataset(dl, encoded_output, save_every=save_every)"
         | 
| 426 | 
            +
               ]
         | 
| 427 | 
            +
              },
         | 
| 428 | 
            +
              {
         | 
| 429 | 
            +
               "cell_type": "markdown",
         | 
| 430 | 
            +
               "id": "8953dd84",
         | 
| 431 | 
            +
               "metadata": {},
         | 
| 432 | 
            +
               "source": [
         | 
| 433 | 
            +
                "----"
         | 
| 434 | 
            +
               ]
         | 
| 435 | 
            +
              }
         | 
| 436 | 
            +
             ],
         | 
| 437 | 
            +
             "metadata": {
         | 
| 438 | 
            +
              "interpreter": {
         | 
| 439 | 
            +
               "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
         | 
| 440 | 
            +
              },
         | 
| 441 | 
            +
              "kernelspec": {
         | 
| 442 | 
            +
               "display_name": "Python 3 (ipykernel)",
         | 
| 443 | 
            +
               "language": "python",
         | 
| 444 | 
            +
               "name": "python3"
         | 
| 445 | 
            +
              },
         | 
| 446 | 
            +
              "language_info": {
         | 
| 447 | 
            +
               "codemirror_mode": {
         | 
| 448 | 
            +
                "name": "ipython",
         | 
| 449 | 
            +
                "version": 3
         | 
| 450 | 
            +
               },
         | 
| 451 | 
            +
               "file_extension": ".py",
         | 
| 452 | 
            +
               "mimetype": "text/x-python",
         | 
| 453 | 
            +
               "name": "python",
         | 
| 454 | 
            +
               "nbconvert_exporter": "python",
         | 
| 455 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 456 | 
            +
               "version": "3.8.10"
         | 
| 457 | 
            +
              }
         | 
| 458 | 
            +
             },
         | 
| 459 | 
            +
             "nbformat": 4,
         | 
| 460 | 
            +
             "nbformat_minor": 5
         | 
| 461 | 
            +
            }
         | 
    	
        dev/inference/dalle_mini
    DELETED
    
    | @@ -1 +0,0 @@ | |
| 1 | 
            -
            ../../dalle_mini
         | 
|  | |
|  | 
    	
        dev/inference/inference_pipeline.ipynb
    CHANGED
    
    | @@ -6,7 +6,7 @@ | |
| 6 | 
             
                  "name": "DALL·E mini - Inference pipeline.ipynb",
         | 
| 7 | 
             
                  "provenance": [],
         | 
| 8 | 
             
                  "collapsed_sections": [],
         | 
| 9 | 
            -
                  "authorship_tag": " | 
| 10 | 
             
                  "include_colab_link": true
         | 
| 11 | 
             
                },
         | 
| 12 | 
             
                "kernelspec": {
         | 
| @@ -22,6 +22,7 @@ | |
| 22 | 
             
                    "49304912717a4995ae45d04a59d1f50e": {
         | 
| 23 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 24 | 
             
                      "model_name": "HBoxModel",
         | 
|  | |
| 25 | 
             
                      "state": {
         | 
| 26 | 
             
                        "_view_name": "HBoxView",
         | 
| 27 | 
             
                        "_dom_classes": [],
         | 
| @@ -42,6 +43,7 @@ | |
| 42 | 
             
                    "5fd9f97986024e8db560a6737ade9e2e": {
         | 
| 43 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 44 | 
             
                      "model_name": "LayoutModel",
         | 
|  | |
| 45 | 
             
                      "state": {
         | 
| 46 | 
             
                        "_view_name": "LayoutView",
         | 
| 47 | 
             
                        "grid_template_rows": null,
         | 
| @@ -93,6 +95,7 @@ | |
| 93 | 
             
                    "caced43e3a4c493b98fb07cb41db045c": {
         | 
| 94 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 95 | 
             
                      "model_name": "FloatProgressModel",
         | 
|  | |
| 96 | 
             
                      "state": {
         | 
| 97 | 
             
                        "_view_name": "ProgressView",
         | 
| 98 | 
             
                        "style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
         | 
| @@ -116,6 +119,7 @@ | |
| 116 | 
             
                    "0acc161f2e9948b68b3fc4e57ef333c9": {
         | 
| 117 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 118 | 
             
                      "model_name": "HTMLModel",
         | 
|  | |
| 119 | 
             
                      "state": {
         | 
| 120 | 
             
                        "_view_name": "HTMLView",
         | 
| 121 | 
             
                        "style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
         | 
| @@ -136,6 +140,7 @@ | |
| 136 | 
             
                    "40c54b9454d346aabd197f2bcf189467": {
         | 
| 137 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 138 | 
             
                      "model_name": "ProgressStyleModel",
         | 
|  | |
| 139 | 
             
                      "state": {
         | 
| 140 | 
             
                        "_view_name": "StyleView",
         | 
| 141 | 
             
                        "_model_name": "ProgressStyleModel",
         | 
| @@ -151,6 +156,7 @@ | |
| 151 | 
             
                    "8b25334a48244a14aa9ba0176887e655": {
         | 
| 152 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 153 | 
             
                      "model_name": "LayoutModel",
         | 
|  | |
| 154 | 
             
                      "state": {
         | 
| 155 | 
             
                        "_view_name": "LayoutView",
         | 
| 156 | 
             
                        "grid_template_rows": null,
         | 
| @@ -202,6 +208,7 @@ | |
| 202 | 
             
                    "7e7c488f57fc4acb8d261e2db81d61f0": {
         | 
| 203 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 204 | 
             
                      "model_name": "DescriptionStyleModel",
         | 
|  | |
| 205 | 
             
                      "state": {
         | 
| 206 | 
             
                        "_view_name": "StyleView",
         | 
| 207 | 
             
                        "_model_name": "DescriptionStyleModel",
         | 
| @@ -216,6 +223,7 @@ | |
| 216 | 
             
                    "72c401062a5348b1a366dffb5a403568": {
         | 
| 217 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 218 | 
             
                      "model_name": "LayoutModel",
         | 
|  | |
| 219 | 
             
                      "state": {
         | 
| 220 | 
             
                        "_view_name": "LayoutView",
         | 
| 221 | 
             
                        "grid_template_rows": null,
         | 
| @@ -267,6 +275,7 @@ | |
| 267 | 
             
                    "022c124dfff348f285335732781b0887": {
         | 
| 268 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 269 | 
             
                      "model_name": "HBoxModel",
         | 
|  | |
| 270 | 
             
                      "state": {
         | 
| 271 | 
             
                        "_view_name": "HBoxView",
         | 
| 272 | 
             
                        "_dom_classes": [],
         | 
| @@ -287,6 +296,7 @@ | |
| 287 | 
             
                    "a44e47e9d26c4deb81a5a11a9db92a9f": {
         | 
| 288 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 289 | 
             
                      "model_name": "LayoutModel",
         | 
|  | |
| 290 | 
             
                      "state": {
         | 
| 291 | 
             
                        "_view_name": "LayoutView",
         | 
| 292 | 
             
                        "grid_template_rows": null,
         | 
| @@ -338,6 +348,7 @@ | |
| 338 | 
             
                    "cd9c7016caae47c1b41fb2608c78b0bf": {
         | 
| 339 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 340 | 
             
                      "model_name": "FloatProgressModel",
         | 
|  | |
| 341 | 
             
                      "state": {
         | 
| 342 | 
             
                        "_view_name": "ProgressView",
         | 
| 343 | 
             
                        "style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
         | 
| @@ -361,6 +372,7 @@ | |
| 361 | 
             
                    "36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
         | 
| 362 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 363 | 
             
                      "model_name": "HTMLModel",
         | 
|  | |
| 364 | 
             
                      "state": {
         | 
| 365 | 
             
                        "_view_name": "HTMLView",
         | 
| 366 | 
             
                        "style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
         | 
| @@ -381,6 +393,7 @@ | |
| 381 | 
             
                    "c22f207311cf4fb69bd9328eabfd4ebb": {
         | 
| 382 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 383 | 
             
                      "model_name": "ProgressStyleModel",
         | 
|  | |
| 384 | 
             
                      "state": {
         | 
| 385 | 
             
                        "_view_name": "StyleView",
         | 
| 386 | 
             
                        "_model_name": "ProgressStyleModel",
         | 
| @@ -396,6 +409,7 @@ | |
| 396 | 
             
                    "5a38c6d83a264bedbf7efe6e97eba953": {
         | 
| 397 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 398 | 
             
                      "model_name": "LayoutModel",
         | 
|  | |
| 399 | 
             
                      "state": {
         | 
| 400 | 
             
                        "_view_name": "LayoutView",
         | 
| 401 | 
             
                        "grid_template_rows": null,
         | 
| @@ -447,6 +461,7 @@ | |
| 447 | 
             
                    "037563a7eadd4ac5abb7249a2914d346": {
         | 
| 448 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 449 | 
             
                      "model_name": "DescriptionStyleModel",
         | 
|  | |
| 450 | 
             
                      "state": {
         | 
| 451 | 
             
                        "_view_name": "StyleView",
         | 
| 452 | 
             
                        "_model_name": "DescriptionStyleModel",
         | 
| @@ -461,6 +476,7 @@ | |
| 461 | 
             
                    "3975e7ed0b704990b1fa05909a9bb9b6": {
         | 
| 462 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 463 | 
             
                      "model_name": "LayoutModel",
         | 
|  | |
| 464 | 
             
                      "state": {
         | 
| 465 | 
             
                        "_view_name": "LayoutView",
         | 
| 466 | 
             
                        "grid_template_rows": null,
         | 
| @@ -512,6 +528,7 @@ | |
| 512 | 
             
                    "f9f1fdc3819a4142b85304cd3c6358a2": {
         | 
| 513 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 514 | 
             
                      "model_name": "HBoxModel",
         | 
|  | |
| 515 | 
             
                      "state": {
         | 
| 516 | 
             
                        "_view_name": "HBoxView",
         | 
| 517 | 
             
                        "_dom_classes": [],
         | 
| @@ -532,6 +549,7 @@ | |
| 532 | 
             
                    "ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
         | 
| 533 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 534 | 
             
                      "model_name": "LayoutModel",
         | 
|  | |
| 535 | 
             
                      "state": {
         | 
| 536 | 
             
                        "_view_name": "LayoutView",
         | 
| 537 | 
             
                        "grid_template_rows": null,
         | 
| @@ -583,6 +601,7 @@ | |
| 583 | 
             
                    "29d42e94b3b34c86a117b623da68faed": {
         | 
| 584 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 585 | 
             
                      "model_name": "FloatProgressModel",
         | 
|  | |
| 586 | 
             
                      "state": {
         | 
| 587 | 
             
                        "_view_name": "ProgressView",
         | 
| 588 | 
             
                        "style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
         | 
| @@ -606,6 +625,7 @@ | |
| 606 | 
             
                    "8b73de7dbdfe40dbbb39fb593520b984": {
         | 
| 607 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 608 | 
             
                      "model_name": "HTMLModel",
         | 
|  | |
| 609 | 
             
                      "state": {
         | 
| 610 | 
             
                        "_view_name": "HTMLView",
         | 
| 611 | 
             
                        "style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
         | 
| @@ -626,6 +646,7 @@ | |
| 626 | 
             
                    "8ce4d20d004a4382afa0abdd3b1f7191": {
         | 
| 627 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 628 | 
             
                      "model_name": "ProgressStyleModel",
         | 
|  | |
| 629 | 
             
                      "state": {
         | 
| 630 | 
             
                        "_view_name": "StyleView",
         | 
| 631 | 
             
                        "_model_name": "ProgressStyleModel",
         | 
| @@ -641,6 +662,7 @@ | |
| 641 | 
             
                    "efc4812245c8459c92e6436889b4f600": {
         | 
| 642 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 643 | 
             
                      "model_name": "LayoutModel",
         | 
|  | |
| 644 | 
             
                      "state": {
         | 
| 645 | 
             
                        "_view_name": "LayoutView",
         | 
| 646 | 
             
                        "grid_template_rows": null,
         | 
| @@ -692,6 +714,7 @@ | |
| 692 | 
             
                    "717ccef4df1f477abb51814650eb47da": {
         | 
| 693 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 694 | 
             
                      "model_name": "DescriptionStyleModel",
         | 
|  | |
| 695 | 
             
                      "state": {
         | 
| 696 | 
             
                        "_view_name": "StyleView",
         | 
| 697 | 
             
                        "_model_name": "DescriptionStyleModel",
         | 
| @@ -706,6 +729,7 @@ | |
| 706 | 
             
                    "7dba58f0391c485a86e34e8039ec6189": {
         | 
| 707 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 708 | 
             
                      "model_name": "LayoutModel",
         | 
|  | |
| 709 | 
             
                      "state": {
         | 
| 710 | 
             
                        "_view_name": "LayoutView",
         | 
| 711 | 
             
                        "grid_template_rows": null,
         | 
| @@ -804,8 +828,7 @@ | |
| 804 | 
             
                  "source": [
         | 
| 805 | 
             
                    "!pip install -q transformers flax\n",
         | 
| 806 | 
             
                    "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git  # VQGAN model in JAX\n",
         | 
| 807 | 
            -
                    "! | 
| 808 | 
            -
                    "%cd dalle-mini/"
         | 
| 809 | 
             
                  ],
         | 
| 810 | 
             
                  "execution_count": null,
         | 
| 811 | 
             
                  "outputs": []
         | 
| @@ -833,7 +856,7 @@ | |
| 833 | 
             
                    "import random\n",
         | 
| 834 | 
             
                    "from tqdm.notebook import tqdm, trange"
         | 
| 835 | 
             
                  ],
         | 
| 836 | 
            -
                  "execution_count":  | 
| 837 | 
             
                  "outputs": []
         | 
| 838 | 
             
                },
         | 
| 839 | 
             
                {
         | 
| @@ -846,7 +869,7 @@ | |
| 846 | 
             
                    "DALLE_REPO = 'flax-community/dalle-mini'\n",
         | 
| 847 | 
             
                    "DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
         | 
| 848 | 
             
                  ],
         | 
| 849 | 
            -
                  "execution_count":  | 
| 850 | 
             
                  "outputs": []
         | 
| 851 | 
             
                },
         | 
| 852 | 
             
                {
         | 
| @@ -871,7 +894,7 @@ | |
| 871 | 
             
                    "# set a prompt\n",
         | 
| 872 | 
             
                    "prompt = 'picture of a waterfall under the sunset'"
         | 
| 873 | 
             
                  ],
         | 
| 874 | 
            -
                  "execution_count":  | 
| 875 | 
             
                  "outputs": []
         | 
| 876 | 
             
                },
         | 
| 877 | 
             
                {
         | 
| @@ -888,7 +911,7 @@ | |
| 888 | 
             
                    "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
         | 
| 889 | 
             
                    "tokenized_prompt"
         | 
| 890 | 
             
                  ],
         | 
| 891 | 
            -
                  "execution_count":  | 
| 892 | 
             
                  "outputs": [
         | 
| 893 | 
             
                    {
         | 
| 894 | 
             
                      "output_type": "execute_result",
         | 
| @@ -956,7 +979,7 @@ | |
| 956 | 
             
                    "subkeys = jax.random.split(key, num=n_predictions)\n",
         | 
| 957 | 
             
                    "subkeys"
         | 
| 958 | 
             
                  ],
         | 
| 959 | 
            -
                  "execution_count":  | 
| 960 | 
             
                  "outputs": [
         | 
| 961 | 
             
                    {
         | 
| 962 | 
             
                      "output_type": "execute_result",
         | 
| @@ -1004,7 +1027,7 @@ | |
| 1004 | 
             
                    "encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
         | 
| 1005 | 
             
                    "encoded_images[0]"
         | 
| 1006 | 
             
                  ],
         | 
| 1007 | 
            -
                  "execution_count":  | 
| 1008 | 
             
                  "outputs": [
         | 
| 1009 | 
             
                    {
         | 
| 1010 | 
             
                      "output_type": "display_data",
         | 
| @@ -1099,7 +1122,7 @@ | |
| 1099 | 
             
                    "encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
         | 
| 1100 | 
             
                    "encoded_images[0]"
         | 
| 1101 | 
             
                  ],
         | 
| 1102 | 
            -
                  "execution_count":  | 
| 1103 | 
             
                  "outputs": [
         | 
| 1104 | 
             
                    {
         | 
| 1105 | 
             
                      "output_type": "execute_result",
         | 
| @@ -1167,7 +1190,7 @@ | |
| 1167 | 
             
                  "source": [
         | 
| 1168 | 
             
                    "encoded_images[0].shape"
         | 
| 1169 | 
             
                  ],
         | 
| 1170 | 
            -
                  "execution_count":  | 
| 1171 | 
             
                  "outputs": [
         | 
| 1172 | 
             
                    {
         | 
| 1173 | 
             
                      "output_type": "execute_result",
         | 
| @@ -1204,7 +1227,7 @@ | |
| 1204 | 
             
                    "import numpy as np\n",
         | 
| 1205 | 
             
                    "from PIL import Image"
         | 
| 1206 | 
             
                  ],
         | 
| 1207 | 
            -
                  "execution_count":  | 
| 1208 | 
             
                  "outputs": []
         | 
| 1209 | 
             
                },
         | 
| 1210 | 
             
                {
         | 
| @@ -1217,7 +1240,7 @@ | |
| 1217 | 
             
                    "VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
         | 
| 1218 | 
             
                    "VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
         | 
| 1219 | 
             
                  ],
         | 
| 1220 | 
            -
                  "execution_count":  | 
| 1221 | 
             
                  "outputs": []
         | 
| 1222 | 
             
                },
         | 
| 1223 | 
             
                {
         | 
| @@ -1233,7 +1256,7 @@ | |
| 1233 | 
             
                    "# set up VQGAN\n",
         | 
| 1234 | 
             
                    "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
         | 
| 1235 | 
             
                  ],
         | 
| 1236 | 
            -
                  "execution_count":  | 
| 1237 | 
             
                  "outputs": [
         | 
| 1238 | 
             
                    {
         | 
| 1239 | 
             
                      "output_type": "stream",
         | 
| @@ -1269,7 +1292,7 @@ | |
| 1269 | 
             
                    "decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
         | 
| 1270 | 
             
                    "decoded_images[0]"
         | 
| 1271 | 
             
                  ],
         | 
| 1272 | 
            -
                  "execution_count":  | 
| 1273 | 
             
                  "outputs": [
         | 
| 1274 | 
             
                    {
         | 
| 1275 | 
             
                      "output_type": "display_data",
         | 
| @@ -1373,7 +1396,7 @@ | |
| 1373 | 
             
                    "# normalize images\n",
         | 
| 1374 | 
             
                    "clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
         | 
| 1375 | 
             
                  ],
         | 
| 1376 | 
            -
                  "execution_count":  | 
| 1377 | 
             
                  "outputs": []
         | 
| 1378 | 
             
                },
         | 
| 1379 | 
             
                {
         | 
| @@ -1385,7 +1408,7 @@ | |
| 1385 | 
             
                    "# convert to image\n",
         | 
| 1386 | 
             
                    "images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
         | 
| 1387 | 
             
                  ],
         | 
| 1388 | 
            -
                  "execution_count":  | 
| 1389 | 
             
                  "outputs": []
         | 
| 1390 | 
             
                },
         | 
| 1391 | 
             
                {
         | 
| @@ -1402,7 +1425,7 @@ | |
| 1402 | 
             
                    "# display an image\n",
         | 
| 1403 | 
             
                    "images[0]"
         | 
| 1404 | 
             
                  ],
         | 
| 1405 | 
            -
                  "execution_count":  | 
| 1406 | 
             
                  "outputs": [
         | 
| 1407 | 
             
                    {
         | 
| 1408 | 
             
                      "output_type": "execute_result",
         | 
| @@ -1438,7 +1461,7 @@ | |
| 1438 | 
             
                  "source": [
         | 
| 1439 | 
             
                    "from transformers import CLIPProcessor, FlaxCLIPModel"
         | 
| 1440 | 
             
                  ],
         | 
| 1441 | 
            -
                  "execution_count":  | 
| 1442 | 
             
                  "outputs": []
         | 
| 1443 | 
             
                },
         | 
| 1444 | 
             
                {
         | 
| @@ -1474,7 +1497,7 @@ | |
| 1474 | 
             
                    "logits = clip(**inputs).logits_per_image\n",
         | 
| 1475 | 
             
                    "scores = jax.nn.softmax(logits, axis=0).squeeze()  # normalize and sum all scores to 1"
         | 
| 1476 | 
             
                  ],
         | 
| 1477 | 
            -
                  "execution_count":  | 
| 1478 | 
             
                  "outputs": []
         | 
| 1479 | 
             
                },
         | 
| 1480 | 
             
                {
         | 
| @@ -1495,7 +1518,7 @@ | |
| 1495 | 
             
                    "    display(images[idx])\n",
         | 
| 1496 | 
             
                    "    print()"
         | 
| 1497 | 
             
                  ],
         | 
| 1498 | 
            -
                  "execution_count":  | 
| 1499 | 
             
                  "outputs": [
         | 
| 1500 | 
             
                    {
         | 
| 1501 | 
             
                      "output_type": "stream",
         | 
| @@ -1690,7 +1713,7 @@ | |
| 1690 | 
             
                    "from flax.training.common_utils import shard\n",
         | 
| 1691 | 
             
                    "from flax.jax_utils import replicate"
         | 
| 1692 | 
             
                  ],
         | 
| 1693 | 
            -
                  "execution_count":  | 
| 1694 | 
             
                  "outputs": []
         | 
| 1695 | 
             
                },
         | 
| 1696 | 
             
                {
         | 
| @@ -1706,7 +1729,7 @@ | |
| 1706 | 
             
                    "# check we can access TPU's or GPU's\n",
         | 
| 1707 | 
             
                    "jax.devices()"
         | 
| 1708 | 
             
                  ],
         | 
| 1709 | 
            -
                  "execution_count":  | 
| 1710 | 
             
                  "outputs": [
         | 
| 1711 | 
             
                    {
         | 
| 1712 | 
             
                      "output_type": "execute_result",
         | 
| @@ -1744,7 +1767,7 @@ | |
| 1744 | 
             
                    "# one set of inputs per device\n",
         | 
| 1745 | 
             
                    "prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
         | 
| 1746 | 
             
                  ],
         | 
| 1747 | 
            -
                  "execution_count":  | 
| 1748 | 
             
                  "outputs": []
         | 
| 1749 | 
             
                },
         | 
| 1750 | 
             
                {
         | 
| @@ -1757,7 +1780,7 @@ | |
| 1757 | 
             
                    "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
         | 
| 1758 | 
             
                    "tokenized_prompt = shard(tokenized_prompt)"
         | 
| 1759 | 
             
                  ],
         | 
| 1760 | 
            -
                  "execution_count":  | 
| 1761 | 
             
                  "outputs": []
         | 
| 1762 | 
             
                },
         | 
| 1763 | 
             
                {
         | 
| @@ -1793,7 +1816,7 @@ | |
| 1793 | 
             
                    "def p_decode(indices, params):\n",
         | 
| 1794 | 
             
                    "    return vqgan.decode_code(indices, params=params)"
         | 
| 1795 | 
             
                  ],
         | 
| 1796 | 
            -
                  "execution_count":  | 
| 1797 | 
             
                  "outputs": []
         | 
| 1798 | 
             
                },
         | 
| 1799 | 
             
                {
         | 
| @@ -1834,7 +1857,7 @@ | |
| 1834 | 
             
                    "    for img in decoded_images:\n",
         | 
| 1835 | 
             
                    "        images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
         | 
| 1836 | 
             
                  ],
         | 
| 1837 | 
            -
                  "execution_count":  | 
| 1838 | 
             
                  "outputs": [
         | 
| 1839 | 
             
                    {
         | 
| 1840 | 
             
                      "output_type": "display_data",
         | 
| @@ -1877,7 +1900,7 @@ | |
| 1877 | 
             
                    "    display(img)\n",
         | 
| 1878 | 
             
                    "    print()"
         | 
| 1879 | 
             
                  ],
         | 
| 1880 | 
            -
                  "execution_count":  | 
| 1881 | 
             
                  "outputs": [
         | 
| 1882 | 
             
                    {
         | 
| 1883 | 
             
                      "output_type": "display_data",
         | 
|  | |
| 6 | 
             
                  "name": "DALL·E mini - Inference pipeline.ipynb",
         | 
| 7 | 
             
                  "provenance": [],
         | 
| 8 | 
             
                  "collapsed_sections": [],
         | 
| 9 | 
            +
                  "authorship_tag": "ABX9TyMUjEt1XMLq+6/GhSnVFsSx",
         | 
| 10 | 
             
                  "include_colab_link": true
         | 
| 11 | 
             
                },
         | 
| 12 | 
             
                "kernelspec": {
         | 
|  | |
| 22 | 
             
                    "49304912717a4995ae45d04a59d1f50e": {
         | 
| 23 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 24 | 
             
                      "model_name": "HBoxModel",
         | 
| 25 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 26 | 
             
                      "state": {
         | 
| 27 | 
             
                        "_view_name": "HBoxView",
         | 
| 28 | 
             
                        "_dom_classes": [],
         | 
|  | |
| 43 | 
             
                    "5fd9f97986024e8db560a6737ade9e2e": {
         | 
| 44 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 45 | 
             
                      "model_name": "LayoutModel",
         | 
| 46 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 47 | 
             
                      "state": {
         | 
| 48 | 
             
                        "_view_name": "LayoutView",
         | 
| 49 | 
             
                        "grid_template_rows": null,
         | 
|  | |
| 95 | 
             
                    "caced43e3a4c493b98fb07cb41db045c": {
         | 
| 96 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 97 | 
             
                      "model_name": "FloatProgressModel",
         | 
| 98 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 99 | 
             
                      "state": {
         | 
| 100 | 
             
                        "_view_name": "ProgressView",
         | 
| 101 | 
             
                        "style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
         | 
|  | |
| 119 | 
             
                    "0acc161f2e9948b68b3fc4e57ef333c9": {
         | 
| 120 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 121 | 
             
                      "model_name": "HTMLModel",
         | 
| 122 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 123 | 
             
                      "state": {
         | 
| 124 | 
             
                        "_view_name": "HTMLView",
         | 
| 125 | 
             
                        "style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
         | 
|  | |
| 140 | 
             
                    "40c54b9454d346aabd197f2bcf189467": {
         | 
| 141 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 142 | 
             
                      "model_name": "ProgressStyleModel",
         | 
| 143 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 144 | 
             
                      "state": {
         | 
| 145 | 
             
                        "_view_name": "StyleView",
         | 
| 146 | 
             
                        "_model_name": "ProgressStyleModel",
         | 
|  | |
| 156 | 
             
                    "8b25334a48244a14aa9ba0176887e655": {
         | 
| 157 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 158 | 
             
                      "model_name": "LayoutModel",
         | 
| 159 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 160 | 
             
                      "state": {
         | 
| 161 | 
             
                        "_view_name": "LayoutView",
         | 
| 162 | 
             
                        "grid_template_rows": null,
         | 
|  | |
| 208 | 
             
                    "7e7c488f57fc4acb8d261e2db81d61f0": {
         | 
| 209 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 210 | 
             
                      "model_name": "DescriptionStyleModel",
         | 
| 211 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 212 | 
             
                      "state": {
         | 
| 213 | 
             
                        "_view_name": "StyleView",
         | 
| 214 | 
             
                        "_model_name": "DescriptionStyleModel",
         | 
|  | |
| 223 | 
             
                    "72c401062a5348b1a366dffb5a403568": {
         | 
| 224 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 225 | 
             
                      "model_name": "LayoutModel",
         | 
| 226 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 227 | 
             
                      "state": {
         | 
| 228 | 
             
                        "_view_name": "LayoutView",
         | 
| 229 | 
             
                        "grid_template_rows": null,
         | 
|  | |
| 275 | 
             
                    "022c124dfff348f285335732781b0887": {
         | 
| 276 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 277 | 
             
                      "model_name": "HBoxModel",
         | 
| 278 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 279 | 
             
                      "state": {
         | 
| 280 | 
             
                        "_view_name": "HBoxView",
         | 
| 281 | 
             
                        "_dom_classes": [],
         | 
|  | |
| 296 | 
             
                    "a44e47e9d26c4deb81a5a11a9db92a9f": {
         | 
| 297 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 298 | 
             
                      "model_name": "LayoutModel",
         | 
| 299 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 300 | 
             
                      "state": {
         | 
| 301 | 
             
                        "_view_name": "LayoutView",
         | 
| 302 | 
             
                        "grid_template_rows": null,
         | 
|  | |
| 348 | 
             
                    "cd9c7016caae47c1b41fb2608c78b0bf": {
         | 
| 349 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 350 | 
             
                      "model_name": "FloatProgressModel",
         | 
| 351 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 352 | 
             
                      "state": {
         | 
| 353 | 
             
                        "_view_name": "ProgressView",
         | 
| 354 | 
             
                        "style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
         | 
|  | |
| 372 | 
             
                    "36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
         | 
| 373 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 374 | 
             
                      "model_name": "HTMLModel",
         | 
| 375 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 376 | 
             
                      "state": {
         | 
| 377 | 
             
                        "_view_name": "HTMLView",
         | 
| 378 | 
             
                        "style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
         | 
|  | |
| 393 | 
             
                    "c22f207311cf4fb69bd9328eabfd4ebb": {
         | 
| 394 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 395 | 
             
                      "model_name": "ProgressStyleModel",
         | 
| 396 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 397 | 
             
                      "state": {
         | 
| 398 | 
             
                        "_view_name": "StyleView",
         | 
| 399 | 
             
                        "_model_name": "ProgressStyleModel",
         | 
|  | |
| 409 | 
             
                    "5a38c6d83a264bedbf7efe6e97eba953": {
         | 
| 410 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 411 | 
             
                      "model_name": "LayoutModel",
         | 
| 412 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 413 | 
             
                      "state": {
         | 
| 414 | 
             
                        "_view_name": "LayoutView",
         | 
| 415 | 
             
                        "grid_template_rows": null,
         | 
|  | |
| 461 | 
             
                    "037563a7eadd4ac5abb7249a2914d346": {
         | 
| 462 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 463 | 
             
                      "model_name": "DescriptionStyleModel",
         | 
| 464 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 465 | 
             
                      "state": {
         | 
| 466 | 
             
                        "_view_name": "StyleView",
         | 
| 467 | 
             
                        "_model_name": "DescriptionStyleModel",
         | 
|  | |
| 476 | 
             
                    "3975e7ed0b704990b1fa05909a9bb9b6": {
         | 
| 477 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 478 | 
             
                      "model_name": "LayoutModel",
         | 
| 479 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 480 | 
             
                      "state": {
         | 
| 481 | 
             
                        "_view_name": "LayoutView",
         | 
| 482 | 
             
                        "grid_template_rows": null,
         | 
|  | |
| 528 | 
             
                    "f9f1fdc3819a4142b85304cd3c6358a2": {
         | 
| 529 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 530 | 
             
                      "model_name": "HBoxModel",
         | 
| 531 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 532 | 
             
                      "state": {
         | 
| 533 | 
             
                        "_view_name": "HBoxView",
         | 
| 534 | 
             
                        "_dom_classes": [],
         | 
|  | |
| 549 | 
             
                    "ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
         | 
| 550 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 551 | 
             
                      "model_name": "LayoutModel",
         | 
| 552 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 553 | 
             
                      "state": {
         | 
| 554 | 
             
                        "_view_name": "LayoutView",
         | 
| 555 | 
             
                        "grid_template_rows": null,
         | 
|  | |
| 601 | 
             
                    "29d42e94b3b34c86a117b623da68faed": {
         | 
| 602 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 603 | 
             
                      "model_name": "FloatProgressModel",
         | 
| 604 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 605 | 
             
                      "state": {
         | 
| 606 | 
             
                        "_view_name": "ProgressView",
         | 
| 607 | 
             
                        "style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
         | 
|  | |
| 625 | 
             
                    "8b73de7dbdfe40dbbb39fb593520b984": {
         | 
| 626 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 627 | 
             
                      "model_name": "HTMLModel",
         | 
| 628 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 629 | 
             
                      "state": {
         | 
| 630 | 
             
                        "_view_name": "HTMLView",
         | 
| 631 | 
             
                        "style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
         | 
|  | |
| 646 | 
             
                    "8ce4d20d004a4382afa0abdd3b1f7191": {
         | 
| 647 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 648 | 
             
                      "model_name": "ProgressStyleModel",
         | 
| 649 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 650 | 
             
                      "state": {
         | 
| 651 | 
             
                        "_view_name": "StyleView",
         | 
| 652 | 
             
                        "_model_name": "ProgressStyleModel",
         | 
|  | |
| 662 | 
             
                    "efc4812245c8459c92e6436889b4f600": {
         | 
| 663 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 664 | 
             
                      "model_name": "LayoutModel",
         | 
| 665 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 666 | 
             
                      "state": {
         | 
| 667 | 
             
                        "_view_name": "LayoutView",
         | 
| 668 | 
             
                        "grid_template_rows": null,
         | 
|  | |
| 714 | 
             
                    "717ccef4df1f477abb51814650eb47da": {
         | 
| 715 | 
             
                      "model_module": "@jupyter-widgets/controls",
         | 
| 716 | 
             
                      "model_name": "DescriptionStyleModel",
         | 
| 717 | 
            +
                      "model_module_version": "1.5.0",
         | 
| 718 | 
             
                      "state": {
         | 
| 719 | 
             
                        "_view_name": "StyleView",
         | 
| 720 | 
             
                        "_model_name": "DescriptionStyleModel",
         | 
|  | |
| 729 | 
             
                    "7dba58f0391c485a86e34e8039ec6189": {
         | 
| 730 | 
             
                      "model_module": "@jupyter-widgets/base",
         | 
| 731 | 
             
                      "model_name": "LayoutModel",
         | 
| 732 | 
            +
                      "model_module_version": "1.2.0",
         | 
| 733 | 
             
                      "state": {
         | 
| 734 | 
             
                        "_view_name": "LayoutView",
         | 
| 735 | 
             
                        "grid_template_rows": null,
         | 
|  | |
| 828 | 
             
                  "source": [
         | 
| 829 | 
             
                    "!pip install -q transformers flax\n",
         | 
| 830 | 
             
                    "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git  # VQGAN model in JAX\n",
         | 
| 831 | 
            +
                    "!pip install -q git+https://github.com/borisdayma/dalle-mini.git  # Model files"
         | 
|  | |
| 832 | 
             
                  ],
         | 
| 833 | 
             
                  "execution_count": null,
         | 
| 834 | 
             
                  "outputs": []
         | 
|  | |
| 856 | 
             
                    "import random\n",
         | 
| 857 | 
             
                    "from tqdm.notebook import tqdm, trange"
         | 
| 858 | 
             
                  ],
         | 
| 859 | 
            +
                  "execution_count": null,
         | 
| 860 | 
             
                  "outputs": []
         | 
| 861 | 
             
                },
         | 
| 862 | 
             
                {
         | 
|  | |
| 869 | 
             
                    "DALLE_REPO = 'flax-community/dalle-mini'\n",
         | 
| 870 | 
             
                    "DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
         | 
| 871 | 
             
                  ],
         | 
| 872 | 
            +
                  "execution_count": null,
         | 
| 873 | 
             
                  "outputs": []
         | 
| 874 | 
             
                },
         | 
| 875 | 
             
                {
         | 
|  | |
| 894 | 
             
                    "# set a prompt\n",
         | 
| 895 | 
             
                    "prompt = 'picture of a waterfall under the sunset'"
         | 
| 896 | 
             
                  ],
         | 
| 897 | 
            +
                  "execution_count": null,
         | 
| 898 | 
             
                  "outputs": []
         | 
| 899 | 
             
                },
         | 
| 900 | 
             
                {
         | 
|  | |
| 911 | 
             
                    "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
         | 
| 912 | 
             
                    "tokenized_prompt"
         | 
| 913 | 
             
                  ],
         | 
| 914 | 
            +
                  "execution_count": null,
         | 
| 915 | 
             
                  "outputs": [
         | 
| 916 | 
             
                    {
         | 
| 917 | 
             
                      "output_type": "execute_result",
         | 
|  | |
| 979 | 
             
                    "subkeys = jax.random.split(key, num=n_predictions)\n",
         | 
| 980 | 
             
                    "subkeys"
         | 
| 981 | 
             
                  ],
         | 
| 982 | 
            +
                  "execution_count": null,
         | 
| 983 | 
             
                  "outputs": [
         | 
| 984 | 
             
                    {
         | 
| 985 | 
             
                      "output_type": "execute_result",
         | 
|  | |
| 1027 | 
             
                    "encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
         | 
| 1028 | 
             
                    "encoded_images[0]"
         | 
| 1029 | 
             
                  ],
         | 
| 1030 | 
            +
                  "execution_count": null,
         | 
| 1031 | 
             
                  "outputs": [
         | 
| 1032 | 
             
                    {
         | 
| 1033 | 
             
                      "output_type": "display_data",
         | 
|  | |
| 1122 | 
             
                    "encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
         | 
| 1123 | 
             
                    "encoded_images[0]"
         | 
| 1124 | 
             
                  ],
         | 
| 1125 | 
            +
                  "execution_count": null,
         | 
| 1126 | 
             
                  "outputs": [
         | 
| 1127 | 
             
                    {
         | 
| 1128 | 
             
                      "output_type": "execute_result",
         | 
|  | |
| 1190 | 
             
                  "source": [
         | 
| 1191 | 
             
                    "encoded_images[0].shape"
         | 
| 1192 | 
             
                  ],
         | 
| 1193 | 
            +
                  "execution_count": null,
         | 
| 1194 | 
             
                  "outputs": [
         | 
| 1195 | 
             
                    {
         | 
| 1196 | 
             
                      "output_type": "execute_result",
         | 
|  | |
| 1227 | 
             
                    "import numpy as np\n",
         | 
| 1228 | 
             
                    "from PIL import Image"
         | 
| 1229 | 
             
                  ],
         | 
| 1230 | 
            +
                  "execution_count": null,
         | 
| 1231 | 
             
                  "outputs": []
         | 
| 1232 | 
             
                },
         | 
| 1233 | 
             
                {
         | 
|  | |
| 1240 | 
             
                    "VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
         | 
| 1241 | 
             
                    "VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
         | 
| 1242 | 
             
                  ],
         | 
| 1243 | 
            +
                  "execution_count": null,
         | 
| 1244 | 
             
                  "outputs": []
         | 
| 1245 | 
             
                },
         | 
| 1246 | 
             
                {
         | 
|  | |
| 1256 | 
             
                    "# set up VQGAN\n",
         | 
| 1257 | 
             
                    "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
         | 
| 1258 | 
             
                  ],
         | 
| 1259 | 
            +
                  "execution_count": null,
         | 
| 1260 | 
             
                  "outputs": [
         | 
| 1261 | 
             
                    {
         | 
| 1262 | 
             
                      "output_type": "stream",
         | 
|  | |
| 1292 | 
             
                    "decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
         | 
| 1293 | 
             
                    "decoded_images[0]"
         | 
| 1294 | 
             
                  ],
         | 
| 1295 | 
            +
                  "execution_count": null,
         | 
| 1296 | 
             
                  "outputs": [
         | 
| 1297 | 
             
                    {
         | 
| 1298 | 
             
                      "output_type": "display_data",
         | 
|  | |
| 1396 | 
             
                    "# normalize images\n",
         | 
| 1397 | 
             
                    "clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
         | 
| 1398 | 
             
                  ],
         | 
| 1399 | 
            +
                  "execution_count": null,
         | 
| 1400 | 
             
                  "outputs": []
         | 
| 1401 | 
             
                },
         | 
| 1402 | 
             
                {
         | 
|  | |
| 1408 | 
             
                    "# convert to image\n",
         | 
| 1409 | 
             
                    "images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
         | 
| 1410 | 
             
                  ],
         | 
| 1411 | 
            +
                  "execution_count": null,
         | 
| 1412 | 
             
                  "outputs": []
         | 
| 1413 | 
             
                },
         | 
| 1414 | 
             
                {
         | 
|  | |
| 1425 | 
             
                    "# display an image\n",
         | 
| 1426 | 
             
                    "images[0]"
         | 
| 1427 | 
             
                  ],
         | 
| 1428 | 
            +
                  "execution_count": null,
         | 
| 1429 | 
             
                  "outputs": [
         | 
| 1430 | 
             
                    {
         | 
| 1431 | 
             
                      "output_type": "execute_result",
         | 
|  | |
| 1461 | 
             
                  "source": [
         | 
| 1462 | 
             
                    "from transformers import CLIPProcessor, FlaxCLIPModel"
         | 
| 1463 | 
             
                  ],
         | 
| 1464 | 
            +
                  "execution_count": null,
         | 
| 1465 | 
             
                  "outputs": []
         | 
| 1466 | 
             
                },
         | 
| 1467 | 
             
                {
         | 
|  | |
| 1497 | 
             
                    "logits = clip(**inputs).logits_per_image\n",
         | 
| 1498 | 
             
                    "scores = jax.nn.softmax(logits, axis=0).squeeze()  # normalize and sum all scores to 1"
         | 
| 1499 | 
             
                  ],
         | 
| 1500 | 
            +
                  "execution_count": null,
         | 
| 1501 | 
             
                  "outputs": []
         | 
| 1502 | 
             
                },
         | 
| 1503 | 
             
                {
         | 
|  | |
| 1518 | 
             
                    "    display(images[idx])\n",
         | 
| 1519 | 
             
                    "    print()"
         | 
| 1520 | 
             
                  ],
         | 
| 1521 | 
            +
                  "execution_count": null,
         | 
| 1522 | 
             
                  "outputs": [
         | 
| 1523 | 
             
                    {
         | 
| 1524 | 
             
                      "output_type": "stream",
         | 
|  | |
| 1713 | 
             
                    "from flax.training.common_utils import shard\n",
         | 
| 1714 | 
             
                    "from flax.jax_utils import replicate"
         | 
| 1715 | 
             
                  ],
         | 
| 1716 | 
            +
                  "execution_count": null,
         | 
| 1717 | 
             
                  "outputs": []
         | 
| 1718 | 
             
                },
         | 
| 1719 | 
             
                {
         | 
|  | |
| 1729 | 
             
                    "# check we can access TPU's or GPU's\n",
         | 
| 1730 | 
             
                    "jax.devices()"
         | 
| 1731 | 
             
                  ],
         | 
| 1732 | 
            +
                  "execution_count": null,
         | 
| 1733 | 
             
                  "outputs": [
         | 
| 1734 | 
             
                    {
         | 
| 1735 | 
             
                      "output_type": "execute_result",
         | 
|  | |
| 1767 | 
             
                    "# one set of inputs per device\n",
         | 
| 1768 | 
             
                    "prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
         | 
| 1769 | 
             
                  ],
         | 
| 1770 | 
            +
                  "execution_count": null,
         | 
| 1771 | 
             
                  "outputs": []
         | 
| 1772 | 
             
                },
         | 
| 1773 | 
             
                {
         | 
|  | |
| 1780 | 
             
                    "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
         | 
| 1781 | 
             
                    "tokenized_prompt = shard(tokenized_prompt)"
         | 
| 1782 | 
             
                  ],
         | 
| 1783 | 
            +
                  "execution_count": null,
         | 
| 1784 | 
             
                  "outputs": []
         | 
| 1785 | 
             
                },
         | 
| 1786 | 
             
                {
         | 
|  | |
| 1816 | 
             
                    "def p_decode(indices, params):\n",
         | 
| 1817 | 
             
                    "    return vqgan.decode_code(indices, params=params)"
         | 
| 1818 | 
             
                  ],
         | 
| 1819 | 
            +
                  "execution_count": null,
         | 
| 1820 | 
             
                  "outputs": []
         | 
| 1821 | 
             
                },
         | 
| 1822 | 
             
                {
         | 
|  | |
| 1857 | 
             
                    "    for img in decoded_images:\n",
         | 
| 1858 | 
             
                    "        images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
         | 
| 1859 | 
             
                  ],
         | 
| 1860 | 
            +
                  "execution_count": null,
         | 
| 1861 | 
             
                  "outputs": [
         | 
| 1862 | 
             
                    {
         | 
| 1863 | 
             
                      "output_type": "display_data",
         | 
|  | |
| 1900 | 
             
                    "    display(img)\n",
         | 
| 1901 | 
             
                    "    print()"
         | 
| 1902 | 
             
                  ],
         | 
| 1903 | 
            +
                  "execution_count": null,
         | 
| 1904 | 
             
                  "outputs": [
         | 
| 1905 | 
             
                    {
         | 
| 1906 | 
             
                      "output_type": "display_data",
         | 
    	
        dev/requirements.txt
    CHANGED
    
    | @@ -1,10 +1,8 @@ | |
| 1 | 
            -
            # Note: install with the following command:
         | 
| 2 | 
            -
            # pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
         | 
| 3 | 
            -
            # Otherwise it won't find the appropriate libtpu_nightly
         | 
| 4 | 
             
            requests
         | 
|  | |
| 5 | 
             
            jax[tpu]>=0.2.16
         | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
             
            flax
         | 
| 9 | 
             
            jupyter
         | 
| 10 | 
             
            wandb
         | 
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            requests
         | 
| 2 | 
            +
            -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
         | 
| 3 | 
             
            jax[tpu]>=0.2.16
         | 
| 4 | 
            +
            transformers
         | 
| 5 | 
            +
            datasets
         | 
| 6 | 
             
            flax
         | 
| 7 | 
             
            jupyter
         | 
| 8 | 
             
            wandb
         | 
    	
        requirements.txt
    DELETED
    
    | @@ -1,2 +0,0 @@ | |
| 1 | 
            -
            # Requirements for huggingface spaces
         | 
| 2 | 
            -
            streamlit>=0.84.2
         | 
|  | |
|  | |
|  | 
    	
        setup.cfg
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [metadata]
         | 
| 2 | 
            +
            name = dalle_mini
         | 
| 3 | 
            +
            version = attr: dalle_mini.__version__
         | 
| 4 | 
            +
            description = DALL·E mini - Generate images from a text prompt
         | 
| 5 | 
            +
            long_description = file: README.md
         | 
| 6 | 
            +
            long_description_content_type = text/markdown
         | 
| 7 | 
            +
            url = https://github.com/borisdayma/dalle-mini
         | 
| 8 | 
            +
            project_urls =
         | 
| 9 | 
            +
                Bug Tracker = https://github.com/borisdayma/dalle-mini/issues
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            [options]
         | 
| 12 | 
            +
            packages = find:
         | 
| 13 | 
            +
            install_requires =
         | 
| 14 | 
            +
                transformers
         | 
| 15 | 
            +
                jax
         | 
| 16 | 
            +
                flax
         | 
    	
        setup.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from setuptools import setup
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            if __name__ == "__main__":
         | 
| 4 | 
            +
                setup()
         | 
