{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "bert.ipynb", "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "KGXmWGTuLrGp" }, "source": [ "# BERT" ] }, { "cell_type": "markdown", "metadata": { "id": "mNjLjTRRLyK4" }, "source": [ "[BERT](https://arxiv.org/pdf/1810.04805.pdf) is a Transformer encoder pretrained on a large quantity of unlabeled text by masked language modeling (MLM). The motivation for pretraining by MLM, rather than conventional left-to-right language modeling, is that the model \"sees\" the entire input text rather than being limited to only the left side, so that when we take the pretrained model and finetune it for some classification task, there's less discrepancy between pretraining and finetuning in input. \n", "\n", "
\n", "\n", "
\n", "\n", "Pretraining an NLP model on unlabeled text and transferring the model to various downstream tasks is certainly not new. There are many successful and influential earlier works like [word2vec](https://arxiv.org/pdf/1310.4546.pdf), which only trains non-contextual word embeddings, [ELMo](https://arxiv.org/pdf/1802.05365.pdf), which trains LSTMs by bidirectional language modeling, and [GPT-1](https://www.cs.ubc.ca/~amuham01/LING530/papers/radford2018improving.pdf), which trains a Transformer decoder by left-to-right language modeling. All of these were quite successful.\n", "\n", "BERT, and all these large-scale pretrained models, improve not just a single task but a whole array of language understanding tasks, introducing the notion of **general language understanding** capabilities. It's not so surprising in retrospect that language modeling is key to general language understanding, if we reflect on how we learn language ourselves. " ] }, { "cell_type": "markdown", "metadata": { "id": "BJwNaswNKbsl" }, "source": [ "# Data: SST-2 Revisited" ] }, { "cell_type": "markdown", "metadata": { "id": "A05adnn9KTa_" }, "source": [ "First, import some useful packages and set everything up." ] }, { "cell_type": "code", "metadata": { "id": "mnV6Qmu4pPAC" }, "source": [ "%matplotlib inline\n", "!pip install transformers \n", "\n", "import copy\n", "import math\n", "import numpy as np\n", "import os\n", "import random\n", "import torch\n", "import torch.nn as nn\n", "\n", "from torch.utils.data import Dataset, DataLoader\n", "from transformers import BertModel, BertTokenizer\n", "from transformers import AdamW, get_linear_schedule_with_warmup\n", "\n", "def set_seed(seed): # For reproducibility, fix random seeds.\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "-0Pg_BcJKhpk" }, "source": [ "We will again use the binary sentiment classification task in the SST-2 dataset prepared by the [GLUE Benchmark](https://gluebenchmark.com/). If you need to download the dataset again here is a [download link](https://dl.fbaipublicfiles.com/glue/data/SST-2.zip). We will assume that we have the directory data/SST-2/ in our Google Drive account. " ] }, { "cell_type": "code", "metadata": { "id": "UiWakwWapf3z" }, "source": [ "# Load the Drive helper and mount. You will have to authorize this operation. \n", "from google.colab import drive\n", "drive.mount('/content/drive')" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "OJ9e2U2Spsy5" }, "source": [ "def read_sst2_file(path):\n", " examples = []\n", " with open(path) as f:\n", " f.readline() \n", " for line in f:\n", " sent, label = line.split('\\t')\n", " examples.append([sent, int(label)])\n", " return examples\n", "\n", "class SST2Dataset(Dataset):\n", " def __init__(self, filename):\n", " path = os.path.join('/content/drive/My Drive/data/SST-2/', filename)\n", " self.examples = read_sst2_file(path)\n", "\n", " def __len__(self):\n", " return len(self.examples)\n", "\n", " def __getitem__(self, index):\n", " return self.examples[index]\n", "\n", "dataset_train = SST2Dataset('train.tsv')\n", "dataset_val = SST2Dataset('dev.tsv')\n", "print('{} train sents, {} val sents'.format(len(dataset_train), len(dataset_val)))\n", "print(dataset_train[42])" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "9FTe8yJON3zF" }, "source": [ "# Tokenization " ] }, { "cell_type": "markdown", "metadata": { "id": "dahqHafmREYD" }, "source": [ "Since we're taking a pretrained BERT, we must use the same tokenization that it was pretrained with when we finetune it. [Hugging Face](https://huggingface.co/transformers/model_doc/bert.html) has made importing and using BERT trivial, and it provides the associated tokenizer along pre-built text preprocessing like padding. There are many versions of BERT (uncased vs cased, model size, etc.). We'll use [bert-base-uncased](https://huggingface.co/bert-base-uncased)." ] }, { "cell_type": "code", "metadata": { "id": "K46rxBhMtEFX" }, "source": [ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", "print('The pretrained tokenizer in bert-base-uncased has vocab size {:d}\\n'.format(tokenizer.vocab_size))\n", "\n", "# We need to provide a custom collate function to DataLoader because we're handling sentences of varying lengths. \n", "def collate_fn(batch):\n", " # https://huggingface.co/transformers/preprocessing.html\n", " sents, labels = zip(*batch)\n", " labels = torch.FloatTensor(labels)\n", " encoded = tokenizer(sents, padding=True, return_tensors='pt')\n", " return encoded['input_ids'], encoded['attention_mask'], labels\n", "\n", "set_seed(42)\n", "dataloader_val = DataLoader(dataset_val, batch_size=2, shuffle=False, num_workers=2, collate_fn=collate_fn)\n", "\n", "for i, (input_ids, attention_mask, labels) in enumerate(dataloader_val):\n", " if i == 0:\n", " print(input_ids)\n", " print(tokenizer.convert_ids_to_tokens(input_ids[0]))\n", " print(tokenizer.convert_ids_to_tokens(input_ids[1]))\n", " print(tokenizer.decode(input_ids[0]))\n", " print(tokenizer.decode(input_ids[1]))\n", " print(attention_mask)\n", " print(labels)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "a8zegTIfeObX" }, "source": [ "Note that the tokenizer has special symbols used for pretraining on sentence pairs. [SEP] is used as a sentence separator, [PAD] is used for padding, and [CLS] is used at the beginning of input to predict whether the pair is consecutive or not. We will be using the final embedding of [CLS] for finetuning. " ] }, { "cell_type": "markdown", "metadata": { "id": "S-FdBSH2e-oB" }, "source": [ "# Model" ] }, { "cell_type": "markdown", "metadata": { "id": "lzHg_XugfCtb" }, "source": [ "Our sentiment classifier will import BERT as an encoder and have 1 extra score layer to convert [CLS] embedding into a logit. Whenever we import a pretrained model it's important that we make anything we introduce be consistent with the pretrained model, in particular **initialization**. So we'll hunt down how BERT was initialized and apply the same initialization scheme to the extra layer that we introduce. " ] }, { "cell_type": "code", "metadata": { "id": "jE78irrhu227" }, "source": [ "def get_init_transformer(transformer):\n", " \"\"\"\n", " Initialization scheme used for transformers:\n", " https://huggingface.co/transformers/_modules/transformers/modeling_bert.html\n", " \"\"\"\n", " def init_transformer(module):\n", " if isinstance(module, (nn.Linear, nn.Embedding)):\n", " module.weight.data.normal_(mean=0.0, std=transformer.config.initializer_range)\n", " elif isinstance(module, nn.LayerNorm):\n", " module.bias.data.zero_()\n", " module.weight.data.fill_(1.0)\n", " if isinstance(module, nn.Linear) and module.bias is not None:\n", " module.bias.data.zero_()\n", "\n", " return init_transformer\n", "\n", "\n", "class BertClassifier(nn.Module):\n", "\n", " def __init__(self, drop=0.1):\n", " super().__init__()\n", " self.encoder = BertModel.from_pretrained('bert-base-uncased')\n", " self.score = nn.Sequential(nn.Dropout(drop), \n", " nn.Linear(self.encoder.config.hidden_size, 1))\n", " self.score.apply(get_init_transformer(self.encoder)) # Important to initialize any additional weights the same way as pretrained encoder. \n", " self.loss = nn.BCEWithLogitsLoss(reduction='sum') \n", "\n", " def forward(self, input_ids, attention_mask, labels):\n", " hiddens_last = self.encoder(input_ids, attention_mask=attention_mask)[0] # (batch_size, length, dim), these are last layer embeddings\n", " embs = hiddens_last[:,0,:] # [CLS] token embeddings\n", " logits = self.score(embs).squeeze(1) # batch_size\n", " loss_total = self.loss(logits, labels)\n", " return logits, loss_total" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "bqaCrlJAgSii" }, "source": [ "def count_params(model):\n", " return sum(p.numel() for p in model.parameters())\n", "\n", "model = BertClassifier()\n", "print('Model has {} parameters\\n'.format(count_params(model)))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "m-K_5aXrgjGj" }, "source": [ "# Optimization " ] }, { "cell_type": "markdown", "metadata": { "id": "afpMVJiggl0l" }, "source": [ "Training a Transformer from scratch can be tricky and one must be quite careful about the choice of optimizer, learning rate schedule, warmup steps, and so on. But *finetuning* a pretrained Transformer is more forgiving. In fact, if we use a simple Adam with default values and no scheduling it'll still work pretty well (try it if you want). \n", "\n", "But it never seems to hurt to follow a similar optimization scheme used in pretraining. Sometimes it helps slightly. So we'll use Adam with weight decay with a linear schedule. " ] }, { "cell_type": "code", "metadata": { "id": "Y0YyBFN74K4F" }, "source": [ "def configure_optimization(model, num_train_steps, num_warmup_steps, lr, weight_decay=0.01): \n", " # Copied from: https://huggingface.co/transformers/training.html\n", " no_decay = ['bias', 'LayerNorm.weight']\n", " optimizer_grouped_parameters = [{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], \n", " 'weight_decay': weight_decay},\n", " {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n", " 'weight_decay': 0.}]\n", " optimizer = AdamW(optimizer_grouped_parameters, lr=lr) \n", " scheduler = get_linear_schedule_with_warmup(optimizer, num_training_steps=num_train_steps, num_warmup_steps=num_warmup_steps) \n", " return optimizer, scheduler" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "8r6AS1Gxh1pJ" }, "source": [ "Code for evaluation." ] }, { "cell_type": "code", "metadata": { "id": "5clksA-Lz-67" }, "source": [ "def get_acc_val(model, device):\n", " num_correct_val = 0\n", " model.eval() \n", " with torch.no_grad(): \n", " for input_ids, attention_mask, labels in dataloader_val:\n", " input_ids = input_ids.to(device)\n", " attention_mask = attention_mask.to(device)\n", " labels = labels.to(device)\n", " logits, _ = model(input_ids, attention_mask, labels) \n", " preds = torch.where(logits > 0., 1, 0) # 1 if p(1|x) > 0.5, 0 else\n", " num_correct_val += (preds == labels).sum()\n", " acc_val = num_correct_val / len(dataloader_val.dataset) * 100.\n", " return acc_val" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "YE5Ap79Mh4DZ" }, "source": [ "Finetuning is straightforward. While it's known that it has high variance (i.e., different random seeds may produce nontrivial performance differences when finetuning), it often works easily with hyperparameter values recommended by the BERT paper: \n", "\n", "- Batch size 32 (unless you're short on memory)\n", "- Dropout rate 0.1 (keep in mind that BERT encoder itself has dropout layers that will be used, in addition to our additional classification layer's dropout)\n", "- Learning rate chosen from 5e-5, 4e-5, 3e-5, and 2e-5 based on validation performance\n", "- 3 epochs. But it's known that training longer (up to 10 epochs) can improve performance further. \n", "\n", "There are other optimization-specific hyperparameters like the number of warmup steps. If you're serious about performance, you must tune these variables thoroughly. But here we'll just do minimal work and stick with reasonable-looking values." ] }, { "cell_type": "code", "metadata": { "id": "oPQ7H_Bf04Zc" }, "source": [ "def train(model, batch_size=32, num_warmup_steps=10, lr=0.00005, num_epochs=3, clip=1., verbose=True, device='cuda'):\n", " model = model.to(device) # Move the model to device. \n", " dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn) \n", " num_train_steps = len(dataset_train) // batch_size * num_epochs\n", " optimizer, scheduler = configure_optimization(model, num_train_steps, num_warmup_steps, lr)\n", "\n", " loss_avg = float('inf')\n", " acc_train = 0.\n", " best_acc_val = 0.\n", " for epoch in range(num_epochs):\n", " model.train() # This turns on the training mode (e.g., enable dropout).\n", " loss_total = 0.\n", " num_correct_train = 0\n", " for batch_ind, (input_ids, attention_mask, labels) in enumerate(dataloader_train):\n", " if (batch_ind + 1) % 400 == 0: \n", " print(batch_ind + 1, '/', len(dataloader_train), 'batches done')\n", " input_ids = input_ids.to(device).long()\n", " attention_mask = attention_mask.to(device)\n", " labels = labels.to(device) \n", " logits, loss_batch_total = model(input_ids, attention_mask, labels) \n", " preds = torch.where(logits > 0., 1, 0) # 1 if p(1|x) > 0.5, 0 else\n", " num_correct_train += (preds == labels).sum()\n", " loss_total += loss_batch_total.item() \n", " \n", " loss_batch_avg = loss_batch_total / input_ids.size(0) \n", " loss_batch_avg.backward() \n", "\n", " if clip > 0.: # Optional gradient clipping\n", " nn.utils.clip_grad_norm_(model.parameters(), clip)\n", "\n", " optimizer.step() # optimizer updates model weights based on stored gradients\n", " scheduler.step() # Update lr. \n", " optimizer.zero_grad() # Reset gradient slots to zero\n", "\n", " # Useful training information\n", " loss_avg = loss_total / len(dataloader_train.dataset)\n", " acc_train = num_correct_train / len(dataloader_train.dataset) * 100.\n", "\n", " # Check validation performance at the end of every epoch. \n", " acc_val = get_acc_val(model, device)\n", "\n", " if verbose:\n", " print('Epoch {:3d} | avg loss {:8.4f} | train acc {:2.2f} | val acc {:2.2f}'.format(epoch + 1, loss_avg, acc_train, acc_val))\n", "\n", " if acc_val > best_acc_val: \n", " best_acc_val = acc_val\n", " \n", " if verbose: \n", " print('Final avg loss {:8.4f} | final train acc {:2.2f} | best val acc {:2.2f}'.format(loss_avg, acc_train, best_acc_val))\n", "\n", " return best_acc_val" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "cFBS8N2VkINA" }, "source": [ "BERT-base is small enough to finetune on a small data using a single GPU in a reasonable amount of time. But It'll still take some time, even with 3 epochs." ] }, { "cell_type": "code", "metadata": { "id": "HqULePEElRTr" }, "source": [ "!pip install ipython-autotime # Useful library for tracking runtime on Python Notebooks\n", " \n", "%load_ext autotime" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ts-OqAng6NUl" }, "source": [ "if True: # Set True to run. \n", " set_seed(42)\n", " model = BertClassifier()\n", " best_acc_val = train(model, batch_size=32)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "XpQ2eI5uo1CP" }, "source": [ "So this is pretty cool. Recall that the best validation accuracy we could get with models trained from scratch was $80$ to $83$, even though we tried pretty hard with different architectures like CNNs and BiLSTMs. Now we're getting $>91$ accuracy. It wasn't that painful to train either since we have an okay sense of what hyperparameter values to use. Finetuning took only $15$ minutes or so because the data is small but computation becomes more challenging with bigger datasets and/or larger models. " ] } ] }