LGJan 12, 2025

DRDT3: Diffusion-Refined Decision Test-Time Training Model

arXiv:2501.06718v21 citationsh-index: 5Trans. Mach. Learn. Res.
Originality Incremental advance
AI Analysis

This work addresses a limitation in offline reinforcement learning for decision-making tasks, offering an incremental improvement over existing DT-based approaches.

The paper tackles the problem of Decision Transformer struggling to learn optimal policies from suboptimal trajectories by proposing DRDT3, a unified framework that combines a Decision TTT module with a diffusion model for iterative refinement, achieving superior results over state-of-the-art methods on the D4RL benchmark.

Decision Transformer (DT), a trajectory modelling method, has shown competitive performance compared to traditional offline reinforcement learning (RL) approaches on various classic control tasks. However, it struggles to learn optimal policies from suboptimal, reward-labelled trajectories. In this study, we explore the use of conditional generative modelling to facilitate trajectory stitching given its high-quality data generation ability. Additionally, recent advancements in Recurrent Neural Networks (RNNs) have shown their linear complexity and competitive sequence modelling performance over Transformers. We leverage the Test-Time Training (TTT) layer, an RNN that updates hidden states during testing, to model trajectories in the form of DT. We introduce a unified framework, called Diffusion-Refined Decision TTT (DRDT3), to achieve performance beyond DT models. Specifically, we propose the Decision TTT (DT3) module, which harnesses the sequence modelling strengths of both self-attention and the TTT layer to capture recent contextual information and make coarse action predictions. DRDT3 iteratively refines the coarse action predictions through the generative diffusion model, progressively moving closer to the optimal actions. We further integrate DT3 with the diffusion model using a unified optimization objective. With experiments on multiple tasks in the D4RL benchmark, our DT3 model without diffusion refinement demonstrates improved performance over standard DT, while DRDT3 further achieves superior results compared to state-of-the-art DT-based and offline RL methods.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes