laureBe commited on
Commit
48bd5aa
Β·
verified Β·
1 Parent(s): 0ae53cb

Upload notebooks_submitted-text.ipynb

Browse files
notebooks/notebooks_submitted-text.ipynb ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Text task notebook template\n",
8
+ "## Loading the necessary libraries"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "metadata": {},
15
+ "outputs": [
16
+ {
17
+ "name": "stderr",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "2025-01-29 12:18:59.954133: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
21
+ "To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
22
+ ]
23
+ },
24
+ {
25
+ "name": "stdout",
26
+ "output_type": "stream",
27
+ "text": [
28
+ "{'quote': 'Interesting to note that Oklahoma minimum temperatures in 2011 were in the bottom ten, including the coldest Oklahoma temperature ever recorded, -31F on February 10, 2011.', 'label': '0_not_relevant', 'source': 'FLICC', 'url': 'https://huggingface.co/datasets/fzanartu/FLICCdataset', 'language': 'en', 'subsource': 'CARDS', 'id': None, '__index_level_0__': 1109}\n"
29
+ ]
30
+ },
31
+ {
32
+ "data": {
33
+ "text/plain": [
34
+ "DatasetDict({\n",
35
+ " train: Dataset({\n",
36
+ " features: ['quote', 'label', 'source', 'url', 'language', 'subsource', 'id', '__index_level_0__'],\n",
37
+ " num_rows: 4872\n",
38
+ " })\n",
39
+ " test: Dataset({\n",
40
+ " features: ['quote', 'label', 'source', 'url', 'language', 'subsource', 'id', '__index_level_0__'],\n",
41
+ " num_rows: 1219\n",
42
+ " })\n",
43
+ "})"
44
+ ]
45
+ },
46
+ "execution_count": 1,
47
+ "metadata": {},
48
+ "output_type": "execute_result"
49
+ }
50
+ ],
51
+ "source": [
52
+ "from codecarbon import EmissionsTracker\n",
53
+ "import huggingface_hub\n",
54
+ "from fastapi import APIRouter\n",
55
+ "from datetime import datetime\n",
56
+ "from datasets import load_dataset\n",
57
+ "from sklearn.metrics import accuracy_score\n",
58
+ "import pandas as pd\n",
59
+ "from tqdm import tqdm\n",
60
+ "from sklearn.model_selection import train_test_split\n",
61
+ "import tensorflow as tf\n",
62
+ "from sklearn import preprocessing, decomposition, model_selection, metrics, pipeline\n",
63
+ "from keras.layers import GlobalMaxPooling1D, Conv1D, MaxPooling1D, Flatten, Bidirectional, SpatialDropout1D\n",
64
+ "\n",
65
+ "\n",
66
+ "import sys\n",
67
+ "sys.path.append('../tasks')\n",
68
+ "\n",
69
+ "#from utils.evaluation import TextEvaluationRequest\n",
70
+ "#from utils.emissions import tracker, clean_emissions_data, get_space_info\n",
71
+ "\n",
72
+ "dataset = load_dataset(\"quotaclimat/frugalaichallenge-text-train\")\n",
73
+ "print(next(iter(dataset['train'])))\n",
74
+ " # Convert string labels to integers\n",
75
+ "LABEL_MAPPING = {\n",
76
+ " \"0_not_relevant\": 0,\n",
77
+ " \"1_not_happening\": 1,\n",
78
+ " \"2_not_human\": 2,\n",
79
+ " \"3_not_bad\": 3,\n",
80
+ " \"4_solutions_harmful_unnecessary\": 4,\n",
81
+ " \"5_science_unreliable\": 5,\n",
82
+ " \"6_proponents_biased\": 6,\n",
83
+ " \"7_fossil_fuels_needed\": 7\n",
84
+ " }\n",
85
+ "dataset = dataset.map(lambda x: {\"label\": LABEL_MAPPING[x[\"label\"]]})\n",
86
+ "dataset\n"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "metadata": {},
92
+ "source": [
93
+ "## Loading the datasets and splitting them"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 2,
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "#request = TextEvaluationRequest()\n",
103
+ "\n",
104
+ "# Load and prepare the dataset\n",
105
+ "#dataset = load_dataset(request.dataset_name)\n",
106
+ "\n",
107
+ "# Convert string labels to integers\n",
108
+ "#dataset = dataset.map(lambda x: {\"label\": LABEL_MAPPING[x[\"label\"]]})\n",
109
+ "\n",
110
+ "# Split dataset\n",
111
+ "train_test = dataset[\"train\"].train_test_split(test_size=.2, #request.test_size, \n",
112
+ " seed=42 )#request.test_seed)\n"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 3,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "train_dataset = train_test[\"train\"]\n",
122
+ "test_dataset = train_test[\"test\"]\n"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": 4,
128
+ "metadata": {},
129
+ "outputs": [
130
+ {
131
+ "name": "stderr",
132
+ "output_type": "stream",
133
+ "text": [
134
+ "[nltk_data] Downloading package stopwords to\n",
135
+ "[nltk_data] /Users/laureberti/nltk_data...\n",
136
+ "[nltk_data] Package stopwords is already up-to-date!\n",
137
+ "[nltk_data] Downloading package wordnet to\n",
138
+ "[nltk_data] /Users/laureberti/nltk_data...\n",
139
+ "[nltk_data] Package wordnet is already up-to-date!\n"
140
+ ]
141
+ },
142
+ {
143
+ "data": {
144
+ "text/html": [
145
+ "<div>\n",
146
+ "<style scoped>\n",
147
+ " .dataframe tbody tr th:only-of-type {\n",
148
+ " vertical-align: middle;\n",
149
+ " }\n",
150
+ "\n",
151
+ " .dataframe tbody tr th {\n",
152
+ " vertical-align: top;\n",
153
+ " }\n",
154
+ "\n",
155
+ " .dataframe thead th {\n",
156
+ " text-align: right;\n",
157
+ " }\n",
158
+ "</style>\n",
159
+ "<table border=\"1\" class=\"dataframe\">\n",
160
+ " <thead>\n",
161
+ " <tr style=\"text-align: right;\">\n",
162
+ " <th></th>\n",
163
+ " <th>quote</th>\n",
164
+ " <th>clean_text</th>\n",
165
+ " <th>length_clean_text</th>\n",
166
+ " </tr>\n",
167
+ " </thead>\n",
168
+ " <tbody>\n",
169
+ " <tr>\n",
170
+ " <th>0</th>\n",
171
+ " <td>Americans for Tax Reform opposes a carbon tax ...</td>\n",
172
+ " <td>american tax reform oppose carbon tax work tir...</td>\n",
173
+ " <td>79</td>\n",
174
+ " </tr>\n",
175
+ " <tr>\n",
176
+ " <th>1</th>\n",
177
+ " <td>More than 100 climate models over the past 30 ...</td>\n",
178
+ " <td>100 climate model past 30 year predict actuall...</td>\n",
179
+ " <td>152</td>\n",
180
+ " </tr>\n",
181
+ " <tr>\n",
182
+ " <th>2</th>\n",
183
+ " <td>As an oil and gas operator who has been in the...</td>\n",
184
+ " <td>oil gas operator ha industry 30 year im fortun...</td>\n",
185
+ " <td>362</td>\n",
186
+ " </tr>\n",
187
+ " <tr>\n",
188
+ " <th>3</th>\n",
189
+ " <td>Climate has always changed, there've been many...</td>\n",
190
+ " <td>climate ha always change thereve many extincti...</td>\n",
191
+ " <td>141</td>\n",
192
+ " </tr>\n",
193
+ " <tr>\n",
194
+ " <th>4</th>\n",
195
+ " <td>People have made a mistake. They’ve started to...</td>\n",
196
+ " <td>people make mistake theyve start believe human...</td>\n",
197
+ " <td>118</td>\n",
198
+ " </tr>\n",
199
+ " </tbody>\n",
200
+ "</table>\n",
201
+ "</div>"
202
+ ],
203
+ "text/plain": [
204
+ " quote \\\n",
205
+ "0 Americans for Tax Reform opposes a carbon tax ... \n",
206
+ "1 More than 100 climate models over the past 30 ... \n",
207
+ "2 As an oil and gas operator who has been in the... \n",
208
+ "3 Climate has always changed, there've been many... \n",
209
+ "4 People have made a mistake. They’ve started to... \n",
210
+ "\n",
211
+ " clean_text length_clean_text \n",
212
+ "0 american tax reform oppose carbon tax work tir... 79 \n",
213
+ "1 100 climate model past 30 year predict actuall... 152 \n",
214
+ "2 oil gas operator ha industry 30 year im fortun... 362 \n",
215
+ "3 climate ha always change thereve many extincti... 141 \n",
216
+ "4 people make mistake theyve start believe human... 118 "
217
+ ]
218
+ },
219
+ "execution_count": 4,
220
+ "metadata": {},
221
+ "output_type": "execute_result"
222
+ }
223
+ ],
224
+ "source": [
225
+ "import nltk\n",
226
+ "nltk.download('stopwords')\n",
227
+ "nltk.download('wordnet')\n",
228
+ "\n",
229
+ "import re\n",
230
+ "from nltk.stem import WordNetLemmatizer\n",
231
+ "from nltk.corpus import stopwords\n",
232
+ "\n",
233
+ "stop_words = set(stopwords.words(\"english\")) \n",
234
+ "lemmatizer = WordNetLemmatizer()\n",
235
+ "\n",
236
+ "\n",
237
+ "def clean_text(text):\n",
238
+ " text = re.sub(r'[^\\w\\s]','',text, re.UNICODE)\n",
239
+ " text = text.lower()\n",
240
+ " text = [lemmatizer.lemmatize(token) for token in text.split(\" \")]\n",
241
+ " text = [lemmatizer.lemmatize(token, \"v\") for token in text]\n",
242
+ " text = [word for word in text if not word in stop_words]\n",
243
+ " text = \" \".join(text)\n",
244
+ " return text\n",
245
+ "\n",
246
+ "train_df= pd.DataFrame(train_dataset[\"quote\"], columns=['quote']) \n",
247
+ "train_df['clean_text'] = train_df.map(clean_text) \n",
248
+ "train_df['length_clean_text'] = train_df['clean_text'].map(len)\n",
249
+ "\n",
250
+ "train_df.head()\n"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 5,
256
+ "metadata": {},
257
+ "outputs": [
258
+ {
259
+ "data": {
260
+ "text/html": [
261
+ "<div>\n",
262
+ "<style scoped>\n",
263
+ " .dataframe tbody tr th:only-of-type {\n",
264
+ " vertical-align: middle;\n",
265
+ " }\n",
266
+ "\n",
267
+ " .dataframe tbody tr th {\n",
268
+ " vertical-align: top;\n",
269
+ " }\n",
270
+ "\n",
271
+ " .dataframe thead th {\n",
272
+ " text-align: right;\n",
273
+ " }\n",
274
+ "</style>\n",
275
+ "<table border=\"1\" class=\"dataframe\">\n",
276
+ " <thead>\n",
277
+ " <tr style=\"text-align: right;\">\n",
278
+ " <th></th>\n",
279
+ " <th>quote</th>\n",
280
+ " <th>clean_text</th>\n",
281
+ " <th>length_clean_text</th>\n",
282
+ " </tr>\n",
283
+ " </thead>\n",
284
+ " <tbody>\n",
285
+ " <tr>\n",
286
+ " <th>0</th>\n",
287
+ " <td>The term climate change was hijacked by β€œprogr...</td>\n",
288
+ " <td>term climate change wa hijack progressive term...</td>\n",
289
+ " <td>76</td>\n",
290
+ " </tr>\n",
291
+ " <tr>\n",
292
+ " <th>1</th>\n",
293
+ " <td>Climate change is a scam.Banks and Home Owner'...</td>\n",
294
+ " <td>climate change scambanks home owner insurance ...</td>\n",
295
+ " <td>82</td>\n",
296
+ " </tr>\n",
297
+ " <tr>\n",
298
+ " <th>2</th>\n",
299
+ " <td>Against the half-trillion in benefits you can ...</td>\n",
300
+ " <td>halftrillion benefit weigh global warm impact ...</td>\n",
301
+ " <td>337</td>\n",
302
+ " </tr>\n",
303
+ " <tr>\n",
304
+ " <th>3</th>\n",
305
+ " <td>Do you agree with the vast majority of climate...</td>\n",
306
+ " <td>agree vast majority climate scientist climate ...</td>\n",
307
+ " <td>59</td>\n",
308
+ " </tr>\n",
309
+ " <tr>\n",
310
+ " <th>4</th>\n",
311
+ " <td>Global warming and climate change, even if it ...</td>\n",
312
+ " <td>global warm climate change even 100 cause huma...</td>\n",
313
+ " <td>165</td>\n",
314
+ " </tr>\n",
315
+ " </tbody>\n",
316
+ "</table>\n",
317
+ "</div>"
318
+ ],
319
+ "text/plain": [
320
+ " quote \\\n",
321
+ "0 The term climate change was hijacked by β€œprogr... \n",
322
+ "1 Climate change is a scam.Banks and Home Owner'... \n",
323
+ "2 Against the half-trillion in benefits you can ... \n",
324
+ "3 Do you agree with the vast majority of climate... \n",
325
+ "4 Global warming and climate change, even if it ... \n",
326
+ "\n",
327
+ " clean_text length_clean_text \n",
328
+ "0 term climate change wa hijack progressive term... 76 \n",
329
+ "1 climate change scambanks home owner insurance ... 82 \n",
330
+ "2 halftrillion benefit weigh global warm impact ... 337 \n",
331
+ "3 agree vast majority climate scientist climate ... 59 \n",
332
+ "4 global warm climate change even 100 cause huma... 165 "
333
+ ]
334
+ },
335
+ "execution_count": 5,
336
+ "metadata": {},
337
+ "output_type": "execute_result"
338
+ }
339
+ ],
340
+ "source": [
341
+ "test_df= pd.DataFrame(test_dataset[\"quote\"], columns=['quote']) \n",
342
+ "test_df['clean_text'] = test_df.map(clean_text) \n",
343
+ "test_df['length_clean_text'] = test_df['clean_text'].map(len)\n",
344
+ "\n",
345
+ "test_df.head()"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": 6,
351
+ "metadata": {},
352
+ "outputs": [
353
+ {
354
+ "data": {
355
+ "text/plain": [
356
+ "27.92250449063382"
357
+ ]
358
+ },
359
+ "execution_count": 6,
360
+ "metadata": {},
361
+ "output_type": "execute_result"
362
+ }
363
+ ],
364
+ "source": [
365
+ "train_df['clean_text'].apply(lambda x: len(x.split(\" \"))).mean()"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": 7,
371
+ "metadata": {},
372
+ "outputs": [
373
+ {
374
+ "data": {
375
+ "text/plain": [
376
+ "27.25948717948718"
377
+ ]
378
+ },
379
+ "execution_count": 7,
380
+ "metadata": {},
381
+ "output_type": "execute_result"
382
+ }
383
+ ],
384
+ "source": [
385
+ "test_df['clean_text'].apply(lambda x: len(x.split(\" \"))).mean()"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": 32,
391
+ "metadata": {},
392
+ "outputs": [],
393
+ "source": [
394
+ "import tensorflow as tf\n",
395
+ "import tensorflow.keras as keras\n",
396
+ "from tensorflow.keras.preprocessing.text import Tokenizer\n",
397
+ "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
398
+ "from tensorflow.keras.layers import Concatenate, Dense, Input, LSTM, Embedding, Dropout, Activation, GRU, Flatten\n",
399
+ "from tensorflow.keras.layers import Bidirectional, GlobalMaxPool1D\n",
400
+ "from tensorflow.keras.models import Model, Sequential\n",
401
+ "from tensorflow.keras.layers import Convolution1D\n",
402
+ "from tensorflow.keras import initializers, regularizers, constraints, optimizers, layers\n",
403
+ "\n",
404
+ "\n",
405
+ "MAX_FEATURES = 6000\n",
406
+ "EMBED_SIZE = 28\n",
407
+ "tokenizer = Tokenizer(num_words=MAX_FEATURES)\n",
408
+ "tokenizer.fit_on_texts(train_df['clean_text'])\n",
409
+ "list_tokenized_train = tokenizer.texts_to_sequences(train_df['clean_text'])\n",
410
+ "\n",
411
+ "RNN_CELL_SIZE = 32\n",
412
+ "\n",
413
+ "MAX_LEN = 30 \n",
414
+ "\n",
415
+ "X_train = pad_sequences(list_tokenized_train, maxlen=MAX_LEN)\n"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": 33,
421
+ "metadata": {},
422
+ "outputs": [],
423
+ "source": [
424
+ "true_labels = test_dataset[\"label\"]\n",
425
+ "y_train = train_dataset[\"label\"]\n",
426
+ "y_test = test_dataset[\"label\"]"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "code",
431
+ "execution_count": 34,
432
+ "metadata": {},
433
+ "outputs": [],
434
+ "source": [
435
+ "class Attention(tf.keras.Model):\n",
436
+ " def __init__(self, units):\n",
437
+ " super(Attention, self).__init__()\n",
438
+ " self.W1 = tf.keras.layers.Dense(units)\n",
439
+ " self.W2 = tf.keras.layers.Dense(units)\n",
440
+ " self.V = tf.keras.layers.Dense(1)\n",
441
+ " \n",
442
+ " def call(self, features, hidden):\n",
443
+ " # hidden shape == (batch_size, hidden size)\n",
444
+ " # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n",
445
+ " # we are doing this to perform addition to calculate the score\n",
446
+ " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n",
447
+ "\n",
448
+ " # score shape == (batch_size, max_length, 1)\n",
449
+ " # we get 1 at the last axis because we are applying score to self.V\n",
450
+ " # the shape of the tensor before applying self.V is (batch_size, max_length, units)\n",
451
+ " score = tf.nn.tanh(\n",
452
+ " self.W1(features) + self.W2(hidden_with_time_axis))\n",
453
+ " \n",
454
+ " # attention_weights shape == (batch_size, max_length, 1)\n",
455
+ " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n",
456
+ "\n",
457
+ " # context_vector shape after sum == (batch_size, hidden_size)\n",
458
+ " context_vector = attention_weights * features\n",
459
+ " context_vector = tf.reduce_sum(context_vector, axis=1)\n",
460
+ " \n",
461
+ " return context_vector, attention_weights"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": 35,
467
+ "metadata": {},
468
+ "outputs": [],
469
+ "source": [
470
+ "sequence_input = Input(shape=(MAX_LEN,), dtype=\"int32\")\n",
471
+ "embedded_sequences = Embedding(MAX_FEATURES, EMBED_SIZE)(sequence_input)"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": 36,
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": [
480
+ "lstm = Bidirectional(LSTM(RNN_CELL_SIZE, return_sequences = True), name=\"bi_lstm_0\")(embedded_sequences)\n",
481
+ "\n",
482
+ "# Getting our LSTM outputs\n",
483
+ "(lstm, forward_h, forward_c, backward_h, backward_c) = Bidirectional(LSTM(RNN_CELL_SIZE, return_sequences=True, return_state=True), name=\"bi_lstm_1\")(lstm)"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": 37,
489
+ "metadata": {},
490
+ "outputs": [],
491
+ "source": [
492
+ "state_h = Concatenate()([forward_h, backward_h])\n",
493
+ "state_c = Concatenate()([forward_c, backward_c])\n",
494
+ "\n",
495
+ "context_vector, attention_weights = Attention(10)(lstm, state_h)\n",
496
+ "\n",
497
+ "# Removal of the globalMaxPool1D could be trouble\n",
498
+ "#globmax = GlobalMaxPool1D()(context_vector)\n",
499
+ "dense1 = Dense(20, activation=\"relu\")(context_vector)\n",
500
+ "dropout = Dropout(0.05)(dense1)\n",
501
+ "output = Dense(8, activation=\"sigmoid\")(dropout)\n",
502
+ "\n",
503
+ "model = keras.Model(inputs=sequence_input, outputs=output)"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": 38,
509
+ "metadata": {},
510
+ "outputs": [
511
+ {
512
+ "data": {
513
+ "text/html": [
514
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_1\"</span>\n",
515
+ "</pre>\n"
516
+ ],
517
+ "text/plain": [
518
+ "\u001b[1mModel: \"functional_1\"\u001b[0m\n"
519
+ ]
520
+ },
521
+ "metadata": {},
522
+ "output_type": "display_data"
523
+ },
524
+ {
525
+ "data": {
526
+ "text/html": [
527
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
528
+ "┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
529
+ "┑━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
530
+ "β”‚ input_layer_1 β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">30</span>) β”‚ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> β”‚ - β”‚\n",
531
+ "β”‚ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) β”‚ β”‚ β”‚ β”‚\n",
532
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
533
+ "β”‚ embedding_1 β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">30</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>) β”‚ <span style=\"color: #00af00; text-decoration-color: #00af00\">168,000</span> β”‚ input_layer_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… β”‚\n",
534
+ "β”‚ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) β”‚ β”‚ β”‚ β”‚\n",
535
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
536
+ "β”‚ bi_lstm_0 β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">30</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) β”‚ <span style=\"color: #00af00; text-decoration-color: #00af00\">15,616</span> β”‚ embedding_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] β”‚\n",
537
+ "β”‚ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) β”‚ β”‚ β”‚ β”‚\n",
538
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
539
+ "β”‚ bi_lstm_1 β”‚ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">30</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), β”‚ <span style=\"color: #00af00; text-decoration-color: #00af00\">24,832</span> β”‚ bi_lstm_0[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] β”‚\n",
540
+ "β”‚ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>), β”‚ β”‚ β”‚\n",
541
+ "β”‚ β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>), β”‚ β”‚ β”‚\n",
542
+ "β”‚ β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>), β”‚ β”‚ β”‚\n",
543
+ "β”‚ β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>)] β”‚ β”‚ β”‚\n",
544
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
545
+ "β”‚ concatenate_2 β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) β”‚ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> β”‚ bi_lstm_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>], β”‚\n",
546
+ "β”‚ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) β”‚ β”‚ β”‚ bi_lstm_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>] β”‚\n",
547
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
548
+ "β”‚ attention_1 β”‚ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), β”‚ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,311</span> β”‚ bi_lstm_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], β”‚\n",
549
+ "β”‚ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Attention</span>) β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">30</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>)] β”‚ β”‚ concatenate_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… β”‚\n",
550
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
551
+ "β”‚ dense_8 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">20</span>) β”‚ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,300</span> β”‚ attention_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] β”‚\n",
552
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
553
+ "β”‚ dropout_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">20</span>) β”‚ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> β”‚ dense_8[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] β”‚\n",
554
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
555
+ "β”‚ dense_9 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) β”‚ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>) β”‚ <span style=\"color: #00af00; text-decoration-color: #00af00\">168</span> β”‚ dropout_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] β”‚\n",
556
+ "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n",
557
+ "</pre>\n"
558
+ ],
559
+ "text/plain": [
560
+ "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
561
+ "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
562
+ "┑━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
563
+ "β”‚ input_layer_1 β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m30\u001b[0m) β”‚ \u001b[38;5;34m0\u001b[0m β”‚ - β”‚\n",
564
+ "β”‚ (\u001b[38;5;33mInputLayer\u001b[0m) β”‚ β”‚ β”‚ β”‚\n",
565
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
566
+ "β”‚ embedding_1 β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m30\u001b[0m, \u001b[38;5;34m28\u001b[0m) β”‚ \u001b[38;5;34m168,000\u001b[0m β”‚ input_layer_1[\u001b[38;5;34m0\u001b[0m]… β”‚\n",
567
+ "β”‚ (\u001b[38;5;33mEmbedding\u001b[0m) β”‚ β”‚ β”‚ β”‚\n",
568
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
569
+ "β”‚ bi_lstm_0 β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m30\u001b[0m, \u001b[38;5;34m64\u001b[0m) β”‚ \u001b[38;5;34m15,616\u001b[0m β”‚ embedding_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β”‚\n",
570
+ "β”‚ (\u001b[38;5;33mBidirectional\u001b[0m) β”‚ β”‚ β”‚ β”‚\n",
571
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
572
+ "β”‚ bi_lstm_1 β”‚ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m30\u001b[0m, \u001b[38;5;34m64\u001b[0m), β”‚ \u001b[38;5;34m24,832\u001b[0m β”‚ bi_lstm_0[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β”‚\n",
573
+ "β”‚ (\u001b[38;5;33mBidirectional\u001b[0m) β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m), β”‚ β”‚ β”‚\n",
574
+ "β”‚ β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m), β”‚ β”‚ β”‚\n",
575
+ "β”‚ β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m), β”‚ β”‚ β”‚\n",
576
+ "β”‚ β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m)] β”‚ β”‚ β”‚\n",
577
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
578
+ "β”‚ concatenate_2 β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) β”‚ \u001b[38;5;34m0\u001b[0m β”‚ bi_lstm_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m], β”‚\n",
579
+ "β”‚ (\u001b[38;5;33mConcatenate\u001b[0m) β”‚ β”‚ β”‚ bi_lstm_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m3\u001b[0m] β”‚\n",
580
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
581
+ "β”‚ attention_1 β”‚ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), β”‚ \u001b[38;5;34m1,311\u001b[0m β”‚ bi_lstm_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], β”‚\n",
582
+ "β”‚ (\u001b[38;5;33mAttention\u001b[0m) β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m30\u001b[0m, \u001b[38;5;34m1\u001b[0m)] β”‚ β”‚ concatenate_2[\u001b[38;5;34m0\u001b[0m]… β”‚\n",
583
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
584
+ "β”‚ dense_8 (\u001b[38;5;33mDense\u001b[0m) β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m20\u001b[0m) β”‚ \u001b[38;5;34m1,300\u001b[0m β”‚ attention_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β”‚\n",
585
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
586
+ "β”‚ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m20\u001b[0m) β”‚ \u001b[38;5;34m0\u001b[0m β”‚ dense_8[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β”‚\n",
587
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
588
+ "β”‚ dense_9 (\u001b[38;5;33mDense\u001b[0m) β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m) β”‚ \u001b[38;5;34m168\u001b[0m β”‚ dropout_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β”‚\n",
589
+ "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n"
590
+ ]
591
+ },
592
+ "metadata": {},
593
+ "output_type": "display_data"
594
+ },
595
+ {
596
+ "data": {
597
+ "text/html": [
598
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">211,227</span> (825.11 KB)\n",
599
+ "</pre>\n"
600
+ ],
601
+ "text/plain": [
602
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m211,227\u001b[0m (825.11 KB)\n"
603
+ ]
604
+ },
605
+ "metadata": {},
606
+ "output_type": "display_data"
607
+ },
608
+ {
609
+ "data": {
610
+ "text/html": [
611
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">211,227</span> (825.11 KB)\n",
612
+ "</pre>\n"
613
+ ],
614
+ "text/plain": [
615
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m211,227\u001b[0m (825.11 KB)\n"
616
+ ]
617
+ },
618
+ "metadata": {},
619
+ "output_type": "display_data"
620
+ },
621
+ {
622
+ "data": {
623
+ "text/html": [
624
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
625
+ "</pre>\n"
626
+ ],
627
+ "text/plain": [
628
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
629
+ ]
630
+ },
631
+ "metadata": {},
632
+ "output_type": "display_data"
633
+ },
634
+ {
635
+ "name": "stdout",
636
+ "output_type": "stream",
637
+ "text": [
638
+ "None\n"
639
+ ]
640
+ }
641
+ ],
642
+ "source": [
643
+ "# summarize layers\n",
644
+ "print(model.summary())"
645
+ ]
646
+ },
647
+ {
648
+ "cell_type": "code",
649
+ "execution_count": 39,
650
+ "metadata": {},
651
+ "outputs": [],
652
+ "source": [
653
+ "from keras.callbacks import EarlyStopping\n",
654
+ "from keras import backend \n",
655
+ "\n",
656
+ "es = EarlyStopping(monitor='accuracy', mode='min', verbose=1, patience=5)\n",
657
+ "model.compile(loss='SparseCategoricalCrossentropy', optimizer='adam', metrics=['accuracy'])\n"
658
+ ]
659
+ },
660
+ {
661
+ "cell_type": "code",
662
+ "execution_count": 40,
663
+ "metadata": {},
664
+ "outputs": [],
665
+ "source": [
666
+ "\n",
667
+ "import numpy as np\n",
668
+ "\n",
669
+ "X_train_np = np.array(X_train)\n",
670
+ "y_train_np = np.array(y_train)"
671
+ ]
672
+ },
673
+ {
674
+ "cell_type": "code",
675
+ "execution_count": 42,
676
+ "metadata": {},
677
+ "outputs": [
678
+ {
679
+ "name": "stdout",
680
+ "output_type": "stream",
681
+ "text": [
682
+ "Epoch 1/30\n",
683
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 39ms/step - accuracy: 0.7935 - loss: 0.6349\n",
684
+ "Epoch 2/30\n",
685
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 38ms/step - accuracy: 0.8229 - loss: 0.5661\n",
686
+ "Epoch 3/30\n",
687
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 44ms/step - accuracy: 0.8691 - loss: 0.4346\n",
688
+ "Epoch 4/30\n",
689
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 39ms/step - accuracy: 0.8974 - loss: 0.3836\n",
690
+ "Epoch 5/30\n",
691
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 52ms/step - accuracy: 0.9059 - loss: 0.3363\n",
692
+ "Epoch 6/30\n",
693
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 55ms/step - accuracy: 0.9146 - loss: 0.2993\n",
694
+ "Epoch 7/30\n",
695
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 54ms/step - accuracy: 0.9364 - loss: 0.2439\n",
696
+ "Epoch 8/30\n",
697
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 48ms/step - accuracy: 0.9365 - loss: 0.2423\n",
698
+ "Epoch 9/30\n",
699
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 40ms/step - accuracy: 0.9464 - loss: 0.1978\n",
700
+ "Epoch 10/30\n",
701
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 39ms/step - accuracy: 0.9516 - loss: 0.1880\n",
702
+ "Epoch 11/30\n",
703
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 49ms/step - accuracy: 0.9478 - loss: 0.1854\n",
704
+ "Epoch 12/30\n",
705
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 59ms/step - accuracy: 0.9545 - loss: 0.1586\n",
706
+ "Epoch 13/30\n",
707
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 59ms/step - accuracy: 0.9563 - loss: 0.1485\n",
708
+ "Epoch 14/30\n",
709
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 61ms/step - accuracy: 0.9598 - loss: 0.1378\n",
710
+ "Epoch 15/30\n",
711
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 52ms/step - accuracy: 0.9575 - loss: 0.1429\n",
712
+ "Epoch 16/30\n",
713
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 60ms/step - accuracy: 0.9576 - loss: 0.1285\n",
714
+ "Epoch 17/30\n",
715
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 53ms/step - accuracy: 0.9585 - loss: 0.1384\n",
716
+ "Epoch 18/30\n",
717
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 45ms/step - accuracy: 0.9597 - loss: 0.1333\n",
718
+ "Epoch 19/30\n",
719
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 51ms/step - accuracy: 0.9671 - loss: 0.1189\n",
720
+ "Epoch 20/30\n",
721
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 52ms/step - accuracy: 0.9709 - loss: 0.1102\n",
722
+ "Epoch 21/30\n",
723
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 58ms/step - accuracy: 0.9691 - loss: 0.1136\n",
724
+ "Epoch 22/30\n",
725
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 59ms/step - accuracy: 0.9774 - loss: 0.0918\n",
726
+ "Epoch 23/30\n",
727
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 63ms/step - accuracy: 0.9777 - loss: 0.0876\n",
728
+ "Epoch 24/30\n",
729
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 59ms/step - accuracy: 0.9841 - loss: 0.0615\n",
730
+ "Epoch 25/30\n",
731
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 43ms/step - accuracy: 0.9781 - loss: 0.0804\n",
732
+ "Epoch 26/30\n",
733
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 43ms/step - accuracy: 0.9724 - loss: 0.0936\n",
734
+ "Epoch 27/30\n",
735
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 42ms/step - accuracy: 0.9711 - loss: 0.1026\n",
736
+ "Epoch 28/30\n",
737
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 44ms/step - accuracy: 0.9728 - loss: 0.0933\n",
738
+ "Epoch 29/30\n",
739
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 49ms/step - accuracy: 0.9771 - loss: 0.0772\n",
740
+ "Epoch 30/30\n",
741
+ "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 55ms/step - accuracy: 0.9771 - loss: 0.0940\n"
742
+ ]
743
+ }
744
+ ],
745
+ "source": [
746
+ "BATCH_SIZE = 100\n",
747
+ "EPOCHS = 30\n",
748
+ "history = model.fit(X_train_np,y_train_np, shuffle=True,\n",
749
+ " batch_size=BATCH_SIZE, verbose=1,\n",
750
+ " epochs=EPOCHS)#, callbacks=[es])"
751
+ ]
752
+ },
753
+ {
754
+ "cell_type": "code",
755
+ "execution_count": 43,
756
+ "metadata": {},
757
+ "outputs": [],
758
+ "source": [
759
+ "def classifier(input_text,candidate_labels):\n",
760
+ " #PREPROCESS THE INPUT TEXT\n",
761
+ " input_text_cleaned = clean_text(input_text)\n",
762
+ " input_sequence = tokenizer.texts_to_sequences([input_text_cleaned])\n",
763
+ " input_padded = pad_sequences(input_sequence, maxlen = MAX_LEN, padding = 'post')\n",
764
+ " #PREDICTION\n",
765
+ " prediction = np.ravel(model.predict(input_padded))\n",
766
+ " return {'sequence': input_text,'labels': candidate_labels,'scores': list(prediction)}\n"
767
+ ]
768
+ },
769
+ {
770
+ "cell_type": "code",
771
+ "execution_count": 44,
772
+ "metadata": {},
773
+ "outputs": [],
774
+ "source": [
775
+ "candidate_labels = [\n",
776
+ " \"Not related to climate change disinformation\",\n",
777
+ " \"Climate change is not real and not happening\",\n",
778
+ " \"Climate change is not human-induced\",\n",
779
+ " \"Climate change impacts are not that bad\",\n",
780
+ " \"Climate change solutions are harmful and unnecessary\",\n",
781
+ " \"Climate change science is unreliable\",\n",
782
+ " \"Climate change proponents are biased\",\n",
783
+ " \"Fossil fuels are needed to address climate change\"\n",
784
+ "]"
785
+ ]
786
+ },
787
+ {
788
+ "cell_type": "code",
789
+ "execution_count": 48,
790
+ "metadata": {},
791
+ "outputs": [
792
+ {
793
+ "data": {
794
+ "text/plain": [
795
+ "[6, 6, 4, 0, 5, 5, 2, 4, 1, 0]"
796
+ ]
797
+ },
798
+ "execution_count": 48,
799
+ "metadata": {},
800
+ "output_type": "execute_result"
801
+ }
802
+ ],
803
+ "source": [
804
+ "true_labels[:10]"
805
+ ]
806
+ },
807
+ {
808
+ "cell_type": "code",
809
+ "execution_count": 49,
810
+ "metadata": {},
811
+ "outputs": [
812
+ {
813
+ "data": {
814
+ "text/plain": [
815
+ "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
816
+ ]
817
+ },
818
+ "execution_count": 49,
819
+ "metadata": {},
820
+ "output_type": "execute_result"
821
+ }
822
+ ],
823
+ "source": [
824
+ "predictions[:10]"
825
+ ]
826
+ },
827
+ {
828
+ "cell_type": "code",
829
+ "execution_count": null,
830
+ "metadata": {},
831
+ "outputs": [],
832
+ "source": [
833
+ "# Start tracking emissions\n",
834
+ "tracker.start()\n",
835
+ "tracker.start_task(\"inference\")\n"
836
+ ]
837
+ },
838
+ {
839
+ "cell_type": "code",
840
+ "execution_count": 46,
841
+ "metadata": {},
842
+ "outputs": [],
843
+ "source": [
844
+ "%%capture\n",
845
+ "\n",
846
+ "from tqdm.auto import tqdm\n",
847
+ "predictions = []\n",
848
+ "\n",
849
+ "for i, text in tqdm(enumerate(test_dataset[\"quote\"])):\n",
850
+ "\n",
851
+ " result = classifier(text, candidate_labels)\n",
852
+ "\n",
853
+ " # Get index of highest scoring label\n",
854
+ "\n",
855
+ " pred_label = candidate_labels.index(result[\"labels\"][0])\n",
856
+ "\n",
857
+ " predictions.append(pred_label)\n"
858
+ ]
859
+ },
860
+ {
861
+ "cell_type": "code",
862
+ "execution_count": null,
863
+ "metadata": {},
864
+ "outputs": [],
865
+ "source": [
866
+ "# Stop tracking emissions\n",
867
+ "emissions_data = tracker.stop_task()\n",
868
+ "emissions_data"
869
+ ]
870
+ },
871
+ {
872
+ "cell_type": "code",
873
+ "execution_count": 47,
874
+ "metadata": {},
875
+ "outputs": [
876
+ {
877
+ "data": {
878
+ "text/plain": [
879
+ "0.27"
880
+ ]
881
+ },
882
+ "execution_count": 47,
883
+ "metadata": {},
884
+ "output_type": "execute_result"
885
+ }
886
+ ],
887
+ "source": [
888
+ "# Calculate accuracy\n",
889
+ "accuracy = accuracy_score(true_labels[:100], predictions[:100])\n",
890
+ "accuracy"
891
+ ]
892
+ },
893
+ {
894
+ "cell_type": "code",
895
+ "execution_count": null,
896
+ "metadata": {},
897
+ "outputs": [],
898
+ "source": [
899
+ "# Prepare results dictionary\n",
900
+ "results = {\n",
901
+ " \"submission_timestamp\": datetime.now().isoformat(),\n",
902
+ " \"accuracy\": float(accuracy),\n",
903
+ " \"energy_consumed_wh\": emissions_data.energy_consumed * 1000,\n",
904
+ " \"emissions_gco2eq\": emissions_data.emissions * 1000,\n",
905
+ " \"emissions_data\": clean_emissions_data(emissions_data),\n",
906
+ " \"dataset_config\": {\n",
907
+ " \"dataset_name\": request.dataset_name,\n",
908
+ " \"test_size\": request.test_size,\n",
909
+ " \"test_seed\": request.test_seed\n",
910
+ " }\n",
911
+ "}\n",
912
+ "\n",
913
+ "results"
914
+ ]
915
+ },
916
+ {
917
+ "cell_type": "code",
918
+ "execution_count": null,
919
+ "metadata": {},
920
+ "outputs": [],
921
+ "source": []
922
+ }
923
+ ],
924
+ "metadata": {
925
+ "kernelspec": {
926
+ "display_name": "Python 3 (ipykernel)",
927
+ "language": "python",
928
+ "name": "python3"
929
+ },
930
+ "language_info": {
931
+ "codemirror_mode": {
932
+ "name": "ipython",
933
+ "version": 3
934
+ },
935
+ "file_extension": ".py",
936
+ "mimetype": "text/x-python",
937
+ "name": "python",
938
+ "nbconvert_exporter": "python",
939
+ "pygments_lexer": "ipython3",
940
+ "version": "3.12.8"
941
+ }
942
+ },
943
+ "nbformat": 4,
944
+ "nbformat_minor": 4
945
+ }