• About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
Tuesday, March 10, 2026
mGrowTech
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
No Result
View All Result
mGrowTech
No Result
View All Result
Home Al, Analytics and Automation

A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax

Josh by Josh
November 12, 2025
in Al, Analytics and Automation
0
A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax


In this tutorial, we explore how to build and train an advanced neural network using JAX, Flax, and Optax in an efficient and modular way. We begin by designing a deep architecture that integrates residual connections and self-attention mechanisms for expressive feature learning. As we progress, we implement sophisticated optimization strategies with learning rate scheduling, gradient clipping, and adaptive weight decay. Throughout the process, we leverage JAX transformations such as jit, grad, and vmap to accelerate computation and ensure smooth training performance across devices. Check out the FULL CODES here.

!pip install jax jaxlib flax optax matplotlib


import jax
import jax.numpy as jnp
from jax import random, jit, vmap, grad
import flax.linen as nn
from flax.training import train_state
import optax
import matplotlib.pyplot as plt
from typing import Any, Callable


print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

We begin by installing and importing JAX, Flax, and Optax, along with essential utilities for numerical operations and visualization. We check our device setup to ensure that JAX is running efficiently on available hardware. This setup forms the foundation for the entire training pipeline. Check out the FULL CODES here.

class SelfAttention(nn.Module):
   num_heads: int
   dim: int
   @nn.compact
   def __call__(self, x):
       B, L, D = x.shape
       head_dim = D // self.num_heads
       qkv = nn.Dense(3 * D)(x)
       qkv = qkv.reshape(B, L, 3, self.num_heads, head_dim)
       q, k, v = jnp.split(qkv, 3, axis=2)
       q, k, v = q.squeeze(2), k.squeeze(2), v.squeeze(2)
       attn_scores = jnp.einsum('bhqd,bhkd->bhqk', q, k) / jnp.sqrt(head_dim)
       attn_weights = jax.nn.softmax(attn_scores, axis=-1)
       attn_output = jnp.einsum('bhqk,bhvd->bhqd', attn_weights, v)
       attn_output = attn_output.reshape(B, L, D)
       return nn.Dense(D)(attn_output)


class ResidualBlock(nn.Module):
   features: int
   @nn.compact
   def __call__(self, x, training: bool = True):
       residual = x
       x = nn.Conv(self.features, (3, 3), padding='SAME')(x)
       x = nn.BatchNorm(use_running_average=not training)(x)
       x = nn.relu(x)
       x = nn.Conv(self.features, (3, 3), padding='SAME')(x)
       x = nn.BatchNorm(use_running_average=not training)(x)
       if residual.shape[-1] != self.features:
           residual = nn.Conv(self.features, (1, 1))(residual)
       return nn.relu(x + residual)


class AdvancedCNN(nn.Module):
   num_classes: int = 10
   @nn.compact
   def __call__(self, x, training: bool = True):
       x = nn.Conv(32, (3, 3), padding='SAME')(x)
       x = nn.relu(x)
       x = ResidualBlock(64)(x, training)
       x = ResidualBlock(64)(x, training)
       x = nn.max_pool(x, (2, 2), strides=(2, 2))
       x = ResidualBlock(128)(x, training)
       x = ResidualBlock(128)(x, training)
       x = jnp.mean(x, axis=(1, 2))
       x = x[:, None, :]
       x = SelfAttention(num_heads=4, dim=128)(x)
       x = x.squeeze(1)
       x = nn.Dense(256)(x)
       x = nn.relu(x)
       x = nn.Dropout(0.5, deterministic=not training)(x)
       x = nn.Dense(self.num_classes)(x)
       return x

We define a deep neural network that combines residual blocks and a self-attention mechanism for enhanced feature learning. We construct the layers modularly, ensuring that the model can capture both spatial and contextual relationships. This design enables the network to generalize effectively across various types of input data. Check out the FULL CODES here.

class TrainState(train_state.TrainState):
   batch_stats: Any


def create_learning_rate_schedule(base_lr: float = 1e-3, warmup_steps: int = 100, decay_steps: int = 1000) -> optax.Schedule:
   warmup_fn = optax.linear_schedule(init_value=0.0, end_value=base_lr, transition_steps=warmup_steps)
   decay_fn = optax.cosine_decay_schedule(init_value=base_lr, decay_steps=decay_steps, alpha=0.1)
   return optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])


def create_optimizer(learning_rate_schedule: optax.Schedule) -> optax.GradientTransformation:
   return optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=learning_rate_schedule, weight_decay=1e-4))

We create a custom training state that tracks model parameters and batch statistics. We also define a learning rate schedule with warmup and cosine decay, paired with an AdamW optimizer that includes gradient clipping and weight decay. This combination ensures stable and adaptive training. Check out the FULL CODES here.

@jit
def compute_metrics(logits, labels):
   loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
   accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
   return {'loss': loss, 'accuracy': accuracy}


def create_train_state(rng, model, input_shape, learning_rate_schedule):
   variables = model.init(rng, jnp.ones(input_shape), training=False)
   params = variables['params']
   batch_stats = variables.get('batch_stats', {})
   tx = create_optimizer(learning_rate_schedule)
   return TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats)


@jit
def train_step(state, batch, dropout_rng):
   images, labels = batch
   def loss_fn(params):
       variables = {'params': params, 'batch_stats': state.batch_stats}
       logits, new_model_state = state.apply_fn(variables, images, training=True, mutable=['batch_stats'], rngs={'dropout': dropout_rng})
       loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
       return loss, (logits, new_model_state)
   grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
   (loss, (logits, new_model_state)), grads = grad_fn(state.params)
   state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
   metrics = compute_metrics(logits, labels)
   return state, metrics


@jit
def eval_step(state, batch):
   images, labels = batch
   variables = {'params': state.params, 'batch_stats': state.batch_stats}
   logits = state.apply_fn(variables, images, training=False)
   return compute_metrics(logits, labels)

We implement JIT-compiled training and evaluation functions to achieve efficient execution. The training step computes gradients, updates parameters, and dynamically maintains batch statistics. We also define evaluation metrics that help us monitor loss and accuracy throughout the training process. Check out the FULL CODES here.

def generate_synthetic_data(rng, num_samples=1000, img_size=32):
   rng_x, rng_y = random.split(rng)
   images = random.normal(rng_x, (num_samples, img_size, img_size, 3))
   labels = random.randint(rng_y, (num_samples,), 0, 10)
   return images, labels


def create_batches(images, labels, batch_size=32):
   num_batches = len(images) // batch_size
   for i in range(num_batches):
       idx = slice(i * batch_size, (i + 1) * batch_size)
       yield images[idx], labels[idx]

We generate synthetic data to simulate an image classification task, enabling us to train the model without relying on external datasets. We then batch the data efficiently for iterative updates. This approach allows us to test the full pipeline quickly and verify that all components function correctly. Check out the FULL CODES here.

def train_model(num_epochs=5, batch_size=32):
   rng = random.PRNGKey(0)
   rng, data_rng, model_rng = random.split(rng, 3)
   train_images, train_labels = generate_synthetic_data(data_rng, num_samples=1000)
   test_images, test_labels = generate_synthetic_data(data_rng, num_samples=200)
   model = AdvancedCNN(num_classes=10)
   lr_schedule = create_learning_rate_schedule(base_lr=1e-3, warmup_steps=50, decay_steps=500)
   state = create_train_state(model_rng, model, (1, 32, 32, 3), lr_schedule)
   history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
   print("Starting training...")
   for epoch in range(num_epochs):
       train_metrics = []
       for batch in create_batches(train_images, train_labels, batch_size):
           rng, dropout_rng = random.split(rng)
           state, metrics = train_step(state, batch, dropout_rng)
           train_metrics.append(metrics)
       train_loss = jnp.mean(jnp.array([m['loss'] for m in train_metrics]))
       train_acc = jnp.mean(jnp.array([m['accuracy'] for m in train_metrics]))
       test_metrics = [eval_step(state, batch) for batch in create_batches(test_images, test_labels, batch_size)]
       test_acc = jnp.mean(jnp.array([m['accuracy'] for m in test_metrics]))
       history['train_loss'].append(float(train_loss))
       history['train_acc'].append(float(train_acc))
       history['test_acc'].append(float(test_acc))
       print(f"Epoch {epoch + 1}/{num_epochs}: Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")
   return history, state


history, trained_state = train_model(num_epochs=5)


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history['train_loss'], label="Train Loss")
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True)
ax2.plot(history['train_acc'], label="Train Accuracy")
ax2.plot(history['test_acc'], label="Test Accuracy")
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy'); ax2.set_title('Model Accuracy'); ax2.legend(); ax2.grid(True)
plt.tight_layout(); plt.show()


print("\n✅ Tutorial complete! This covers:")
print("- Custom Flax modules (ResNet blocks, Self-Attention)")
print("- Advanced Optax optimizers (AdamW with gradient clipping)")
print("- Learning rate schedules (warmup + cosine decay)")
print("- JAX transformations (@jit for performance)")
print("- Proper state management (batch normalization statistics)")
print("- Complete training pipeline with evaluation")

We bring all components together to train the model over several epochs, track performance metrics, and visualize the trends in loss and accuracy. We monitor the model’s learning progress and validate its performance on test data. Ultimately, we confirm the stability and effectiveness of our JAX-based training workflow.

In conclusion, we implemented a comprehensive training pipeline utilizing JAX, Flax, and Optax, which demonstrates both flexibility and computational efficiency. We observed how custom architectures, advanced optimization strategies, and precise state management can come together to form a high-performance deep learning workflow. Through this exercise, we gain a deeper understanding of how to structure scalable experiments in JAX and prepare ourselves to adapt these techniques to real-world machine learning research and production tasks.


Check out the FULL CODES here. Feel free to check out our GitHub Page for Tutorials, Codes and Notebooks. Also, feel free to follow us on Twitter and don’t forget to join our 100k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.


Asif Razzaq is the CEO of Marktechpost Media Inc.. As a visionary entrepreneur and engineer, Asif is committed to harnessing the potential of Artificial Intelligence for social good. His most recent endeavor is the launch of an Artificial Intelligence Media Platform, Marktechpost, which stands out for its in-depth coverage of machine learning and deep learning news that is both technically sound and easily understandable by a wide audience. The platform boasts of over 2 million monthly views, illustrating its popularity among audiences.

🙌 Follow MARKTECHPOST: Add us as a preferred source on Google.



Source_link

READ ALSO

VirtuaLover Image Generator Pricing & Features Overview

The ‘Bayesian’ Upgrade: Why Google AI’s New Teaching Method is the Key to LLM Reasoning

Related Posts

VirtuaLover Image Generator Pricing & Features Overview
Al, Analytics and Automation

VirtuaLover Image Generator Pricing & Features Overview

March 9, 2026
Al, Analytics and Automation

The ‘Bayesian’ Upgrade: Why Google AI’s New Teaching Method is the Key to LLM Reasoning

March 9, 2026
Pricing Breakdown and Core Feature Overview
Al, Analytics and Automation

Pricing Breakdown and Core Feature Overview

March 9, 2026
Improving AI models’ ability to explain their predictions | MIT News
Al, Analytics and Automation

Improving AI models’ ability to explain their predictions | MIT News

March 9, 2026
Beyond Accuracy: Quantifying the Production Fragility Caused by Excessive, Redundant, and Low-Signal Features in Regression
Al, Analytics and Automation

Beyond Accuracy: Quantifying the Production Fragility Caused by Excessive, Redundant, and Low-Signal Features in Regression

March 9, 2026
Build Semantic Search with LLM Embeddings
Al, Analytics and Automation

Build Semantic Search with LLM Embeddings

March 8, 2026
Next Post
The Nike x Hyperice Hyperboot Is $200 Off

The Nike x Hyperice Hyperboot Is $200 Off

POPULAR NEWS

Trump ends trade talks with Canada over a digital services tax

Trump ends trade talks with Canada over a digital services tax

June 28, 2025
Communication Effectiveness Skills For Business Leaders

Communication Effectiveness Skills For Business Leaders

June 10, 2025
15 Trending Songs on TikTok in 2025 (+ How to Use Them)

15 Trending Songs on TikTok in 2025 (+ How to Use Them)

June 18, 2025
App Development Cost in Singapore: Pricing Breakdown & Insights

App Development Cost in Singapore: Pricing Breakdown & Insights

June 22, 2025
Google announced the next step in its nuclear energy plans 

Google announced the next step in its nuclear energy plans 

August 20, 2025

EDITOR'S PICK

Streamlining Inventory and Supply Chains

Streamlining Inventory and Supply Chains

December 10, 2025
Top TurboTax Discount Codes & Coupons October 2025

Top TurboTax Discount Codes & Coupons October 2025

October 18, 2025

The Scoop: How Southwest CEO tackled major change comms in New York Times interview

July 7, 2025

Webinar Recap: Agentic AI and Customer Engagement in 2025 and Beyond

July 8, 2025

About

We bring you the best Premium WordPress Themes that perfect for news, magazine, personal blog, etc. Check our landing page for details.

Follow us

Categories

  • Account Based Marketing
  • Ad Management
  • Al, Analytics and Automation
  • Brand Management
  • Channel Marketing
  • Digital Marketing
  • Direct Marketing
  • Event Management
  • Google Marketing
  • Marketing Attribution and Consulting
  • Marketing Automation
  • Mobile Marketing
  • PR Solutions
  • Social Media Management
  • Technology And Software
  • Uncategorized

Recent Posts

  • Andrej Karpathy's new open source 'autoresearch' lets you run hundreds of AI experiments a night — with revolutionary implications
  • A First Look at The National Ballet of Canada’s 75th Anniversary
  • Introducing Wednesday Build Hour – Google Developers Blog
  • The Scoop: NYT interview with Nike’s Elliott Hill shows art of CEO profile
  • About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions