• About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
Tuesday, June 9, 2026
mGrowTech
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
No Result
View All Result
mGrowTech
No Result
View All Result
Home Al, Analytics and Automation

Training a Model on Multiple GPUs with Data Parallelism

Josh by Josh
December 27, 2025
in Al, Analytics and Automation
0
Training a Model on Multiple GPUs with Data Parallelism


import dataclasses

import os

 

import datasets

import tqdm

import tokenizers

import torch

import torch.distributed as dist

import torch.nn as nn

import torch.nn.functional as F

import torch.optim.lr_scheduler as lr_scheduler

from torch import Tensor

from torch.nn.parallel import DistributedDataParallel as DDP

from torch.utils.data.distributed import DistributedSampler

 

# Build the model

@dataclasses.dataclass

class LlamaConfig:

    “”“Define Llama model hyperparameters.”“”

    vocab_size: int = 50000  # Size of the tokenizer vocabulary

    max_position_embeddings: int = 2048  # Maximum sequence length

    hidden_size: int = 768  # Dimension of hidden layers

    intermediate_size: int = 4*768  # Dimension of MLP’s hidden layer

    num_hidden_layers: int = 12  # Number of transformer layers

    num_attention_heads: int = 12  # Number of attention heads

    num_key_value_heads: int = 3  # Number of key-value heads for GQA

 

 

class RotaryPositionEncoding(nn.Module):

    “”“Rotary position encoding.”“”

 

    def __init__(self, dim: int, max_position_embeddings: int) -> None:

        “”“Initialize the RotaryPositionEncoding module

 

        Args:

            dim: The hidden dimension of the input tensor to which RoPE is applied

            max_position_embeddings: The maximum sequence length of the input tensor

        ““”

        super().__init__()

        self.dim = dim

        self.max_position_embeddings = max_position_embeddings

        # compute a matrix of n\theta_i

        N = 10_000.0

        inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim))

        inv_freq = torch.cat((inv_freq, inv_freq), dim=–1)

        position = torch.arange(max_position_embeddings)

        sinusoid_inp = torch.outer(position, inv_freq)

        # save cosine and sine matrices as buffers, not parameters

        self.register_buffer(“cos”, sinusoid_inp.cos())

        self.register_buffer(“sin”, sinusoid_inp.sin())

 

    def forward(self, x: Tensor) -> Tensor:

        “”“Apply RoPE to tensor x

 

        Args:

            x: Input tensor of shape (batch_size, seq_length, num_heads, head_dim)

 

        Returns:

            Output tensor of shape (batch_size, seq_length, num_heads, head_dim)

        ““”

        batch_size, seq_len, num_heads, head_dim = x.shape

        dtype = x.dtype

        # transform the cosine and sine matrices to 4D tensor and the same dtype as x

        cos = self.cos.to(dtype)[:seq_len].view(1, seq_len, 1, –1)

        sin = self.sin.to(dtype)[:seq_len].view(1, seq_len, 1, –1)

        # apply RoPE to x

        x1, x2 = x.chunk(2, dim=–1)

        rotated = torch.cat((–x2, x1), dim=–1)

        output = (x * cos) + (rotated * sin)

        return output

 

 

class LlamaAttention(nn.Module):

    “”“Grouped-query attention with rotary embeddings.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.hidden_size = config.hidden_size

        self.num_heads = config.num_attention_heads

        self.head_dim = self.hidden_size // self.num_heads

        self.num_kv_heads = config.num_key_value_heads  # GQA: H_kv < H_q

 

        # hidden_size must be divisible by num_heads

        assert (self.head_dim * self.num_heads) == self.hidden_size

 

        # Linear layers for Q, K, V projections

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)

        self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

        self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

 

    def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:

        bs, seq_len, dim = hidden_states.size()

 

        # Project inputs to Q, K, V

        query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)

        key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

        value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

 

        # Apply rotary position embeddings

        query_states = rope(query_states)

        key_states = rope(key_states)

 

        # Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention

        query_states = query_states.transpose(1, 2)

        key_states = key_states.transpose(1, 2)

        value_states = value_states.transpose(1, 2)

 

        # Use PyTorch’s optimized attention implementation

        # setting is_causal=True is incompatible with setting explicit attention mask

        attn_output = F.scaled_dot_product_attention(

            query_states,

            key_states,

            value_states,

            attn_mask=attn_mask,

            dropout_p=0.0,

            enable_gqa=True,

        )

 

        # Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, and then project output

        attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output

 

 

