{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.14","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":9407065,"sourceType":"datasetVersion","datasetId":5711615}],"dockerImageVersionId":30762,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"# This Python 3 environment comes with many helpful analytics libraries installed\n# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n# For example, here's several helpful packages to load\n\nimport numpy as np # linear algebra\nimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n\n# Input data files are available in the read-only \"../input/\" directory\n# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n\nimport os\nfor dirname, _, filenames in os.walk('/kaggle/input'):\n for filename in filenames:\n print(os.path.join(dirname, filename))\n\n# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","jupyter":{"source_hidden":true},"execution":{"iopub.status.busy":"2024-09-16T03:32:51.651185Z","iopub.execute_input":"2024-09-16T03:32:51.651963Z","iopub.status.idle":"2024-09-16T03:32:51.660806Z","shell.execute_reply.started":"2024-09-16T03:32:51.651918Z","shell.execute_reply":"2024-09-16T03:32:51.659896Z"},"trusted":true},"execution_count":26,"outputs":[{"name":"stdout","text":"/kaggle/input/ocrsampledataner/train50000ocred.csv\n/kaggle/input/ocrsampledataner/train50000.csv\n","output_type":"stream"}]},{"cell_type":"code","source":"import pandas as pd\nimport os","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:32:51.662421Z","iopub.execute_input":"2024-09-16T03:32:51.662773Z","iopub.status.idle":"2024-09-16T03:32:51.672588Z","shell.execute_reply.started":"2024-09-16T03:32:51.662733Z","shell.execute_reply":"2024-09-16T03:32:51.671659Z"},"trusted":true},"execution_count":27,"outputs":[]},{"cell_type":"code","source":"datasetDir = '/kaggle/input/ocrsampledataner/'\ndf = pd.read_csv(os.path.join(datasetDir, 'train50000ocred.csv'))","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:32:51.674245Z","iopub.execute_input":"2024-09-16T03:32:51.674576Z","iopub.status.idle":"2024-09-16T03:32:52.497978Z","shell.execute_reply.started":"2024-09-16T03:32:51.674541Z","shell.execute_reply":"2024-09-16T03:32:52.496917Z"},"trusted":true},"execution_count":28,"outputs":[]},{"cell_type":"code","source":"df.reset_index(drop = True, inplace = True)\ndf = df.drop(columns = ['Unnamed: 0', 'image_link', 'imageName', 'ocrdata', 'result'])","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:32:52.499784Z","iopub.execute_input":"2024-09-16T03:32:52.500114Z","iopub.status.idle":"2024-09-16T03:32:52.518678Z","shell.execute_reply.started":"2024-09-16T03:32:52.500080Z","shell.execute_reply":"2024-09-16T03:32:52.517724Z"},"trusted":true},"execution_count":29,"outputs":[]},{"cell_type":"code","source":"df","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:32:52.520042Z","iopub.execute_input":"2024-09-16T03:32:52.520405Z","iopub.status.idle":"2024-09-16T03:32:52.536535Z","shell.execute_reply.started":"2024-09-16T03:32:52.520355Z","shell.execute_reply":"2024-09-16T03:32:52.535623Z"},"trusted":true},"execution_count":30,"outputs":[{"execution_count":30,"output_type":"execute_result","data":{"text/plain":" group_id entity_name entity_value \\\n0 997176 wattage 3.0 kilowatt \n1 403664 width 9.0 centimetre \n2 681445 height 11.8 inch \n3 599772 height 20.63 inch \n4 767202 item_weight 120.0 gram \n... ... ... ... \n49995 145452 height 6.0 centimetre \n49996 664736 width 6.6 inch \n49997 952470 depth 8.5 inch \n49998 459516 item_weight 550.0 milligram \n49999 653767 depth 45.0 centimetre \n\n cleandata \n0 ['3kW', '8'] \n1 ['34 cm', '9cm'] \n2 ['19.5cm(7.6in)', '30cm (11.8in)'] \n3 ['615mm/24.21in', '459mm/18.07in', '250mm/9.84... \n4 ['20Gm', '120Gms', '120Gm', '120Gm', '120Gm', ... \n... ... \n49995 ['2.36inch/6.0cm', '.77inch/4.5cm', '2.75inch/... \n49996 ['15.7\"', '2.2', '6.6\"', 'CO2'] \n49997 ['8.5\"', '15\"', '12.5\"', '8.7\"', '6.2\"'] \n49998 ['ServingSize:2 Capsules/Servings Per Containe... \n49999 ['45cm/18in', '45cm/18in', '45cm/18in', '45cm/... \n\n[50000 rows x 4 columns]","text/html":"
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
group_identity_nameentity_valuecleandata
0997176wattage3.0 kilowatt['3kW', '8']
1403664width9.0 centimetre['34 cm', '9cm']
2681445height11.8 inch['19.5cm(7.6in)', '30cm (11.8in)']
3599772height20.63 inch['615mm/24.21in', '459mm/18.07in', '250mm/9.84...
4767202item_weight120.0 gram['20Gm', '120Gms', '120Gm', '120Gm', '120Gm', ...
...............
49995145452height6.0 centimetre['2.36inch/6.0cm', '.77inch/4.5cm', '2.75inch/...
49996664736width6.6 inch['15.7\"', '2.2', '6.6\"', 'CO2']
49997952470depth8.5 inch['8.5\"', '15\"', '12.5\"', '8.7\"', '6.2\"']
49998459516item_weight550.0 milligram['ServingSize:2 Capsules/Servings Per Containe...
49999653767depth45.0 centimetre['45cm/18in', '45cm/18in', '45cm/18in', '45cm/...
\n

50000 rows × 4 columns

\n
"},"metadata":{}}]},{"cell_type":"code","source":"import re\n\n# Abbreviation mapping\nabbreviation_map = {\n 'cm': 'centimetre',\n 'mm': 'millimetre',\n 'm': 'metre',\n 'in': 'inch',\n 'ft': 'foot',\n 'yd': 'yard',\n 'g': 'gram',\n 'kg': 'kilogram',\n 'mg': 'milligram',\n 'µg': 'microgram',\n 'lb': 'pound',\n 'oz': 'ounce',\n 't': 'ton',\n 'ml': 'millilitre',\n 'l': 'litre',\n 'cl': 'centilitre',\n 'dl': 'decilitre',\n 'fl oz': 'fluid ounce',\n 'gal': 'gallon',\n 'pt': 'pint',\n 'qt': 'quart',\n 'cu ft': 'cubic foot',\n 'cu in': 'cubic inch',\n 'v': 'volt',\n 'kv': 'kilovolt',\n 'mv': 'millivolt',\n 'w': 'watt',\n 'kw': 'kilowatt'\n}\n\n# Unit and entity_name mapping (your map)\nunit_map = {\n 'width': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},\n 'depth': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},\n 'height': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},\n 'item_weight': {'gram', 'kilogram', 'microgram', 'milligram', 'ounce', 'pound', 'ton'},\n 'maximum_weight_recommendation': {'gram', 'kilogram', 'microgram', 'milligram', 'ounce', 'pound', 'ton'},\n 'voltage': {'kilovolt', 'millivolt', 'volt'},\n 'wattage': {'kilowatt', 'watt'},\n 'item_volume': {'centilitre', 'cubic foot', 'cubic inch', 'cup', 'decilitre', 'fluid ounce', 'gallon',\n 'imperial gallon', 'litre', 'microlitre', 'millilitre', 'pint', 'quart'}\n}\n\n# Function to preprocess a single string (helper)\ndef preprocess_string(text, entity_type):\n # Add a space between a number and an abbreviation (e.g., 65cm to 65 cm)\n text = re.sub(r'(\\d+)([a-zA-Z]+)', r'\\1 \\2', text)\n\n # Replace abbreviations with full unit names\n for abbr, full in abbreviation_map.items():\n text = re.sub(fr'\\b{abbr}\\b', full, text)\n\n # Ensure units are fully written and check the map for each entity type\n if entity_type in unit_map:\n valid_units = unit_map[entity_type]\n for unit in valid_units:\n text = re.sub(fr'\\b{unit.capitalize()}\\b', unit, text.lower()) # Standardize capitalization\n\n return text\n\n# Function to preprocess the 'cleandata' column\ndef preprocess_clean_data_row(clean_data_list, entity_type):\n # Process each string in the list individually\n if isinstance(clean_data_list, list):\n processed_list = [preprocess_string(item, entity_type) for item in clean_data_list]\n return processed_list\n else:\n # Handle cases where clean_data_list is not a list\n return [preprocess_string(clean_data_list, entity_type)]\n\n# Applying the function to the DataFrame\ndef preprocess_clean_data_df(df):\n df['cleaned_data'] = df.apply(lambda row: preprocess_clean_data_row(row['cleandata'], row['entity_name']), axis=1)\n return df\n","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:32:52.539175Z","iopub.execute_input":"2024-09-16T03:32:52.539554Z","iopub.status.idle":"2024-09-16T03:32:52.553799Z","shell.execute_reply.started":"2024-09-16T03:32:52.539510Z","shell.execute_reply":"2024-09-16T03:32:52.552879Z"},"trusted":true},"execution_count":31,"outputs":[]},{"cell_type":"code","source":"# Assuming df is your DataFrame\ndf = preprocess_clean_data_df(df)\n\n# Check the first few rows of the cleaned data\nprint(df[['cleandata', 'cleaned_data']].head())\n","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:32:52.555503Z","iopub.execute_input":"2024-09-16T03:32:52.556452Z","iopub.status.idle":"2024-09-16T03:33:03.189642Z","shell.execute_reply.started":"2024-09-16T03:32:52.556401Z","shell.execute_reply":"2024-09-16T03:33:03.188513Z"},"trusted":true},"execution_count":32,"outputs":[{"name":"stdout","text":" cleandata \\\n0 ['3kW', '8'] \n1 ['34 cm', '9cm'] \n2 ['19.5cm(7.6in)', '30cm (11.8in)'] \n3 ['615mm/24.21in', '459mm/18.07in', '250mm/9.84... \n4 ['20Gm', '120Gms', '120Gm', '120Gm', '120Gm', ... \n\n cleaned_data \n0 [['3 kw', '8']] \n1 [['34 centimetre', '9 centimetre']] \n2 [['19.5 centimetre(7.6 inch)', '30 centimetre ... \n3 [['615 millimetre/24.21 inch', '459 millimetre... \n4 [['20 gm', '120 gms', '120 gm', '120 gm', '120... \n","output_type":"stream"}]},{"cell_type":"code","source":"df","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:33:03.190863Z","iopub.execute_input":"2024-09-16T03:33:03.191184Z","iopub.status.idle":"2024-09-16T03:33:03.206608Z","shell.execute_reply.started":"2024-09-16T03:33:03.191150Z","shell.execute_reply":"2024-09-16T03:33:03.205747Z"},"trusted":true},"execution_count":33,"outputs":[{"execution_count":33,"output_type":"execute_result","data":{"text/plain":" group_id entity_name entity_value \\\n0 997176 wattage 3.0 kilowatt \n1 403664 width 9.0 centimetre \n2 681445 height 11.8 inch \n3 599772 height 20.63 inch \n4 767202 item_weight 120.0 gram \n... ... ... ... \n49995 145452 height 6.0 centimetre \n49996 664736 width 6.6 inch \n49997 952470 depth 8.5 inch \n49998 459516 item_weight 550.0 milligram \n49999 653767 depth 45.0 centimetre \n\n cleandata \\\n0 ['3kW', '8'] \n1 ['34 cm', '9cm'] \n2 ['19.5cm(7.6in)', '30cm (11.8in)'] \n3 ['615mm/24.21in', '459mm/18.07in', '250mm/9.84... \n4 ['20Gm', '120Gms', '120Gm', '120Gm', '120Gm', ... \n... ... \n49995 ['2.36inch/6.0cm', '.77inch/4.5cm', '2.75inch/... \n49996 ['15.7\"', '2.2', '6.6\"', 'CO2'] \n49997 ['8.5\"', '15\"', '12.5\"', '8.7\"', '6.2\"'] \n49998 ['ServingSize:2 Capsules/Servings Per Containe... \n49999 ['45cm/18in', '45cm/18in', '45cm/18in', '45cm/... \n\n cleaned_data \n0 [['3 kw', '8']] \n1 [['34 centimetre', '9 centimetre']] \n2 [['19.5 centimetre(7.6 inch)', '30 centimetre ... \n3 [['615 millimetre/24.21 inch', '459 millimetre... \n4 [['20 gm', '120 gms', '120 gm', '120 gm', '120... \n... ... \n49995 [['2.36 inch/6.0 centimetre', '.77 inch/4.5 ce... \n49996 [['15.7\"', '2.2', '6.6\"', 'co2']] \n49997 [['8.5\"', '15\"', '12.5\"', '8.7\"', '6.2\"']] \n49998 [['servingsize:2 capsules/servings per contain... \n49999 [['45 centimetre/18 inch', '45 centimetre/18 i... \n\n[50000 rows x 5 columns]","text/html":"
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
group_identity_nameentity_valuecleandatacleaned_data
0997176wattage3.0 kilowatt['3kW', '8'][['3 kw', '8']]
1403664width9.0 centimetre['34 cm', '9cm'][['34 centimetre', '9 centimetre']]
2681445height11.8 inch['19.5cm(7.6in)', '30cm (11.8in)'][['19.5 centimetre(7.6 inch)', '30 centimetre ...
3599772height20.63 inch['615mm/24.21in', '459mm/18.07in', '250mm/9.84...[['615 millimetre/24.21 inch', '459 millimetre...
4767202item_weight120.0 gram['20Gm', '120Gms', '120Gm', '120Gm', '120Gm', ...[['20 gm', '120 gms', '120 gm', '120 gm', '120...
..................
49995145452height6.0 centimetre['2.36inch/6.0cm', '.77inch/4.5cm', '2.75inch/...[['2.36 inch/6.0 centimetre', '.77 inch/4.5 ce...
49996664736width6.6 inch['15.7\"', '2.2', '6.6\"', 'CO2'][['15.7\"', '2.2', '6.6\"', 'co2']]
49997952470depth8.5 inch['8.5\"', '15\"', '12.5\"', '8.7\"', '6.2\"'][['8.5\"', '15\"', '12.5\"', '8.7\"', '6.2\"']]
49998459516item_weight550.0 milligram['ServingSize:2 Capsules/Servings Per Containe...[['servingsize:2 capsules/servings per contain...
49999653767depth45.0 centimetre['45cm/18in', '45cm/18in', '45cm/18in', '45cm/...[['45 centimetre/18 inch', '45 centimetre/18 i...
\n

50000 rows × 5 columns

\n
"},"metadata":{}}]},{"cell_type":"markdown","source":"# Model Work","metadata":{}},{"cell_type":"code","source":"from transformers import BertTokenizerFast, AutoModelForTokenClassification\nfrom sklearn.model_selection import train_test_split\nfrom torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\nimport torch","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:33:03.207958Z","iopub.execute_input":"2024-09-16T03:33:03.208308Z","iopub.status.idle":"2024-09-16T03:33:03.218049Z","shell.execute_reply.started":"2024-09-16T03:33:03.208256Z","shell.execute_reply":"2024-09-16T03:33:03.217178Z"},"trusted":true},"execution_count":34,"outputs":[]},{"cell_type":"code","source":"def addContext(row):\n return f\"For Group {row['group_id']}, the {row['entity_name']} is {row['entity_value']} from {row['cleaned_data']}.\"\n\ndef addSpecialContext(row):\n group_id_token = f'[GROUP_ID_{row[\"group_id\"]}]'\n entity_name_token = f'[{row[\"entity_name\"].upper()}]'\n\n return f\"{group_id_token} {entity_name_token} is {row['entity_value']} from {row['cleaned_data']}.\"\n\n#Preprocessing\ndf['context'] = df.apply(addContext, axis = 1)\ndf['specialContext'] = df.apply(addSpecialContext, axis = 1)\n\nuniqueGroups = df['group_id'].unique()\nuniqueEntities = df['entity_name'].unique()\n\nspecialGroups = [f'[group_id_{group_id}]' for group_id in uniqueGroups]\nspecialEntities = [f'[{entity_name}]' for entity_name in uniqueEntities]\nspecialTokens = specialGroups + specialEntities\nspecials = {'group_id': specialGroups, 'entity_name': specialEntities}\n\n#trainText, valText, trainLabels, valLabels = train_test_split(df['specialContext'], df['entity_value'], random_state = 42, test_size = 0.2)","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:33:03.219196Z","iopub.execute_input":"2024-09-16T03:33:03.219543Z","iopub.status.idle":"2024-09-16T03:33:05.926507Z","shell.execute_reply.started":"2024-09-16T03:33:03.219511Z","shell.execute_reply":"2024-09-16T03:33:05.925485Z"},"trusted":true},"execution_count":35,"outputs":[]},{"cell_type":"code","source":"tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')\ntokenizer.add_tokens(specialTokens)\n\ntokenizer.save_pretrained('/kaggle/working/nertokenizer')\n\n# tokenizedInputs = tokenizer(df['specialContext'].tolist(), padding=True, truncation=True, return_tensors='pt')","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:33:05.928212Z","iopub.execute_input":"2024-09-16T03:33:05.928653Z","iopub.status.idle":"2024-09-16T03:33:06.092617Z","shell.execute_reply.started":"2024-09-16T03:33:05.928603Z","shell.execute_reply":"2024-09-16T03:33:06.091724Z"},"trusted":true},"execution_count":36,"outputs":[{"name":"stderr","text":"/opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n warnings.warn(\n","output_type":"stream"},{"execution_count":36,"output_type":"execute_result","data":{"text/plain":"('/kaggle/working/nertokenizer/tokenizer_config.json',\n '/kaggle/working/nertokenizer/special_tokens_map.json',\n '/kaggle/working/nertokenizer/vocab.txt',\n '/kaggle/working/nertokenizer/added_tokens.json',\n '/kaggle/working/nertokenizer/tokenizer.json')"},"metadata":{}}]},{"cell_type":"code","source":"def label_tokens(contextual_data, entity_value, special_tokens):\n tokens = tokenizer.tokenize(contextual_data)\n entity_tokens = tokenizer.tokenize(entity_value)\n labels = ['O'] * len(tokens)\n\n for idx, token in enumerate(tokens):\n if token in special_tokens['group_id']:\n labels[idx] = 'B-GROUP_ID'\n elif token in special_tokens['entity_name']:\n labels[idx] = 'B-ENTITY_TYPE'\n\n for i in range(len(tokens) - len(entity_tokens) + 1):\n if tokens[i:i + len(entity_tokens)] == entity_tokens:\n labels[i] = 'B-ENTITY'\n for j in range(1, len(entity_tokens)):\n labels[i + j] = 'I-ENTITY'\n break\n \n return tokens, labels","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:33:06.095867Z","iopub.execute_input":"2024-09-16T03:33:06.096192Z","iopub.status.idle":"2024-09-16T03:33:06.103201Z","shell.execute_reply.started":"2024-09-16T03:33:06.096158Z","shell.execute_reply":"2024-09-16T03:33:06.102152Z"},"trusted":true},"execution_count":37,"outputs":[]},{"cell_type":"code","source":"df['tokens'], df['labels'] = zip(*df.apply(lambda row: label_tokens(row['specialContext'], row['entity_value'], specials), axis = 1))\nprint(df[['specialContext', 'tokens', 'labels']].head())","metadata":{"execution":{"iopub.status.busy":"2024-09-16T03:33:06.104535Z","iopub.execute_input":"2024-09-16T03:33:06.105006Z"},"trusted":true},"execution_count":null,"outputs":[{"name":"stderr","text":"Token indices sequence length is longer than the specified maximum sequence length for this model (886 > 512). Running this sequence through the model will result in indexing errors\n","output_type":"stream"}]},{"cell_type":"code","source":"def map_labels_to_int(examples):\n label_list = {'B-ENTITY': 0, 'I-ENTITY': 1, 'O': 2, 'B-GROUP_ID': 3, 'B-ENTITY_TYPE': 4}\n examples['labels'] = [[label_list[label] for label in label_seq] for label_seq in examples['labels']]\n return examples","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"label_mapping = {'B-ENTITY': 0, 'I-ENTITY': 1, 'O': 2, 'B-GROUP_ID': 3, 'B-ENTITY_TYPE': 4}","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def tokenize_and_align_labels(examples):\n tokenized_inputs = tokenizer(\n examples['tokens'],\n is_split_into_words=True,\n truncation=True,\n padding=True,\n max_length=512 # Adjust based on your data\n )\n \n labels = []\n for i, label in enumerate(examples['labels']):\n word_ids = tokenized_inputs.word_ids(batch_index=i) # Map tokens to words\n previous_word_idx = None\n label_ids = []\n for word_idx in word_ids:\n if word_idx is None:\n label_ids.append(-100) # Special tokens\n elif word_idx != previous_word_idx:\n label_ids.append(label[word_idx]) # Use the label of the word\n else:\n label_ids.append(-100) # Subword tokens get -100 label\n previous_word_idx = word_idx\n labels.append(label_ids)\n\n tokenized_inputs['labels'] = labels\n return tokenized_inputs\n","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Converting DataFrame into Hugging Face Dataset","metadata":{}},{"cell_type":"code","source":"from datasets import Dataset\n\ndataset = Dataset.from_pandas(df[['specialContext', 'tokens', 'labels']])\ndataset = dataset.map(map_labels_to_int, batched = True)\n\n# dataset = dataset.rename_column('tokens', 'tokens')\n# dataset = dataset.rename_column('labels', 'labels')\n\ntokenized_dataset = dataset.map(tokenize_and_align_labels, batched = True)\n\nprint(tokenized_dataset[0])","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"train_test_split = tokenized_dataset.train_test_split(test_size=0.2)\ntrain_dataset = train_test_split['train']\ntest_dataset = train_test_split['test']\n\nprint(f\"Training samples: {len(train_dataset)}\")\nprint(f\"Testing samples: {len(test_dataset)}\")","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"!pip install seqeval","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"labelled = list(label_mapping.keys())","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from seqeval.metrics import classification_report, f1_score, accuracy_score, precision_score, recall_score\nimport numpy as np\n\n# Define a function to align predictions with true labels\ndef align_predictions(predictions, label_ids):\n preds = np.argmax(predictions, axis=2)\n batch_size, seq_len = preds.shape\n \n # Convert predictions and labels to list format to align with token-level evaluation\n out_preds = [[] for _ in range(batch_size)]\n out_labels = [[] for _ in range(batch_size)]\n \n for i in range(batch_size):\n for j in range(seq_len):\n # Skip padding tokens (-100)\n if label_ids[i, j] != -100:\n out_preds[i].append(preds[i][j])\n out_labels[i].append(label_ids[i][j])\n \n return out_preds, out_labels\n\n# Define the compute_metrics function\ndef compute_metrics(p):\n predictions, label_ids = p\n preds, labels = align_predictions(predictions, label_ids)\n\n # Convert to the string labels, if needed\n true_labels = [[labelled[l] for l in label] for label in labels]\n true_preds = [[labelled[p] for p in pred] for pred in preds]\n\n # Use seqeval to calculate precision, recall, and F1\n precision = precision_score(true_labels, true_preds)\n recall = recall_score(true_labels, true_preds)\n f1 = f1_score(true_labels, true_preds)\n accuracy = accuracy_score(true_labels, true_preds)\n\n return {\n \"precision\": precision,\n \"recall\": recall,\n \"f1\": f1,\n \"accuracy\": accuracy\n }","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from transformers import BertForTokenClassification\n\n# Number of labels (B-ENTITY, I-ENTITY, O)\nnum_labels = len(label_mapping)\n\n# Initialize the model\nmodel = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)\n\n# Resize token embeddings to accommodate new special tokens\nmodel.resize_token_embeddings(len(tokenizer))","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from transformers import TrainingArguments\n\ntraining_args = TrainingArguments(\n output_dir='/kaggle/working/results/', # Output directory\n num_train_epochs=3, # Number of training epochs\n per_device_train_batch_size=32, # Batch size for training\n per_device_eval_batch_size=32, # Batch size for evaluation\n warmup_steps=500, # Warmup steps\n weight_decay=0.01, # Weight decay\n logging_dir='/kaggle/working/logs', # Directory for storing logs\n logging_steps=10, # Log every 10 steps\n evaluation_strategy=\"epoch\", # Evaluate at the end of each epoch\n save_total_limit=2, # Only save 2 model checkpoints\n learning_rate=3e-5, # Learning rate\n report_to=\"none\", # This disables W&B, TensorBoard, etc.\n run_name=\"local_run\", # Optional, if you want a specific run name\n)\n","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from transformers import Trainer, DataCollatorForTokenClassification\n\n# Define data collator\ndata_collator = DataCollatorForTokenClassification(tokenizer)\n\n# Initialize the Trainer\ntrainer = Trainer(\n model=model,\n args=training_args,\n train_dataset=train_dataset,\n eval_dataset=test_dataset,\n tokenizer=tokenizer,\n data_collator=data_collator,\n compute_metrics=compute_metrics,\n)","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"trainer.train()","metadata":{"trusted":true},"execution_count":null,"outputs":[]}]}