Spaces:
Runtime error
Runtime error
feat(inference_notebook): dalle-mini is installable
Browse files
dev/inference/inference_pipeline.ipynb
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
"name": "DALL·E mini - Inference pipeline.ipynb",
|
| 7 |
"provenance": [],
|
| 8 |
"collapsed_sections": [],
|
| 9 |
-
"authorship_tag": "
|
| 10 |
"include_colab_link": true
|
| 11 |
},
|
| 12 |
"kernelspec": {
|
|
@@ -22,6 +22,7 @@
|
|
| 22 |
"49304912717a4995ae45d04a59d1f50e": {
|
| 23 |
"model_module": "@jupyter-widgets/controls",
|
| 24 |
"model_name": "HBoxModel",
|
|
|
|
| 25 |
"state": {
|
| 26 |
"_view_name": "HBoxView",
|
| 27 |
"_dom_classes": [],
|
|
@@ -42,6 +43,7 @@
|
|
| 42 |
"5fd9f97986024e8db560a6737ade9e2e": {
|
| 43 |
"model_module": "@jupyter-widgets/base",
|
| 44 |
"model_name": "LayoutModel",
|
|
|
|
| 45 |
"state": {
|
| 46 |
"_view_name": "LayoutView",
|
| 47 |
"grid_template_rows": null,
|
|
@@ -93,6 +95,7 @@
|
|
| 93 |
"caced43e3a4c493b98fb07cb41db045c": {
|
| 94 |
"model_module": "@jupyter-widgets/controls",
|
| 95 |
"model_name": "FloatProgressModel",
|
|
|
|
| 96 |
"state": {
|
| 97 |
"_view_name": "ProgressView",
|
| 98 |
"style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
|
|
@@ -116,6 +119,7 @@
|
|
| 116 |
"0acc161f2e9948b68b3fc4e57ef333c9": {
|
| 117 |
"model_module": "@jupyter-widgets/controls",
|
| 118 |
"model_name": "HTMLModel",
|
|
|
|
| 119 |
"state": {
|
| 120 |
"_view_name": "HTMLView",
|
| 121 |
"style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
|
|
@@ -136,6 +140,7 @@
|
|
| 136 |
"40c54b9454d346aabd197f2bcf189467": {
|
| 137 |
"model_module": "@jupyter-widgets/controls",
|
| 138 |
"model_name": "ProgressStyleModel",
|
|
|
|
| 139 |
"state": {
|
| 140 |
"_view_name": "StyleView",
|
| 141 |
"_model_name": "ProgressStyleModel",
|
|
@@ -151,6 +156,7 @@
|
|
| 151 |
"8b25334a48244a14aa9ba0176887e655": {
|
| 152 |
"model_module": "@jupyter-widgets/base",
|
| 153 |
"model_name": "LayoutModel",
|
|
|
|
| 154 |
"state": {
|
| 155 |
"_view_name": "LayoutView",
|
| 156 |
"grid_template_rows": null,
|
|
@@ -202,6 +208,7 @@
|
|
| 202 |
"7e7c488f57fc4acb8d261e2db81d61f0": {
|
| 203 |
"model_module": "@jupyter-widgets/controls",
|
| 204 |
"model_name": "DescriptionStyleModel",
|
|
|
|
| 205 |
"state": {
|
| 206 |
"_view_name": "StyleView",
|
| 207 |
"_model_name": "DescriptionStyleModel",
|
|
@@ -216,6 +223,7 @@
|
|
| 216 |
"72c401062a5348b1a366dffb5a403568": {
|
| 217 |
"model_module": "@jupyter-widgets/base",
|
| 218 |
"model_name": "LayoutModel",
|
|
|
|
| 219 |
"state": {
|
| 220 |
"_view_name": "LayoutView",
|
| 221 |
"grid_template_rows": null,
|
|
@@ -267,6 +275,7 @@
|
|
| 267 |
"022c124dfff348f285335732781b0887": {
|
| 268 |
"model_module": "@jupyter-widgets/controls",
|
| 269 |
"model_name": "HBoxModel",
|
|
|
|
| 270 |
"state": {
|
| 271 |
"_view_name": "HBoxView",
|
| 272 |
"_dom_classes": [],
|
|
@@ -287,6 +296,7 @@
|
|
| 287 |
"a44e47e9d26c4deb81a5a11a9db92a9f": {
|
| 288 |
"model_module": "@jupyter-widgets/base",
|
| 289 |
"model_name": "LayoutModel",
|
|
|
|
| 290 |
"state": {
|
| 291 |
"_view_name": "LayoutView",
|
| 292 |
"grid_template_rows": null,
|
|
@@ -338,6 +348,7 @@
|
|
| 338 |
"cd9c7016caae47c1b41fb2608c78b0bf": {
|
| 339 |
"model_module": "@jupyter-widgets/controls",
|
| 340 |
"model_name": "FloatProgressModel",
|
|
|
|
| 341 |
"state": {
|
| 342 |
"_view_name": "ProgressView",
|
| 343 |
"style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
|
|
@@ -361,6 +372,7 @@
|
|
| 361 |
"36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
|
| 362 |
"model_module": "@jupyter-widgets/controls",
|
| 363 |
"model_name": "HTMLModel",
|
|
|
|
| 364 |
"state": {
|
| 365 |
"_view_name": "HTMLView",
|
| 366 |
"style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
|
|
@@ -381,6 +393,7 @@
|
|
| 381 |
"c22f207311cf4fb69bd9328eabfd4ebb": {
|
| 382 |
"model_module": "@jupyter-widgets/controls",
|
| 383 |
"model_name": "ProgressStyleModel",
|
|
|
|
| 384 |
"state": {
|
| 385 |
"_view_name": "StyleView",
|
| 386 |
"_model_name": "ProgressStyleModel",
|
|
@@ -396,6 +409,7 @@
|
|
| 396 |
"5a38c6d83a264bedbf7efe6e97eba953": {
|
| 397 |
"model_module": "@jupyter-widgets/base",
|
| 398 |
"model_name": "LayoutModel",
|
|
|
|
| 399 |
"state": {
|
| 400 |
"_view_name": "LayoutView",
|
| 401 |
"grid_template_rows": null,
|
|
@@ -447,6 +461,7 @@
|
|
| 447 |
"037563a7eadd4ac5abb7249a2914d346": {
|
| 448 |
"model_module": "@jupyter-widgets/controls",
|
| 449 |
"model_name": "DescriptionStyleModel",
|
|
|
|
| 450 |
"state": {
|
| 451 |
"_view_name": "StyleView",
|
| 452 |
"_model_name": "DescriptionStyleModel",
|
|
@@ -461,6 +476,7 @@
|
|
| 461 |
"3975e7ed0b704990b1fa05909a9bb9b6": {
|
| 462 |
"model_module": "@jupyter-widgets/base",
|
| 463 |
"model_name": "LayoutModel",
|
|
|
|
| 464 |
"state": {
|
| 465 |
"_view_name": "LayoutView",
|
| 466 |
"grid_template_rows": null,
|
|
@@ -512,6 +528,7 @@
|
|
| 512 |
"f9f1fdc3819a4142b85304cd3c6358a2": {
|
| 513 |
"model_module": "@jupyter-widgets/controls",
|
| 514 |
"model_name": "HBoxModel",
|
|
|
|
| 515 |
"state": {
|
| 516 |
"_view_name": "HBoxView",
|
| 517 |
"_dom_classes": [],
|
|
@@ -532,6 +549,7 @@
|
|
| 532 |
"ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
|
| 533 |
"model_module": "@jupyter-widgets/base",
|
| 534 |
"model_name": "LayoutModel",
|
|
|
|
| 535 |
"state": {
|
| 536 |
"_view_name": "LayoutView",
|
| 537 |
"grid_template_rows": null,
|
|
@@ -583,6 +601,7 @@
|
|
| 583 |
"29d42e94b3b34c86a117b623da68faed": {
|
| 584 |
"model_module": "@jupyter-widgets/controls",
|
| 585 |
"model_name": "FloatProgressModel",
|
|
|
|
| 586 |
"state": {
|
| 587 |
"_view_name": "ProgressView",
|
| 588 |
"style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
|
|
@@ -606,6 +625,7 @@
|
|
| 606 |
"8b73de7dbdfe40dbbb39fb593520b984": {
|
| 607 |
"model_module": "@jupyter-widgets/controls",
|
| 608 |
"model_name": "HTMLModel",
|
|
|
|
| 609 |
"state": {
|
| 610 |
"_view_name": "HTMLView",
|
| 611 |
"style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
|
|
@@ -626,6 +646,7 @@
|
|
| 626 |
"8ce4d20d004a4382afa0abdd3b1f7191": {
|
| 627 |
"model_module": "@jupyter-widgets/controls",
|
| 628 |
"model_name": "ProgressStyleModel",
|
|
|
|
| 629 |
"state": {
|
| 630 |
"_view_name": "StyleView",
|
| 631 |
"_model_name": "ProgressStyleModel",
|
|
@@ -641,6 +662,7 @@
|
|
| 641 |
"efc4812245c8459c92e6436889b4f600": {
|
| 642 |
"model_module": "@jupyter-widgets/base",
|
| 643 |
"model_name": "LayoutModel",
|
|
|
|
| 644 |
"state": {
|
| 645 |
"_view_name": "LayoutView",
|
| 646 |
"grid_template_rows": null,
|
|
@@ -692,6 +714,7 @@
|
|
| 692 |
"717ccef4df1f477abb51814650eb47da": {
|
| 693 |
"model_module": "@jupyter-widgets/controls",
|
| 694 |
"model_name": "DescriptionStyleModel",
|
|
|
|
| 695 |
"state": {
|
| 696 |
"_view_name": "StyleView",
|
| 697 |
"_model_name": "DescriptionStyleModel",
|
|
@@ -706,6 +729,7 @@
|
|
| 706 |
"7dba58f0391c485a86e34e8039ec6189": {
|
| 707 |
"model_module": "@jupyter-widgets/base",
|
| 708 |
"model_name": "LayoutModel",
|
|
|
|
| 709 |
"state": {
|
| 710 |
"_view_name": "LayoutView",
|
| 711 |
"grid_template_rows": null,
|
|
@@ -804,8 +828,7 @@
|
|
| 804 |
"source": [
|
| 805 |
"!pip install -q transformers flax\n",
|
| 806 |
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX\n",
|
| 807 |
-
"!
|
| 808 |
-
"%cd dalle-mini/"
|
| 809 |
],
|
| 810 |
"execution_count": null,
|
| 811 |
"outputs": []
|
|
@@ -833,7 +856,7 @@
|
|
| 833 |
"import random\n",
|
| 834 |
"from tqdm.notebook import tqdm, trange"
|
| 835 |
],
|
| 836 |
-
"execution_count":
|
| 837 |
"outputs": []
|
| 838 |
},
|
| 839 |
{
|
|
@@ -846,7 +869,7 @@
|
|
| 846 |
"DALLE_REPO = 'flax-community/dalle-mini'\n",
|
| 847 |
"DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
|
| 848 |
],
|
| 849 |
-
"execution_count":
|
| 850 |
"outputs": []
|
| 851 |
},
|
| 852 |
{
|
|
@@ -871,7 +894,7 @@
|
|
| 871 |
"# set a prompt\n",
|
| 872 |
"prompt = 'picture of a waterfall under the sunset'"
|
| 873 |
],
|
| 874 |
-
"execution_count":
|
| 875 |
"outputs": []
|
| 876 |
},
|
| 877 |
{
|
|
@@ -888,7 +911,7 @@
|
|
| 888 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
|
| 889 |
"tokenized_prompt"
|
| 890 |
],
|
| 891 |
-
"execution_count":
|
| 892 |
"outputs": [
|
| 893 |
{
|
| 894 |
"output_type": "execute_result",
|
|
@@ -956,7 +979,7 @@
|
|
| 956 |
"subkeys = jax.random.split(key, num=n_predictions)\n",
|
| 957 |
"subkeys"
|
| 958 |
],
|
| 959 |
-
"execution_count":
|
| 960 |
"outputs": [
|
| 961 |
{
|
| 962 |
"output_type": "execute_result",
|
|
@@ -1004,7 +1027,7 @@
|
|
| 1004 |
"encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
|
| 1005 |
"encoded_images[0]"
|
| 1006 |
],
|
| 1007 |
-
"execution_count":
|
| 1008 |
"outputs": [
|
| 1009 |
{
|
| 1010 |
"output_type": "display_data",
|
|
@@ -1099,7 +1122,7 @@
|
|
| 1099 |
"encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
|
| 1100 |
"encoded_images[0]"
|
| 1101 |
],
|
| 1102 |
-
"execution_count":
|
| 1103 |
"outputs": [
|
| 1104 |
{
|
| 1105 |
"output_type": "execute_result",
|
|
@@ -1167,7 +1190,7 @@
|
|
| 1167 |
"source": [
|
| 1168 |
"encoded_images[0].shape"
|
| 1169 |
],
|
| 1170 |
-
"execution_count":
|
| 1171 |
"outputs": [
|
| 1172 |
{
|
| 1173 |
"output_type": "execute_result",
|
|
@@ -1204,7 +1227,7 @@
|
|
| 1204 |
"import numpy as np\n",
|
| 1205 |
"from PIL import Image"
|
| 1206 |
],
|
| 1207 |
-
"execution_count":
|
| 1208 |
"outputs": []
|
| 1209 |
},
|
| 1210 |
{
|
|
@@ -1217,7 +1240,7 @@
|
|
| 1217 |
"VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
|
| 1218 |
"VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
|
| 1219 |
],
|
| 1220 |
-
"execution_count":
|
| 1221 |
"outputs": []
|
| 1222 |
},
|
| 1223 |
{
|
|
@@ -1233,7 +1256,7 @@
|
|
| 1233 |
"# set up VQGAN\n",
|
| 1234 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
|
| 1235 |
],
|
| 1236 |
-
"execution_count":
|
| 1237 |
"outputs": [
|
| 1238 |
{
|
| 1239 |
"output_type": "stream",
|
|
@@ -1269,7 +1292,7 @@
|
|
| 1269 |
"decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
|
| 1270 |
"decoded_images[0]"
|
| 1271 |
],
|
| 1272 |
-
"execution_count":
|
| 1273 |
"outputs": [
|
| 1274 |
{
|
| 1275 |
"output_type": "display_data",
|
|
@@ -1373,7 +1396,7 @@
|
|
| 1373 |
"# normalize images\n",
|
| 1374 |
"clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
|
| 1375 |
],
|
| 1376 |
-
"execution_count":
|
| 1377 |
"outputs": []
|
| 1378 |
},
|
| 1379 |
{
|
|
@@ -1385,7 +1408,7 @@
|
|
| 1385 |
"# convert to image\n",
|
| 1386 |
"images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
|
| 1387 |
],
|
| 1388 |
-
"execution_count":
|
| 1389 |
"outputs": []
|
| 1390 |
},
|
| 1391 |
{
|
|
@@ -1402,7 +1425,7 @@
|
|
| 1402 |
"# display an image\n",
|
| 1403 |
"images[0]"
|
| 1404 |
],
|
| 1405 |
-
"execution_count":
|
| 1406 |
"outputs": [
|
| 1407 |
{
|
| 1408 |
"output_type": "execute_result",
|
|
@@ -1438,7 +1461,7 @@
|
|
| 1438 |
"source": [
|
| 1439 |
"from transformers import CLIPProcessor, FlaxCLIPModel"
|
| 1440 |
],
|
| 1441 |
-
"execution_count":
|
| 1442 |
"outputs": []
|
| 1443 |
},
|
| 1444 |
{
|
|
@@ -1474,7 +1497,7 @@
|
|
| 1474 |
"logits = clip(**inputs).logits_per_image\n",
|
| 1475 |
"scores = jax.nn.softmax(logits, axis=0).squeeze() # normalize and sum all scores to 1"
|
| 1476 |
],
|
| 1477 |
-
"execution_count":
|
| 1478 |
"outputs": []
|
| 1479 |
},
|
| 1480 |
{
|
|
@@ -1495,7 +1518,7 @@
|
|
| 1495 |
" display(images[idx])\n",
|
| 1496 |
" print()"
|
| 1497 |
],
|
| 1498 |
-
"execution_count":
|
| 1499 |
"outputs": [
|
| 1500 |
{
|
| 1501 |
"output_type": "stream",
|
|
@@ -1690,7 +1713,7 @@
|
|
| 1690 |
"from flax.training.common_utils import shard\n",
|
| 1691 |
"from flax.jax_utils import replicate"
|
| 1692 |
],
|
| 1693 |
-
"execution_count":
|
| 1694 |
"outputs": []
|
| 1695 |
},
|
| 1696 |
{
|
|
@@ -1706,7 +1729,7 @@
|
|
| 1706 |
"# check we can access TPU's or GPU's\n",
|
| 1707 |
"jax.devices()"
|
| 1708 |
],
|
| 1709 |
-
"execution_count":
|
| 1710 |
"outputs": [
|
| 1711 |
{
|
| 1712 |
"output_type": "execute_result",
|
|
@@ -1744,7 +1767,7 @@
|
|
| 1744 |
"# one set of inputs per device\n",
|
| 1745 |
"prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
|
| 1746 |
],
|
| 1747 |
-
"execution_count":
|
| 1748 |
"outputs": []
|
| 1749 |
},
|
| 1750 |
{
|
|
@@ -1757,7 +1780,7 @@
|
|
| 1757 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
|
| 1758 |
"tokenized_prompt = shard(tokenized_prompt)"
|
| 1759 |
],
|
| 1760 |
-
"execution_count":
|
| 1761 |
"outputs": []
|
| 1762 |
},
|
| 1763 |
{
|
|
@@ -1793,7 +1816,7 @@
|
|
| 1793 |
"def p_decode(indices, params):\n",
|
| 1794 |
" return vqgan.decode_code(indices, params=params)"
|
| 1795 |
],
|
| 1796 |
-
"execution_count":
|
| 1797 |
"outputs": []
|
| 1798 |
},
|
| 1799 |
{
|
|
@@ -1834,7 +1857,7 @@
|
|
| 1834 |
" for img in decoded_images:\n",
|
| 1835 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
| 1836 |
],
|
| 1837 |
-
"execution_count":
|
| 1838 |
"outputs": [
|
| 1839 |
{
|
| 1840 |
"output_type": "display_data",
|
|
@@ -1877,7 +1900,7 @@
|
|
| 1877 |
" display(img)\n",
|
| 1878 |
" print()"
|
| 1879 |
],
|
| 1880 |
-
"execution_count":
|
| 1881 |
"outputs": [
|
| 1882 |
{
|
| 1883 |
"output_type": "display_data",
|
|
|
|
| 6 |
"name": "DALL·E mini - Inference pipeline.ipynb",
|
| 7 |
"provenance": [],
|
| 8 |
"collapsed_sections": [],
|
| 9 |
+
"authorship_tag": "ABX9TyMUjEt1XMLq+6/GhSnVFsSx",
|
| 10 |
"include_colab_link": true
|
| 11 |
},
|
| 12 |
"kernelspec": {
|
|
|
|
| 22 |
"49304912717a4995ae45d04a59d1f50e": {
|
| 23 |
"model_module": "@jupyter-widgets/controls",
|
| 24 |
"model_name": "HBoxModel",
|
| 25 |
+
"model_module_version": "1.5.0",
|
| 26 |
"state": {
|
| 27 |
"_view_name": "HBoxView",
|
| 28 |
"_dom_classes": [],
|
|
|
|
| 43 |
"5fd9f97986024e8db560a6737ade9e2e": {
|
| 44 |
"model_module": "@jupyter-widgets/base",
|
| 45 |
"model_name": "LayoutModel",
|
| 46 |
+
"model_module_version": "1.2.0",
|
| 47 |
"state": {
|
| 48 |
"_view_name": "LayoutView",
|
| 49 |
"grid_template_rows": null,
|
|
|
|
| 95 |
"caced43e3a4c493b98fb07cb41db045c": {
|
| 96 |
"model_module": "@jupyter-widgets/controls",
|
| 97 |
"model_name": "FloatProgressModel",
|
| 98 |
+
"model_module_version": "1.5.0",
|
| 99 |
"state": {
|
| 100 |
"_view_name": "ProgressView",
|
| 101 |
"style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
|
|
|
|
| 119 |
"0acc161f2e9948b68b3fc4e57ef333c9": {
|
| 120 |
"model_module": "@jupyter-widgets/controls",
|
| 121 |
"model_name": "HTMLModel",
|
| 122 |
+
"model_module_version": "1.5.0",
|
| 123 |
"state": {
|
| 124 |
"_view_name": "HTMLView",
|
| 125 |
"style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
|
|
|
|
| 140 |
"40c54b9454d346aabd197f2bcf189467": {
|
| 141 |
"model_module": "@jupyter-widgets/controls",
|
| 142 |
"model_name": "ProgressStyleModel",
|
| 143 |
+
"model_module_version": "1.5.0",
|
| 144 |
"state": {
|
| 145 |
"_view_name": "StyleView",
|
| 146 |
"_model_name": "ProgressStyleModel",
|
|
|
|
| 156 |
"8b25334a48244a14aa9ba0176887e655": {
|
| 157 |
"model_module": "@jupyter-widgets/base",
|
| 158 |
"model_name": "LayoutModel",
|
| 159 |
+
"model_module_version": "1.2.0",
|
| 160 |
"state": {
|
| 161 |
"_view_name": "LayoutView",
|
| 162 |
"grid_template_rows": null,
|
|
|
|
| 208 |
"7e7c488f57fc4acb8d261e2db81d61f0": {
|
| 209 |
"model_module": "@jupyter-widgets/controls",
|
| 210 |
"model_name": "DescriptionStyleModel",
|
| 211 |
+
"model_module_version": "1.5.0",
|
| 212 |
"state": {
|
| 213 |
"_view_name": "StyleView",
|
| 214 |
"_model_name": "DescriptionStyleModel",
|
|
|
|
| 223 |
"72c401062a5348b1a366dffb5a403568": {
|
| 224 |
"model_module": "@jupyter-widgets/base",
|
| 225 |
"model_name": "LayoutModel",
|
| 226 |
+
"model_module_version": "1.2.0",
|
| 227 |
"state": {
|
| 228 |
"_view_name": "LayoutView",
|
| 229 |
"grid_template_rows": null,
|
|
|
|
| 275 |
"022c124dfff348f285335732781b0887": {
|
| 276 |
"model_module": "@jupyter-widgets/controls",
|
| 277 |
"model_name": "HBoxModel",
|
| 278 |
+
"model_module_version": "1.5.0",
|
| 279 |
"state": {
|
| 280 |
"_view_name": "HBoxView",
|
| 281 |
"_dom_classes": [],
|
|
|
|
| 296 |
"a44e47e9d26c4deb81a5a11a9db92a9f": {
|
| 297 |
"model_module": "@jupyter-widgets/base",
|
| 298 |
"model_name": "LayoutModel",
|
| 299 |
+
"model_module_version": "1.2.0",
|
| 300 |
"state": {
|
| 301 |
"_view_name": "LayoutView",
|
| 302 |
"grid_template_rows": null,
|
|
|
|
| 348 |
"cd9c7016caae47c1b41fb2608c78b0bf": {
|
| 349 |
"model_module": "@jupyter-widgets/controls",
|
| 350 |
"model_name": "FloatProgressModel",
|
| 351 |
+
"model_module_version": "1.5.0",
|
| 352 |
"state": {
|
| 353 |
"_view_name": "ProgressView",
|
| 354 |
"style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
|
|
|
|
| 372 |
"36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
|
| 373 |
"model_module": "@jupyter-widgets/controls",
|
| 374 |
"model_name": "HTMLModel",
|
| 375 |
+
"model_module_version": "1.5.0",
|
| 376 |
"state": {
|
| 377 |
"_view_name": "HTMLView",
|
| 378 |
"style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
|
|
|
|
| 393 |
"c22f207311cf4fb69bd9328eabfd4ebb": {
|
| 394 |
"model_module": "@jupyter-widgets/controls",
|
| 395 |
"model_name": "ProgressStyleModel",
|
| 396 |
+
"model_module_version": "1.5.0",
|
| 397 |
"state": {
|
| 398 |
"_view_name": "StyleView",
|
| 399 |
"_model_name": "ProgressStyleModel",
|
|
|
|
| 409 |
"5a38c6d83a264bedbf7efe6e97eba953": {
|
| 410 |
"model_module": "@jupyter-widgets/base",
|
| 411 |
"model_name": "LayoutModel",
|
| 412 |
+
"model_module_version": "1.2.0",
|
| 413 |
"state": {
|
| 414 |
"_view_name": "LayoutView",
|
| 415 |
"grid_template_rows": null,
|
|
|
|
| 461 |
"037563a7eadd4ac5abb7249a2914d346": {
|
| 462 |
"model_module": "@jupyter-widgets/controls",
|
| 463 |
"model_name": "DescriptionStyleModel",
|
| 464 |
+
"model_module_version": "1.5.0",
|
| 465 |
"state": {
|
| 466 |
"_view_name": "StyleView",
|
| 467 |
"_model_name": "DescriptionStyleModel",
|
|
|
|
| 476 |
"3975e7ed0b704990b1fa05909a9bb9b6": {
|
| 477 |
"model_module": "@jupyter-widgets/base",
|
| 478 |
"model_name": "LayoutModel",
|
| 479 |
+
"model_module_version": "1.2.0",
|
| 480 |
"state": {
|
| 481 |
"_view_name": "LayoutView",
|
| 482 |
"grid_template_rows": null,
|
|
|
|
| 528 |
"f9f1fdc3819a4142b85304cd3c6358a2": {
|
| 529 |
"model_module": "@jupyter-widgets/controls",
|
| 530 |
"model_name": "HBoxModel",
|
| 531 |
+
"model_module_version": "1.5.0",
|
| 532 |
"state": {
|
| 533 |
"_view_name": "HBoxView",
|
| 534 |
"_dom_classes": [],
|
|
|
|
| 549 |
"ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
|
| 550 |
"model_module": "@jupyter-widgets/base",
|
| 551 |
"model_name": "LayoutModel",
|
| 552 |
+
"model_module_version": "1.2.0",
|
| 553 |
"state": {
|
| 554 |
"_view_name": "LayoutView",
|
| 555 |
"grid_template_rows": null,
|
|
|
|
| 601 |
"29d42e94b3b34c86a117b623da68faed": {
|
| 602 |
"model_module": "@jupyter-widgets/controls",
|
| 603 |
"model_name": "FloatProgressModel",
|
| 604 |
+
"model_module_version": "1.5.0",
|
| 605 |
"state": {
|
| 606 |
"_view_name": "ProgressView",
|
| 607 |
"style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
|
|
|
|
| 625 |
"8b73de7dbdfe40dbbb39fb593520b984": {
|
| 626 |
"model_module": "@jupyter-widgets/controls",
|
| 627 |
"model_name": "HTMLModel",
|
| 628 |
+
"model_module_version": "1.5.0",
|
| 629 |
"state": {
|
| 630 |
"_view_name": "HTMLView",
|
| 631 |
"style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
|
|
|
|
| 646 |
"8ce4d20d004a4382afa0abdd3b1f7191": {
|
| 647 |
"model_module": "@jupyter-widgets/controls",
|
| 648 |
"model_name": "ProgressStyleModel",
|
| 649 |
+
"model_module_version": "1.5.0",
|
| 650 |
"state": {
|
| 651 |
"_view_name": "StyleView",
|
| 652 |
"_model_name": "ProgressStyleModel",
|
|
|
|
| 662 |
"efc4812245c8459c92e6436889b4f600": {
|
| 663 |
"model_module": "@jupyter-widgets/base",
|
| 664 |
"model_name": "LayoutModel",
|
| 665 |
+
"model_module_version": "1.2.0",
|
| 666 |
"state": {
|
| 667 |
"_view_name": "LayoutView",
|
| 668 |
"grid_template_rows": null,
|
|
|
|
| 714 |
"717ccef4df1f477abb51814650eb47da": {
|
| 715 |
"model_module": "@jupyter-widgets/controls",
|
| 716 |
"model_name": "DescriptionStyleModel",
|
| 717 |
+
"model_module_version": "1.5.0",
|
| 718 |
"state": {
|
| 719 |
"_view_name": "StyleView",
|
| 720 |
"_model_name": "DescriptionStyleModel",
|
|
|
|
| 729 |
"7dba58f0391c485a86e34e8039ec6189": {
|
| 730 |
"model_module": "@jupyter-widgets/base",
|
| 731 |
"model_name": "LayoutModel",
|
| 732 |
+
"model_module_version": "1.2.0",
|
| 733 |
"state": {
|
| 734 |
"_view_name": "LayoutView",
|
| 735 |
"grid_template_rows": null,
|
|
|
|
| 828 |
"source": [
|
| 829 |
"!pip install -q transformers flax\n",
|
| 830 |
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX\n",
|
| 831 |
+
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git # Model files"
|
|
|
|
| 832 |
],
|
| 833 |
"execution_count": null,
|
| 834 |
"outputs": []
|
|
|
|
| 856 |
"import random\n",
|
| 857 |
"from tqdm.notebook import tqdm, trange"
|
| 858 |
],
|
| 859 |
+
"execution_count": null,
|
| 860 |
"outputs": []
|
| 861 |
},
|
| 862 |
{
|
|
|
|
| 869 |
"DALLE_REPO = 'flax-community/dalle-mini'\n",
|
| 870 |
"DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
|
| 871 |
],
|
| 872 |
+
"execution_count": null,
|
| 873 |
"outputs": []
|
| 874 |
},
|
| 875 |
{
|
|
|
|
| 894 |
"# set a prompt\n",
|
| 895 |
"prompt = 'picture of a waterfall under the sunset'"
|
| 896 |
],
|
| 897 |
+
"execution_count": null,
|
| 898 |
"outputs": []
|
| 899 |
},
|
| 900 |
{
|
|
|
|
| 911 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
|
| 912 |
"tokenized_prompt"
|
| 913 |
],
|
| 914 |
+
"execution_count": null,
|
| 915 |
"outputs": [
|
| 916 |
{
|
| 917 |
"output_type": "execute_result",
|
|
|
|
| 979 |
"subkeys = jax.random.split(key, num=n_predictions)\n",
|
| 980 |
"subkeys"
|
| 981 |
],
|
| 982 |
+
"execution_count": null,
|
| 983 |
"outputs": [
|
| 984 |
{
|
| 985 |
"output_type": "execute_result",
|
|
|
|
| 1027 |
"encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
|
| 1028 |
"encoded_images[0]"
|
| 1029 |
],
|
| 1030 |
+
"execution_count": null,
|
| 1031 |
"outputs": [
|
| 1032 |
{
|
| 1033 |
"output_type": "display_data",
|
|
|
|
| 1122 |
"encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
|
| 1123 |
"encoded_images[0]"
|
| 1124 |
],
|
| 1125 |
+
"execution_count": null,
|
| 1126 |
"outputs": [
|
| 1127 |
{
|
| 1128 |
"output_type": "execute_result",
|
|
|
|
| 1190 |
"source": [
|
| 1191 |
"encoded_images[0].shape"
|
| 1192 |
],
|
| 1193 |
+
"execution_count": null,
|
| 1194 |
"outputs": [
|
| 1195 |
{
|
| 1196 |
"output_type": "execute_result",
|
|
|
|
| 1227 |
"import numpy as np\n",
|
| 1228 |
"from PIL import Image"
|
| 1229 |
],
|
| 1230 |
+
"execution_count": null,
|
| 1231 |
"outputs": []
|
| 1232 |
},
|
| 1233 |
{
|
|
|
|
| 1240 |
"VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
|
| 1241 |
"VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
|
| 1242 |
],
|
| 1243 |
+
"execution_count": null,
|
| 1244 |
"outputs": []
|
| 1245 |
},
|
| 1246 |
{
|
|
|
|
| 1256 |
"# set up VQGAN\n",
|
| 1257 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
|
| 1258 |
],
|
| 1259 |
+
"execution_count": null,
|
| 1260 |
"outputs": [
|
| 1261 |
{
|
| 1262 |
"output_type": "stream",
|
|
|
|
| 1292 |
"decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
|
| 1293 |
"decoded_images[0]"
|
| 1294 |
],
|
| 1295 |
+
"execution_count": null,
|
| 1296 |
"outputs": [
|
| 1297 |
{
|
| 1298 |
"output_type": "display_data",
|
|
|
|
| 1396 |
"# normalize images\n",
|
| 1397 |
"clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
|
| 1398 |
],
|
| 1399 |
+
"execution_count": null,
|
| 1400 |
"outputs": []
|
| 1401 |
},
|
| 1402 |
{
|
|
|
|
| 1408 |
"# convert to image\n",
|
| 1409 |
"images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
|
| 1410 |
],
|
| 1411 |
+
"execution_count": null,
|
| 1412 |
"outputs": []
|
| 1413 |
},
|
| 1414 |
{
|
|
|
|
| 1425 |
"# display an image\n",
|
| 1426 |
"images[0]"
|
| 1427 |
],
|
| 1428 |
+
"execution_count": null,
|
| 1429 |
"outputs": [
|
| 1430 |
{
|
| 1431 |
"output_type": "execute_result",
|
|
|
|
| 1461 |
"source": [
|
| 1462 |
"from transformers import CLIPProcessor, FlaxCLIPModel"
|
| 1463 |
],
|
| 1464 |
+
"execution_count": null,
|
| 1465 |
"outputs": []
|
| 1466 |
},
|
| 1467 |
{
|
|
|
|
| 1497 |
"logits = clip(**inputs).logits_per_image\n",
|
| 1498 |
"scores = jax.nn.softmax(logits, axis=0).squeeze() # normalize and sum all scores to 1"
|
| 1499 |
],
|
| 1500 |
+
"execution_count": null,
|
| 1501 |
"outputs": []
|
| 1502 |
},
|
| 1503 |
{
|
|
|
|
| 1518 |
" display(images[idx])\n",
|
| 1519 |
" print()"
|
| 1520 |
],
|
| 1521 |
+
"execution_count": null,
|
| 1522 |
"outputs": [
|
| 1523 |
{
|
| 1524 |
"output_type": "stream",
|
|
|
|
| 1713 |
"from flax.training.common_utils import shard\n",
|
| 1714 |
"from flax.jax_utils import replicate"
|
| 1715 |
],
|
| 1716 |
+
"execution_count": null,
|
| 1717 |
"outputs": []
|
| 1718 |
},
|
| 1719 |
{
|
|
|
|
| 1729 |
"# check we can access TPU's or GPU's\n",
|
| 1730 |
"jax.devices()"
|
| 1731 |
],
|
| 1732 |
+
"execution_count": null,
|
| 1733 |
"outputs": [
|
| 1734 |
{
|
| 1735 |
"output_type": "execute_result",
|
|
|
|
| 1767 |
"# one set of inputs per device\n",
|
| 1768 |
"prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
|
| 1769 |
],
|
| 1770 |
+
"execution_count": null,
|
| 1771 |
"outputs": []
|
| 1772 |
},
|
| 1773 |
{
|
|
|
|
| 1780 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
|
| 1781 |
"tokenized_prompt = shard(tokenized_prompt)"
|
| 1782 |
],
|
| 1783 |
+
"execution_count": null,
|
| 1784 |
"outputs": []
|
| 1785 |
},
|
| 1786 |
{
|
|
|
|
| 1816 |
"def p_decode(indices, params):\n",
|
| 1817 |
" return vqgan.decode_code(indices, params=params)"
|
| 1818 |
],
|
| 1819 |
+
"execution_count": null,
|
| 1820 |
"outputs": []
|
| 1821 |
},
|
| 1822 |
{
|
|
|
|
| 1857 |
" for img in decoded_images:\n",
|
| 1858 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
| 1859 |
],
|
| 1860 |
+
"execution_count": null,
|
| 1861 |
"outputs": [
|
| 1862 |
{
|
| 1863 |
"output_type": "display_data",
|
|
|
|
| 1900 |
" display(img)\n",
|
| 1901 |
" print()"
|
| 1902 |
],
|
| 1903 |
+
"execution_count": null,
|
| 1904 |
"outputs": [
|
| 1905 |
{
|
| 1906 |
"output_type": "display_data",
|