• About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
Thursday, January 22, 2026
mGrowTech
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
No Result
View All Result
mGrowTech
No Result
View All Result
Home Al, Analytics and Automation

A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax

Josh by Josh
November 12, 2025
in Al, Analytics and Automation
0
A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax
0
SHARES
1
VIEWS
Share on FacebookShare on Twitter


In this tutorial, we explore how to build and train an advanced neural network using JAX, Flax, and Optax in an efficient and modular way. We begin by designing a deep architecture that integrates residual connections and self-attention mechanisms for expressive feature learning. As we progress, we implement sophisticated optimization strategies with learning rate scheduling, gradient clipping, and adaptive weight decay. Throughout the process, we leverage JAX transformations such as jit, grad, and vmap to accelerate computation and ensure smooth training performance across devices. Check out the FULL CODES here.

!pip install jax jaxlib flax optax matplotlib


import jax
import jax.numpy as jnp
from jax import random, jit, vmap, grad
import flax.linen as nn
from flax.training import train_state
import optax
import matplotlib.pyplot as plt
from typing import Any, Callable


print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

We begin by installing and importing JAX, Flax, and Optax, along with essential utilities for numerical operations and visualization. We check our device setup to ensure that JAX is running efficiently on available hardware. This setup forms the foundation for the entire training pipeline. Check out the FULL CODES here.

class SelfAttention(nn.Module):
   num_heads: int
   dim: int
   @nn.compact
   def __call__(self, x):
       B, L, D = x.shape
       head_dim = D // self.num_heads
       qkv = nn.Dense(3 * D)(x)
       qkv = qkv.reshape(B, L, 3, self.num_heads, head_dim)
       q, k, v = jnp.split(qkv, 3, axis=2)
       q, k, v = q.squeeze(2), k.squeeze(2), v.squeeze(2)
       attn_scores = jnp.einsum('bhqd,bhkd->bhqk', q, k) / jnp.sqrt(head_dim)
       attn_weights = jax.nn.softmax(attn_scores, axis=-1)
       attn_output = jnp.einsum('bhqk,bhvd->bhqd', attn_weights, v)
       attn_output = attn_output.reshape(B, L, D)
       return nn.Dense(D)(attn_output)


class ResidualBlock(nn.Module):
   features: int
   @nn.compact
   def __call__(self, x, training: bool = True):
       residual = x
       x = nn.Conv(self.features, (3, 3), padding='SAME')(x)
       x = nn.BatchNorm(use_running_average=not training)(x)
       x = nn.relu(x)
       x = nn.Conv(self.features, (3, 3), padding='SAME')(x)
       x = nn.BatchNorm(use_running_average=not training)(x)
       if residual.shape[-1] != self.features:
           residual = nn.Conv(self.features, (1, 1))(residual)
       return nn.relu(x + residual)


class AdvancedCNN(nn.Module):
   num_classes: int = 10
   @nn.compact
   def __call__(self, x, training: bool = True):
       x = nn.Conv(32, (3, 3), padding='SAME')(x)
       x = nn.relu(x)
       x = ResidualBlock(64)(x, training)
       x = ResidualBlock(64)(x, training)
       x = nn.max_pool(x, (2, 2), strides=(2, 2))
       x = ResidualBlock(128)(x, training)
       x = ResidualBlock(128)(x, training)
       x = jnp.mean(x, axis=(1, 2))
       x = x[:, None, :]
       x = SelfAttention(num_heads=4, dim=128)(x)
       x = x.squeeze(1)
       x = nn.Dense(256)(x)
       x = nn.relu(x)
       x = nn.Dropout(0.5, deterministic=not training)(x)
       x = nn.Dense(self.num_classes)(x)
       return x

We define a deep neural network that combines residual blocks and a self-attention mechanism for enhanced feature learning. We construct the layers modularly, ensuring that the model can capture both spatial and contextual relationships. This design enables the network to generalize effectively across various types of input data. Check out the FULL CODES here.

class TrainState(train_state.TrainState):
   batch_stats: Any


def create_learning_rate_schedule(base_lr: float = 1e-3, warmup_steps: int = 100, decay_steps: int = 1000) -> optax.Schedule:
   warmup_fn = optax.linear_schedule(init_value=0.0, end_value=base_lr, transition_steps=warmup_steps)
   decay_fn = optax.cosine_decay_schedule(init_value=base_lr, decay_steps=decay_steps, alpha=0.1)
   return optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])


def create_optimizer(learning_rate_schedule: optax.Schedule) -> optax.GradientTransformation:
   return optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=learning_rate_schedule, weight_decay=1e-4))

We create a custom training state that tracks model parameters and batch statistics. We also define a learning rate schedule with warmup and cosine decay, paired with an AdamW optimizer that includes gradient clipping and weight decay. This combination ensures stable and adaptive training. Check out the FULL CODES here.

@jit
def compute_metrics(logits, labels):
   loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
   accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
   return {'loss': loss, 'accuracy': accuracy}


def create_train_state(rng, model, input_shape, learning_rate_schedule):
   variables = model.init(rng, jnp.ones(input_shape), training=False)
   params = variables['params']
   batch_stats = variables.get('batch_stats', {})
   tx = create_optimizer(learning_rate_schedule)
   return TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats)


@jit
def train_step(state, batch, dropout_rng):
   images, labels = batch
   def loss_fn(params):
       variables = {'params': params, 'batch_stats': state.batch_stats}
       logits, new_model_state = state.apply_fn(variables, images, training=True, mutable=['batch_stats'], rngs={'dropout': dropout_rng})
       loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
       return loss, (logits, new_model_state)
   grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
   (loss, (logits, new_model_state)), grads = grad_fn(state.params)
   state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
   metrics = compute_metrics(logits, labels)
   return state, metrics


@jit
def eval_step(state, batch):
   images, labels = batch
   variables = {'params': state.params, 'batch_stats': state.batch_stats}
   logits = state.apply_fn(variables, images, training=False)
   return compute_metrics(logits, labels)

We implement JIT-compiled training and evaluation functions to achieve efficient execution. The training step computes gradients, updates parameters, and dynamically maintains batch statistics. We also define evaluation metrics that help us monitor loss and accuracy throughout the training process. Check out the FULL CODES here.

def generate_synthetic_data(rng, num_samples=1000, img_size=32):
   rng_x, rng_y = random.split(rng)
   images = random.normal(rng_x, (num_samples, img_size, img_size, 3))
   labels = random.randint(rng_y, (num_samples,), 0, 10)
   return images, labels


def create_batches(images, labels, batch_size=32):
   num_batches = len(images) // batch_size
   for i in range(num_batches):
       idx = slice(i * batch_size, (i + 1) * batch_size)
       yield images[idx], labels[idx]

We generate synthetic data to simulate an image classification task, enabling us to train the model without relying on external datasets. We then batch the data efficiently for iterative updates. This approach allows us to test the full pipeline quickly and verify that all components function correctly. Check out the FULL CODES here.

def train_model(num_epochs=5, batch_size=32):
   rng = random.PRNGKey(0)
   rng, data_rng, model_rng = random.split(rng, 3)
   train_images, train_labels = generate_synthetic_data(data_rng, num_samples=1000)
   test_images, test_labels = generate_synthetic_data(data_rng, num_samples=200)
   model = AdvancedCNN(num_classes=10)
   lr_schedule = create_learning_rate_schedule(base_lr=1e-3, warmup_steps=50, decay_steps=500)
   state = create_train_state(model_rng, model, (1, 32, 32, 3), lr_schedule)
   history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
   print("Starting training...")
   for epoch in range(num_epochs):
       train_metrics = []
       for batch in create_batches(train_images, train_labels, batch_size):
           rng, dropout_rng = random.split(rng)
           state, metrics = train_step(state, batch, dropout_rng)
           train_metrics.append(metrics)
       train_loss = jnp.mean(jnp.array([m['loss'] for m in train_metrics]))
       train_acc = jnp.mean(jnp.array([m['accuracy'] for m in train_metrics]))
       test_metrics = [eval_step(state, batch) for batch in create_batches(test_images, test_labels, batch_size)]
       test_acc = jnp.mean(jnp.array([m['accuracy'] for m in test_metrics]))
       history['train_loss'].append(float(train_loss))
       history['train_acc'].append(float(train_acc))
       history['test_acc'].append(float(test_acc))
       print(f"Epoch {epoch + 1}/{num_epochs}: Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")
   return history, state


history, trained_state = train_model(num_epochs=5)


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history['train_loss'], label="Train Loss")
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True)
ax2.plot(history['train_acc'], label="Train Accuracy")
ax2.plot(history['test_acc'], label="Test Accuracy")
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy'); ax2.set_title('Model Accuracy'); ax2.legend(); ax2.grid(True)
plt.tight_layout(); plt.show()


print("\n✅ Tutorial complete! This covers:")
print("- Custom Flax modules (ResNet blocks, Self-Attention)")
print("- Advanced Optax optimizers (AdamW with gradient clipping)")
print("- Learning rate schedules (warmup + cosine decay)")
print("- JAX transformations (@jit for performance)")
print("- Proper state management (batch normalization statistics)")
print("- Complete training pipeline with evaluation")

We bring all components together to train the model over several epochs, track performance metrics, and visualize the trends in loss and accuracy. We monitor the model’s learning progress and validate its performance on test data. Ultimately, we confirm the stability and effectiveness of our JAX-based training workflow.

In conclusion, we implemented a comprehensive training pipeline utilizing JAX, Flax, and Optax, which demonstrates both flexibility and computational efficiency. We observed how custom architectures, advanced optimization strategies, and precise state management can come together to form a high-performance deep learning workflow. Through this exercise, we gain a deeper understanding of how to structure scalable experiments in JAX and prepare ourselves to adapt these techniques to real-world machine learning research and production tasks.


Check out the FULL CODES here. Feel free to check out our GitHub Page for Tutorials, Codes and Notebooks. Also, feel free to follow us on Twitter and don’t forget to join our 100k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.


Asif Razzaq is the CEO of Marktechpost Media Inc.. As a visionary entrepreneur and engineer, Asif is committed to harnessing the potential of Artificial Intelligence for social good. His most recent endeavor is the launch of an Artificial Intelligence Media Platform, Marktechpost, which stands out for its in-depth coverage of machine learning and deep learning news that is both technically sound and easily understandable by a wide audience. The platform boasts of over 2 million monthly views, illustrating its popularity among audiences.

🙌 Follow MARKTECHPOST: Add us as a preferred source on Google.



Source_link

READ ALSO

FlashLabs Researchers Release Chroma 1.0: A 4B Real Time Speech Dialogue Model With Personalized Voice Cloning

Salesforce AI Introduces FOFPred: A Language-Driven Future Optical Flow Prediction Framework that Enables Improved Robot Control and Video Generation

Related Posts

FlashLabs Researchers Release Chroma 1.0: A 4B Real Time Speech Dialogue Model With Personalized Voice Cloning
Al, Analytics and Automation

FlashLabs Researchers Release Chroma 1.0: A 4B Real Time Speech Dialogue Model With Personalized Voice Cloning

January 22, 2026
Al, Analytics and Automation

Salesforce AI Introduces FOFPred: A Language-Driven Future Optical Flow Prediction Framework that Enables Improved Robot Control and Video Generation

January 21, 2026
Why it’s critical to move beyond overly aggregated machine-learning metrics | MIT News
Al, Analytics and Automation

Why it’s critical to move beyond overly aggregated machine-learning metrics | MIT News

January 21, 2026
What are Context Graphs? – MarkTechPost
Al, Analytics and Automation

What are Context Graphs? – MarkTechPost

January 21, 2026
IVO’s $55M Boost Signals AI-Driven Law Future (and It’s Just Getting Started)
Al, Analytics and Automation

IVO’s $55M Boost Signals AI-Driven Law Future (and It’s Just Getting Started)

January 20, 2026
How to Design a Fully Streaming Voice Agent with End-to-End Latency Budgets, Incremental ASR, LLM Streaming, and Real-Time TTS
Al, Analytics and Automation

How to Design a Fully Streaming Voice Agent with End-to-End Latency Budgets, Incremental ASR, LLM Streaming, and Real-Time TTS

January 20, 2026
Next Post
The Nike x Hyperice Hyperboot Is $200 Off

The Nike x Hyperice Hyperboot Is $200 Off

POPULAR NEWS

Trump ends trade talks with Canada over a digital services tax

Trump ends trade talks with Canada over a digital services tax

June 28, 2025
Communication Effectiveness Skills For Business Leaders

Communication Effectiveness Skills For Business Leaders

June 10, 2025
15 Trending Songs on TikTok in 2025 (+ How to Use Them)

15 Trending Songs on TikTok in 2025 (+ How to Use Them)

June 18, 2025
App Development Cost in Singapore: Pricing Breakdown & Insights

App Development Cost in Singapore: Pricing Breakdown & Insights

June 22, 2025
Google announced the next step in its nuclear energy plans 

Google announced the next step in its nuclear energy plans 

August 20, 2025

EDITOR'S PICK

Meta Announces Rollout of Threads Feed Ad Placement

Meta Announces Rollout of Threads Feed Ad Placement

July 10, 2025
Google announces new $4 billion investment in Arkansas

Google announces new $4 billion investment in Arkansas

October 3, 2025
Grok’s Share and Claude’s Leak: 5 Things We Can Learn From System Prompts

Grok’s Share and Claude’s Leak: 5 Things We Can Learn From System Prompts

August 9, 2025
26 tips & strategies for 2025

26 tips & strategies for 2025

May 27, 2025

About

We bring you the best Premium WordPress Themes that perfect for news, magazine, personal blog, etc. Check our landing page for details.

Follow us

Categories

  • Account Based Marketing
  • Ad Management
  • Al, Analytics and Automation
  • Brand Management
  • Channel Marketing
  • Digital Marketing
  • Direct Marketing
  • Event Management
  • Google Marketing
  • Marketing Attribution and Consulting
  • Marketing Automation
  • Mobile Marketing
  • PR Solutions
  • Social Media Management
  • Technology And Software
  • Uncategorized

Recent Posts

  • How To Forecast SEO With Just 8 Core Metrics
  • How to create a social media report [free template included]
  • X is also launching Bluesky-like starter packs
  • How to Build an AI Agent in Australia: Step-by-Step Guide for Enterprises
  • About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions

Are you sure want to unlock this post?
Unlock left : 0
Are you sure want to cancel subscription?