• About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
Saturday, August 23, 2025
mGrowTech
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
No Result
View All Result
mGrowTech
No Result
View All Result
Home Google Marketing

Train a GPT2 model with JAX on TPU for free

Josh by Josh
August 22, 2025
in Google Marketing
0
Train a GPT2 model with JAX on TPU for free
0
SHARES
0
VIEWS
Share on FacebookShare on Twitter


Train a GPT2 model with JAX on TPU for free

If you’ve ever been curious about what it takes to build a language model from scratch with JAX, this post is for you. I recently ran a workshop on this topic at Cloud Next 2025 and got some great feedback, so I thought I’d write up a guide for everyone who couldn’t make it.

In this article and code example, you’re going to build and pretrain a GPT-2 model, showing how JAX makes it straightforward to leverage the power of Google TPUs. You can run the entire project for free using the TPUs in Colab or Kaggle, and you can find the full notebook here.

This is a hands-on tutorial, so I’ll assume you’re familiar with general machine learning concepts. If JAX is new to you, the PyTorch developer’s guide to JAX fundamentals is a great place to start.

First, let’s take a quick look at the tools we’ll be using.

JAX ecosystem

Before we start building the model, let’s briefly talk about the JAX ecosystem. The JAX ecosystem takes a modular approach, with JAX core providing the core numerical processing capabilities, and a rich collection of libraries built on top to serve different application-specific needs, for example, Flax for building neural networks, Orbax for checkpointing and model persistence, and Optax for optimization (we are going to use all 3 in this article). Built-in function transformations, such as autograd, vectorization and JIT compilation, plus strong performance and easy-to-use APIs, make JAX perfect for training Large Language Models.

Getting started: Build your GPT2 model

OpenAI previously released the GPT2 model code and weights, which are good references, and there are many community efforts, such as nanoGPT, to replicate the model. Here is a high level model architecture diagram for GPT2:

architecture diagram for the GPT2 model

We are going to use NNX (the new Flax interface) to build the GPT2 model. For brevity, let’s focus on the transformer block, which is the key for modern large language models. The transformer block captures long-range dependencies in any sequence and builds a rich contextual understanding of it. A GPT2 transformer block consists of 2 LayerNorm layers, 1 Multi-Head Attention (MHA) layer, 2 dropout layers, 2 linear projection layers and 2 residual connections. So we first define these layers in the __init__ function of the TransformerBlock class:

class TransformerBlock(nnx.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        ff_dim: int,
        dropout_rate: float,
        rngs: nnx.Rngs,
    ):
        self.layer_norm1 = nnx.LayerNorm(
            epsilon=1e-6, num_features=embed_dim, rngs=rngs
        )
        self.mha = nnx.MultiHeadAttention(
            num_heads=num_heads, in_features=embed_dim, rngs=rngs
        )
        self.dropout1 = nnx.Dropout(rate=dropout_rate)
        self.layer_norm2 = nnx.LayerNorm(
            epsilon=1e-6, num_features=embed_dim, rngs=rngs
        )
        self.linear1 = nnx.Linear(
            in_features=embed_dim, out_features=ff_dim, rngs=rngs
        )
        self.linear2 = nnx.Linear(
            in_features=ff_dim, out_features=embed_dim, rngs=rngs
        )
        self.dropout2 = nnx.Dropout(rate=dropout_rate)

Python

Next we assemble these layers in the __call__ function:

class TransformerBlock(nnx.Module):
    def __call__(self, inputs, training: bool = False):
        input_shape = inputs.shape
        bs, seq_len, emb_sz = input_shape

        attention_output = self.mha(
            inputs_q=self.layer_norm1(inputs),
            mask=causal_attention_mask(seq_len),
            decode=False,
        )
        x = inputs + self.dropout1(
            attention_output, deterministic=not training
        )

        # MLP
        mlp_output = self.linear1(self.layer_norm2(x))
        mlp_output = nnx.gelu(mlp_output)
        mlp_output = self.linear2(mlp_output)
        mlp_output = self.dropout2(
            mlp_output, deterministic=not training
        )

        return x + mlp_output

Python

This code should look very familiar if you have used any other ML framework, like PyTorch or TensorFlow, to train a language model. But one of the things I really like about JAX is that it has the amazing capability to automatically run the code in parallel via SPMD (Single Program Multiple Data), which is needed because we will be running the code on multiple accelerators (multiple TPU cores). Let’s see how it works.

To perform SPMD, first we need to make sure we are using TPUs. Choose the TPU runtime if you are using Colab or Kaggle (you can also use a Cloud TPU VM).

import jax
jax.devices()

