{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "98d53c05"
},
"source": [
"## Saving a Cats v Dogs Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a minimal example showing how to train a fastai model on Kaggle, and save it so you can use it in your app."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, import all the stuff we need from fastai:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "44eb0ad3"
},
"outputs": [],
"source": [
"from fastai.vision.all import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Download and decompress our dataset, which is pictures of dogs and cats:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.PETS)/'images'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need a way to label our images as dogs or cats. In this dataset, pictures of cats are given a filename that starts with a capital letter:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "44eb0ad3"
},
"outputs": [],
"source": [
"def is_cat(x): return x[0].isupper() "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can create our `DataLoaders`:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "44eb0ad3"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hoanganh/mambaforge/envs/fastcourse/lib/python3.10/site-packages/fastai/data/transforms.py:225: FutureWarning: is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead\n",
" if is_categorical_dtype(col):\n"
]
}
],
"source": [
"dls = ImageDataLoaders.from_name_func('.',\n",
" get_image_files(path), valid_pct=0.2, seed=42,\n",
" label_func=is_cat,\n",
" item_tfms=Resize(192))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"... and train our model, a resnet18 (to keep it small and fast):"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "c107f724",
"outputId": "fcc1de68-7c8b-43f5-b9eb-fcdb0773ef07"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hoanganh/mambaforge/envs/fastcourse/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/home/hoanganh/mambaforge/envs/fastcourse/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"