• About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
Thursday, November 13, 2025
mGrowTech
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
No Result
View All Result
mGrowTech
No Result
View All Result
Home Al, Analytics and Automation

A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax

Josh by Josh
November 12, 2025
in Al, Analytics and Automation
0
A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax
0
SHARES
1
VIEWS
Share on FacebookShare on Twitter


In this tutorial, we explore how to build and train an advanced neural network using JAX, Flax, and Optax in an efficient and modular way. We begin by designing a deep architecture that integrates residual connections and self-attention mechanisms for expressive feature learning. As we progress, we implement sophisticated optimization strategies with learning rate scheduling, gradient clipping, and adaptive weight decay. Throughout the process, we leverage JAX transformations such as jit, grad, and vmap to accelerate computation and ensure smooth training performance across devices. Check out the FULL CODES here.

!pip install jax jaxlib flax optax matplotlib


import jax
import jax.numpy as jnp
from jax import random, jit, vmap, grad
import flax.linen as nn
from flax.training import train_state
import optax
import matplotlib.pyplot as plt
from typing import Any, Callable


print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

We begin by installing and importing JAX, Flax, and Optax, along with essential utilities for numerical operations and visualization. We check our device setup to ensure that JAX is running efficiently on available hardware. This setup forms the foundation for the entire training pipeline. Check out the FULL CODES here.

class SelfAttention(nn.Module):
   num_heads: int
   dim: int
   @nn.compact
   def __call__(self, x):
       B, L, D = x.shape
       head_dim = D // self.num_heads
       qkv = nn.Dense(3 * D)(x)
       qkv = qkv.reshape(B, L, 3, self.num_heads, head_dim)
       q, k, v = jnp.split(qkv, 3, axis=2)
       q, k, v = q.squeeze(2), k.squeeze(2), v.squeeze(2)
       attn_scores = jnp.einsum('bhqd,bhkd->bhqk', q, k) / jnp.sqrt(head_dim)
       attn_weights = jax.nn.softmax(attn_scores, axis=-1)
       attn_output = jnp.einsum('bhqk,bhvd->bhqd', attn_weights, v)
       attn_output = attn_output.reshape(B, L, D)
       return nn.Dense(D)(attn_output)


class ResidualBlock(nn.Module):
   features: int
   @nn.compact
   def __call__(self, x, training: bool = True):
       residual = x
       x = nn.Conv(self.features, (3, 3), padding='SAME')(x)
       x = nn.BatchNorm(use_running_average=not training)(x)
       x = nn.relu(x)
       x = nn.Conv(self.features, (3, 3), padding='SAME')(x)
       x = nn.BatchNorm(use_running_average=not training)(x)
       if residual.shape[-1] != self.features:
           residual = nn.Conv(self.features, (1, 1))(residual)
       return nn.relu(x + residual)


class AdvancedCNN(nn.Module):
   num_classes: int = 10
   @nn.compact
   def __call__(self, x, training: bool = True):
       x = nn.Conv(32, (3, 3), padding='SAME')(x)
       x = nn.relu(x)
       x = ResidualBlock(64)(x, training)
       x = ResidualBlock(64)(x, training)
       x = nn.max_pool(x, (2, 2), strides=(2, 2))
       x = ResidualBlock(128)(x, training)
       x = ResidualBlock(128)(x, training)
       x = jnp.mean(x, axis=(1, 2))
       x = x[:, None, :]
       x = SelfAttention(num_heads=4, dim=128)(x)
       x = x.squeeze(1)
       x = nn.Dense(256)(x)
       x = nn.relu(x)
       x = nn.Dropout(0.5, deterministic=not training)(x)
       x = nn.Dense(self.num_classes)(x)
       return x

We define a deep neural network that combines residual blocks and a self-attention mechanism for enhanced feature learning. We construct the layers modularly, ensuring that the model can capture both spatial and contextual relationships. This design enables the network to generalize effectively across various types of input data. Check out the FULL CODES here.

class TrainState(train_state.TrainState):
   batch_stats: Any


def create_learning_rate_schedule(base_lr: float = 1e-3, warmup_steps: int = 100, decay_steps: int = 1000) -> optax.Schedule:
   warmup_fn = optax.linear_schedule(init_value=0.0, end_value=base_lr, transition_steps=warmup_steps)
   decay_fn = optax.cosine_decay_schedule(init_value=base_lr, decay_steps=decay_steps, alpha=0.1)
   return optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])


def create_optimizer(learning_rate_schedule: optax.Schedule) -> optax.GradientTransformation:
   return optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=learning_rate_schedule, weight_decay=1e-4))

We create a custom training state that tracks model parameters and batch statistics. We also define a learning rate schedule with warmup and cosine decay, paired with an AdamW optimizer that includes gradient clipping and weight decay. This combination ensures stable and adaptive training. Check out the FULL CODES here.

@jit
def compute_metrics(logits, labels):
   loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
   accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
   return {'loss': loss, 'accuracy': accuracy}


def create_train_state(rng, model, input_shape, learning_rate_schedule):
   variables = model.init(rng, jnp.ones(input_shape), training=False)
   params = variables['params']
   batch_stats = variables.get('batch_stats', {})
   tx = create_optimizer(learning_rate_schedule)
   return TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats)


@jit
def train_step(state, batch, dropout_rng):
   images, labels = batch
   def loss_fn(params):
       variables = {'params': params, 'batch_stats': state.batch_stats}
       logits, new_model_state = state.apply_fn(variables, images, training=True, mutable=['batch_stats'], rngs={'dropout': dropout_rng})
       loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
       return loss, (logits, new_model_state)
   grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
   (loss, (logits, new_model_state)), grads = grad_fn(state.params)
   state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
   metrics = compute_metrics(logits, labels)
   return state, metrics


@jit
def eval_step(state, batch):
   images, labels = batch
   variables = {'params': state.params, 'batch_stats': state.batch_stats}
   logits = state.apply_fn(variables, images, training=False)
   return compute_metrics(logits, labels)

We implement JIT-compiled training and evaluation functions to achieve efficient execution. The training step computes gradients, updates parameters, and dynamically maintains batch statistics. We also define evaluation metrics that help us monitor loss and accuracy throughout the training process. Check out the FULL CODES here.

def generate_synthetic_data(rng, num_samples=1000, img_size=32):
   rng_x, rng_y = random.split(rng)
   images = random.normal(rng_x, (num_samples, img_size, img_size, 3))
   labels = random.randint(rng_y, (num_samples,), 0, 10)
   return images, labels


def create_batches(images, labels, batch_size=32):
   num_batches = len(images) // batch_size
   for i in range(num_batches):
       idx = slice(i * batch_size, (i + 1) * batch_size)
       yield images[idx], labels[idx]

We generate synthetic data to simulate an image classification task, enabling us to train the model without relying on external datasets. We then batch the data efficiently for iterative updates. This approach allows us to test the full pipeline quickly and verify that all components function correctly. Check out the FULL CODES here.

def train_model(num_epochs=5, batch_size=32):
   rng = random.PRNGKey(0)
   rng, data_rng, model_rng = random.split(rng, 3)
   train_images, train_labels = generate_synthetic_data(data_rng, num_samples=1000)
   test_images, test_labels = generate_synthetic_data(data_rng, num_samples=200)
   model = AdvancedCNN(num_classes=10)
   lr_schedule = create_learning_rate_schedule(base_lr=1e-3, warmup_steps=50, decay_steps=500)
   state = create_train_state(model_rng, model, (1, 32, 32, 3), lr_schedule)
   history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
   print("Starting training...")
   for epoch in range(num_epochs):
       train_metrics = []
       for batch in create_batches(train_images, train_labels, batch_size):
           rng, dropout_rng = random.split(rng)
           state, metrics = train_step(state, batch, dropout_rng)
           train_metrics.append(metrics)
       train_loss = jnp.mean(jnp.array([m['loss'] for m in train_metrics]))
       train_acc = jnp.mean(jnp.array([m['accuracy'] for m in train_metrics]))
       test_metrics = [eval_step(state, batch) for batch in create_batches(test_images, test_labels, batch_size)]
       test_acc = jnp.mean(jnp.array([m['accuracy'] for m in test_metrics]))
       history['train_loss'].append(float(train_loss))
       history['train_acc'].append(float(train_acc))
       history['test_acc'].append(float(test_acc))
       print(f"Epoch {epoch + 1}/{num_epochs}: Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")
   return history, state


history, trained_state = train_model(num_epochs=5)


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history['train_loss'], label="Train Loss")
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True)
ax2.plot(history['train_acc'], label="Train Accuracy")
ax2.plot(history['test_acc'], label="Test Accuracy")
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy'); ax2.set_title('Model Accuracy'); ax2.legend(); ax2.grid(True)
plt.tight_layout(); plt.show()


print("\n✅ Tutorial complete! This covers:")
print("- Custom Flax modules (ResNet blocks, Self-Attention)")
print("- Advanced Optax optimizers (AdamW with gradient clipping)")
print("- Learning rate schedules (warmup + cosine decay)")
print("- JAX transformations (@jit for performance)")
print("- Proper state management (batch normalization statistics)")
print("- Complete training pipeline with evaluation")

We bring all components together to train the model over several epochs, track performance metrics, and visualize the trends in loss and accuracy. We monitor the model’s learning progress and validate its performance on test data. Ultimately, we confirm the stability and effectiveness of our JAX-based training workflow.