# Free-tier Colab offers TPU v2:
# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
#  TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
#  TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
#  TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
#  TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
#  TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
#  TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
#  TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Python

Colab and Kaggle offer a TPU v2 or v3, which has 8 separate TPU cores. Here is what a TPU v3 tray looks like:

TPU v3 board

Train your GPT2 model

To train the GPT2 model efficiently, we will run all TPU cores together via SPMD and leverage data parallelism in JAX. To achieve this, we define a hardware mesh:

mesh = jax.make_mesh((8, 1), ('batch', 'model'))

Python

Think of the mesh as a 2D matrix of accelerators. In this case, we define 2 axes for the mesh – the batch axis and the model axis. So in total we have 8 by 1, which is 8 cores. These axes determine how we partition the data and the model parameters. We can change the axes later if we want to experiment with other parallelism schemes.

Now we change the __init__ function by telling JAX how we would like to partition the model parameters using the ‘model’ axis. This is done by adding nnx.with_partitioning when initializing the weight tensors: for 1D weight tensors like LayerNorm scale/bias tensors, we directly shard them along the ‘model’ axis; for 2D weight tensors like MHA and Linear kernel tensors, we shard the 2nd dimension along the model axis.

class TransformerBlock(nnx.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        ff_dim: int,
        dropout_rate: float,
        rngs: nnx.Rngs,
    ):
        self.layer_norm1 = nnx.LayerNorm(
            epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs,
            scale_init=nnx.with_partitioning(
                nnx.initializers.ones_init(),
                ("model"),
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(),
               ("model"),
            ),
        )
        self.mha = nnx.MultiHeadAttention(
            num_heads=num_heads, in_features=embed_dim,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.xavier_uniform(),
               (None, "model"),
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(),
               ("model"),
            ),
        )
        # Other layers in the block are omitted for brevity

Python

We need to partition other layers like this so that we can enable model tensor parallelism for the entire GPT2 model. Even though we don’t use model tensor parallelism in this tutorial, it’s still a good idea to implement this because the model size may grow and we may need to partition our model parameters in the future. Having implemented this allows us to change just one line of code and immediately run bigger models. For example,

mesh = jax.make_mesh((4, 2), ('batch', 'model'))

Python

Next, we define the loss_fn and train_step functions similar to the previous blog. The train_step() function computes the gradients of the cross-entropy loss function and updates the weights via the optimizer, and it will be called in a loop to train the model. To achieve the best performance, we are JIT-compiling both functions using the @nnx.jit decorator, since they are compute intensive.

@nnx.jit
def loss_fn(model, batch):
    logits = model(batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch[1]
    ).mean()
    return loss, logits


@nnx.jit
def train_step(
    model: nnx.Module,
    optimizer: nnx.Optimizer,
    metrics: nnx.MultiMetric,
    batch,
):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, batch)
    metrics.update(loss=loss, logits=logits, lables=batch[1])
    optimizer.update(grads)

Python

For the optimizer, we are using AdamW from Optax with a cosine decay schedule. You can experiment with other optimizers or schedules in Optax as well.

schedule = optax.cosine_decay_schedule(
    init_value=init_learning_rate, decay_steps=max_steps
)
optax_chain = optax.chain(
    optax.adamw(learning_rate=schedule, weight_decay=weight_decay)
)
optimizer = nnx.Optimizer(model, optax_chain)

Python

Lastly, we create a simple training loop.

while True:
    input_batch, target_batch = get_batch("train")

    train_step(
        model,
        optimizer,
        train_metrics,
        jax.device_put(
            (input_batch, target_batch),
            NamedSharding(mesh, P("batch", None)),
        ),
    )

    step += 1
    if step > max_steps:
        break

Python

Note how we partition the input data along the batch axis using the jax.device_put function. In this case JAX will enable data parallelism and bring everything together by inserting communication collectives (AllReduce) automatically, and overlap computation and communication as much as possible. For a more in-depth discussion on parallel computation, please refer to JAX’s Introduction to parallel programming documentation.

At this point the model should be training and we can observe the training loss if Weights and Biases is used to track the run. Here is a test run for training the GPT2 124M model:

test run for the GPT2 124M model

It takes ~7 hours on Kaggle TPU v3 (which we can use for 9 hours without interruption), but if we use Trillium, training time goes down to ~1.5 hours (note that Trillium has 32G HBM (High Bandwidth Memory) per chip, so we can double the batch size and halve the training steps).

Final losses are roughly in line with nanoGPT’s, which I really enjoyed, and studied while writing this code example.

nanoGPT baseline comparison table

