File size: 3,709 Bytes
33d4721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Text Classification using AutoTrain Advanced\n",
    "\n",
    "In this notebook, we will train a text classification model using AutoTrain Advanced.\n",
    "You can replace the model with any Hugging Face transformers compatible model and dataset with any other dataset in proper formatting.\n",
    "For dataset formatting, please take a look at [docs](https://huggingface.co/docs/autotrain/index)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from autotrain.params import TextClassificationParams\n",
    "from autotrain.project import AutoTrainProject"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "HF_USERNAME = \"your_huggingface_username\"\n",
    "HF_TOKEN = \"your_huggingface_write_token\" # get it from https://huggingface.co/settings/token\n",
    "# It is recommended to use secrets or environment variables to store your HF_TOKEN\n",
    "# your token is required if push_to_hub is set to True or if you are accessing a gated model/dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = TextClassificationParams(\n",
    "    model=\"google-bert/bert-base-uncased\",\n",
    "    data_path=\"stanfordnlp/imdb\", # path to the dataset on huggingface hub\n",
    "    text_column=\"text\", # the column in the dataset that contains the text\n",
    "    target_column=\"label\", # the column in the dataset that contains the labels\n",
    "    train_split=\"train\",\n",
    "    valid_split=\"test\",\n",
    "    epochs=3,\n",
    "    batch_size=8,\n",
    "    max_seq_length=512,\n",
    "    lr=1e-5,\n",
    "    optimizer=\"adamw_torch\",\n",
    "    scheduler=\"linear\",\n",
    "    gradient_accumulation=1,\n",
    "    mixed_precision=\"fp16\",\n",
    "    project_name=\"autotrain-model\",\n",
    "    log=\"tensorboard\",\n",
    "    push_to_hub=True,\n",
    "    username=HF_USERNAME,\n",
    "    token=HF_TOKEN,\n",
    ")\n",
    "# tip: you can use `?TextClassificationParams` to see the full list of allowed parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If your dataset is in CSV / JSONL format (JSONL is most preferred) and is stored locally, make the following changes to `params`:\n",
    "\n",
    "```python\n",
    "params = TextClassificationParams(\n",
    "    data_path=\"data/\", # this is the path to folder where train.jsonl/train.csv is located\n",
    "    text_column=\"text\", # this is the column name in the CSV/JSONL file which contains the text\n",
    "    train_split = \"train\" # this is the filename without extension\n",
    "    valid_split = \"valid\" # this is the filename without extension\n",
    "    .\n",
    "    .\n",
    "    .\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this will train the model locally\n",
    "project = AutoTrainProject(params=params, backend=\"local\", process=True)\n",
    "project.create()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "autotrain",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}