In conclusion, we implemented a comprehensive training pipeline utilizing JAX, Flax, and Optax, which demonstrates both flexibility and computational efficiency. We observed how custom architectures, advanced optimization strategies, and precise state management can come together to form a high-performance deep learning workflow. Through this exercise, we gain a deeper understanding of how to structure scalable experiments in JAX and prepare ourselves to adapt these techniques to real-world machine learning research and production tasks.


Check out the FULL CODES here. Feel free to check out our GitHub Page for Tutorials, Codes and Notebooks. Also, feel free to follow us on Twitter and don’t forget to join our 100k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.


Asif Razzaq is the CEO of Marktechpost Media Inc.. As a visionary entrepreneur and engineer, Asif is committed to harnessing the potential of Artificial Intelligence for social good. His most recent endeavor is the launch of an Artificial Intelligence Media Platform, Marktechpost, which stands out for its in-depth coverage of machine learning and deep learning news that is both technically sound and easily understandable by a wide audience. The platform boasts of over 2 million monthly views, illustrating its popularity among audiences.

🙌 Follow MARKTECHPOST: Add us as a preferred source on Google.



Source_link

READ ALSO

How to Build a Fully Functional Custom GPT-style Conversational AI Locally Using Hugging Face Transformers

Datasets for Training a Language Model

Related Posts

How to Build a Fully Functional Custom GPT-style Conversational AI Locally Using Hugging Face Transformers
Al, Analytics and Automation

How to Build a Fully Functional Custom GPT-style Conversational AI Locally Using Hugging Face Transformers

November 13, 2025
Datasets for Training a Language Model
Al, Analytics and Automation

Datasets for Training a Language Model

November 13, 2025
PR Newswire via Morningstar PR Newswire Introduces AI-Led Platform Redefining the Future of Public Relations
Al, Analytics and Automation

PR Newswire via Morningstar PR Newswire Introduces AI-Led Platform Redefining the Future of Public Relations

November 12, 2025
How to Build an End-to-End Interactive Analytics Dashboard Using PyGWalker Features for Insightful Data Exploration
Al, Analytics and Automation

How to Build an End-to-End Interactive Analytics Dashboard Using PyGWalker Features for Insightful Data Exploration

November 12, 2025
The AI Image Model That Could Redefine Visual Creativity
Al, Analytics and Automation

The AI Image Model That Could Redefine Visual Creativity

November 12, 2025
Understanding the nuances of human-like intelligence | MIT News
Al, Analytics and Automation

Understanding the nuances of human-like intelligence | MIT News

November 11, 2025
Next Post
The Nike x Hyperice Hyperboot Is $200 Off

The Nike x Hyperice Hyperboot Is $200 Off

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

POPULAR NEWS

Communication Effectiveness Skills For Business Leaders

Communication Effectiveness Skills For Business Leaders

June 10, 2025
Trump ends trade talks with Canada over a digital services tax

Trump ends trade talks with Canada over a digital services tax

June 28, 2025
15 Trending Songs on TikTok in 2025 (+ How to Use Them)

15 Trending Songs on TikTok in 2025 (+ How to Use Them)

June 18, 2025
App Development Cost in Singapore: Pricing Breakdown & Insights

App Development Cost in Singapore: Pricing Breakdown & Insights

June 22, 2025
7 Best EOR Platforms for Software Companies in 2025

7 Best EOR Platforms for Software Companies in 2025

June 21, 2025

EDITOR'S PICK

How to Do It Right (+ 7 Tips to Get Results)

How to Do It Right (+ 7 Tips to Get Results)

October 7, 2025
3 Easy Ways to Generate Diverse Creative

3 Easy Ways to Generate Diverse Creative

November 6, 2025
How to Migrate Azure File Shares with Zero Downtime

How to Migrate Azure File Shares with Zero Downtime

September 19, 2025

How the College of American Pathologist’s cyber crisis reshaped its CEO’s leadership voice

October 22, 2025

About

We bring you the best Premium WordPress Themes that perfect for news, magazine, personal blog, etc. Check our landing page for details.

Follow us

Categories

  • Account Based Marketing
  • Ad Management
  • Al, Analytics and Automation
  • Brand Management
  • Channel Marketing
  • Digital Marketing
  • Direct Marketing
  • Event Management
  • Google Marketing
  • Marketing Attribution and Consulting
  • Marketing Automation
  • Mobile Marketing
  • PR Solutions
  • Social Media Management
  • Technology And Software
  • Uncategorized

Recent Posts

  • Google relaunches Cameyo to entice businesses from Windows to ChromeOS
  • Our favorite 2025 advent calendars from Lego, Pokémon, Funko Pop, Magna-Tiles and more
  • Key Differences Between Multi-, Cross-, and Omni-Channel
  • How to Build a Fully Functional Custom GPT-style Conversational AI Locally Using Hugging Face Transformers
  • About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions

Are you sure want to unlock this post?
Unlock left : 0
Are you sure want to cancel subscription?