• About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
Sunday, March 22, 2026
mGrowTech
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions
No Result
View All Result
mGrowTech
No Result
View All Result
Home Al, Analytics and Automation

A Coding Implementation to Train Safety-Critical Reinforcement Learning Agents Offline Using Conservative Q-Learning with d3rlpy and Fixed Historical Data

Josh by Josh
February 4, 2026
in Al, Analytics and Automation
0


In this tutorial, we build a safety-critical reinforcement learning pipeline that learns entirely from fixed, offline data rather than live exploration. We design a custom environment, generate a behavior dataset from a constrained policy, and then train both a Behavior Cloning baseline and a Conservative Q-Learning agent using d3rlpy. By structuring the workflow around offline datasets, careful evaluation, and conservative learning objectives, we demonstrate how robust decision-making policies can be trained in settings where unsafe exploration is not an option. Check out the FULL CODES here.

!pip -q install -U "d3rlpy" "gymnasium" "numpy" "torch" "matplotlib" "scikit-learn"


import os
import time
import random
import inspect
import numpy as np
import matplotlib.pyplot as plt


import gymnasium as gym
from gymnasium import spaces


import torch
import d3rlpy




SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)




def pick_device():
   if torch.cuda.is_available():
       return "cuda:0"
   return "cpu"




DEVICE = pick_device()
print("d3rlpy:", getattr(d3rlpy, "__version__", "unknown"), "| torch:", torch.__version__, "| device:", DEVICE)




def make_config(cls, **kwargs):
   sig = inspect.signature(cls.__init__)
   allowed = set(sig.parameters.keys())
   allowed.discard("self")
   filtered = {k: v for k, v in kwargs.items() if k in allowed}
   return cls(**filtered)

We set up the environment by installing dependencies, importing libraries, and fixing random seeds for reproducibility. We detect and configure the computation device to ensure consistent execution across systems. We also define a utility to create configuration objects safely across different d3rlpy versions. Check out the FULL CODES here.

class SafetyCriticalGridWorld(gym.Env):
   metadata = {"render_modes": []}


   def __init__(
       self,
       size=15,
       max_steps=80,
       hazard_coords=None,
       start=(0, 0),
       goal=None,
       slip_prob=0.05,
       seed=0,
   ):
       super().__init__()
       self.size = int(size)
       self.max_steps = int(max_steps)
       self.start = tuple(start)
       self.goal = tuple(goal) if goal is not None else (self.size - 1, self.size - 1)
       self.slip_prob = float(slip_prob)


       if hazard_coords is None:
           hz = set()
           rng = np.random.default_rng(seed)
           for _ in range(max(1, self.size // 2)):
               x = rng.integers(2, self.size - 2)
               y = rng.integers(2, self.size - 2)
               hz.add((int(x), int(y)))
           self.hazards = hz
       else:
           self.hazards = set(tuple(x) for x in hazard_coords)


       self.action_space = spaces.Discrete(4)
       self.observation_space = spaces.Box(low=0.0, high=float(self.size - 1), shape=(2,), dtype=np.float32)


       self._rng = np.random.default_rng(seed)
       self._pos = None
       self._t = 0


   def reset(self, *, seed=None, options=None):
       if seed is not None:
           self._rng = np.random.default_rng(seed)
       self._pos = [int(self.start[0]), int(self.start[1])]
       self._t = 0
       obs = np.array(self._pos, dtype=np.float32)
       return obs, {}


   def _clip(self):
       self._pos[0] = int(np.clip(self._pos[0], 0, self.size - 1))
       self._pos[1] = int(np.clip(self._pos[1], 0, self.size - 1))


   def step(self, action):
       self._t += 1


       a = int(action)
       if self._rng.random() < self.slip_prob:
           a = int(self._rng.integers(0, 4))


       if a == 0:
           self._pos[1] += 1
       elif a == 1:
           self._pos[0] += 1
       elif a == 2:
           self._pos[1] -= 1
       elif a == 3:
           self._pos[0] -= 1


       self._clip()


       x, y = int(self._pos[0]), int(self._pos[1])
       terminated = False
       truncated = self._t >= self.max_steps


       reward = -1.0


       if (x, y) in self.hazards:
           reward = -100.0
           terminated = True


       if (x, y) == self.goal:
           reward = +50.0
           terminated = True


       obs = np.array([x, y], dtype=np.float32)
       return obs, float(reward), terminated, truncated, {}

We define a safety-critical GridWorld environment with hazards, terminal states, and stochastic transitions. We encode penalties for unsafe states and rewards for successful task completion. We ensure the environment strictly controls dynamics to reflect real-world safety constraints. Check out the FULL CODES here.

def safe_behavior_policy(obs, env: SafetyCriticalGridWorld, epsilon=0.15):
   x, y = int(obs[0]), int(obs[1])
   gx, gy = env.goal


   preferred = []
   if gx > x:
       preferred.append(1)
   elif gx < x:
       preferred.append(3)
   if gy > y:
       preferred.append(0)
   elif gy < y:
       preferred.append(2)


   if len(preferred) == 0:
       preferred = [int(env._rng.integers(0, 4))]


   if env._rng.random() < epsilon:
       return int(env._rng.integers(0, 4))


   candidates = []
   for a in preferred:
       nx, ny = x, y
       if a == 0:
           ny += 1
       elif a == 1:
           nx += 1
       elif a == 2:
           ny -= 1
       elif a == 3:
           nx -= 1
       nx = int(np.clip(nx, 0, env.size - 1))
       ny = int(np.clip(ny, 0, env.size - 1))
       if (nx, ny) not in env.hazards:
           candidates.append(a)


   if len(candidates) == 0:
       return preferred[0]
   return int(random.choice(candidates))




def generate_offline_episodes(env, n_episodes=400, epsilon=0.20, seed=0):
   episodes = []
   for i in range(n_episodes):
       obs, _ = env.reset(seed=int(seed + i))
       obs_list = []
       act_list = []
       rew_list = []
       done_list = []


       done = False
       while not done:
           a = safe_behavior_policy(obs, env, epsilon=epsilon)
           nxt, r, terminated, truncated, _ = env.step(a)
           done = bool(terminated or truncated)


           obs_list.append(np.array(obs, dtype=np.float32))
           act_list.append(np.array([a], dtype=np.int64))
           rew_list.append(np.array([r], dtype=np.float32))
           done_list.append(np.array([1.0 if done else 0.0], dtype=np.float32))


           obs = nxt


       episodes.append(
           {
               "observations": np.stack(obs_list, axis=0),
               "actions": np.stack(act_list, axis=0),
               "rewards": np.stack(rew_list, axis=0),
               "terminals": np.stack(done_list, axis=0),
           }
       )
   return episodes




def build_mdpdataset(episodes):
   obs = np.concatenate([ep["observations"] for ep in episodes], axis=0).astype(np.float32)
   acts = np.concatenate([ep["actions"] for ep in episodes], axis=0).astype(np.int64)
   rews = np.concatenate([ep["rewards"] for ep in episodes], axis=0).astype(np.float32)
   terms = np.concatenate([ep["terminals"] for ep in episodes], axis=0).astype(np.float32)


   if hasattr(d3rlpy, "dataset") and hasattr(d3rlpy.dataset, "MDPDataset"):
       return d3rlpy.dataset.MDPDataset(observations=obs, actions=acts, rewards=rews, terminals=terms)


   raise RuntimeError("d3rlpy.dataset.MDPDataset not found. Upgrade d3rlpy.")

We design a constrained behavior policy that generates offline data without risky exploration. We roll out this policy to collect trajectories and structure them into episodes. We then convert these episodes into a format compatible with d3rlpy’s offline learning APIs. Check out the FULL CODES here.

def _get_episodes_from_dataset(dataset):
   if hasattr(dataset, "episodes") and dataset.episodes is not None:
       return dataset.episodes
   if hasattr(dataset, "get_episodes"):
       return dataset.get_episodes()
   raise AttributeError("Could not find episodes in dataset (d3rlpy version mismatch).")




def _iter_all_observations(dataset):
   for ep in _get_episodes_from_dataset(dataset):
       obs = getattr(ep, "observations", None)
       if obs is None:
           continue
       for o in obs:
           yield o




def _iter_all_transitions(dataset):
   for ep in _get_episodes_from_dataset(dataset):
       obs = getattr(ep, "observations", None)
       acts = getattr(ep, "actions", None)
       rews = getattr(ep, "rewards", None)
       if obs is None or acts is None:
           continue
       n = min(len(obs), len(acts))
       for i in range(n):
           o = obs[i]
           a = acts[i]
           r = rews[i] if rews is not None and i < len(rews) else None
           yield o, a, r




def visualize_dataset(dataset, env, title="Offline Dataset"):
   state_visits = np.zeros((env.size, env.size), dtype=np.float32)
   for obs in _iter_all_observations(dataset):
       x, y = int(obs[0]), int(obs[1])
       x = int(np.clip(x, 0, env.size - 1))
       y = int(np.clip(y, 0, env.size - 1))
       state_visits[y, x] += 1


   plt.figure(figsize=(6, 5))
   plt.imshow(state_visits, origin="lower")
   plt.colorbar(label="Visits")
   plt.scatter([env.start[0]], [env.start[1]], marker="o", label="start")
   plt.scatter([env.goal[0]], [env.goal[1]], marker="*", label="goal")
   if len(env.hazards) > 0:
       hz = np.array(list(env.hazards), dtype=np.int32)
       plt.scatter(hz[:, 0], hz[:, 1], marker="x", label="hazards")
   plt.title(f"{title} — State visitation")
   plt.xlabel("x")
   plt.ylabel("y")
   plt.legend()
   plt.show()


   rewards = []
   for _, _, r in _iter_all_transitions(dataset):
       if r is not None:
           rewards.append(float(r))
   if len(rewards) > 0:
       plt.figure(figsize=(6, 4))
       plt.hist(rewards, bins=60)
       plt.title(f"{title} — Reward distribution")
       plt.xlabel("reward")
       plt.ylabel("count")
       plt.show()

We implement dataset utilities that correctly iterate through episodes rather than assuming flat arrays. We visualize state visitation to understand coverage and data bias in the offline dataset. We also analyze reward distributions to inspect the learning signal available to the agent. Check out the FULL CODES here.

def rollout_eval(env, algo, n_episodes=25, seed=0):
   returns = []
   lengths = []
   hazard_hits = 0
   goal_hits = 0


   for i in range(n_episodes):
       obs, _ = env.reset(seed=seed + i)
       done = False
       total = 0.0
       steps = 0
       while not done:
           a = int(algo.predict(np.asarray(obs, dtype=np.float32)[None, ...])[0])
           obs, r, terminated, truncated, _ = env.step(a)
           total += float(r)
           steps += 1
           done = bool(terminated or truncated)
           if terminated:
               x, y = int(obs[0]), int(obs[1])
               if (x, y) in env.hazards:
                   hazard_hits += 1
               if (x, y) == env.goal:
                   goal_hits += 1


       returns.append(total)
       lengths.append(steps)


   return {
       "return_mean": float(np.mean(returns)),
       "return_std": float(np.std(returns)),
       "len_mean": float(np.mean(lengths)),
       "hazard_rate": float(hazard_hits / max(1, n_episodes)),
       "goal_rate": float(goal_hits / max(1, n_episodes)),
       "returns": np.asarray(returns, dtype=np.float32),
   }




def action_mismatch_rate_vs_data(dataset, algo, sample_obs=7000, seed=0):
   rng = np.random.default_rng(seed)
   obs_all = []
   act_all = []
   for o, a, _ in _iter_all_transitions(dataset):
       obs_all.append(np.asarray(o, dtype=np.float32))
       act_all.append(int(np.asarray(a).reshape(-1)[0]))
       if len(obs_all) >= 80_000:
           break


   obs_all = np.stack(obs_all, axis=0)
   act_all = np.asarray(act_all, dtype=np.int64)


   idx = rng.choice(len(obs_all), size=min(sample_obs, len(obs_all)), replace=False)
   obs_probe = obs_all[idx]
   act_probe_data = act_all[idx]
   act_probe_pi = algo.predict(obs_probe).astype(np.int64)


   mismatch = (act_probe_pi != act_probe_data).astype(np.float32)
   return float(mismatch.mean())




def create_discrete_bc(device):
   if hasattr(d3rlpy.algos, "DiscreteBCConfig"):
       cls = d3rlpy.algos.DiscreteBCConfig
       cfg = make_config(
           cls,
           learning_rate=3e-4,
           batch_size=256,
       )
       return cfg.create(device=device)
   if hasattr(d3rlpy.algos, "DiscreteBC"):
       return d3rlpy.algos.DiscreteBC()
   raise RuntimeError("DiscreteBC not available in this d3rlpy version.")




def create_discrete_cql(device, conservative_weight=6.0):
   if hasattr(d3rlpy.algos, "DiscreteCQLConfig"):
       cls = d3rlpy.algos.DiscreteCQLConfig
       cfg = make_config(
           cls,
           learning_rate=3e-4,
           actor_learning_rate=3e-4,
           critic_learning_rate=3e-4,
           temp_learning_rate=3e-4,
           alpha_learning_rate=3e-4,
           batch_size=256,
           conservative_weight=float(conservative_weight),
           n_action_samples=10,
           rollout_interval=0,
       )
       return cfg.create(device=device)
   if hasattr(d3rlpy.algos, "DiscreteCQL"):
       algo = d3rlpy.algos.DiscreteCQL()
       if hasattr(algo, "conservative_weight"):
           try:
               algo.conservative_weight = float(conservative_weight)
           except Exception:
               pass
       return algo
   raise RuntimeError("DiscreteCQL not available in this d3rlpy version.")

We define controlled evaluation routines to measure policy performance without uncontrolled exploration. We compute returns and safety metrics, including hazard and goal rates. We also introduce a mismatch diagnostic to quantify how often learned actions deviate from the dataset behavior. Check out the FULL CODES here.

def main():
   env = SafetyCriticalGridWorld(
       size=15,
       max_steps=80,
       slip_prob=0.05,
       seed=SEED,
   )


   raw_eps = generate_offline_episodes(env, n_episodes=500, epsilon=0.22, seed=SEED)
   dataset = build_mdpdataset(raw_eps)


   print("dataset built:", type(dataset).__name__)
   visualize_dataset(dataset, env, title="Behavior Dataset (Offline)")


   bc = create_discrete_bc(DEVICE)
   cql = create_discrete_cql(DEVICE, conservative_weight=6.0)


   print("\nTraining Discrete BC (offline)...")
   t0 = time.time()
   bc.fit(
       dataset,
       n_steps=25_000,
       n_steps_per_epoch=2_500,
       experiment_name="grid_bc_offline",
   )
   print("BC train sec:", round(time.time() - t0, 2))


   print("\nTraining Discrete CQL (offline)...")
   t0 = time.time()
   cql.fit(
       dataset,
       n_steps=80_000,
       n_steps_per_epoch=8_000,
       experiment_name="grid_cql_offline",
   )
   print("CQL train sec:", round(time.time() - t0, 2))


   print("\nControlled online evaluation (small number of rollouts):")
   bc_metrics = rollout_eval(env, bc, n_episodes=30, seed=SEED + 1000)
   cql_metrics = rollout_eval(env, cql, n_episodes=30, seed=SEED + 2000)


   print("BC :", {k: v for k, v in bc_metrics.items() if k != "returns"})
   print("CQL:", {k: v for k, v in cql_metrics.items() if k != "returns"})


   print("\nOOD-ish diagnostic (policy action mismatch vs data action at same states):")
   bc_mismatch = action_mismatch_rate_vs_data(dataset, bc, sample_obs=7000, seed=SEED + 1)
   cql_mismatch = action_mismatch_rate_vs_data(dataset, cql, sample_obs=7000, seed=SEED + 2)
   print("BC mismatch rate :", bc_mismatch)
   print("CQL mismatch rate:", cql_mismatch)


   plt.figure(figsize=(6, 4))
   labels = ["BC", "CQL"]
   means = [bc_metrics["return_mean"], cql_metrics["return_mean"]]
   stds = [bc_metrics["return_std"], cql_metrics["return_std"]]
   plt.bar(labels, means, yerr=stds)
   plt.ylabel("Return")
   plt.title("Online Rollout Return (Controlled)")
   plt.show()


   plt.figure(figsize=(6, 4))
   plt.plot(np.sort(bc_metrics["returns"]), label="BC")
   plt.plot(np.sort(cql_metrics["returns"]), label="CQL")
   plt.xlabel("Episode (sorted)")
   plt.ylabel("Return")
   plt.title("Return Distribution (Sorted)")
   plt.legend()
   plt.show()


   out_dir = "/content/offline_rl_artifacts"
   os.makedirs(out_dir, exist_ok=True)
   bc_path = os.path.join(out_dir, "grid_bc_policy.pt")
   cql_path = os.path.join(out_dir, "grid_cql_policy.pt")


   if hasattr(bc, "save_policy"):
       bc.save_policy(bc_path)
       print("Saved BC policy:", bc_path)
   if hasattr(cql, "save_policy"):
       cql.save_policy(cql_path)
       print("Saved CQL policy:", cql_path)


   print("\nDone.")




if __name__ == "__main__":
   main()

We train both Behavior Cloning and Conservative Q-Learning agents purely from offline data. We compare their performance using controlled rollouts and diagnostic metrics. We finalize the workflow by saving trained policies and summarizing safety-aware learning outcomes.

In conclusion, we demonstrated that Conservative Q-Learning yields a more reliable policy than simple imitation when learning from historical data in safety-sensitive environments. By comparing offline training outcomes, controlled online evaluations, and action-distribution mismatches, we illustrated how conservatism helps reduce risky, out-of-distribution behavior. Overall, we presented a complete, reproducible offline RL workflow that we can extend to more complex domains such as robotics, healthcare, or finance without compromising safety.


Check out the FULL CODES here. Also, feel free to follow us on Twitter and don’t forget to join our 100k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.




Source_link

READ ALSO

From Text to Tables: Feature Engineering with LLMs for Tabular Data

Safely Deploying ML Models to Production: Four Controlled Strategies (A/B, Canary, Interleaved, Shadow Testing)

Related Posts

From Text to Tables: Feature Engineering with LLMs for Tabular Data
Al, Analytics and Automation

From Text to Tables: Feature Engineering with LLMs for Tabular Data

March 22, 2026
Al, Analytics and Automation

Safely Deploying ML Models to Production: Four Controlled Strategies (A/B, Canary, Interleaved, Shadow Testing)

March 22, 2026
Setting Up a Google Colab AI-Assisted Coding Environment That Actually Works
Al, Analytics and Automation

Setting Up a Google Colab AI-Assisted Coding Environment That Actually Works

March 21, 2026
MIT and Hasso Plattner Institute establish collaborative hub for AI and creativity | MIT News
Al, Analytics and Automation

MIT and Hasso Plattner Institute establish collaborative hub for AI and creativity | MIT News

March 21, 2026
NVIDIA Releases Nemotron-Cascade 2: An Open 30B MoE with 3B Active Parameters, Delivering Better Reasoning and Strong Agentic Capabilities
Al, Analytics and Automation

NVIDIA Releases Nemotron-Cascade 2: An Open 30B MoE with 3B Active Parameters, Delivering Better Reasoning and Strong Agentic Capabilities

March 21, 2026
Building Smart Machine Learning in Low-Resource Settings
Al, Analytics and Automation

Building Smart Machine Learning in Low-Resource Settings

March 21, 2026
Next Post
9 Best Outdoor Security Cameras (2026): Battery-Powered, LTE, No Subscription

9 Best Outdoor Security Cameras (2026): Battery-Powered, LTE, No Subscription

POPULAR NEWS

Trump ends trade talks with Canada over a digital services tax

Trump ends trade talks with Canada over a digital services tax

June 28, 2025
Communication Effectiveness Skills For Business Leaders

Communication Effectiveness Skills For Business Leaders

June 10, 2025
15 Trending Songs on TikTok in 2025 (+ How to Use Them)

15 Trending Songs on TikTok in 2025 (+ How to Use Them)

June 18, 2025
App Development Cost in Singapore: Pricing Breakdown & Insights

App Development Cost in Singapore: Pricing Breakdown & Insights

June 22, 2025
Comparing the Top 7 Large Language Models LLMs/Systems for Coding in 2025

Comparing the Top 7 Large Language Models LLMs/Systems for Coding in 2025

November 4, 2025

EDITOR'S PICK

Navigating the Shift to Health-Centric Food Branding In PR Strategies

Navigating the Shift to Health-Centric Food Branding In PR Strategies

August 7, 2025
LG’s massive 52-inch ultra-wide gaming monitor costs $2,000

LG’s massive 52-inch ultra-wide gaming monitor costs $2,000

February 25, 2026
Powering SEA’s First Fully AI-Driven Telco Platform

Powering SEA’s First Fully AI-Driven Telco Platform

August 27, 2025
Brookline PR and the Field of Crosses – Brookline PR

Brookline PR and the Field of Crosses – Brookline PR

November 26, 2025

About

We bring you the best Premium WordPress Themes that perfect for news, magazine, personal blog, etc. Check our landing page for details.

Follow us

Categories

  • Account Based Marketing
  • Ad Management
  • Al, Analytics and Automation
  • Brand Management
  • Channel Marketing
  • Digital Marketing
  • Direct Marketing
  • Event Management
  • Google Marketing
  • Marketing Attribution and Consulting
  • Marketing Automation
  • Mobile Marketing
  • PR Solutions
  • Social Media Management
  • Technology And Software
  • Uncategorized

Recent Posts

  • Why cultural insight beats product messaging every time
  • AI Voice Agents in 2026 – How Businesses Are Replacing IVR With Conversational AI That Actually Works
  • Mistral's Small 4 consolidates reasoning, vision and coding into one model — at a fraction of the inference cost
  • From Text to Tables: Feature Engineering with LLMs for Tabular Data
  • About Us
  • Disclaimer
  • Contact Us
  • Privacy Policy
No Result
View All Result
  • Technology And Software
    • Account Based Marketing
    • Channel Marketing
    • Marketing Automation
      • Al, Analytics and Automation
      • Ad Management
  • Digital Marketing
    • Social Media Management
    • Google Marketing
  • Direct Marketing
    • Brand Management
    • Marketing Attribution and Consulting
  • Mobile Marketing
  • Event Management
  • PR Solutions