class LlamaMLP(nn.Module):

    “”“Feed-forward network with SwiGLU activation.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        # Two parallel projections for SwiGLU

        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

        self.act_fn = F.silu  # SwiGLU activation function

        # Project back to hidden size

        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

 

    def forward(self, x: Tensor) -> Tensor:

        # SwiGLU activation: multiply gate and up-projected inputs

        gate = self.act_fn(self.gate_proj(x))

        up = self.up_proj(x)

        return self.down_proj(gate * up)

 

 

class LlamaDecoderLayer(nn.Module):

    “”“Single transformer layer for a Llama model.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)

        self.self_attn = LlamaAttention(config)

        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)

        self.mlp = LlamaMLP(config)

 

    def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:

        # First residual block: Self-attention

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        attn_outputs = self.self_attn(hidden_states, rope=rope, attn_mask=attn_mask)

        hidden_states = attn_outputs + residual

 

        # Second residual block: MLP

        residual = hidden_states

        hidden_states = self.post_attention_layernorm(hidden_states)

        hidden_states = self.mlp(hidden_states) + residual

        return hidden_states

 

 

class LlamaModel(nn.Module):

    “”“The full Llama model without any pretraining heads.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.rotary_emb = RotaryPositionEncoding(

            config.hidden_size // config.num_attention_heads,

            config.max_position_embeddings,

        )

 

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])

        self.norm = nn.RMSNorm(config.hidden_size, eps=1e–5)

 

    def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:

        # Convert input token IDs to embeddings

        hidden_states = self.embed_tokens(input_ids)

        # Process through all transformer layers, then the final norm layer

        for layer in self.layers:

            hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask)

        hidden_states = self.norm(hidden_states)

        # Return the final hidden states

        return hidden_states

 

 

class LlamaForPretraining(nn.Module):

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.base_model = LlamaModel(config)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

 

    def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:

        hidden_states = self.base_model(input_ids, attn_mask)

        return self.lm_head(hidden_states)

 

 

def create_causal_mask(batch: Tensor, dtype: torch.dtype = torch.float32) -> Tensor:

    “”“Create a causal mask for self-attention.

 

    Args:

        batch: Batch of sequences, shape (batch_size, seq_len)

        dtype: Data type of the mask

 

    Returns:

        Causal mask of shape (seq_len, seq_len)

    ““”

    batch_size, seq_len = batch.shape

    mask = torch.full((seq_len, seq_len), float(‘-inf’), device=batch.device, dtype=dtype) \

                .triu(diagonal=1)

    return mask

 

 

def create_padding_mask(batch: Tensor, padding_token_id: int, dtype: torch.dtype = torch.float32) -> Tensor:

    “”“Create a padding mask for a batch of sequences for self-attention.

 

    Args:

        batch: Batch of sequences, shape (batch_size, seq_len)

        padding_token_id: ID of the padding token

        dtype: Data type of the mask

 

    Returns:

        Padding mask of shape (batch_size, 1, seq_len, seq_len)

    ““”

    padded = torch.zeros_like(batch, device=batch.device, dtype=dtype) \

                  .masked_fill(batch == padding_token_id, float(‘-inf’))

    mask = padded[:,:,None] + padded[:,None,:]

    return mask[:, None, :, :]

 

 

# Generator function to create padded sequences of fixed length

class PretrainingDataset(torch.utils.data.Dataset):

    def __init__(self, dataset: datasets.Dataset, tokenizer: tokenizers.Tokenizer,

                seq_length: int):

        self.dataset = dataset

        self.tokenizer = tokenizer

        self.seq_length = seq_length

        self.bot = tokenizer.token_to_id(“[BOT]”)

        self.eot = tokenizer.token_to_id(“[EOT]”)

        self.pad = tokenizer.token_to_id(“[PAD]”)

 

    def __len__(self):

        return len(self.dataset)

 

    def __getitem__(self, index):

        “”“Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens

        are added. Clipped and padded to the sequence length.

        ““”

        seq = self.dataset[index][“text”]

        tokens: list[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot]

        # pad to target sequence length

        toklen = len(tokens)

        if toklen < self.seq_length+1:

            pad_length = self.seq_length+1 – toklen

            tokens += [self.pad] * pad_length

        # return the sequence

        x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64)

        y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64)

        return x, y

 

# Load the tokenizer

