Spaces:
Runtime error
Runtime error
feat(log_inference_samples): cleanup
Browse files
tools/inference/log_inference_samples.ipynb
CHANGED
|
@@ -100,11 +100,12 @@
|
|
| 100 |
"outputs": [],
|
| 101 |
"source": [
|
| 102 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
| 103 |
-
"clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
| 104 |
-
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
| 105 |
-
"clip_params = replicate(clip.params)\n",
|
| 106 |
"vqgan_params = replicate(vqgan.params)\n",
|
| 107 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
"if add_clip_32:\n",
|
| 109 |
" clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
| 110 |
" processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
|
@@ -123,8 +124,8 @@
|
|
| 123 |
" return vqgan.decode_code(indices, params=params)\n",
|
| 124 |
"\n",
|
| 125 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 126 |
-
"def
|
| 127 |
-
" logits =
|
| 128 |
" return logits\n",
|
| 129 |
"\n",
|
| 130 |
"if add_clip_32:\n",
|
|
@@ -229,7 +230,7 @@
|
|
| 229 |
"outputs": [],
|
| 230 |
"source": [
|
| 231 |
"run_id = run_ids[0]\n",
|
| 232 |
-
"# TODO: turn everything into a class"
|
| 233 |
]
|
| 234 |
},
|
| 235 |
{
|
|
@@ -248,10 +249,8 @@
|
|
| 248 |
"for artifact in artifact_versions:\n",
|
| 249 |
" print(f'Processing artifact: {artifact.name}')\n",
|
| 250 |
" version = int(artifact.version[1:])\n",
|
| 251 |
-
"
|
| 252 |
-
"
|
| 253 |
-
" results32 = []\n",
|
| 254 |
-
" columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
|
| 255 |
" \n",
|
| 256 |
" if latest_only:\n",
|
| 257 |
" assert last_inference_version is None or version > last_inference_version\n",
|
|
@@ -307,34 +306,13 @@
|
|
| 307 |
" for img in decoded_images:\n",
|
| 308 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
|
| 309 |
"\n",
|
| 310 |
-
"
|
| 311 |
-
"
|
| 312 |
-
" clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
| 313 |
-
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
| 314 |
-
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
| 315 |
-
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
| 316 |
-
" clip_inputs = shard(clip_inputs)\n",
|
| 317 |
-
" logits = p_clip(clip_inputs, clip_params)\n",
|
| 318 |
-
" logits = logits.reshape(-1, num_images)\n",
|
| 319 |
-
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 320 |
-
" logits = jax.device_get(logits)\n",
|
| 321 |
-
" # add to results table\n",
|
| 322 |
-
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
| 323 |
-
" if sample == padding_item: continue\n",
|
| 324 |
-
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
| 325 |
-
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
| 326 |
-
" top_scores = [scores[x] for x in idx]\n",
|
| 327 |
-
" results.append([sample] + top_images + top_scores)\n",
|
| 328 |
-
" \n",
|
| 329 |
-
" # get clip 32 scores - TODO: this should be refactored as it is same code as above\n",
|
| 330 |
-
" if add_clip_32:\n",
|
| 331 |
-
" print('Calculating CLIP 32 scores')\n",
|
| 332 |
-
" clip_inputs = processor32(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
| 333 |
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
| 334 |
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
| 335 |
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
| 336 |
" clip_inputs = shard(clip_inputs)\n",
|
| 337 |
-
" logits =
|
| 338 |
" logits = logits.reshape(-1, num_images)\n",
|
| 339 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 340 |
" logits = jax.device_get(logits)\n",
|
|
@@ -342,13 +320,24 @@
|
|
| 342 |
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
| 343 |
" if sample == padding_item: continue\n",
|
| 344 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
| 345 |
-
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
| 346 |
-
"
|
| 347 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
" pbar.close()\n",
|
| 349 |
"\n",
|
|
|
|
|
|
|
| 350 |
" # log results\n",
|
| 351 |
-
" table = wandb.Table(columns=columns, data=
|
| 352 |
" run.log({'Samples': table, 'version': version})\n",
|
| 353 |
" wandb.finish()\n",
|
| 354 |
" \n",
|
|
@@ -359,19 +348,6 @@
|
|
| 359 |
" wandb.finish()\n",
|
| 360 |
" run = None # ensure we don't log on this run"
|
| 361 |
]
|
| 362 |
-
},
|
| 363 |
-
{
|
| 364 |
-
"cell_type": "code",
|
| 365 |
-
"execution_count": null,
|
| 366 |
-
"id": "4e4c7d0c-2848-4f88-b967-82fd571534f1",
|
| 367 |
-
"metadata": {},
|
| 368 |
-
"outputs": [],
|
| 369 |
-
"source": [
|
| 370 |
-
"# TODO: not implemented\n",
|
| 371 |
-
"def log_runs(runs):\n",
|
| 372 |
-
" for run in tqdm(runs):\n",
|
| 373 |
-
" log_run(run)"
|
| 374 |
-
]
|
| 375 |
}
|
| 376 |
],
|
| 377 |
"metadata": {
|
|
|
|
| 100 |
"outputs": [],
|
| 101 |
"source": [
|
| 102 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
|
|
|
|
|
|
|
|
|
| 103 |
"vqgan_params = replicate(vqgan.params)\n",
|
| 104 |
"\n",
|
| 105 |
+
"clip16 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
| 106 |
+
"processor16 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
| 107 |
+
"clip16_params = replicate(clip16.params)\n",
|
| 108 |
+
"\n",
|
| 109 |
"if add_clip_32:\n",
|
| 110 |
" clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
| 111 |
" processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
|
|
|
| 124 |
" return vqgan.decode_code(indices, params=params)\n",
|
| 125 |
"\n",
|
| 126 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 127 |
+
"def p_clip16(inputs, params):\n",
|
| 128 |
+
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
| 129 |
" return logits\n",
|
| 130 |
"\n",
|
| 131 |
"if add_clip_32:\n",
|
|
|
|
| 230 |
"outputs": [],
|
| 231 |
"source": [
|
| 232 |
"run_id = run_ids[0]\n",
|
| 233 |
+
"# TODO: turn everything into a class or loop over runs"
|
| 234 |
]
|
| 235 |
},
|
| 236 |
{
|
|
|
|
| 249 |
"for artifact in artifact_versions:\n",
|
| 250 |
" print(f'Processing artifact: {artifact.name}')\n",
|
| 251 |
" version = int(artifact.version[1:])\n",
|
| 252 |
+
" results16, results32 = [], []\n",
|
| 253 |
+
" columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)]\n",
|
|
|
|
|
|
|
| 254 |
" \n",
|
| 255 |
" if latest_only:\n",
|
| 256 |
" assert last_inference_version is None or version > last_inference_version\n",
|
|
|
|
| 306 |
" for img in decoded_images:\n",
|
| 307 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
|
| 308 |
"\n",
|
| 309 |
+
" def add_clip_results(results, processor, p_clip, clip_params): \n",
|
| 310 |
+
" clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
| 312 |
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
| 313 |
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
| 314 |
" clip_inputs = shard(clip_inputs)\n",
|
| 315 |
+
" logits = p_clip(clip_inputs, clip32_params)\n",
|
| 316 |
" logits = logits.reshape(-1, num_images)\n",
|
| 317 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 318 |
" logits = jax.device_get(logits)\n",
|
|
|
|
| 320 |
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
| 321 |
" if sample == padding_item: continue\n",
|
| 322 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
| 323 |
+
" top_images = [wandb.Image(cur_images[x], caption=f'Score: {scores[x]:.2f}') for x in idx]\n",
|
| 324 |
+
" results.append([sample] + top_images)\n",
|
| 325 |
+
" \n",
|
| 326 |
+
" # get clip scores\n",
|
| 327 |
+
" pbar.set_description('Calculating CLIP 16 scores')\n",
|
| 328 |
+
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
| 329 |
+
" \n",
|
| 330 |
+
" # get clip 32 scores\n",
|
| 331 |
+
" if add_clip_32:\n",
|
| 332 |
+
" pbar.set_description('Calculating CLIP 32 scores')\n",
|
| 333 |
+
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
| 334 |
+
"\n",
|
| 335 |
" pbar.close()\n",
|
| 336 |
"\n",
|
| 337 |
+
" \n",
|
| 338 |
+
"\n",
|
| 339 |
" # log results\n",
|
| 340 |
+
" table = wandb.Table(columns=columns, data=results16)\n",
|
| 341 |
" run.log({'Samples': table, 'version': version})\n",
|
| 342 |
" wandb.finish()\n",
|
| 343 |
" \n",
|
|
|
|
| 348 |
" wandb.finish()\n",
|
| 349 |
" run = None # ensure we don't log on this run"
|
| 350 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
}
|
| 352 |
],
|
| 353 |
"metadata": {
|