• About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
Tuesday, May 5, 2026
mGrowTech
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
No Result
View All Result
mGrowTech
No Result
View All Result
Home Al, Analytics and Automation

Pretrain a BERT Model from Scratch

Josh by Josh
December 1, 2025
in Al, Analytics and Automation
0
Pretrain a BERT Model from Scratch


import dataclasses

 

import datasets

import torch

import torch.nn as nn

import tqdm

 

 

@dataclasses.dataclass

class BertConfig:

    “”“Configuration for BERT model.”“”

    vocab_size: int = 30522

    num_layers: int = 12

    hidden_size: int = 768

    num_heads: int = 12

    dropout_prob: float = 0.1

    pad_id: int = 0

    max_seq_len: int = 512

    num_types: int = 2

 

 

 

class BertBlock(nn.Module):

    “”“One transformer block in BERT.”“”

    def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float):

        super().__init__()

        self.attention = nn.MultiheadAttention(hidden_size, num_heads,

                                               dropout=dropout_prob, batch_first=True)

        self.attn_norm = nn.LayerNorm(hidden_size)

        self.ff_norm = nn.LayerNorm(hidden_size)

        self.dropout = nn.Dropout(dropout_prob)

        self.feed_forward = nn.Sequential(

            nn.Linear(hidden_size, 4 * hidden_size),

            nn.GELU(),

            nn.Linear(4 * hidden_size, hidden_size),

        )

 

    def forward(self, x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:

        # self-attention with padding mask and post-norm

        attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)

        x = self.attn_norm(x + attn_output)

        # feed-forward with GeLU activation and post-norm

        ff_output = self.feed_forward(x)

        x = self.ff_norm(x + self.dropout(ff_output))

        return x

 

 

class BertPooler(nn.Module):

    “”“Pooler layer for BERT to process the [CLS] token output.”“”

    def __init__(self, hidden_size: int):

        super().__init__()

        self.dense = nn.Linear(hidden_size, hidden_size)

        self.activation = nn.Tanh()

 

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.dense(x)

        x = self.activation(x)

        return x

 

 

class BertModel(nn.Module):

    “”“Backbone of BERT model.”“”

    def __init__(self, config: BertConfig):

        super().__init__()

        # embedding layers

        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size,

                                            padding_idx=config.pad_id)

        self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)

        self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)

        self.embeddings_norm = nn.LayerNorm(config.hidden_size)

        self.embeddings_dropout = nn.Dropout(config.dropout_prob)

        # transformer blocks

        self.blocks = nn.ModuleList([

            BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)

            for _ in range(config.num_layers)

        ])

        # [CLS] pooler layer

        self.pooler = BertPooler(config.hidden_size)

 

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0

                ) -> tuple[torch.Tensor, torch.Tensor]:

        # create attention mask for padding tokens

        pad_mask = input_ids == pad_id

        # convert integer tokens to embedding vectors

        batch_size, seq_len = input_ids.shape

        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.position_embeddings(position_ids)

        type_embeddings = self.type_embeddings(token_type_ids)

        token_embeddings = self.word_embeddings(input_ids)

        x = token_embeddings + type_embeddings + position_embeddings

        x = self.embeddings_norm(x)

        x = self.embeddings_dropout(x)

        # process the sequence with transformer blocks

        for block in self.blocks:

            x = block(x, pad_mask)

        # pool the hidden state of the `[CLS]` token

        pooled_output = self.pooler(x[:, 0, :])

        return x, pooled_output

 

 

class BertPretrainingModel(nn.Module):

    def __init__(self, config: BertConfig):

        super().__init__()

        self.bert = BertModel(config)

        self.mlm_head = nn.Sequential(

            nn.Linear(config.hidden_size, config.hidden_size),

            nn.GELU(),

            nn.LayerNorm(config.hidden_size),

            nn.Linear(config.hidden_size, config.vocab_size),

        )

        self.nsp_head = nn.Linear(config.hidden_size, 2)

 

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0

                ) -> tuple[torch.Tensor, torch.Tensor]:

        # Process the sequence with the BERT model backbone

        x, pooled_output = self.bert(input_ids, token_type_ids, pad_id)

        # Predict the masked tokens for the MLM task and the classification for the NSP task

        mlm_logits = self.mlm_head(x)

        nsp_logits = self.nsp_head(pooled_output)

        return mlm_logits, nsp_logits

 

 

# Training parameters

epochs = 10

learning_rate = 1e–4

batch_size = 32

 

# Load dataset and set up dataloader

dataset = datasets.Dataset.from_parquet(“wikitext-2_train_data.parquet”)

 

def collate_fn(batch: list[dict]):

    “”“Custom collate function to handle variable-length sequences in dataset.”“”

    # always at max length: tokens, segment_ids; always singleton: is_random_next

    input_ids = torch.tensor([item[“tokens”] for item in batch])

    token_type_ids = torch.tensor([item[“segment_ids”] for item in batch]).abs()

    is_random_next = torch.tensor([item[“is_random_next”] for item in batch]).to(int)

    # variable length: masked_positions, masked_labels

    masked_pos = [(idx, pos) for idx, item in enumerate(batch) for pos in item[“masked_positions”]]

    masked_labels = torch.tensor([label for item in batch for label in item[“masked_labels”]])

    return input_ids, token_type_ids, is_random_next, masked_pos, masked_labels

 

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,

                                         collate_fn=collate_fn, num_workers=8)

 

# train the model

 

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

model = BertPretrainingModel(BertConfig()).to(device)

model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

loss_fn = nn.CrossEntropyLoss()

 

for epoch in range(epochs):

    pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”)

    for batch in pbar:

        # get batched data

        input_ids, token_type_ids, is_random_next, masked_pos, masked_labels = batch

        input_ids = input_ids.to(device)

        token_type_ids = token_type_ids.to(device)

        is_random_next = is_random_next.to(device)

        masked_labels = masked_labels.to(device)

        # extract output from model

        mlm_logits, nsp_logits = model(input_ids, token_type_ids)

        # MLM loss: masked_positions is a list of tuples of (B, S), extract the

        # corresponding logits from tensor mlm_logits of shape (B, S, V)

        batch_indices, token_positions = zip(*masked_pos)

        mlm_logits = mlm_logits[batch_indices, token_positions]

        mlm_loss = loss_fn(mlm_logits, masked_labels)

        # Compute the loss for the NSP task

        nsp_loss = loss_fn(nsp_logits, is_random_next)

        # backward with total loss

        total_loss = mlm_loss + nsp_loss

        pbar.set_postfix(MLM=mlm_loss.item(), NSP=nsp_loss.item(), Total=total_loss.item())

        optimizer.zero_grad()

        total_loss.backward()

        optimizer.step()

        scheduler.step()

        pbar.update(1)

    pbar.close()

 

# Save the model

torch.save(model.state_dict(), “bert_pretraining_model.pth”)

torch.save(model.bert.state_dict(), “bert_model.pth”)



Source_link

READ ALSO

Why Gradient Descent Zigzags and How Momentum Fixes It

White House Weighs AI Checks Before Public Release, Silicon Valley Warned

Related Posts

Why Gradient Descent Zigzags and How Momentum Fixes It
Al, Analytics and Automation

Why Gradient Descent Zigzags and How Momentum Fixes It

May 5, 2026
White House Weighs AI Checks Before Public Release, Silicon Valley Warned
Al, Analytics and Automation

White House Weighs AI Checks Before Public Release, Silicon Valley Warned

May 5, 2026
Al, Analytics and Automation

Zyphra Introduces Tensor and Sequence Parallelism (TSP): A Hardware-Aware Training and Inference Strategy That Delivers 2.6x Throughput Over Matched TP+SP Baselines

May 5, 2026
A Coding Implementation to Explore and Analyze the TaskTrove Dataset with Streaming Parsing Visualization and Verifier Detection
Al, Analytics and Automation

A Coding Implementation to Explore and Analyze the TaskTrove Dataset with Streaming Parsing Visualization and Verifier Detection

May 4, 2026
A Developer’s Guide to Systematic Prompting: Mastering Negative Constraints, Structured JSON Outputs, and Multi-Hypothesis Verbalized Sampling
Al, Analytics and Automation

A Developer’s Guide to Systematic Prompting: Mastering Negative Constraints, Structured JSON Outputs, and Multi-Hypothesis Verbalized Sampling

May 4, 2026
Sakana AI Introduces KAME: A Tandem Speech-to-Speech Architecture That Injects LLM Knowledge in Real Time
Al, Analytics and Automation

Sakana AI Introduces KAME: A Tandem Speech-to-Speech Architecture That Injects LLM Knowledge in Real Time

May 3, 2026
Next Post
Best Cyber Monday Coffee Subscription Deals (2025): Atlas, Trade

Best Cyber Monday Coffee Subscription Deals (2025): Atlas, Trade

POPULAR NEWS

Trump ends trade talks with Canada over a digital services tax

Trump ends trade talks with Canada over a digital services tax

June 28, 2025
Communication Effectiveness Skills For Business Leaders

Communication Effectiveness Skills For Business Leaders

June 10, 2025
15 Trending Songs on TikTok in 2025 (+ How to Use Them)

15 Trending Songs on TikTok in 2025 (+ How to Use Them)

June 18, 2025
App Development Cost in Singapore: Pricing Breakdown & Insights

App Development Cost in Singapore: Pricing Breakdown & Insights

June 22, 2025
Comparing the Top 7 Large Language Models LLMs/Systems for Coding in 2025

Comparing the Top 7 Large Language Models LLMs/Systems for Coding in 2025

November 4, 2025

EDITOR'S PICK

An AI Reckoning for HR: Transform or Fade Away

March 25, 2026
Fashion Takes the Lead on CTV with Sky-High Engagement Rates: VDO.AI Report

Fashion Takes the Lead on CTV with Sky-High Engagement Rates: VDO.AI Report

July 30, 2025
New tactics and tools you should know

New tactics and tools you should know

December 15, 2025
with Steve Kearns of LinkedIn – TopRank® Marketing

with Steve Kearns of LinkedIn – TopRank® Marketing

July 23, 2025

About

We bring you the best Premium WordPress Themes that perfect for news, magazine, personal blog, etc. Check our landing page for details.

Follow us

Categories

  • Account Based Marketing
  • Ad Management
  • Al, Analytics and Automation
  • Brand Management
  • Channel Marketing
  • Digital Marketing
  • Direct Marketing
  • Event Management
  • Google Marketing
  • Marketing Attribution and Consulting
  • Marketing Automation
  • Mobile Marketing
  • PR Solutions
  • Social Media Management
  • Technology And Software
  • Uncategorized

Recent Posts

  • Gemini API File Search is now multimodal
  • The Complete Website Migration Checklist [SEO-Friendly]
  • What audiences expect from brands in moments like the Met Gala
  • Planter Outfit Location in Goat Simulator 3
  • About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions