• About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
Tuesday, December 2, 2025
mGrowTech
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
No Result
View All Result
mGrowTech
No Result
View All Result
Home Al, Analytics and Automation

Pretrain a BERT Model from Scratch

Josh by Josh
December 1, 2025
in Al, Analytics and Automation
0
Pretrain a BERT Model from Scratch
0
SHARES
0
VIEWS
Share on FacebookShare on Twitter


import dataclasses

 

import datasets

import torch

import torch.nn as nn

import tqdm

 

 

@dataclasses.dataclass

class BertConfig:

    “”“Configuration for BERT model.”“”

    vocab_size: int = 30522

    num_layers: int = 12

    hidden_size: int = 768

    num_heads: int = 12

    dropout_prob: float = 0.1

    pad_id: int = 0

    max_seq_len: int = 512

    num_types: int = 2

 

 

 

class BertBlock(nn.Module):

    “”“One transformer block in BERT.”“”

    def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float):

        super().__init__()

        self.attention = nn.MultiheadAttention(hidden_size, num_heads,

                                               dropout=dropout_prob, batch_first=True)

        self.attn_norm = nn.LayerNorm(hidden_size)

        self.ff_norm = nn.LayerNorm(hidden_size)

        self.dropout = nn.Dropout(dropout_prob)

        self.feed_forward = nn.Sequential(

            nn.Linear(hidden_size, 4 * hidden_size),

            nn.GELU(),

            nn.Linear(4 * hidden_size, hidden_size),

        )

 

    def forward(self, x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:

        # self-attention with padding mask and post-norm

        attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)

        x = self.attn_norm(x + attn_output)

        # feed-forward with GeLU activation and post-norm

        ff_output = self.feed_forward(x)

        x = self.ff_norm(x + self.dropout(ff_output))

        return x

 

 

class BertPooler(nn.Module):

    “”“Pooler layer for BERT to process the [CLS] token output.”“”

    def __init__(self, hidden_size: int):

        super().__init__()

        self.dense = nn.Linear(hidden_size, hidden_size)

        self.activation = nn.Tanh()

 

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.dense(x)

        x = self.activation(x)

        return x

 

 

class BertModel(nn.Module):

    “”“Backbone of BERT model.”“”

    def __init__(self, config: BertConfig):

        super().__init__()

        # embedding layers

        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size,

                                            padding_idx=config.pad_id)

        self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)

        self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)

        self.embeddings_norm = nn.LayerNorm(config.hidden_size)

        self.embeddings_dropout = nn.Dropout(config.dropout_prob)

        # transformer blocks

        self.blocks = nn.ModuleList([

            BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)

            for _ in range(config.num_layers)

        ])

        # [CLS] pooler layer

        self.pooler = BertPooler(config.hidden_size)

 

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0

                ) -> tuple[torch.Tensor, torch.Tensor]:

        # create attention mask for padding tokens

        pad_mask = input_ids == pad_id

        # convert integer tokens to embedding vectors

        batch_size, seq_len = input_ids.shape

        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.position_embeddings(position_ids)

        type_embeddings = self.type_embeddings(token_type_ids)

        token_embeddings = self.word_embeddings(input_ids)

        x = token_embeddings + type_embeddings + position_embeddings

        x = self.embeddings_norm(x)

        x = self.embeddings_dropout(x)

        # process the sequence with transformer blocks

        for block in self.blocks:

            x = block(x, pad_mask)

        # pool the hidden state of the `[CLS]` token

        pooled_output = self.pooler(x[:, 0, :])

        return x, pooled_output

 

 

class BertPretrainingModel(nn.Module):

    def __init__(self, config: BertConfig):

        super().__init__()

        self.bert = BertModel(config)

        self.mlm_head = nn.Sequential(

            nn.Linear(config.hidden_size, config.hidden_size),

            nn.GELU(),

            nn.LayerNorm(config.hidden_size),

            nn.Linear(config.hidden_size, config.vocab_size),

        )

        self.nsp_head = nn.Linear(config.hidden_size, 2)

 

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0

                ) -> tuple[torch.Tensor, torch.Tensor]:

        # Process the sequence with the BERT model backbone

        x, pooled_output = self.bert(input_ids, token_type_ids, pad_id)

        # Predict the masked tokens for the MLM task and the classification for the NSP task

        mlm_logits = self.mlm_head(x)

        nsp_logits = self.nsp_head(pooled_output)

        return mlm_logits, nsp_logits

 

 

# Training parameters

epochs = 10

learning_rate = 1e–4

batch_size = 32

 

# Load dataset and set up dataloader

dataset = datasets.Dataset.from_parquet(“wikitext-2_train_data.parquet”)

 

def collate_fn(batch: list[dict]):

    “”“Custom collate function to handle variable-length sequences in dataset.”“”

    # always at max length: tokens, segment_ids; always singleton: is_random_next

    input_ids = torch.tensor([item[“tokens”] for item in batch])

    token_type_ids = torch.tensor([item[“segment_ids”] for item in batch]).abs()

    is_random_next = torch.tensor([item[“is_random_next”] for item in batch]).to(int)

    # variable length: masked_positions, masked_labels

    masked_pos = [(idx, pos) for idx, item in enumerate(batch) for pos in item[“masked_positions”]]

    masked_labels = torch.tensor([label for item in batch for label in item[“masked_labels”]])

    return input_ids, token_type_ids, is_random_next, masked_pos, masked_labels

 

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,

                                         collate_fn=collate_fn, num_workers=8)

 

# train the model

 

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

model = BertPretrainingModel(BertConfig()).to(device)

model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

loss_fn = nn.CrossEntropyLoss()

 

for epoch in range(epochs):

    pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”)

    for batch in pbar:

        # get batched data

        input_ids, token_type_ids, is_random_next, masked_pos, masked_labels = batch

        input_ids = input_ids.to(device)

        token_type_ids = token_type_ids.to(device)

        is_random_next = is_random_next.to(device)

        masked_labels = masked_labels.to(device)

        # extract output from model

        mlm_logits, nsp_logits = model(input_ids, token_type_ids)

        # MLM loss: masked_positions is a list of tuples of (B, S), extract the

        # corresponding logits from tensor mlm_logits of shape (B, S, V)

        batch_indices, token_positions = zip(*masked_pos)

        mlm_logits = mlm_logits[batch_indices, token_positions]

        mlm_loss = loss_fn(mlm_logits, masked_labels)

        # Compute the loss for the NSP task

        nsp_loss = loss_fn(nsp_logits, is_random_next)

        # backward with total loss

        total_loss = mlm_loss + nsp_loss

        pbar.set_postfix(MLM=mlm_loss.item(), NSP=nsp_loss.item(), Total=total_loss.item())

        optimizer.zero_grad()

        total_loss.backward()

        optimizer.step()

        scheduler.step()

        pbar.update(1)

    pbar.close()

 

# Save the model

torch.save(model.state_dict(), “bert_pretraining_model.pth”)

torch.save(model.bert.state_dict(), “bert_model.pth”)



Source_link

READ ALSO

MIT Sea Grant students explore the intersection of technology and offshore aquaculture in Norway | MIT News

MiniMax-M2: Technical Deep Dive into Interleaved Thinking for Agentic Coding Workflows

Related Posts

MIT Sea Grant students explore the intersection of technology and offshore aquaculture in Norway | MIT News
Al, Analytics and Automation

MIT Sea Grant students explore the intersection of technology and offshore aquaculture in Norway | MIT News

December 2, 2025
MiniMax-M2: Technical Deep Dive into Interleaved Thinking for Agentic Coding Workflows
Al, Analytics and Automation

MiniMax-M2: Technical Deep Dive into Interleaved Thinking for Agentic Coding Workflows

December 2, 2025
How to Design an Advanced Multi-Page Interactive Analytics Dashboard with Dynamic Filtering, Live KPIs, and Rich Visual Exploration Using Panel
Al, Analytics and Automation

How to Design an Advanced Multi-Page Interactive Analytics Dashboard with Dynamic Filtering, Live KPIs, and Rich Visual Exploration Using Panel

December 1, 2025
The Journey of a Token: What Really Happens Inside a Transformer
Al, Analytics and Automation

The Journey of a Token: What Really Happens Inside a Transformer

December 1, 2025
Al, Analytics and Automation

Meta AI Researchers Introduce Matrix: A Ray Native a Decentralized Framework for Multi Agent Synthetic Data Generation

November 30, 2025
Training a Tokenizer for BERT Models
Al, Analytics and Automation

Training a Tokenizer for BERT Models

November 30, 2025
Next Post
Best Cyber Monday Coffee Subscription Deals (2025): Atlas, Trade

Best Cyber Monday Coffee Subscription Deals (2025): Atlas, Trade

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

POPULAR NEWS

Communication Effectiveness Skills For Business Leaders

Communication Effectiveness Skills For Business Leaders

June 10, 2025
15 Trending Songs on TikTok in 2025 (+ How to Use Them)

15 Trending Songs on TikTok in 2025 (+ How to Use Them)

June 18, 2025
Trump ends trade talks with Canada over a digital services tax

Trump ends trade talks with Canada over a digital services tax

June 28, 2025
App Development Cost in Singapore: Pricing Breakdown & Insights

App Development Cost in Singapore: Pricing Breakdown & Insights

June 22, 2025
7 Best EOR Platforms for Software Companies in 2025

7 Best EOR Platforms for Software Companies in 2025

June 21, 2025

EDITOR'S PICK

Social Media Management Tool Comparison

Social Media Management Tool Comparison

September 15, 2025
Fitbit’s overhauled app and AI health coach arrive tomorrow

Fitbit’s overhauled app and AI health coach arrive tomorrow

October 28, 2025
What to expect at the Google Pixel 10 launch event on August 20

What to expect at the Google Pixel 10 launch event on August 20

July 23, 2025

Inside the OKCupid playbook on using data in storytelling

October 12, 2025

About

We bring you the best Premium WordPress Themes that perfect for news, magazine, personal blog, etc. Check our landing page for details.

Follow us

Categories

  • Account Based Marketing
  • Ad Management
  • Al, Analytics and Automation
  • Brand Management
  • Channel Marketing
  • Digital Marketing
  • Direct Marketing
  • Event Management
  • Google Marketing
  • Marketing Attribution and Consulting
  • Marketing Automation
  • Mobile Marketing
  • PR Solutions
  • Social Media Management
  • Technology And Software
  • Uncategorized

Recent Posts

  • How to create an Instagram marketing strategy (2025 guide)
  • The best charities for helping animals in 2025
  • MIT Sea Grant students explore the intersection of technology and offshore aquaculture in Norway | MIT News
  • Boeing And The Quest For Quality
  • About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions

Are you sure want to unlock this post?
Unlock left : 0
Are you sure want to cancel subscription?