Junjie Bai

CV
h-index31
15papers
53,066citations
Novelty58%
AI Score66

15 Papers

ROOct 30, 2025
Alpamayo-R1: Bridging Reasoning and Action Prediction for Generalizable Autonomous Driving in the Long Tail

Yan Wang, Wenjie Luo, Junjie Bai et al. · nvidia

End-to-end architectures trained via imitation learning have advanced autonomous driving by scaling model size and data, yet performance remains brittle in safety-critical long-tail scenarios where supervision is sparse and causal understanding is limited. To address this, we introduce Alpamayo-R1 (AR1), a vision-language-action model (VLA) that integrates Chain of Causation reasoning with trajectory planning to enhance decision-making in complex driving scenarios. Our approach features three key innovations: (1) the Chain of Causation (CoC) dataset, built through a hybrid auto-labeling and human-in-the-loop pipeline producing decision-grounded, causally linked reasoning traces aligned with driving behaviors; (2) a modular VLA architecture combining Cosmos-Reason, a Vision-Language Model pre-trained for Physical AI applications, with a diffusion-based trajectory decoder that generates dynamically feasible plans in real time; (3) a multi-stage training strategy using supervised fine-tuning to elicit reasoning and reinforcement learning (RL) to optimize reasoning quality via large reasoning model feedback and enforce reasoning-action consistency. Evaluation shows AR1 achieves up to a 12% improvement in planning accuracy on challenging cases compared to a trajectory-only baseline, with a 35% reduction in off-road rate and 25% reduction in close encounter rate in closed-loop simulation. RL post-training improves reasoning quality by 45% as measured by a large reasoning model critic and reasoning-action consistency by 37%. Model scaling from 0.5B to 7B parameters shows consistent improvements. On-vehicle road tests confirm real-time performance (99 ms latency) and successful urban deployment. By bridging interpretable reasoning with precise control, AR1 demonstrates a practical path towards Level 4 autonomous driving. We plan to release AR1 models and a subset of the CoC in a future update.

AIMar 18, 2025Code
Cosmos-Reason1: From Physical Common Sense To Embodied Reasoning

Alisson Azzolini, Junjie Bai, Hannah Brandon et al. · nvidia

Physical AI systems need to perceive, understand, and perform complex actions in the physical world. In this paper, we present the Cosmos-Reason1 models that can understand the physical world and generate appropriate embodied decisions (e.g., next step action) in natural language through long chain-of-thought reasoning processes. We begin by defining key capabilities for Physical AI reasoning, with a focus on physical common sense and embodied reasoning. To represent physical common sense, we use a hierarchical ontology that captures fundamental knowledge about space, time, and physics. For embodied reasoning, we rely on a two-dimensional ontology that generalizes across different physical embodiments. Building on these capabilities, we develop two multimodal large language models, Cosmos-Reason1-7B and Cosmos-Reason1-56B. We curate data and train our models in two stages: Physical AI supervised fine-tuning (SFT) and Physical AI reinforcement learning (RL). To evaluate our models, we build comprehensive benchmarks for physical common sense and embodied reasoning according to our ontologies. Evaluation results show that Physical AI SFT and RL bring significant improvements. To facilitate the development of Physical AI, we make our code and pre-trained models available under the NVIDIA Open Model License at https://github.com/nvidia-cosmos/cosmos-reason1.

CVOct 28, 2025Code
World Simulation with Video Foundation Models for Physical AI

Arslan Ali, Junjie Bai, Maciej Bala et al. · nvidia

We introduce [Cosmos-Predict2.5], the latest generation of the Cosmos World Foundation Models for Physical AI. Built on a flow-based architecture, [Cosmos-Predict2.5] unifies Text2World, Image2World, and Video2World generation in a single model and leverages [Cosmos-Reason1], a Physical AI vision-language model, to provide richer text grounding and finer control of world simulation. Trained on 200M curated video clips and refined with reinforcement learning-based post-training, [Cosmos-Predict2.5] achieves substantial improvements over [Cosmos-Predict1] in video quality and instruction alignment, with models released at 2B and 14B scales. These capabilities enable more reliable synthetic data generation, policy evaluation, and closed-loop simulation for robotics and autonomous systems. We further extend the family with [Cosmos-Transfer2.5], a control-net style framework for Sim2Real and Real2Real world translation. Despite being 3.5$\times$ smaller than [Cosmos-Transfer1], it delivers higher fidelity and robust long-horizon video generation. Together, these advances establish [Cosmos-Predict2.5] and [Cosmos-Transfer2.5] as versatile tools for scaling embodied intelligence. To accelerate research and deployment in Physical AI, we release source code, pretrained checkpoints, and curated benchmarks under the NVIDIA Open Model License at https://github.com/nvidia-cosmos/cosmos-predict2.5 and https://github.com/nvidia-cosmos/cosmos-transfer2.5. We hope these open resources lower the barrier to adoption and foster innovation in building the next generation of embodied intelligence.

99.1CVJun 1Code
Cosmos 3: Omnimodal World Models for Physical AI

Aditi, Niket Agarwal, Arslan Ali et al.

We introduce Cosmos 3, a family of omnimodal world models designed to jointly process and generate language, image, video, audio, and action sequences within a unified mixture-of-transformers architecture. By supporting highly flexible input-output configurations, Cosmos 3 seamlessly unifies critical modalities for Physical AI -- effectively subsuming vision-language models, video generators, world simulators, and world-action models into a single framework. Our evaluation demonstrates that Cosmos 3 establishes a new state-of-the-art across a diverse suite of understanding and generation tasks, demonstrating omnimodal world models as scalable, general-purpose backbones for embodied agents. Our post-trained Cosmos 3 models were ranked as the best open-source Text-to-Image and Image-to-Video models by Artificial Analysis, and the best policy model by RoboArena at the time the technical report was written. To accelerate open research and deployment in Physical AI, we make our code, model checkpoints, curated synthetic datasets, and evaluation benchmark available under the Linux Foundation's OpenMDW-1.1 https://openmdw.ai/license/1-1/ License at https://github.com/nvidia/cosmos}{github.com/nvidia/cosmos and https://huggingface.co/collections/nvidia/cosmos3 . The project website is available at https://research.nvidia.com/labs/cosmos-lab/cosmos3 .

AIMay 23, 2022
Parameter-Efficient Sparsity for Large Language Models Fine-Tuning

Yuchao Li, Fuli Luo, Chuanqi Tan et al.

With the dramatically increased number of parameters in language models, sparsity methods have received ever-increasing research focus to compress and accelerate the models. While most research focuses on how to accurately retain appropriate weights while maintaining the performance of the compressed model, there are challenges in the computational overhead and memory footprint of sparse training when compressing large-scale language models. To address this problem, we propose a Parameter-efficient Sparse Training (PST) method to reduce the number of trainable parameters during sparse-aware training in downstream tasks. Specifically, we first combine the data-free and data-driven criteria to efficiently and accurately measure the importance of weights. Then we investigate the intrinsic redundancy of data-driven weight importance and derive two obvious characteristics i.e., low-rankness and structuredness. Based on that, two groups of small matrices are introduced to compute the data-driven importance of weights, instead of using the original large importance score matrix, which therefore makes the sparse training resource-efficient and parameter-efficient. Experiments with diverse networks (i.e., BERT, RoBERTa and GPT-2) on dozens of datasets demonstrate PST performs on par or better than previous sparsity methods, despite only training a small number of parameters. For instance, compared with previous sparsity methods, our PST only requires 1.5% trainable parameters to achieve comparable performance on BERT.

LGDec 3, 2025
Data-regularized Reinforcement Learning for Diffusion Models at Scale

Haotian Ye, Kaiwen Zheng, Jiashu Xu et al. · gatech

Aligning generative diffusion models with human preferences via reinforcement learning (RL) is critical yet challenging. Most existing algorithms are often vulnerable to reward hacking, such as quality degradation, over-stylization, or reduced diversity. Our analysis demonstrates that this can be attributed to the inherent limitations of their regularization, which provides unreliable penalties. We introduce Data-regularized Diffusion Reinforcement Learning (DDRL), a novel framework that uses the forward KL divergence to anchor the policy to an off-policy data distribution. Theoretically, DDRL enables robust, unbiased integration of RL with standard diffusion training. Empirically, this translates into a simple yet effective algorithm that combines reward maximization with diffusion loss minimization. With over a million GPU hours of experiments and ten thousand double-blind human evaluations, we demonstrate on high-resolution video generation tasks that DDRL significantly improves rewards while alleviating the reward hacking seen in baselines, achieving the highest human preference and establishing a robust and scalable paradigm for diffusion post-training.

CVFeb 29, 2024Code
DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models

Muyang Li, Tianle Cai, Jiaxin Cao et al.

Diffusion models have achieved great success in synthesizing high-quality images. However, generating high-resolution images with diffusion models is still challenging due to the enormous computational costs, resulting in a prohibitive latency for interactive applications. In this paper, we propose DistriFusion to tackle this problem by leveraging parallelism across multiple GPUs. Our method splits the model input into multiple patches and assigns each patch to a GPU. However, naively implementing such an algorithm breaks the interaction between patches and loses fidelity, while incorporating such an interaction will incur tremendous communication overhead. To overcome this dilemma, we observe the high similarity between the input from adjacent diffusion steps and propose displaced patch parallelism, which takes advantage of the sequential nature of the diffusion process by reusing the pre-computed feature maps from the previous timestep to provide context for the current step. Therefore, our method supports asynchronous communication, which can be pipelined by computation. Extensive experiments show that our method can be applied to recent Stable Diffusion XL with no quality degradation and achieve up to a 6.1$\times$ speedup on eight NVIDIA A100s compared to one. Our code is publicly available at https://github.com/mit-han-lab/distrifuser.

CVNov 30, 2023
OmniMotionGPT: Animal Motion Generation with Limited Data

Zhangsihao Yang, Mingyuan Zhou, Mengyi Shan et al.

Our paper aims to generate diverse and realistic animal motion sequences from textual descriptions, without a large-scale animal text-motion dataset. While the task of text-driven human motion synthesis is already extensively studied and benchmarked, it remains challenging to transfer this success to other skeleton structures with limited data. In this work, we design a model architecture that imitates Generative Pretraining Transformer (GPT), utilizing prior knowledge learned from human data to the animal domain. We jointly train motion autoencoders for both animal and human motions and at the same time optimize through the similarity scores among human motion encoding, animal motion encoding, and text CLIP embedding. Presenting the first solution to this problem, we are able to generate animal motions with high diversity and fidelity, quantitatively and qualitatively outperforming the results of training human motion generation baselines on animal data. Additionally, we introduce AnimalML3D, the first text-animal motion dataset with 1240 animation sequences spanning 36 different animal identities. We hope this dataset would mediate the data scarcity problem in text-driven animal motion generation, providing a new playground for the research community.

99.4LGApr 8
FP4 Explore, BF16 Train: Diffusion Reinforcement Learning via Efficient Rollout Scaling

Yitong Li, Junsong Chen, Shuchen Xue et al.

Reinforcement-Learning-based post-training has recently emerged as a promising paradigm for aligning text-to-image diffusion models with human preferences. In recent studies, increasing the rollout group size yields pronounced performance improvements, indicating substantial room for further alignment gains. However, scaling rollouts on large-scale foundational diffusion models (e.g., FLUX.1-12B) imposes a heavy computational burden. To alleviate this bottleneck, we explore the integration of FP4 quantization into Diffusion RL rollouts. Yet, we identify that naive quantized pipelines inherently introduce risks of performance degradation. To overcome this dilemma between efficiency and training integrity, we propose Sol-RL (Speed-of-light RL), a novel FP4-empowered Two-stage Reinforcement Learning framework. First, we utilize high-throughput NVFP4 rollouts to generate a massive candidate pool and extract a highly contrastive subset. Second, we regenerate these selected samples in BF16 precision and optimize the policy exclusively on them. By decoupling candidate exploration from policy optimization, Sol-RL integrates the algorithmic mechanisms of rollout scaling with the system-level throughput gains of NVFP4. This synergistic algorithm-hardware design effectively accelerates the rollout phase while reserving high-fidelity samples for optimization. We empirically demonstrate that our framework maintains the training integrity of BF16 precision pipeline while fully exploiting the throughput gains enabled by FP4 arithmetic. Extensive experiments across SANA, FLUX.1, and SD3.5-L substantiate that our approach delivers superior alignment performance across multiple metrics while accelerating training convergence by up to $4.64\times$, unlocking the power of massive rollout scaling at a fraction of the cost.

17.6IRMay 14
Efficient Generative Retrieval for E-commerce Search with Semantic Cluster IDs and Expert-Guided RL

Jianbo Zhu, Xing Fang, Jing Wang et al.

Generative retrieval offers a promising alternative by unifying the fragmented multi-stage retrieval process into a single end-to-end model. However, its practical adoption in industrial e-commerce search remains challenging, given the massive and dynamic product catalogs, strict latency requirements, and the need to align retrieval with downstream ranking goals. In this work, we propose a retrieval framework tailored for real-world recall scenarios, positioning generative retrieval as a recall-stage supplement rather than an end-to-end replacement. Our method, CQ-SID (Category-and-Query constrained Semantic ID), employs category-aware and query-item contrastive learning along with Residual Quantized VAEs to encode items into hierarchical semantic cluster identifiers, significantly reducing beam search complexity. Additionally, we develop EG-GRPO (Expert-Guided Group Relative Policy Optimization), a reinforcement learning approach that aligns generative recall with downstream ranking under sparse rewards by injecting ground-truth samples to stabilize training. Offline experiments on TmallAPP search logs show that CQ-SID achieves up to 26.76% and 11.11% relative gains in semantic and personalized click hitrate over RQ-VAE baselines, while halving beam search size. EG-GRPO further improves multi-objective performance. Online A/B tests confirm gains in GMV (+1.15%) and UCTCVR (+0.40%). The generative recall channel now contributes substantially in production, accounting for over 50.25% of exposures, 58.96% of clicks, and 72.63% of purchases, demonstrating a viable path for deploying generative retrieval in real-world e-commerce systems.

LGDec 3, 2019
PyTorch: An Imperative Style, High-Performance Deep Learning Library

Adam Paszke, Sam Gross, Francisco Massa et al.

Deep learning frameworks have often focused on either usability or speed, but not both. PyTorch is a machine learning library that shows that these two goals are in fact compatible: it provides an imperative and Pythonic programming style that supports code as a model, makes debugging easy and is consistent with other popular scientific computing libraries, while remaining efficient and supporting hardware accelerators such as GPUs. In this paper, we detail the principles that drove the implementation of PyTorch and how they are reflected in its architecture. We emphasize that every aspect of PyTorch is a regular Python program under the full control of its user. We also explain how the careful and pragmatic implementation of the key components of its runtime enables them to work together to achieve compelling performance. We demonstrate the efficiency of individual subsystems, as well as the overall speed of PyTorch on several common benchmarks.

CVMar 25, 2019
DeepCenterline: a Multi-task Fully Convolutional Network for Centerline Extraction

Zhihui Guo, Junjie Bai, Yi Lu et al.

A novel centerline extraction framework is reported which combines an end-to-end trainable multi-task fully convolutional network (FCN) with a minimal path extractor. The FCN simultaneously computes centerline distance maps and detects branch endpoints. The method generates single-pixel-wide centerlines with no spurious branches. It handles arbitrary tree-structured object with no prior assumption regarding depth of the tree or its bifurcation pattern. It is also robust to substantial scale changes across different parts of the target object and minor imperfections of the object's segmentation mask. To the best of our knowledge, this is the first deep-learning based centerline extraction method that guarantees single-pixel-wide centerline for a complex tree-structured object. The proposed method is validated in coronary artery centerline extraction on a dataset of 620 patients (400 of which used as test set). This application is challenging due to the large number of coronary branches, branch tortuosity, and large variations in length, thickness, shape, etc. The proposed method generates well-positioned centerlines, exhibiting lower number of missing branches and is more robust in the presence of minor imperfections of the object segmentation mask. Compared to a state-of-the-art traditional minimal path approach, our method improves patient-level success rate of centerline extraction from 54.3% to 88.8% according to independent human expert review.

CVJan 29, 2019
Attention-driven Tree-structured Convolutional LSTM for High Dimensional Data Understanding

Bin Kong, Xin Wang, Junjie Bai et al.

Modeling the sequential information of image sequences has been a vital step of various vision tasks and convolutional long short-term memory (ConvLSTM) has demonstrated its superb performance in such spatiotemporal problems. Nevertheless, the hierarchical data structures in a significant amount of tasks (e.g., human body parts and vessel/airway tree in biomedical images) cannot be properly modeled by sequential models. Thus, ConvLSTM is not suitable for tree-structured image data analysis. In order to address these limitations, we present tree-structured ConvLSTM models for tree-structured image analysis tasks which can be trained end-to-end. To demonstrate the effectiveness of the proposed tree-structured ConvLSTM model, we present a tree-structured segmentation framework which consists of a tree-structured ConvLSTM and an attention fully convolutional network (FCN) model. The proposed framework is extensively validated on four large-scale coronary artery datasets. The results demonstrate the effectiveness and efficiency of the proposed method.

CVDec 21, 2018
Residual Attention based Network for Hand Bone Age Assessment

Eric Wu, Bin Kong, Xin Wang et al.

Computerized automatic methods have been employed to boost the productivity as well as objectiveness of hand bone age assessment. These approaches make predictions according to the whole X-ray images, which include other objects that may introduce distractions. Instead, our framework is inspired by the clinical workflow (Tanner-Whitehouse) of hand bone age assessment, which focuses on the key components of the hand. The proposed framework is composed of two components: a Mask R-CNN subnet of pixelwise hand segmentation and a residual attention network for hand bone age assessment. The Mask R-CNN subnet segments the hands from X-ray images to avoid the distractions of other objects (e.g., X-ray tags). The hierarchical attention components of the residual attention subnet force our network to focus on the key components of the X-ray images and generate the final predictions as well as the associated visual supports, which is similar to the assessment procedure of clinicians. We evaluate the performance of the proposed pipeline on the RSNA pediatric bone age dataset and the results demonstrate its superiority over the previous methods.

CVMay 22, 2017
Optimal Multi-Object Segmentation with Novel Gradient Vector Flow Based Shape Priors

Junjie Bai, Abhay Shah, Xiaodong Wu

Shape priors have been widely utilized in medical image segmentation to improve segmentation accuracy and robustness. A major way to encode such a prior shape model is to use a mesh representation, which is prone to causing self-intersection or mesh folding. Those problems require complex and expensive algorithms to mitigate. In this paper, we propose a novel shape prior directly embedded in the voxel grid space, based on gradient vector flows of a pre-segmentation. The flexible and powerful prior shape representation is ready to be extended to simultaneously segmenting multiple interacting objects with minimum separation distance constraint. The problem is formulated as a Markov random field problem whose exact solution can be efficiently computed with a single minimum s-t cut in an appropriately constructed graph. The proposed algorithm is validated on two multi-object segmentation applications: the brain tissue segmentation in MRI images, and the bladder/prostate segmentation in CT images. Both sets of experiments show superior or competitive performance of the proposed method to other state-of-the-art methods.