tokenizer = tokenizers.Tokenizer.from_file(“bpe_50K.json”)

 

# Load the dataset

dataset = datasets.load_dataset(“HuggingFaceFW/fineweb”, “sample-10BT”, split=“train”)

 

# Initialize the distributed environment

dist.init_process_group(backend=“nccl”)

rank = dist.get_rank()

local_rank = int(os.environ[“LOCAL_RANK”])

world_size = dist.get_world_size()

device = torch.device(f“cuda:{local_rank}”)

print(f“World size: {world_size}, Rank: {rank}, Local rank: {local_rank}. Using device: {device}”)

#torch.cuda.set_device(local_rank)

#torch.set_default_device(device)

 

# Create pretraining model with default config, then wrap it in DDP

model_config = LlamaConfig()

model = LlamaForPretraining(model_config).to(rank)

model = DDP(model, device_ids=[local_rank])  # , output_device=local_rank)

model.train()

 

# print the model size

print(f“Model parameters size: {sum(p.numel() for p in model.parameters()) / 1024**2:.2f} M”)

print(f“Model buffers size: {sum(p.numel() for p in model.buffers()) / 1024**2:.2f} M”)

print(f“Model precision(s): {set([x.dtype for x in model.state_dict().values()])}”)

 

# Training parameters

epochs = 3

learning_rate = 1e–3

batch_size = 64

seq_length = 512

num_warmup_steps = 1000

PAD_TOKEN_ID = tokenizer.token_to_id(“[PAD]”)

 

# DataLoader, optimizer, scheduler, and loss function

dataset = PretrainingDataset(dataset, tokenizer, seq_length)

sampler = DistributedSampler(dataset, shuffle=False)

dataloader = torch.utils.data.DataLoader(

    dataset,

    batch_size=batch_size,

    sampler=sampler,

    pin_memory=True,  # optional

    shuffle=False,

    num_workers=world_size,

)

optimizer = torch.optim.AdamW(

    model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e–8, weight_decay=0.1

)

num_training_steps = len(dataloader) * epochs

print(f“Number of training steps: {num_training_steps} = {len(dataloader)} * {epochs}”)

warmup_scheduler = lr_scheduler.LinearLR(

    optimizer,

    start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps

)

cosine_scheduler = lr_scheduler.CosineAnnealingLR(

    optimizer,

    T_max=num_training_steps – num_warmup_steps,

    eta_min=0

)

scheduler = lr_scheduler.SequentialLR(

    optimizer,

    schedulers=[warmup_scheduler, cosine_scheduler],

    milestones=[num_warmup_steps]

)

loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID)

 

# start training

for epoch in range(epochs):

    pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”)

    sampler.set_epoch(epoch)   # required for shuffling only

    for batch_id, batch in enumerate(pbar):

        if batch_id % 1000 == 0 and rank == 0:

            # checkpoint the model and optimizer state, only on rank 0 process

            torch.save({

                “model”: model.module.state_dict() if isinstance(model, DDP) else model.state_dict(),

                “optimizer”: optimizer.state_dict(),

                “scheduler”: scheduler.state_dict(),

                “epoch”: epoch,

                “batch”: batch_id,

            }, f“llama_pretraining_checkpoint.pth”)

        # get batched data, move from CPU to GPU

        input_ids, target_ids = batch

        input_ids = input_ids.to(device)

        target_ids = target_ids.to(device)

        # create attention mask: causal mask + padding mask

        attn_mask = create_causal_mask(input_ids) + \

                    create_padding_mask(input_ids, PAD_TOKEN_ID)

        # extract output from model

        logits = model(input_ids, attn_mask)

        # compute loss: cross-entropy between logits and target, ignoring padding tokens

        loss = loss_fn(logits.view(–1, logits.size(–1)), target_ids.view(–1))

        # backward with loss and gradient clipping by L2 norm to 1.0

        optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        scheduler.step()

        pbar.set_postfix(loss=loss.item())

        pbar.update(1)

    pbar.close()

 

# Save the model

if rank == 0:

    torch.save(model.state_dict(), “llama_pretraining_model.pth”)

    torch.save(model.base_model.state_dict(), “llama_model.pth”)

 

# Clean up the distributed environment

dist.destroy_process_group()



Source_link

READ ALSO

ClawHub Security Signals: A Coding Guide to End-to-End Security Signal Analysis and Verdict Classification on the AI Skills Dataset

Microsoft AI Introduces MAI-Transcribe-1.5: 2.4% WER on Artificial Analysis, Best-in-Class FLEURS Accuracy, and Up to 5x Faster Long-Audio Transcription

Related Posts

ClawHub Security Signals: A Coding Guide to End-to-End Security Signal Analysis and Verdict Classification on the AI Skills Dataset
Al, Analytics and Automation

ClawHub Security Signals: A Coding Guide to End-to-End Security Signal Analysis and Verdict Classification on the AI Skills Dataset

June 8, 2026
Microsoft AI Introduces MAI-Transcribe-1.5: 2.4% WER on Artificial Analysis, Best-in-Class FLEURS Accuracy, and Up to 5x Faster Long-Audio Transcription
Al, Analytics and Automation

Microsoft AI Introduces MAI-Transcribe-1.5: 2.4% WER on Artificial Analysis, Best-in-Class FLEURS Accuracy, and Up to 5x Faster Long-Audio Transcription

June 8, 2026
Building Reflective Prompt Optimization with GEPA: Multi-Component Prompts, Structured Feedback, and Held-Out Validation
Al, Analytics and Automation

Building Reflective Prompt Optimization with GEPA: Multi-Component Prompts, Structured Feedback, and Held-Out Validation

June 7, 2026
Best 21 Low-Code and No-Code AI Tools in 2026
Al, Analytics and Automation

Best 21 Low-Code and No-Code AI Tools in 2026

June 7, 2026
Tod Machover receives George Peabody Medal for contributions to music and technology | MIT News
Al, Analytics and Automation

Tod Machover receives George Peabody Medal for contributions to music and technology | MIT News

June 6, 2026
Moonshot AI Releases Kimi Code CLI: A Terminal AI Coding Agent Built in TypeScript for Next-Gen Agents
Al, Analytics and Automation

Moonshot AI Releases Kimi Code CLI: A Terminal AI Coding Agent Built in TypeScript for Next-Gen Agents

June 6, 2026
Next Post
Is listening to podcasts good for your brain?

Is listening to podcasts good for your brain?

POPULAR NEWS

Trump ends trade talks with Canada over a digital services tax

Trump ends trade talks with Canada over a digital services tax

June 28, 2025
15 Trending Songs on TikTok in 2025 (+ How to Use Them)

15 Trending Songs on TikTok in 2025 (+ How to Use Them)

June 18, 2025
Communication Effectiveness Skills For Business Leaders

Communication Effectiveness Skills For Business Leaders

June 10, 2025
App Development Cost in Singapore: Pricing Breakdown & Insights

App Development Cost in Singapore: Pricing Breakdown & Insights

June 22, 2025
Comparing the Top 7 Large Language Models LLMs/Systems for Coding in 2025

Comparing the Top 7 Large Language Models LLMs/Systems for Coding in 2025

November 4, 2025

EDITOR'S PICK

What is social media analytics? A complete guide for 2026

What is social media analytics? A complete guide for 2026

April 6, 2026
Daniel Ek-backed defense tech Helsing to raise $1.2B at $18B valuation

Daniel Ek-backed defense tech Helsing to raise $1.2B at $18B valuation

May 11, 2026

Highlights from Ragan’s PR Daily Media Relations and Nonprofit Communications Awards luncheon

October 1, 2025
Digital Marketing Australia: Why Partnering with a Digital Marketing Agency in Australia is the Smartest Move for Your Business

Digital Marketing Australia: Why Partnering with a Digital Marketing Agency in Australia is the Smartest Move for Your Business

August 8, 2025

About

We bring you the best Premium WordPress Themes that perfect for news, magazine, personal blog, etc. Check our landing page for details.

Follow us

Categories

  • Account Based Marketing
  • Ad Management
  • Al, Analytics and Automation
  • Brand Management
  • Channel Marketing
  • Digital Marketing
  • Direct Marketing
  • Event Management
  • Google Marketing
  • Marketing Attribution and Consulting
  • Marketing Automation
  • Mobile Marketing
  • PR Solutions
  • Social Media Management
  • Technology And Software
  • Uncategorized

Recent Posts

  • LinkedIn Crossclimb Answer Today for June 8, 2026 (Puzzle #769)
  • The Stella Artois Clay Bar, Maple Street’s Biscuit Blaster
  • The Scoop: Tim Cook makes a play for his legacy at final WWDC
  • 12 best online reputation management tools for 2026
  • About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions