REINFORCE: Direct Policy Optimization After Deep Q-Learning

Reinforcement Learning
Policy Gradients
PyTorch
Full PyTorch implementation of the REINFORCE policy gradient algorithm with baseline variance reduction on a continuous-state grid environment.
Author

Ravi Sankar Krothapalli

Published

March 30, 2026

Why REINFORCE After DQN?

In Deep Q-Networks: From Tables to Neural Function Approximators, we kept the Bellman target and replaced the Q-table with a neural network. That solved state-space scaling, but it kept one central design choice: the policy is still implicit, obtained by an argmax over Q(s,a)Q(s, a).

REINFORCE takes the next logical step: parameterize the policy directly and optimize it by gradient ascent on expected return. Instead of learning values first and extracting actions second, we learn action probabilities in one model:

πθ(as)[0,1],aπθ(as)=1 \pi_\theta(a\mid s)\in[0,1],\quad \sum_a \pi_\theta(a\mid s)=1

This post keeps the same compact continuous-state grid from the DQN article, so the transition is clean: same environment, different learning signal.

The shift is conceptual as much as algorithmic. DQN asks: “what is the value of each action in this state?” REINFORCE asks: “how should the policy parameters move so actions that led to higher returns become more likely next time?” That perspective becomes essential before actor-critic methods, PPO, and modern policy-optimization pipelines.

Post road map: We begin with the value-to-policy gradient transition, then isolate REINFORCE’s core variance challenge and the baseline fix. The implementation sections build a full policy-gradient agent in PyTorch (plus baseline variant) on the same 4x4 continuous-state grid, ending with reward/entropy diagnostics and a learned stochastic-policy map.


From Value Gradients to Policy Gradients

The easiest way to feel the difference is to put both update directions side by side. DQN learns by bootstrapping from a target network; REINFORCE learns from full-episode returns sampled under the current stochastic policy.

DQN update direction (off-policy TD):

δt=rt+γ(1dt)maxaQ(st+1,a;θ)Q(st,at;θ) \delta_t = r_t + \gamma (1-d_t)\max_{a'} Q(s_{t+1}, a';\theta^-) - Q(s_t, a_t;\theta)

REINFORCE update direction (on-policy Monte Carlo):

θθ+αt=0T1Gtθlogπθ(atst) \theta \leftarrow \theta + \alpha \sum_{t=0}^{T-1} G_t\,\nabla_\theta \log \pi_\theta(a_t\mid s_t)

where Gt=k=tT1γktrkG_t = \sum_{k=t}^{T-1} \gamma^{k-t} r_k is reward-to-go.

The estimator comes from the score-function (log-derivative) trick used in Williams (1992) and formalized in the policy-gradient framework of Sutton et al. (1999/2000). The key practical implication is simple:

  • High-return trajectories increase the log-probability of sampled actions.
  • Low-return trajectories decrease it.

No bootstrapped target network is required, but variance is much higher than TD methods.

That trade-off is the defining signature of vanilla policy gradients: less target bias from bootstrapping, more Monte Carlo noise from episodic returns.


Baselines and Variance Reduction

The practical bottleneck in REINFORCE is not correctness of the gradient direction; it is the variance of that estimator. Two trajectories that differ only slightly can still produce noticeably different returns, and those fluctuations propagate directly into the policy update.

Plain REINFORCE is unbiased but noisy. A baseline b(st)b(s_t) independent of action can be subtracted without changing expected gradient:

θθ+αt=0T1(Gtb(st))θlogπθ(atst) \theta \leftarrow \theta + \alpha \sum_{t=0}^{T-1} \bigl(G_t - b(s_t)\bigr)\,\nabla_\theta \log \pi_\theta(a_t\mid s_t)

because:

ab(s)θπθ(as)=b(s)θaπθ(as)=b(s)θ1=0 \sum_a b(s)\nabla_\theta \pi_\theta(a\mid s) = b(s)\nabla_\theta\sum_a\pi_\theta(a\mid s)=b(s)\nabla_\theta 1=0

In this article we use a simple running scalar baseline bb (a special case of b(s)b(s)) based on a moving average of reward-to-go statistics from previous episodes. This is not as strong as a learned value baseline, but it demonstrates the variance-reduction idea with minimal code.

Canonical REINFORCE is unbiased. Our implementation also normalizes advantages for numerical stability, which is a practical heuristic and can slightly change the estimator.

This gives us a useful ladder of ideas: plain REINFORCE (conceptually clean), REINFORCE + baseline (lower variance), then actor-critic (learned state-dependent baseline with bootstrapping).


Section 1: Setup and Reproducibility

import random

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import plotly.graph_objects as go

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

GRID_SIZE = 4
N_ACTIONS = 4
OBS_DIM = 2

EPISODES = 500
GAMMA = 0.95
MAX_STEPS = 100

