import dataclasses
import datasets
import torch
import torch.nn as nn
import tqdm
@dataclasses.dataclass
class BertConfig:
“”“Configuration for BERT model.”“”
vocab_size: int = 30522
num_layers: int = 12
hidden_size: int = 768
num_heads: int = 12
dropout_prob: float = 0.1
pad_id: int = 0
max_seq_len: int = 512
num_types: int = 2
class BertBlock(nn.Module):
“”“One transformer block in BERT.”“”
def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_size, num_heads,
dropout=dropout_prob, batch_first=True)
self.attn_norm = nn.LayerNorm(hidden_size)
self.ff_norm = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout_prob)
self.feed_forward = nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size),
nn.GELU(),
nn.Linear(4 * hidden_size, hidden_size),
)
def forward(self, x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:
# self-attention with padding mask and post-norm
attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)
x = self.attn_norm(x + attn_output)
# feed-forward with GeLU activation and post-norm
ff_output = self.feed_forward(x)
x = self.ff_norm(x + self.dropout(ff_output))
return x
class BertPooler(nn.Module):
“”“Pooler layer for BERT to process the [CLS] token output.”“”
def __init__(self, hidden_size: int):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.dense(x)
x = self.activation(x)
return x
class BertModel(nn.Module):
“”“Backbone of BERT model.”“”
def __init__(self, config: BertConfig):
super().__init__()
# embedding layers
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size,
padding_idx=config.pad_id)
self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)
self.embeddings_norm = nn.LayerNorm(config.hidden_size)
self.embeddings_dropout = nn.Dropout(config.dropout_prob)
# transformer blocks
self.blocks = nn.ModuleList([
BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)
for _ in range(config.num_layers)
])
# [CLS] pooler layer
self.pooler = BertPooler(config.hidden_size)
def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
# create attention mask for padding tokens
pad_mask = input_ids == pad_id
# convert integer tokens to embedding vectors
batch_size, seq_len = input_ids.shape
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
position_embeddings = self.position_embeddings(position_ids)
type_embeddings = self.type_embeddings(token_type_ids)
token_embeddings = self.word_embeddings(input_ids)
x = token_embeddings + type_embeddings + position_embeddings
x = self.embeddings_norm(x)
x = self.embeddings_dropout(x)
# process the sequence with transformer blocks
for block in self.blocks:
x = block(x, pad_mask)
# pool the hidden state of the `[CLS]` token
pooled_output = self.pooler(x[:, 0, :])
return x, pooled_output
class BertPretrainingModel(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.bert = BertModel(config)
self.mlm_head = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.LayerNorm(config.hidden_size),
nn.Linear(config.hidden_size, config.vocab_size),
)
self.nsp_head = nn.Linear(config.hidden_size, 2)
def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
# Process the sequence with the BERT model backbone
x, pooled_output = self.bert(input_ids, token_type_ids, pad_id)
# Predict the masked tokens for the MLM task and the classification for the NSP task
mlm_logits = self.mlm_head(x)
nsp_logits = self.nsp_head(pooled_output)
return mlm_logits, nsp_logits
# Training parameters
epochs = 10
learning_rate = 1e–4
batch_size = 32
# Load dataset and set up dataloader
dataset = datasets.Dataset.from_parquet(“wikitext-2_train_data.parquet”)
def collate_fn(batch: list[dict]):
“”“Custom collate function to handle variable-length sequences in dataset.”“”
# always at max length: tokens, segment_ids; always singleton: is_random_next
input_ids = torch.tensor([item[“tokens”] for item in batch])
token_type_ids = torch.tensor([item[“segment_ids”] for item in batch]).abs()
is_random_next = torch.tensor([item[“is_random_next”] for item in batch]).to(int)
# variable length: masked_positions, masked_labels
masked_pos = [(idx, pos) for idx, item in enumerate(batch) for pos in item[“masked_positions”]]
masked_labels = torch.tensor([label for item in batch for label in item[“masked_labels”]])
return input_ids, token_type_ids, is_random_next, masked_pos, masked_labels
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
collate_fn=collate_fn, num_workers=8)
# train the model
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
model = BertPretrainingModel(BertConfig()).to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(epochs):
pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”)
for batch in pbar:
# get batched data
input_ids, token_type_ids, is_random_next, masked_pos, masked_labels = batch
input_ids = input_ids.to(device)
token_type_ids = token_type_ids.to(device)
is_random_next = is_random_next.to(device)
masked_labels = masked_labels.to(device)
# extract output from model
mlm_logits, nsp_logits = model(input_ids, token_type_ids)
# MLM loss: masked_positions is a list of tuples of (B, S), extract the
# corresponding logits from tensor mlm_logits of shape (B, S, V)
batch_indices, token_positions = zip(*masked_pos)
mlm_logits = mlm_logits[batch_indices, token_positions]
mlm_loss = loss_fn(mlm_logits, masked_labels)
# Compute the loss for the NSP task
nsp_loss = loss_fn(nsp_logits, is_random_next)
# backward with total loss
total_loss = mlm_loss + nsp_loss
pbar.set_postfix(MLM=mlm_loss.item(), NSP=nsp_loss.item(), Total=total_loss.item())
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
scheduler.step()
pbar.update(1)
pbar.close()
# Save the model
torch.save(model.state_dict(), “bert_pretraining_model.pth”)
torch.save(model.bert.state_dict(), “bert_model.pth”)














