Combining next-token prediction and video diffusion in computer vision and robotics

In the current AI zeitgeist, sequence models have skyrocketed in popularity for their ability to analyze data and predict what to do next. For instance, you’ve likely used next-token prediction models like ChatGPT, which anticipate each word (token) in a sequence to form answers to users’ queries. There are also full-sequence diffusion models like Sora, which convert words into dazzling, realistic visuals by successively “denoising” an entire video sequence. 

Researchers from MIT’s Computer Science and Artificial Intelligence Laboratory (CSAIL) have proposed a simple change to the diffusion training scheme that makes this sequence denoising considerably more flexible.

When applied to fields like computer vision and robotics, the next-token and full-sequence diffusion models have capability trade-offs. Next-token models can spit out sequences that vary in length. However, they make these generations while being unaware of desirable states in the far future — such as steering its sequence generation toward a certain goal 10 tokens away — and thus require additional mechanisms for long-horizon (long-term) planning. Diffusion models can perform such future-conditioned sampling, but lack the ability of next-token models to generate variable-length sequences.

Researchers from CSAIL want to combine the strengths of both models, so they created a sequence model training technique called “Diffusion Forcing.” The name comes from “Teacher Forcing,” the conventional training scheme that breaks down full sequence generation into the smaller, easier steps of next-token generation (much like a good teacher simplifying a complex concept).

Video thumbnail

Play video

Diffusion Forcing
Video: MIT CSAIL

Diffusion Forcing found common ground between diffusion models and teacher forcing: They both use training schemes that involve predicting masked (noisy) tokens from unmasked ones. In the case of diffusion models, they gradually add noise to data, which can be viewed as fractional masking. The MIT researchers’ Diffusion Forcing method trains neural networks to cleanse a collection of tokens, removing different amounts of noise within each one while simultaneously predicting the next few tokens. The result: a flexible, reliable sequence model that resulted in higher-quality artificial videos and more precise decision-making for robots and AI agents.

By sorting through noisy data and reliably predicting the next steps in a task, Diffusion Forcing can aid a robot in ignoring visual distractions to complete manipulation tasks. It can also generate stable and consistent video sequences and even guide an AI agent through digital mazes. This method could potentially enable household and factory robots to generalize to new tasks and improve AI-generated entertainment.

“Sequence models aim to condition on the known past and predict the unknown future, a type of binary masking. However, masking doesn’t need to be binary,” says lead author, MIT electrical engineering and computer science (EECS) PhD student, and CSAIL member Boyuan Chen. “With Diffusion Forcing, we add different levels of noise to each token, effectively serving as a type of fractional masking. At test time, our system can “unmask” a collection of tokens and diffuse a sequence in the near future at a lower noise level. It knows what to trust within its data to overcome out-of-distribution inputs.”

In several experiments, Diffusion Forcing thrived at ignoring misleading data to execute tasks while anticipating future actions.

When implemented into a robotic arm, for example, it helped swap two toy fruits across three circular mats, a minimal example of a family of long-horizon tasks that require memories. The researchers trained the robot by controlling it from a distance (or teleoperating it) in virtual reality. The robot is trained to mimic the user’s movements from its camera. Despite starting from random positions and seeing distractions like a shopping bag blocking the markers, it placed the objects into its target spots.

To generate videos, they trained Diffusion Forcing on “Minecraft” game play and colorful digital environments created within Google’s DeepMind Lab Simulator. When given a single frame of footage, the method produced more stable, higher-resolution videos than comparable baselines like a Sora-like full-sequence diffusion model and ChatGPT-like next-token models. These approaches created videos that appeared inconsistent, with the latter sometimes failing to generate working video past just 72 frames.

Diffusion Forcing not only generates fancy videos, but can also serve as a motion planner that steers toward desired outcomes or rewards. Thanks to its flexibility, Diffusion Forcing can uniquely generate plans with varying horizon, perform tree search, and incorporate the intuition that the distant future is more uncertain than the near future. In the task of solving a 2D maze, Diffusion Forcing outperformed six baselines by generating faster plans leading to the goal location, indicating that it could be an effective planner for robots in the future.

Across each demo, Diffusion Forcing acted as a full sequence model, a next-token prediction model, or both. According to Chen, this versatile approach could potentially serve as a powerful backbone for a “world model,” an AI system that can simulate the dynamics of the world by training on billions of internet videos. This would allow robots to perform novel tasks by imagining what they need to do based on their surroundings. For example, if you asked a robot to open a door without being trained on how to do it, the model could produce a video that’ll show the machine how to do it.

The team is currently looking to scale up their method to larger datasets and the latest transformer models to improve performance. They intend to broaden their work to build a ChatGPT-like robot brain that helps robots perform tasks in new environments without human demonstration.

“With Diffusion Forcing, we are taking a step to bringing video generation and robotics closer together,” says senior author Vincent Sitzmann, MIT assistant professor and member of CSAIL, where he leads the Scene Representation group. “In the end, we hope that we can use all the knowledge stored in videos on the internet to enable robots to help in everyday life. Many more exciting research challenges remain, like how robots can learn to imitate humans by watching them even when their own bodies are so different from our own!”

Chen and Sitzmann wrote the paper alongside recent MIT visiting researcher Diego Martí Monsó, and CSAIL affiliates: Yilun Du, a EECS graduate student; Max Simchowitz, former postdoc and incoming Carnegie Mellon University assistant professor; and Russ Tedrake, the Toyota Professor of EECS, Aeronautics and Astronautics, and Mechanical Engineering at MIT, vice president of robotics research at the Toyota Research Institute, and CSAIL member. Their work was supported, in part, by the U.S. National Science Foundation, the Singapore Defence Science and Technology Agency, Intelligence Advanced Research Projects Activity via the U.S. Department of the Interior, and the Amazon Science Hub. They will present their research at NeurIPS in December.