Spaces:
Running
Running
feat: improve inference demo
Browse files
tools/inference/inference_pipeline.ipynb
CHANGED
|
@@ -41,10 +41,10 @@
|
|
| 41 |
"outputs": [],
|
| 42 |
"source": [
|
| 43 |
"# Install required libraries\n",
|
| 44 |
-
"
|
| 45 |
-
"
|
| 46 |
-
"
|
| 47 |
-
"
|
| 48 |
]
|
| 49 |
},
|
| 50 |
{
|
|
@@ -70,8 +70,8 @@
|
|
| 70 |
"# Model references\n",
|
| 71 |
"\n",
|
| 72 |
"# dalle-mini\n",
|
| 73 |
-
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-
|
| 74 |
-
"DALLE_COMMIT_ID = None
|
| 75 |
"\n",
|
| 76 |
"# VQGAN model\n",
|
| 77 |
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
|
@@ -91,13 +91,20 @@
|
|
| 91 |
"import jax\n",
|
| 92 |
"import jax.numpy as jnp\n",
|
| 93 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
"# type used for computation - use bfloat16 on TPU's\n",
|
| 95 |
"dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
|
| 96 |
"\n",
|
| 97 |
-
"# TODO
|
| 98 |
-
"# - we currently have an issue with model.generate() in bfloat16\n",
|
| 99 |
-
"# - https://github.com/google/jax/pull/9089 should fix it\n",
|
| 100 |
-
"# - remove below line and test on TPU with next release of JAX\n",
|
| 101 |
"dtype = jnp.float32"
|
| 102 |
]
|
| 103 |
},
|
|
@@ -115,35 +122,18 @@
|
|
| 115 |
"outputs": [],
|
| 116 |
"source": [
|
| 117 |
"# Load models & tokenizer\n",
|
| 118 |
-
"from dalle_mini.model import DalleBart\n",
|
| 119 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
| 120 |
-
"from transformers import
|
| 121 |
"import wandb\n",
|
| 122 |
"\n",
|
| 123 |
"# Load dalle-mini\n",
|
| 124 |
-
"
|
| 125 |
-
"
|
| 126 |
-
"
|
| 127 |
-
"
|
| 128 |
-
"
|
| 129 |
-
"
|
| 130 |
-
" \"flax_model.msgpack\",\n",
|
| 131 |
-
" \"merges.txt\",\n",
|
| 132 |
-
" \"special_tokens_map.json\",\n",
|
| 133 |
-
" \"tokenizer.json\",\n",
|
| 134 |
-
" \"tokenizer_config.json\",\n",
|
| 135 |
-
" \"vocab.json\",\n",
|
| 136 |
-
" ]\n",
|
| 137 |
-
" for f in model_files:\n",
|
| 138 |
-
" artifact.get_path(f).download(\"model\")\n",
|
| 139 |
-
" model = DalleBart.from_pretrained(\"model\", dtype=dtype, abstract_init=True)\n",
|
| 140 |
-
" tokenizer = AutoTokenizer.from_pretrained(\"model\")\n",
|
| 141 |
-
"else:\n",
|
| 142 |
-
" # local folder or 🤗 Hub\n",
|
| 143 |
-
" model = DalleBart.from_pretrained(\n",
|
| 144 |
-
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
|
| 145 |
-
" )\n",
|
| 146 |
-
" tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
|
| 147 |
"\n",
|
| 148 |
"# Load VQGAN\n",
|
| 149 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
|
@@ -210,7 +200,8 @@
|
|
| 210 |
" prng_key=key,\n",
|
| 211 |
" params=params,\n",
|
| 212 |
" top_k=top_k,\n",
|
| 213 |
-
" top_p=top_p
|
|
|
|
| 214 |
" )\n",
|
| 215 |
"\n",
|
| 216 |
"\n",
|
|
@@ -233,7 +224,7 @@
|
|
| 233 |
"id": "HmVN6IBwapBA"
|
| 234 |
},
|
| 235 |
"source": [
|
| 236 |
-
"Keys are passed to the model on each device to generate unique
|
| 237 |
]
|
| 238 |
},
|
| 239 |
{
|
|
@@ -247,7 +238,7 @@
|
|
| 247 |
"import random\n",
|
| 248 |
"\n",
|
| 249 |
"# create a random key\n",
|
| 250 |
-
"seed = random.randint(0, 2
|
| 251 |
"key = jax.random.PRNGKey(seed)"
|
| 252 |
]
|
| 253 |
},
|
|
@@ -299,7 +290,7 @@
|
|
| 299 |
},
|
| 300 |
"outputs": [],
|
| 301 |
"source": [
|
| 302 |
-
"prompt = \"a
|
| 303 |
]
|
| 304 |
},
|
| 305 |
{
|
|
@@ -316,27 +307,19 @@
|
|
| 316 |
},
|
| 317 |
{
|
| 318 |
"cell_type": "markdown",
|
| 319 |
-
"metadata": {
|
| 320 |
-
"id": "iFVOyYboP0L-"
|
| 321 |
-
},
|
| 322 |
"source": [
|
| 323 |
-
"We
|
| 324 |
]
|
| 325 |
},
|
| 326 |
{
|
| 327 |
"cell_type": "code",
|
| 328 |
"execution_count": null,
|
| 329 |
-
"metadata": {
|
| 330 |
-
"id": "Rii_FJ7POw1y"
|
| 331 |
-
},
|
| 332 |
"outputs": [],
|
| 333 |
"source": [
|
| 334 |
-
"# repeat the prompt on each device\n",
|
| 335 |
-
"repeated_prompts = [processed_prompt] * jax.device_count()\n",
|
| 336 |
-
"\n",
|
| 337 |
-
"# tokenize\n",
|
| 338 |
"tokenized_prompt = tokenizer(\n",
|
| 339 |
-
"
|
| 340 |
" return_tensors=\"jax\",\n",
|
| 341 |
" padding=\"max_length\",\n",
|
| 342 |
" truncation=True,\n",
|
|
@@ -360,24 +343,18 @@
|
|
| 360 |
},
|
| 361 |
{
|
| 362 |
"cell_type": "markdown",
|
| 363 |
-
"metadata": {
|
| 364 |
-
"id": "2wiDtG3_SH2u"
|
| 365 |
-
},
|
| 366 |
"source": [
|
| 367 |
-
"Finally we
|
| 368 |
]
|
| 369 |
},
|
| 370 |
{
|
| 371 |
"cell_type": "code",
|
| 372 |
"execution_count": null,
|
| 373 |
-
"metadata": {
|
| 374 |
-
"id": "AImyrxHtR9TG"
|
| 375 |
-
},
|
| 376 |
"outputs": [],
|
| 377 |
"source": [
|
| 378 |
-
"
|
| 379 |
-
"\n",
|
| 380 |
-
"tokenized_prompt = shard(tokenized_prompt)"
|
| 381 |
]
|
| 382 |
},
|
| 383 |
{
|
|
@@ -455,6 +432,8 @@
|
|
| 455 |
},
|
| 456 |
"outputs": [],
|
| 457 |
"source": [
|
|
|
|
|
|
|
| 458 |
"# get clip scores\n",
|
| 459 |
"clip_inputs = processor(\n",
|
| 460 |
" text=[prompt] * jax.device_count(),\n",
|
|
|
|
| 41 |
"outputs": [],
|
| 42 |
"source": [
|
| 43 |
"# Install required libraries\n",
|
| 44 |
+
"#!pip install -q transformers\n",
|
| 45 |
+
"#!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
|
| 46 |
+
"#!pip install -q git+https://github.com/borisdayma/dalle-mini.git\n",
|
| 47 |
+
"#!pip install -q wandb"
|
| 48 |
]
|
| 49 |
},
|
| 50 |
{
|
|
|
|
| 70 |
"# Model references\n",
|
| 71 |
"\n",
|
| 72 |
"# dalle-mini\n",
|
| 73 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-mehdx7dg:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
|
| 74 |
+
"DALLE_COMMIT_ID = None\n",
|
| 75 |
"\n",
|
| 76 |
"# VQGAN model\n",
|
| 77 |
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
|
|
|
| 91 |
"import jax\n",
|
| 92 |
"import jax.numpy as jnp\n",
|
| 93 |
"\n",
|
| 94 |
+
"# check how many devices are available\n",
|
| 95 |
+
"jax.local_device_count()"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "code",
|
| 100 |
+
"execution_count": null,
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"outputs": [],
|
| 103 |
+
"source": [
|
| 104 |
"# type used for computation - use bfloat16 on TPU's\n",
|
| 105 |
"dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
|
| 106 |
"\n",
|
| 107 |
+
"# TODO: fix issue with bfloat16\n",
|
|
|
|
|
|
|
|
|
|
| 108 |
"dtype = jnp.float32"
|
| 109 |
]
|
| 110 |
},
|
|
|
|
| 122 |
"outputs": [],
|
| 123 |
"source": [
|
| 124 |
"# Load models & tokenizer\n",
|
| 125 |
+
"from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
|
| 126 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
| 127 |
+
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
|
| 128 |
"import wandb\n",
|
| 129 |
"\n",
|
| 130 |
"# Load dalle-mini\n",
|
| 131 |
+
"model = DalleBart.from_pretrained(\n",
|
| 132 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
|
| 133 |
+
")\n",
|
| 134 |
+
"tokenizer = DalleBartTokenizer.from_pretrained(\n",
|
| 135 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID\n",
|
| 136 |
+
")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
"\n",
|
| 138 |
"# Load VQGAN\n",
|
| 139 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
|
|
|
| 200 |
" prng_key=key,\n",
|
| 201 |
" params=params,\n",
|
| 202 |
" top_k=top_k,\n",
|
| 203 |
+
" top_p=top_p,\n",
|
| 204 |
+
" max_length=257\n",
|
| 205 |
" )\n",
|
| 206 |
"\n",
|
| 207 |
"\n",
|
|
|
|
| 224 |
"id": "HmVN6IBwapBA"
|
| 225 |
},
|
| 226 |
"source": [
|
| 227 |
+
"Keys are passed to the model on each device to generate unique inference per device."
|
| 228 |
]
|
| 229 |
},
|
| 230 |
{
|
|
|
|
| 238 |
"import random\n",
|
| 239 |
"\n",
|
| 240 |
"# create a random key\n",
|
| 241 |
+
"seed = random.randint(0, 2**32 - 1)\n",
|
| 242 |
"key = jax.random.PRNGKey(seed)"
|
| 243 |
]
|
| 244 |
},
|
|
|
|
| 290 |
},
|
| 291 |
"outputs": [],
|
| 292 |
"source": [
|
| 293 |
+
"prompt = \"a waterfall under the sunset\""
|
| 294 |
]
|
| 295 |
},
|
| 296 |
{
|
|
|
|
| 307 |
},
|
| 308 |
{
|
| 309 |
"cell_type": "markdown",
|
| 310 |
+
"metadata": {},
|
|
|
|
|
|
|
| 311 |
"source": [
|
| 312 |
+
"We tokenize the prompt."
|
| 313 |
]
|
| 314 |
},
|
| 315 |
{
|
| 316 |
"cell_type": "code",
|
| 317 |
"execution_count": null,
|
| 318 |
+
"metadata": {},
|
|
|
|
|
|
|
| 319 |
"outputs": [],
|
| 320 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
"tokenized_prompt = tokenizer(\n",
|
| 322 |
+
" processed_prompt,\n",
|
| 323 |
" return_tensors=\"jax\",\n",
|
| 324 |
" padding=\"max_length\",\n",
|
| 325 |
" truncation=True,\n",
|
|
|
|
| 343 |
},
|
| 344 |
{
|
| 345 |
"cell_type": "markdown",
|
| 346 |
+
"metadata": {},
|
|
|
|
|
|
|
| 347 |
"source": [
|
| 348 |
+
"Finally we replicate it onto each device."
|
| 349 |
]
|
| 350 |
},
|
| 351 |
{
|
| 352 |
"cell_type": "code",
|
| 353 |
"execution_count": null,
|
| 354 |
+
"metadata": {},
|
|
|
|
|
|
|
| 355 |
"outputs": [],
|
| 356 |
"source": [
|
| 357 |
+
"tokenized_prompt = replicate(tokenized_prompt)"
|
|
|
|
|
|
|
| 358 |
]
|
| 359 |
},
|
| 360 |
{
|
|
|
|
| 432 |
},
|
| 433 |
"outputs": [],
|
| 434 |
"source": [
|
| 435 |
+
"from flax.training.common_utils import shard\n",
|
| 436 |
+
"\n",
|
| 437 |
"# get clip scores\n",
|
| 438 |
"clip_inputs = processor(\n",
|
| 439 |
" text=[prompt] * jax.device_count(),\n",
|