Toward Self-Improvement of LLMs via Imagination, Searching, and Criticizing
Today’s paper introduces ALPHA LLM, a framework that enables large language models (LLMs) to self-improve without additional annotations. It integrates Monte Carlo Tree Search (MCTS) with LLMs to establish a self-improving loop that enhances the capabilities of LLMs on complex reasoning and planning tasks.
Method Overview
ALPHA LLM has three key components:
1. An imagination component that synthesizes prompts to enrich the diversity and complexity of the training data. This helps address the challenge of data scarcity when applying MCTS to LLMs.
2. An efficient MCTS approach called ηMCTS that is tailored for language tasks. It uses options (sequences of tokens) as search nodes to reduce the search space compared to token-level search. It also employs importance weighted expansion to dynamically adjust the branching factor of each node based on its importance.
3. A trio of critic models that provide precise feedback to guide the MCTS search:
- A value function that predicts the expected future reward
- A process reward model that estimates the quality of each search node
- An outcome reward model that assesses the overall quality of complete trajectories
These components work together to enable the LLM to efficiently search for high-quality responses and use them to iteratively improve itself. More exactly, the process starts with an initial dataset D0 containing expert-generated prompt-response pairs, a policy model (the LLM) and a reward model. Then, iteratively, in each loop synthetic data is generated using the previous policy model (LLM) and dataset. Then, trajectories are collected using MCTS guided by the reward model. Finally, a new dataset with the generated prompts and responses is constructed and the policy model is updated by minimizing a loss function. The goal is to iteratively refine the LLM to maximize the expected cumulative reward.
Results
Experiments on mathematical reasoning tasks (GSM8K and MATH datasets) show that ALPHA LLM significantly improves the performance of base LLMs like LLaMA-2 70B and WizardMath 70B without using additional annotations.
When augmented with the ηMCTS decoding strategy, ALPHA LLM achieves accuracy scores of 88.9% on GSM8K and 48.7% on MATH. After two iterations of self-improvement using synthetic prompts, it performs comparably to GPT-4.
Conclusion
ALPHA LLM demonstrates an effective approach for LLMs to self-improve on complex reasoning tasks by leveraging MCTS, without relying on extra labeled data. The imagination-searching-criticizing framework addresses key challenges in integrating MCTS with LLMs and shows promising results. For more information please consult the full paper.
Congrats to the authors for their work!
Tian, Ye, et al. "Toward Self-Improvement of LLMs via Imagination, Searching, and Criticizing." arXiv preprint arXiv:2404.12253 (2023).