LR = 5e-3
HIDDEN = 64

BASELINE_BETA = 0.90

The environment and reward structure match the DQN article: +1 at goal, -0.01 otherwise. This keeps algorithm comparison fair.

Just like in the DQN post, seeds are fixed up front so the curves are reproducible and differences reflect algorithmic behavior rather than run-to-run randomness.

Section 2: Continuous-State Grid (Same as DQN)

We intentionally reuse the same continuous-observation environment from DQN. That removes environment drift from the comparison and isolates what changed: value-learning machinery is replaced by direct policy optimization.

ACTIONS = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}


class ContinuousGridEnv:
    """4x4 grid with normalized float observations in [0,1]^2."""

    def __init__(self, grid_size=GRID_SIZE, max_steps=MAX_STEPS):
        self.grid_size = grid_size
        self.max_steps = max_steps
        self.n = grid_size - 1
        self.goal = (grid_size - 1, grid_size - 1)
        self.pos = (0, 0)
        self.steps = 0

    def _obs(self):
        return np.array([self.pos[0] / self.n, self.pos[1] / self.n], dtype=np.float32)

    def reset(self):
        self.pos = (0, 0)
        self.steps = 0
        return self._obs()

    def step(self, action):
        dr, dc = ACTIONS[action]
        r = min(max(self.pos[0] + dr, 0), self.grid_size - 1)
        c = min(max(self.pos[1] + dc, 0), self.grid_size - 1)
        self.pos = (r, c)
        self.steps += 1

        at_goal = self.pos == self.goal
        reward = 1.0 if at_goal else -0.01
        done = at_goal or self.steps >= self.max_steps
        return self._obs(), reward, done


env = ContinuousGridEnv()
obs = env.reset()
print(f"start obs: {obs}")
start obs: [0. 0.]

As before, observations are normalized float coordinates in [0,1]2[0,1]^2, which keeps the state representation continuous and makes direct parity with the DQN setup straightforward.

Section 3: Policy Network

The policy network outputs a probability distribution over actions, not Q-values. The final softmax layer enforces a valid categorical policy, and we sample actions from that distribution during training to preserve on-policy learning.

PyTorch handles gradient computation via autograd, so the update mechanics stay clean: we define the forward pass and loss, and loss.backward() computes the parameter gradients automatically.

class PolicyNetwork(nn.Module):
    """Two-layer tanh policy network with softmax output."""

    def __init__(self, obs_dim=OBS_DIM, n_actions=N_ACTIONS, hidden=HIDDEN):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.Tanh(),
            nn.Linear(hidden, n_actions),
            nn.Softmax(dim=-1),
        )

    def forward(self, x):
        return self.net(x)

    def predict(self, obs_np):
        """Numpy in, numpy out - for environment interaction and visualization."""
        with torch.no_grad():
            probs = self.forward(torch.tensor(obs_np, dtype=torch.float32))
        return probs.numpy()

    def sample_action(self, obs):
        probs = self.predict(obs[np.newaxis, :]).squeeze(0)
        action = int(np.random.choice(len(probs), p=probs))
        entropy = float(-(probs * np.log(probs + 1e-12)).sum())
        return action, entropy

    def update_from_episode(self, optimizer, states, actions, advantages):
        """Policy gradient update using autograd."""
        states_t = torch.tensor(np.array(states, dtype=np.float32))
        actions_t = torch.tensor(np.array(actions, dtype=np.int64))
        advantages_t = torch.tensor(np.array(advantages, dtype=np.float32))

        probs = self.forward(states_t)
        log_probs = torch.log(probs.gather(1, actions_t.unsqueeze(1)).squeeze(1) + 1e-12)

        loss = -(log_probs * advantages_t).mean()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=10.0)
        optimizer.step()

The update minimizes 𝔼[Atlogπθ(atst)]-\mathbb{E}[A_t \log \pi_\theta(a_t\mid s_t)]: when advantage is positive, probability mass moves toward sampled actions; when negative, it moves away. Gradient clipping and the Adam optimizer provide the same numerical stabilization role they played in DQN.

Section 4: Trajectory Returns and Baseline

REINFORCE updates only after a full trajectory is collected. That means we first roll out an episode, then compute reward-to-go GtG_t at each timestep, then apply one policy update from that complete batch.

The baseline variant uses a running scalar estimate from previous episodes, then normalizes advantages for optimization stability.

def discounted_returns(rewards, gamma=GAMMA):
    out = np.zeros(len(rewards), dtype=np.float32)
    g = 0.0
    for t in reversed(range(len(rewards))):
        g = rewards[t] + gamma * g
        out[t] = g
    return out


def run_reinforce(seed=42, use_baseline=False):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    env = ContinuousGridEnv()
    policy = PolicyNetwork()
    optimizer = optim.Adam(policy.parameters(), lr=LR)

    baseline = 0.0
    rewards_ep = []
    entropy_ep = []

    for ep in range(EPISODES):
        obs = env.reset()

        states, actions, rewards = [], [], []
        entropies = []

        while True:
            action, entropy = policy.sample_action(obs)
            obs_next, reward, done = env.step(action)

            states.append(obs)
            actions.append(action)
            rewards.append(reward)
            entropies.append(entropy)

            obs = obs_next
            if done:
                break

        returns = discounted_returns(rewards)

        if use_baseline:
            advantages = returns - baseline
        else:
            advantages = returns

        adv_mean = float(advantages.mean())
        adv_std = float(advantages.std() + 1e-8)
        advantages = (advantages - adv_mean) / adv_std

        policy.update_from_episode(optimizer, states, actions, advantages)

        if use_baseline:
            baseline = BASELINE_BETA * baseline + (1.0 - BASELINE_BETA) * float(returns.mean())

        rewards_ep.append(float(np.sum(rewards)))
        entropy_ep.append(float(np.mean(entropies)))

    return np.array(rewards_ep), np.array(entropy_ep), policy


rewards_plain, entropy_plain, policy_plain = run_reinforce(seed=42, use_baseline=False)
rewards_base, entropy_base, policy_base = run_reinforce(seed=42, use_baseline=True)

print(f"REINFORCE mean reward (last 50): {rewards_plain[-50:].mean():.3f}")
print(f"REINFORCE+baseline mean reward (last 50): {rewards_base[-50:].mean():.3f}")
REINFORCE mean reward (last 50): 0.915
REINFORCE+baseline mean reward (last 50): 0.844

The two printed metrics are the simplest convergence check: if the baseline version shows smoother or stronger late-episode reward, variance reduction is doing its job.

Enjoying this walkthrough? Subscribe for the next hands-on post.

Section 5: Training Diagnostics

The baseline variant typically smooths training and reaches strong returns with less oscillation. On this tiny deterministic grid both variants can converge, but the variance gap is still visible in early and middle episodes.

Entropy is included as a second diagnostic because policy-gradient training can appear to improve reward while collapsing exploration too early. Tracking entropy helps detect that failure mode.

Section 6: Learned Stochastic Policy Map

Tabular posts visualized greedy arrows from Q-tables. Here we visualize the learned stochastic policy itself: each cell shows the most likely action, its confidence, and local policy entropy.

Brighter cells indicate stronger confidence in the greedy action under the learned stochastic policy. Entropy labels show remaining exploration uncertainty state-by-state.


Practical Takeaways

  • REINFORCE is the cleanest direct-policy algorithm and a good conceptual bridge to actor-critic methods.
  • It naturally handles stochastic policies and extends to continuous actions with Gaussian parameterizations.
  • Its main weakness is high gradient variance and sample inefficiency.
  • A baseline helps, and learned state-value baselines lead directly to actor-critic (A2C, PPO-style objectives).

If DQN taught us how to stabilize bootstrapped value learning, REINFORCE teaches the opposite direction: optimize the policy itself and treat values as optional variance-control tools.

If this post was useful, you can subscribe for new implementation-first articles without leaving the page.

Back to where you left off

Show signup form

Powered by Buttondown.


References

[1] Williams, R. J. (1992). Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine Learning, 8, 229-256. https://webdocs.cs.ualberta.ca/~sutton/williams-92.pdf

[2] Sutton, R. S., McAllester, D. A., Singh, S. P., & Mansour, Y. (2000). Policy gradient methods for reinforcement learning with function approximation. NeurIPS 12 (1999 conference proceedings volume). https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation

[3] Sutton, R. S., & Barto, A. G. (2018). Reinforcement Learning: An Introduction (2nd ed.), Chapter 13. MIT Press. https://mitpress.mit.edu/9780262039246/reinforcement-learning/

[4] OpenAI Spinning Up. Vanilla Policy Gradient (VPG) documentation. https://spinningup.openai.com/en/latest/algorithms/vpg.html

[5] Zhang, J., Kim, J., O’Donoghue, B., & Boyd, S. (2021). Sample Efficient Reinforcement Learning with REINFORCE. AAAI 2021 / arXiv:2010.11364. https://arxiv.org/abs/2010.11364

[6] Tingwu Wang. REINFORCE lecture notes. Source provided by user; direct PDF link appears valid but was not machine-extractable in this tooling run. Verification needed for line-level quoting. https://www.cs.toronto.edu/~tingwuwang/REINFORCE.pdf

[7] Dilith Jayakody. REINFORCE - A Quick Introduction (with Code) (2023). Used as implementation-oriented secondary reading, not as a primary theoretical source. https://dilithjay.com/blog/reinforce-a-quick-introduction-with-code