iclr 2024 • top 5% • spotlight

CrossQ:
Batch Normalization in Deep Reinforcement Learning
for Greater Sample Efficiency and Simplicity

*Equal contribution 1Intelligent Autonomous Systems, TU Darmstadt 2Hessian.AI 3University of Freiburg 4German Research Center for AI (DFKI) 2Centre for Cognitive Science, TU Darmstadt

Abstract

Sample efficiency is a crucial problem in deep reinforcement learning. Recent algorithms, such as REDQ and DroQ, found a way to improve the sample efficiency by increasing the update-to-data (UTD) ratio to 20 gradient update steps on the critic per environment sample. However, this comes at the expense of a greatly increased computational cost.

To reduce this computational burden, we introduce CrossQ: a lightweight algorithm for continuous control tasks that makes careful use of Batch Normalization and removes target networks to surpass the current state-of-the-art in sample efficiency, while maintaining a low UTD ratio of 1. Notably, CrossQ does not rely on advanced bias-reduction schemes used in current methods.

CrossQ’s contributions are threefold:

  1. it matches or surpasses current state-of-the-art methods in terms of sample efficiency,
  2. it substantially reduces the computational cost compared to REDQ and DroQ,
  3. it is easy to implement, requiring just a few lines of code on top of SAC.

TL;DR

Problem

We want fast + simple + sample-efficient off-policy Deep RL.
  • UTD=1 methods — like SAC and TD3 — are fast, but not sample-efficient enough
  • UTD=20 methods — like REDQ and DroQ — are sample-efficient, but not fast enough
  • High-UTD training requires Q-function bias reduction, making algorithms more complex
Can we accelerate training without resorting to high UTD ratios?

Insight

BatchNorm greatly accelerates convergence in Supervised Learning. Can it similarly accelerate RL?
We show that naively using BatchNorm with Q networks is harmful:
  • TD learning relates Q predictions from the forward passes of two batches — (S, A) and (S', A')
  • Both batches in TD have different statistics: A comes from the replay, A' comes from the policy
  • BatchNorm's mismatched running statistics degrade Q predictions and harm training

Solution

  1. Delete the target network
  2. Concatenate the batches (S, A) and (S', A') into one, and do a single forward pass

BatchNorm now uses normalization moments from the union of both batches. These moments are not mismatched, as all inputs now belong to the same mixture distribution.

CrossQ

To turn SAC into CrossQ, we make 3 key changes:
  1. Delete target nets, simplifying the algorithm
  2. Use batch normalization, boosting sample-efficiency by an order of magnitude
  3. Widen the critic, further improving performance

These changes take only a few lines of code.

Try

Read

Cite


@inproceedings{
  bhatt2024crossq,
  title={CrossQ: Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity},
  author={Aditya Bhatt and Daniel Palenicek and Boris Belousov and Max Argus and Artemij Amiranashvili and Thomas Brox and Jan Peters},
  booktitle={International Conference on Learning Representations},
  year={2024},
  url={https://openreview.net/forum?id=PczQtTsTIX}
}
    
A 2019 arXiv version of this paper was titled CrossNorm: Normalization for Off-Policy TD Reinforcement Learning.