If we use Cloud TPUs, we can also monitor the TPU utilization via the ‘tpu-info’ command (part of the Cloud TPU Monitoring Debugging package) or Weights and Biases dashboard. Our TPUs are going brrr!

‘tpu-info’ command

After the model is trained, we can save it using Orbax:

checkpointer = orbax.PyTreeCheckpointer()
train_state = nnx.pure(nnx.state(model))
checkpointer.save(checkpoint_path, train_state)

Python

Next steps: Explore advanced LLM training and scaling

That’s it. That’s pretty much all we need to train a GPT2 model. You can find additional details like data loading, hyperparameters, metrics, in the complete notebook.

Of course GPT2 is a small model today and many frontier labs are training models with hundreds of billions of parameters. But now that you have learned how to build a small language model with JAX and TPU, you are ready to dive into How to scale your model.

In addition, you can either use MaxText to train pre-built cutting edge LLMs or learn to build the latest models from scratch by referencing the JAX LLM examples or the Stanford Marin model.

I can’t wait to see the amazing models you build with JAX and TPUs!



Source_link

READ ALSO

Ruth Porat on AI and its applications in finance

Google’s AI-stuffed Pixel 10 event

Related Posts

Ruth Porat on AI and its applications in finance
Google Marketing

Ruth Porat on AI and its applications in finance

August 23, 2025
Google’s AI-stuffed Pixel 10 event
Google Marketing

Google’s AI-stuffed Pixel 10 event

August 22, 2025
New Gemini feature and model updates for Pixels, smartphones
Google Marketing

New Gemini feature and model updates for Pixels, smartphones

August 22, 2025
Google made it easier to edit your Drive videos
Google Marketing

Google made it easier to edit your Drive videos

August 22, 2025
Google shares Made by Google 2025 news in public NotebookLM notebook
Google Marketing

Google shares Made by Google 2025 news in public NotebookLM notebook

August 22, 2025
Google reveals it isn’t making tablets, smart rings, flip phones, or glasses (yet)
Google Marketing

Google reveals it isn’t making tablets, smart rings, flip phones, or glasses (yet)

August 22, 2025
Next Post
What Is Visual Branding? How to Create a Strong Visual Brand Identity

What Is Visual Branding? How to Create a Strong Visual Brand Identity

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

POPULAR NEWS

Communication Effectiveness Skills For Business Leaders

Communication Effectiveness Skills For Business Leaders

June 10, 2025
15 Trending Songs on TikTok in 2025 (+ How to Use Them)

15 Trending Songs on TikTok in 2025 (+ How to Use Them)

June 18, 2025
7 Best EOR Platforms for Software Companies in 2025

7 Best EOR Platforms for Software Companies in 2025

June 21, 2025
Trump ends trade talks with Canada over a digital services tax

Trump ends trade talks with Canada over a digital services tax

June 28, 2025
Refreshing a Legacy Brand for a Meaningful Future – Truly Deeply – Brand Strategy & Creative Agency Melbourne

Refreshing a Legacy Brand for a Meaningful Future – Truly Deeply – Brand Strategy & Creative Agency Melbourne

June 7, 2025

EDITOR'S PICK

Do You Need A Fractional CMO|CRO on Your Leadership Team?

Do You Need A Fractional CMO|CRO on Your Leadership Team?

July 20, 2025
Labor Day Mobile Marketing Strategies

Labor Day Mobile Marketing Strategies

May 29, 2025
Nvidia’s open Nemotron-Nano-9B-v2 has toggle on/off reasoning

Nvidia’s open Nemotron-Nano-9B-v2 has toggle on/off reasoning

August 19, 2025

Renegotiations for jet fighter project aim to ease burden on state budget

April 25, 2025

About

We bring you the best Premium WordPress Themes that perfect for news, magazine, personal blog, etc. Check our landing page for details.

Follow us

Categories

  • Account Based Marketing
  • Ad Management
  • Al, Analytics and Automation
  • Brand Management
  • Channel Marketing
  • Digital Marketing
  • Direct Marketing
  • Event Management
  • Google Marketing
  • Marketing Attribution and Consulting
  • Marketing Automation
  • Mobile Marketing
  • PR Solutions
  • Social Media Management
  • Technology And Software
  • Uncategorized

Recent Posts

  • Ruth Porat on AI and its applications in finance
  • Research reveals the most dangerous Canadian provinces to work in
  • Bluesky Goes Dark in Mississippi Over Age Verification Law
  • So Delicious Talks Sensory Sampling in a Giant Pint
  • About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions

Are you sure want to unlock this post?
Unlock left : 0
Are you sure want to cancel subscription?