Commit 2c4ba911 by zqh

upload files

parents
.vscode/
__pycache__/
logs/
[submodule "mathlib4"]
path = mathlib4
url = https://github.com/xinhjBrant/mathlib4.git
branch = deepseek
MIT License
Copyright (c) 2023 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
This diff is collapsed. Click to expand it.
<!-- markdownlint-disable first-line-h1 -->
<!-- markdownlint-disable html -->
<!-- markdownlint-disable no-duplicate-header -->
<div align="center">
<img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek-V2" />
</div>
<hr>
<div align="center" style="line-height: 1;">
<a href="https://www.deepseek.com/" target="_blank" style="margin: 2px;">
<img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" style="display: inline-block; vertical-align: middle;"/>
</a>
<a href="https://chat.deepseek.com/" target="_blank" style="margin: 2px;">
<img alt="Chat" src="https://img.shields.io/badge/🤖%20Chat-DeepSeek%20V2-536af5?color=536af5&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
</a>
<a href="https://huggingface.co/deepseek-ai" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
</a>
</div>
<div align="center" style="line-height: 1;">
<a href="https://discord.gg/Tc7c45Zzu5" target="_blank" style="margin: 2px;">
<img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" style="display: inline-block; vertical-align: middle;"/>
</a>
<a href="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/qr.jpeg?raw=true" target="_blank" style="margin: 2px;">
<img alt="Wechat" src="https://img.shields.io/badge/WeChat-DeepSeek%20AI-brightgreen?logo=wechat&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
</a>
<a href="https://twitter.com/deepseek_ai" target="_blank" style="margin: 2px;">
<img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
</a>
</div>
<div align="center" style="line-height: 1;">
<a href="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/LICENSE-CODE" style="margin: 2px;">
<img alt="Code License" src="https://img.shields.io/badge/Code_License-MIT-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
</a>
<a href="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/LICENSE-MODEL" style="margin: 2px;">
<img alt="Model License" src="https://img.shields.io/badge/Model_License-Model_Agreement-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
</a>
</div>
<p align="center">
<a href="#3-evaluation-results">Evaluation Results</a> |
<a href="#3-model-downloads">Model Download</a> |
<a href="#4-setup-environment">Setup Environment</a> |
<a href="#5-quick-start">Quick Start</a> |
<a href="#6-questions-and-bugs">Questions and Bugs</a> |
<a href="#7-license">License</a> |
<a href="#8-citation">Citation</a> |
<a href="#9-contact">Contact</a>
</p>
<p align="center">
<a href="https://arxiv.org/abs/2408.08152"><b>Paper Link</b>👁️</a>
</p>
# DeepSeek-Prover-V1.5: Harnessing Proof Assistant Feedback for Reinforcement Learning and Monte-Carlo Tree Search
## 1. Introduction
We introduce DeepSeek-Prover-V1.5, an open-source language model designed for theorem proving in Lean 4, which enhances DeepSeek-Prover-V1 by optimizing both training and inference processes. Pre-trained on DeepSeekMath-Base with specialization in formal mathematical languages, the model undergoes supervised fine-tuning using an enhanced formal theorem proving dataset derived from DeepSeek-Prover-V1. Further refinement is achieved through reinforcement learning from proof assistant feedback (RLPAF). Beyond the single-pass whole-proof generation approach of DeepSeek-Prover-V1, we propose RMaxTS, a variant of Monte-Carlo tree search that employs an intrinsic-reward-driven exploration strategy to generate diverse proof paths. DeepSeek-Prover-V1.5 demonstrates significant improvements over DeepSeek-Prover-V1, achieving new state-of-the-art results on the test set of the high school level miniF2F benchmark (63.5%) and the undergraduate level ProofNet benchmark (25.3%).
<p align="center">
<img width="100%" src="figures/performance.png">
</p>
## 2. Evaluation Results
<div align="center">
| | miniF2F-test | ProofNet |
|--------|------------------|------------------|
| **ReProver** | 26.5% | 13.8% |
| **GPT-f** | 36.6% | - |
| **Hypertree Proof Search** | 41.0% | - |
| **InternLM2-StepProver** | 54.5% | 18.1% |
| **DeepSeek-Prover-V1** | 50.0% | - |
| **DeepSeek-Prover-V1.5-Base** | 42.2% | 13.2% |
| **DeepSeek-Prover-V1.5-SFT** | 57.4% | 22.9% |
| **DeepSeek-Prover-V1.5-RL** | 60.2% | 22.6% |
| **DeepSeek-Prover-V1.5-RL + RMaxTS** | **63.5%** | **25.3%** |
</div>
## 3. Model Downloads
We release the DeepSeek-Prover-V1.5 with 7B parameters, including base, SFT and RL models, to the public.
<div align="center">
| **Model** | **Download** |
| :-----------------------------: | :----------------------------------------------------------: |
| DeepSeek-Prover-V1.5-Base | [🤗 HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-Prover-V1.5-Base) |
| DeepSeek-Prover-V1.5-SFT | [🤗 HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-Prover-V1.5-SFT) |
| DeepSeek-Prover-V1.5-RL | [🤗 HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-Prover-V1.5-RL) |
</div>
## 4. Setup Environment
### Requirements
* Supported platform: Linux
* Python 3.10
### Installation
1. **Install Lean 4**
Follow the instructions on the [Lean 4 installation page](https://leanprover.github.io/lean4/doc/quickstart.html) to set up Lean 4.
2. **Clone the repository**
```sh
git clone --recurse-submodules git@github.com:deepseek-ai/DeepSeek-Prover-V1.5.git
cd DeepSeek-Prover-V1.5
```
3. **Install dependencies**
```sh
pip install -r requirements.txt
```
4. **Build Mathlib4**
```sh
cd mathlib4
lake build
```
## 5. Quick Start
You can directly use [Huggingface's Transformers](https://github.com/huggingface/transformers) for model inference. A simple example of generating a proof for a problem from miniF2F and verifying it can be found in [quick_start.py](https://github.com/deepseek-ai/DeepSeek-Prover-V1.5/blob/master/quick_start.py).
To run paper experiments, you can use the following script to launch a RMaxTS proof search agent:
```sh
python -m prover.launch --config=configs/RMaxTS.py --log_dir=logs/RMaxTS_results
```
You can use `CUDA_VISIBLE_DEVICES=0,1,···` to specify the GPU devices. The experiment results can be gathered using the following script:
```sh
python -m prover.summarize --config=configs/RMaxTS.py --log_dir=logs/RMaxTS_results
```
## 6. Questions and Bugs
* For general questions and discussions, please use [GitHub Discussions](https://github.com/deepseek-ai/DeepSeek-Prover-V1.5/discussions).
* To report a potential bug, please open an issue.
## 7. License
This code repository is licensed under the MIT License. The use of DeepSeekMath models is subject to the Model License. DeepSeekMath supports commercial use.
See the [LICENSE-CODE](LICENSE-CODE) and [LICENSE-MODEL](LICENSE-MODEL) for more details.
## 8. Citation
```latex
@article{xin2024deepseekproverv15harnessingproofassistant,
title={DeepSeek-Prover-V1.5: Harnessing Proof Assistant Feedback for Reinforcement Learning and Monte-Carlo Tree Search},
author={Huajian Xin and Z. Z. Ren and Junxiao Song and Zhihong Shao and Wanjia Zhao and Haocheng Wang and Bo Liu and Liyue Zhang and Xuan Lu and Qiushi Du and Wenjun Gao and Qihao Zhu and Dejian Yang and Zhibin Gou and Z. F. Wu and Fuli Luo and Chong Ruan},
year={2024},
eprint={2408.08152},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2408.08152},
}
```
## 9. Contact
If you have any questions, please raise an issue or contact us at [service@deepseek.com](mailto:service@deepseek.com).
from prover.utils import AttrDict
from prover.algorithms import RMaxTS
# dataset
data_path = 'datasets/minif2f.jsonl'
data_split = 'test'
data_repeat = 16 # run 16 * 6400
# verifier
lean_max_concurrent_requests = 64
lean_memory_limit = 10
lean_timeout = 300
# model
batch_size = 512
model_path = 'deepseek-ai/DeepSeek-Prover-V1.5-RL'
model_args = AttrDict(
mode='cot', # `cot` or `non-cot`
temperature=1,
max_tokens=2048,
top_p=0.95,
)
# algorithm
n_search_procs = 256
sampler = dict(
algorithm=RMaxTS,
gamma=0.99,
sample_num=6400,
concurrent_num=32,
tactic_state_comment=True,
ckpt_interval=128,
log_interval=32,
)
from prover.utils import AttrDict
from prover.algorithms import Sampling
# dataset
data_path = 'datasets/minif2f.jsonl'
data_split = ['valid', 'test']
data_repeat = 1
# verifier
lean_max_concurrent_requests = 64
lean_memory_limit = 10
lean_timeout = 300
# model
batch_size = 32
model_path = 'deepseek-ai/DeepSeek-Prover-V1.5-RL'
model_args = AttrDict(
mode='cot', # `cot` or `non-cot`
temperature=1,
max_tokens=2048,
top_p=0.95,
)
# algorithm
n_search_procs = 64
sampler = dict(
algorithm=Sampling,
sample_num=128,
log_interval=32,
)
from prover.utils import AttrDict
from prover.algorithms import Sampling
# dataset
data_path = 'datasets/minif2f.jsonl'
data_split = ['valid', 'test']
data_repeat = 1
# verifier
lean_max_concurrent_requests = 64
lean_memory_limit = 10
lean_timeout = 300
# model
batch_size = 32
model_path = 'deepseek-ai/DeepSeek-Prover-V1.5-Base'
model_args = AttrDict(
mode='cot', # `cot` or `non-cot`
temperature=1,
max_tokens=2048,
top_p=0.95,
)
# algorithm
n_search_procs = 64
sampler = dict(
algorithm=Sampling,
sample_num=128,
log_interval=32,
few_shot_dataset='datasets/minif2f_valid_few_shot.jsonl',
few_shot_num=3,
)
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="151" height="20" role="img" aria-label="DeepSeek: Homepage"><title>DeepSeek: Homepage</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="151" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="82" height="20" fill="#555"/><rect x="82" width="69" height="20" fill="#536af5"/><rect width="151" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><image x="5" y="3" width="14" height="14" xlink:href="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNjMuMTE5NjI5IiBoZWlnaHQ9IjQ2LjQwMzMyMCIgdmlld0JveD0iMCAwIDYzLjExOTYgNDYuNDAzMyIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiB4bWxuczp4bGluaz0iaHR0cDovL3d3dy53My5vcmcvMTk5OS94bGluayI+Cgk8ZGVzYz4KCQkJQ3JlYXRlZCB3aXRoIFBpeHNvLgoJPC9kZXNjPgoJPGRlZnMvPgoJPHBhdGggaWQ9InBhdGgiIGQ9Ik02Mi40NTc1IDMuODk0NDFDNjEuNzg4OCAzLjU2NzI2IDYxLjUwMSA0LjE5MDggNjEuMTEwMSA0LjUwNzY5QzYwLjk3NjMgNC42MDk5OSA2MC44NjMgNC43NDI4IDYwLjc1IDQuODY1NDhDNTkuNzcyNyA1LjkwODIgNTguNjMxMSA2LjU5MzAyIDU3LjEzOTQgNi41MTEyM0M1NC45NTg3IDYuMzg4NTUgNTMuMDk2OSA3LjA3MzQ5IDUxLjQ1MTIgOC43Mzk3NUM1MS4xMDEzIDYuNjg1MDYgNDkuOTM5IDUuNDU4MzcgNDguMTY5OSA0LjY3MTI2QzQ3LjI0NDEgNC4yNjIzMyA0Ni4zMDgxIDMuODUzNTIgNDUuNjU5OSAyLjk2NDExQzQ1LjIwNzMgMi4zMzAzMiA0NS4wODQgMS42MjUgNDQuODU3NyAwLjkyOTkzMkM0NC43MTM2IDAuNTEwODY0IDQ0LjU2OTYgMC4wODE1NDMgNDQuMDg2MiAwLjAwOTg4NzdDNDMuNTYxNSAtMC4wNzE4OTk0IDQzLjM1NTcgMC4zNjc2NzYgNDMuMTUwMSAwLjczNTcxOEM0Mi4zMjcxIDIuMjM4NCA0Mi4wMDgzIDMuODk0NDEgNDIuMDM5MSA1LjU3MDhDNDIuMTExMSA5LjM0Mjc3IDQzLjcwNTYgMTIuMzQ4MSA0Ni44NzM4IDE0LjQ4NDZDNDcuMjMzNiAxNC43MyA0Ny4zMjY0IDE0Ljk3NTMgNDcuMjEzMSAxNS4zMzNDNDYuOTk3MSAxNi4wNjkxIDQ2Ljc0IDE2Ljc4NDcgNDYuNTEzNyAxNy41MjA2QzQ2LjM2OTYgMTcuOTkwOCA0Ni4xNTM4IDE4LjA5MyA0NS42NDk3IDE3Ljg4ODdDNDMuOTExNCAxNy4xNjI4IDQyLjQwOTQgMTYuMDg5NSA0MS4wODI1IDE0Ljc5MTNDMzguODI5OCAxMi42MTM5IDM2Ljc5MzIgMTAuMjExNyAzNC4yNTI0IDguMzMwODFDMzMuNjU1OCA3Ljg5MTI0IDMzLjA1OTMgNy40ODI0MiAzMi40NDIxIDcuMDkzOTlDMjkuODQ5OSA0LjU3OTIyIDMyLjc4MTUgMi41MTQ0IDMzLjQ2MDQgMi4yNjkwNEMzNC4xNzAyIDIuMDEzNDMgMzMuNzA3MyAxLjEzNDQgMzEuNDEzMyAxLjE0NDY1QzI5LjExOTYgMS4xNTQ3OSAyNy4wMjEyIDEuOTIxNTEgMjQuMzQ2NyAyLjk0MzczQzIzLjk1NTggMy4wOTcwNSAyMy41NDQ0IDMuMjA5NDcgMjMuMTIyNiAzLjMwMTUxQzIwLjY5NTEgMi44NDE0MyAxOC4xNzQ4IDIuNzM5MjYgMTUuNTQxNSAzLjAzNTc3QzEwLjU4MzUgMy41ODc3NyA2LjYyMzI5IDUuOTI4NTkgMy43MTI0IDkuOTI1NTRDMC4yMTUwODggMTQuNzMgLTAuNjA3OTEgMjAuMTg4NiAwLjQwMDE0NiAyNS44ODI0QzEuNDU5NzIgMzEuODgyOCA0LjUyNDkgMzYuODUwOCA5LjIzNjA4IDQwLjczNTRDMTQuMTIyMSA0NC43NjI5IDE5Ljc0ODggNDYuNzM1NyAyNi4xNjc1IDQ2LjM1NzVDMzAuMDY1OSA0Ni4xMzI3IDM0LjQwNjcgNDUuNjExMyAzOS4zMDMgNDEuNDcxM0M0MC41Mzc0IDQyLjA4NDcgNDEuODMzNSA0Mi4zMyA0My45ODM0IDQyLjUxNEM0NS42Mzk0IDQyLjY2NzQgNDcuMjMzNiA0Mi40MzIzIDQ4LjQ2OCA0Mi4xNzY2QzUwLjQwMTkgNDEuNzY3OCA1MC4yNjgzIDM5Ljk3ODkgNDkuNTY4OCAzOS42NTE3QzQzLjkwMDkgMzcuMDE0NCA0NS4xNDU1IDM4LjA4NzggNDQuMDE0MiAzNy4yMTg5QzQ2Ljg5NDMgMzMuODE0OCA1MS4yMzUxIDMwLjI3OCA1Mi45MzI0IDE4LjgxODhDNTMuMDY2MiAxNy45MDkxIDUyLjk1MjkgMTcuMzM2NyA1Mi45MzI0IDE2LjYwMDZDNTIuOTIyMSAxNi4xNTA5IDUzLjAyNDkgMTUuOTc3MSA1My41MzkzIDE1LjkyNTlDNTQuOTU4NyAxNS43NjI1IDU2LjMzNzIgMTUuMzczOSA1Ny42MDIzIDE0LjY3ODhDNjEuMjc0NyAxMi42NzUzIDYyLjc1NTkgOS4zODM2NyA2My4xMDU1IDUuNDM3OTlDNjMuMTU3IDQuODM0ODQgNjMuMDk1MiA0LjIxMTMgNjIuNDU3NSAzLjg5NDQxWk0zMC40NTY4IDM5LjQwNjVDMjQuOTYzOSAzNS4wOTI3IDIyLjI5OTggMzMuNjcxOCAyMS4xOTkgMzMuNzMzMkMyMC4xNzA0IDMzLjc5NDQgMjAuMzU1NyAzNC45NyAyMC41ODE4IDM1LjczNjdDMjAuODE4NiAzNi40OTMgMjEuMTI3MiAzNy4wMTQ0IDIxLjU1OTEgMzcuNjc4OEMyMS44NTc0IDM4LjExODQgMjIuMDYzMiAzOC43NzI3IDIxLjI2MDcgMzkuMjYzM0MxOS40OTE1IDQwLjM1NzEgMTYuNDE2IDM4Ljg5NTMgMTYuMjcyIDM4LjgyMzdDMTIuNjkyNCAzNi43MTggOS42OTg5NyAzMy45Mzc1IDcuNTkwMzMgMzAuMTM0OUM1LjU1MzQ3IDI2LjQ3NTMgNC4zNzA2MSAyMi41NDk5IDQuMTc1MjkgMTguMzU4OUM0LjEyMzc4IDE3LjM0NjggNC40MjIxMiAxNi45ODkgNS40MzAxOCAxNi44MDUxQzYuNzU3MDggMTYuNTU5NyA4LjEyNTI0IDE2LjUwODcgOS40NTIxNSAxNi43MDI5QzE1LjA1ODEgMTcuNTIwNiAxOS44MzExIDIwLjAyNSAyMy44MzIzIDIzLjk5MTNDMjYuMTE2IDI2LjI1MDQgMjcuODQ0IDI4Ljk0OTEgMjkuNjIzNSAzMS41ODY0QzMxLjUxNjQgMzQuMzg3MyAzMy41NTMgMzcuMDU1MyAzNi4xNDUgMzkuMjQyOUMzNy4wNjA1IDQwLjAwOTUgMzcuNzkxIDQwLjU5MjIgMzguNDkwNSA0MS4wMjE1QzM2LjM4MTYgNDEuMjU2NyAzMi44NjM4IDQxLjMwNzcgMzAuNDU2OCAzOS40MDY1Wk0zMy4wOTAxIDIyLjQ4ODZDMzMuMDkwMSAyMi4wMzg4IDMzLjQ1MDIgMjEuNjgxIDMzLjkwMjYgMjEuNjgxQzM0LjAwNTYgMjEuNjgxIDM0LjA5ODEgMjEuNzAxNSAzNC4xODA0IDIxLjczMjJDMzQuMjkzNSAyMS43NzMxIDM0LjM5NjUgMjEuODM0NCAzNC40Nzg4IDIxLjkyNjRDMzQuNjIyOCAyMi4wNjk1IDM0LjcwNTEgMjIuMjczOSAzNC43MDUxIDIyLjQ4ODZDMzQuNzA1MSAyMi45Mzg0IDM0LjM0NSAyMy4yOTYxIDMzLjg5MjMgMjMuMjk2MUMzMy40Mzk3IDIzLjI5NjEgMzMuMDkwMSAyMi45Mzg0IDMzLjA5MDEgMjIuNDg4NlpNNDEuMjY3NiAyNi42Nzk4QzQwLjc0MzIgMjYuODk0NCA0MC4yMTg1IDI3LjA3ODQgMzkuNzE0NCAyNy4wOTg5QzM4LjkzMjYgMjcuMTM5OCAzOC4wNzg5IDI2LjgyMjkgMzcuNjE2IDI2LjQzNDRDMzYuODk2IDI1LjgzMTMgMzYuMzgxNiAyNS40OTQgMzYuMTY1OCAyNC40NDFDMzYuMDczIDIzLjk5MTMgMzYuMTI0NSAyMy4yOTYxIDM2LjIwNjggMjIuODk3NUMzNi4zOTIxIDIyLjAzODggMzYuMTg2MyAyMS40ODY4IDM1LjU3OTMgMjAuOTg2QzM1LjA4NTcgMjAuNTc3IDM0LjQ1ODMgMjAuNDY0NiAzMy43NjkgMjAuNDY0NkMzMy41MTE3IDIwLjQ2NDYgMzMuMjc1MSAyMC4zNTIyIDMzLjEwMDMgMjAuMjYwMUMzMi44MTIzIDIwLjExNzEgMzIuNTc1NyAxOS43NTkzIDMyLjgwMiAxOS4zMTk3QzMyLjg3NCAxOS4xNzY2IDMzLjIyMzkgMTguODI5MSAzMy4zMDYyIDE4Ljc2NzdDMzQuMjQyMiAxOC4yMzYyIDM1LjMyMjMgMTguNDA5OSAzNi4zMjAxIDE4LjgwODZDMzcuMjQ1OCAxOS4xODY5IDM3Ljk0NTMgMTkuODgyIDM4Ljk1MzQgMjAuODYzM0MzOS45ODE5IDIyLjA0OTEgNDAuMTY3IDIyLjM3NjIgNDAuNzUzNCAyMy4yNjU1QzQxLjIxNjMgMjMuOTYwNyA0MS42Mzc5IDI0LjY3NjEgNDEuOTI2IDI1LjQ5NEM0Mi4xMDA4IDI2LjAwNTEgNDEuODc0NSAyNi40MjQyIDQxLjI2NzYgMjYuNjc5OFoiIGZpbGwtcnVsZT0ibm9uemVybyIgZmlsbD0iIzRENkJGRSIvPgo8L3N2Zz4K"/><text aria-hidden="true" x="505" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="550">DeepSeek</text><text x="505" y="140" transform="scale(.1)" fill="#fff" textLength="550">DeepSeek</text><text aria-hidden="true" x="1155" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="590">Homepage</text><text x="1155" y="140" transform="scale(.1)" fill="#fff" textLength="590">Homepage</text></g></svg>
Subproject commit 2f65ba7f1a9144b20c8e7358513548e317d26de1
File added
from .sampling import Sampling
from .rmax_tree_search import RMaxTS
import os
import numpy as np
from transformers import AutoTokenizer
from prover.utils import get_datetime, load_jsonl_objects, MODEL_FORMAT
class SamplingAlgorithmBase(object):
def __init__(self, scheduler, tokenizer_path, process_print, cfg, **kwargs):
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
self.scheduler = scheduler
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.process_print = process_print
self.cfg = cfg
self.max_tokens = cfg.max_tokens
self.few_shot_dataset = cfg.get('few_shot_dataset', None)
if self.few_shot_dataset is not None:
self.few_shot_dataset = load_jsonl_objects(self.few_shot_dataset)
self.few_shot_num = cfg.get('few_shot_num', 3)
self.few_shot_func = MODEL_FORMAT[cfg.mode]['few_shot']
self.log_interval = cfg.get('log_interval', 32)
@property
def algorithm_name(self):
return self.__class__.__name__
def _post_sample_info(self, **kwargs):
return dict(
algorithm=self.algorithm_name,
datetime=get_datetime(),
**kwargs,
)
def _encode_length(self, code):
return len(self.tokenizer.encode(code))
def _preprocess_data(self, input_data):
if self.few_shot_dataset is None or self.few_shot_num == 0:
return input_data
return {
**input_data,
'_extra_header': ''.join([
self.few_shot_func(self.few_shot_dataset[idx])
for idx in np.random.choice([
_idx for _idx, _data in enumerate(self.few_shot_dataset)
if _data['name'] != input_data['name']
], size=self.few_shot_num, replace=False)
] + [input_data.get('_extra_header', str())]),
}
def sample(self, **kwargs):
raise NotImplementedError
\ No newline at end of file
from .base import SamplingAlgorithmBase
class Sampling(SamplingAlgorithmBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sample_num = self.cfg.get('sample_num', 32)
def sample(self, data, **kwargs):
request_id_list = [
self.scheduler.generator_submit_request(
# add few-shot prompts
self._preprocess_data(data),
) for _ in range(self.sample_num)
]
for _idx, request_id in enumerate(request_id_list):
outputs = self.scheduler.generator_get_request_outputs(request_id)
yield outputs, self._post_sample_info(cost=_idx+1)
if _idx + 1 < self.sample_num and (_idx + 1) % self.log_interval == 0:
self.process_print('Progress: {} / {}'.format(
_idx + 1, self.sample_num
))
import os
import copy
import time
import warnings
import argparse
import torch
from prover.workers import DataLoader, Scheduler, ProcessScheduler, GeneratorProcess, SearchProcess
from prover.lean.verifier import Lean4ServerScheduler
from prover.utils import get_datetime, load_config, AttrDict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--log_dir", type=str, default=f'logs/{get_datetime()}')
parser.add_argument("--node_rank", type=int, default=0)
parser.add_argument("--world_size", type=int, default=1)
args = parser.parse_args()
cfg = load_config(args.config)
os.makedirs(args.log_dir, exist_ok=True)
ngpus = torch.cuda.device_count()
assert ngpus >= 1
# create data loader
data_loader = DataLoader(
data_path=cfg.data_path,
data_split=cfg.get('data_split', None),
data_repeat=cfg.get('data_repeat', 1),
node_rank=args.node_rank,
world_size=args.world_size,
log_dir=args.log_dir,
)
# build Lean verifier
verifier_scheduler = Lean4ServerScheduler(
max_concurrent_requests=cfg.lean_max_concurrent_requests,
memory_limit=cfg.lean_memory_limit,
timeout=cfg.lean_timeout,
name='verifier',
)
# load LLM models on gpus
generator_scheduler = ProcessScheduler(batch_size=cfg.batch_size, name='generator')
llm_processes = [
GeneratorProcess(
local_rank=local_rank,
node_rank=args.node_rank,
model_path=cfg.model_path,
task_queue=generator_scheduler.task_queue,
request_statuses=generator_scheduler.request_statuses,
lock=generator_scheduler.lock,
args=cfg.model_args,
)
for local_rank in range(ngpus)
]
# create a unified scheduler interface
scheduler = Scheduler(dict(
verifier=verifier_scheduler,
generator=generator_scheduler,
))
# launch search processes
search_processes = [
SearchProcess(
idx=i+args.node_rank*cfg.n_search_procs,
log_dir=args.log_dir,
tokenizer_path=cfg.model_path,
scheduler=scheduler,
data_loader=data_loader,
cfg=cfg,
)
for i in range(min(cfg.n_search_procs, data_loader.size()))
]
for p in search_processes:
p.start()
print(f'Complete launching {len(search_processes)} SearchProcesses')
for p in llm_processes:
p.start()
print(f'Complete launching {len(llm_processes)} LLMProcesses')
for p in search_processes:
p.join()
print(f'All {len(search_processes)} SearchProcesses stopped')
scheduler.close()
for p in llm_processes:
p.join()
print(f'All {len(llm_processes)} LLMProcesses stopped')
import re
import json
import numpy as np
from prover.utils import AttrDict, LEAN4_DEFAULT_HEADER
class Proof(object):
def __init__(self, full_code, _args, _result_backup=None, **kwargs):
self._kwargs_backup = kwargs
for key, val in kwargs.items():
self.__setattr__(key, val)
self._args = _args
self._update_full_code(full_code, _result_backup=_result_backup)
@property
def result(self):
if self._verifier_request_id is not None:
self._result = self._scheduler.verifier_get_request_outputs(self._verifier_request_id)
self._verifier_request_id = None
return self._result
def is_result_ready(self):
if self._verifier_request_id is None:
return True
status = self._scheduler.verifier_get_request_status(self._verifier_request_id)
if status is not None:
self._result = status
self._verifier_request_id = None
return self._result is not None
@property
def cleaned_code(self):
return self.full_code[len(self.header) + len(self.formal_statement): len(self.full_code) - len(self.tailer)]
def _update_full_code(self, full_code, _result_backup=None):
self.full_code = full_code
self._verifier_request_id, self._result = None, None
if _result_backup is not None:
self._result = _result_backup
elif self._args.require_verification: # need to call verification server
self._verifier_request_id = self._scheduler.verifier_submit_request(dict(
code=self.full_code,
ast=True, tactics=True,
))
self._parse_full_code_lines()
def _parse_full_code_lines(self):
self._full_code_lines = self.full_code.split('\n')
self._line_offset, _offset = [], -1
for _line in self._full_code_lines:
_offset += 1 # '\n'
self._line_offset.append(_offset)
_offset += len(_line)
def _get_idx(self, pos_info):
return self._line_offset[pos_info['line'] - 1] + pos_info['column']
def segmentation(self, result=None):
if result is None:
result = self.result
if 'errors' not in result:
# compiler timeout
return []
_prefix_len = len(self.header) + len(self.formal_statement)
truncate_pos = len(self.full_code) - len(self.tailer)
for info in result['sorries'] + result['errors']:
info_pos = self._get_idx(info['pos'])
if info_pos >= _prefix_len and not info.get('data', str()).lstrip().startswith('unsolved goals'):
truncate_pos = min(truncate_pos, info_pos)
partial_code = self.full_code[:truncate_pos]
if len(partial_code) <= _prefix_len:
# all proof lines are invalid
return []
code_lines = partial_code.split('\n')
pos_last, segments = _prefix_len, []
for line_idx in range(len(code_lines)):
if self._line_offset[line_idx] >= _prefix_len:
def compute_last_valid_char_pos(line):
idx, last_non_blank = 0, len(line) + 1
while idx < len(line):
if line[idx: idx+2] == '--':
return last_non_blank
elif line[idx: idx+2] == '/-':
if '-/' not in line[idx+2:]:
# cannot split in this line
return len(line) + 1
idx = line.find('-/', idx+2) + 1
elif line[idx] != ' ':
last_non_blank = idx
idx += 1
return last_non_blank
line_lastChar = self._line_offset[line_idx] + compute_last_valid_char_pos(code_lines[line_idx])
line_endPos = self._line_offset[line_idx] + len(code_lines[line_idx])
pos_min, goal = 1e9, None
for tactic_info in result['ast']['tactics']:
pos, endPos = tactic_info['pos'], tactic_info['endPos']
if line_lastChar <= endPos and endPos <= line_endPos and pos < pos_min:
pos_min = pos
goal = tactic_info['stateAfter']
if goal is not None:
for tactic_info in result['ast']['tactics']:
pos, endPos = tactic_info['pos'], tactic_info['endPos']
if pos_last < endPos and endPos <= line_endPos and pos < pos_min:
pos_min = pos
while pos_min > 0 and partial_code[pos_min - 1] != '\n':
pos_min -= 1
indent_len = 0
while partial_code[pos_min + indent_len] == ' ':
indent_len += 1
newline_with_indent = '\n' + ' ' * indent_len
segments.append(AttrDict(
tactic_code=partial_code[pos_last: line_endPos] + '\n',
state_comment=newline_with_indent.join([
' ' * indent_len + '/- tactic state:',
' ' + goal.replace('\n', newline_with_indent + ' '),
'-/\n'
]),
goal=goal,
indent=indent_len,
))
pos_last = line_endPos + 1
if result['complete'] and (len(segments) == 0 or segments[-1].goal != 'no goals' or segments[-1].indent != segments[0].indent):
indent_len = 2 if len(segments) == 0 else segments[0].indent
newline_with_indent = '\n' + ' ' * indent_len
segments.append(AttrDict(
tactic_code=partial_code[pos_last:].rstrip(' \n') + '\n',
state_comment=newline_with_indent.join([
' ' * indent_len + '/- tactic state:',
' no goals',
'-/\n'
]),
goal='no goals',
indent=indent_len,
))
segments = [seg for seg in segments if len(seg.tactic_code.strip(' \n')) > 0]
return segments
class ProofSummarizer(object):
def __init__(self, data, scheduler=None):
"""
Inputs:
data (`dict`): The problem information storing in a `dict` object.
formal_statement (`str`): The formal statement of the unproved problem;
header (`str`, *optional*, defaults to ''): The code header required by the complier;
tailer (`str`, *optional*, defaults to ''): The code tailer required by the complier.
scheduler (prover.workers.scheduler.Scheduler, *optional*, defaults to None):
An interface to submit requests to models and the verification server.
If set to None, the downstream tasks may require the verification result as inputs.
"""
self.formal_statement = data['formal_statement']
self.header = data.get('header', LEAN4_DEFAULT_HEADER)
self.tailer = data.get('tailer', str())
self.scheduler = scheduler
def analyze(self, code, require_verification=True):
"""
Inputs:
code (`str`): The code of formal proof.
require_verification (`bool`, *optional*, defaults to True):
Whether to submit a request to the verification server.
If set to False, the downstream tasks may require the verification result as inputs.
Return:
A `Proof` object that summarizes the code.
"""
return Proof(
full_code=''.join([self.header, self.formal_statement, code.rstrip(' \n'), self.tailer]),
raw_code=code,
formal_statement=self.formal_statement,
header=self.header,
tailer=self.tailer,
_scheduler=self.scheduler,
_args=AttrDict(
require_verification=require_verification,
)
)
import os
import time
import json
import ctypes
import resource
import tempfile
import traceback
import threading
import subprocess
import multiprocessing as mp
from pprint import pprint
import numpy as np
from prover.lean.ast_parser import lean4_parser
from prover.workers import ProcessScheduler
from prover.utils import AttrDict
HOME_DIR = os.path.expanduser('~')
DEFAULT_LAKE_PATH = f'{HOME_DIR}/.elan/bin/lake'
DEFAULT_LEAN_WORKSPACE = 'mathlib4/'
def verify_lean4_file(code, lake_path=DEFAULT_LAKE_PATH, lean_workspace=DEFAULT_LEAN_WORKSPACE, last_env=None, verbose=False, timeout=300, allTactics=False, ast=False, premises=False, tactics=False):
command = dict(cmd=code, allTactics=allTactics, ast=ast, tactics=tactics, premises=premises)
if last_env is not None:
command.update(env=last_env)
message_str = json.dumps(command, ensure_ascii=False)
if verbose:
print(message_str)
start_time = time.time()
system_messages = ''
try:
with tempfile.TemporaryFile(mode='w+', encoding='utf-8') as temp_file:
temp_file.write(message_str + "\r\n\r\n")
temp_file.seek(0)
outputs = subprocess.run([lake_path, "exe", 'repl'], stdin=temp_file, capture_output=True, text=True, cwd=lean_workspace, timeout=timeout)
result = json.loads(outputs.stdout)
ast_results = lean4_parser(code, result['ast']) if 'ast' in result and result['ast'] else {}
result = {
"sorries" : result.get('sorries', []),
"tactics" : result.get('tactics', []),
"errors" : [m for m in result.get('messages', []) if m['severity'] == 'error'],
"warnings" : [m for m in result.get('messages', []) if m['severity'] == 'warning'],
"infos" : [m for m in result.get('messages', []) if m['severity'] == 'info'],
"system_messages" : system_messages,
"system_errors" : None,
"ast" : ast_results,
"verified_code" : code,
}
result['pass'] = not result['errors']
result['complete'] = result['pass'] and not result['sorries'] and not any("declaration uses 'sorry'" in warning['data'] or 'failed' in warning['data'] for warning in result['warnings'])
except:
result = {
"pass": False,
"complete": False,
"system_errors": traceback.format_exc(),
"system_messages": system_messages
}
result['verify_time'] = time.time() - start_time
return result
class Lean4ServerProcess(mp.Process):
def __init__(self, idx, task_queue, request_statuses, lock, extra_args=AttrDict()):
super().__init__()
self.idx = idx
self.task_queue = task_queue
self.request_statuses = request_statuses
self.lock = lock
self.extra_args = extra_args
self.timeout = extra_args.get('timeout', 300)
self.memory_limit = extra_args.get('memory_limit', -1)
self.last_output_time = mp.Value(ctypes.c_double, time.time())
self.complete_count = mp.Value(ctypes.c_int, 0)
def run(self):
if self.memory_limit > 0:
resource.setrlimit(
resource.RLIMIT_AS,
(self.memory_limit * (1000 ** 3), self.memory_limit * (1000 ** 3))
)
while True:
inputs = self.task_queue.get()
if inputs is None: # Terminate when receiving None
break
for _, request_id, task in inputs:
if isinstance(task, str):
task = dict(code=task)
if 'timeout' not in task:
task['timeout'] = self.timeout
result = verify_lean4_file(**task)
if len(result['system_messages']) > 0:
retry_start_time = time.time()
while ('lean::exception: failed to create thread' in result['system_messages'] or
'std::bad_alloc: std::bad_alloc' in result['system_messages'] or
'Cannot allocate memory' in result['system_messages']) \
and time.time() - retry_start_time < self.timeout:
time.sleep(0.1)
result = verify_lean4_file(**task)
with self.lock:
self.request_statuses[request_id] = result
self.last_output_time.value = time.time()
self.complete_count.value += 1
class Lean4ServerScheduler(ProcessScheduler):
def __init__(self, max_concurrent_requests=64, timeout=300, memory_limit=-1, name='verifier'):
super().__init__(batch_size=1, name=name)
self.processes = [
Lean4ServerProcess(
idx=idx,
task_queue=self.task_queue,
request_statuses=self.request_statuses,
lock=self.lock,
extra_args=AttrDict(
timeout=timeout,
memory_limit=memory_limit,
)
)
for idx in range(max_concurrent_requests)
]
for p in self.processes:
p.start()
print(f'Complete launching {len(self.processes)} LeanServerProcesses')
self.timeout = timeout
self._running_monitor = mp.Value(ctypes.c_bool, True)
self._last_complete_count = mp.Value(ctypes.c_int, 0)
self._monitor_process = mp.Process(target=self._monitor)
self._monitor_process.start()
def _monitor(self):
while self._running_monitor.value:
time.sleep(1.0)
subprocess.run(['killall', 'repl', f'--older-than={int(self.timeout) + 10}s'], capture_output=True)
def close(self):
super().close()
for p in self.processes:
p.join()
self._running_monitor.value = False
self._monitor_process.join()
print(f'All {len(self.processes)} LeanServerProcesses stopped')
if __name__ == '__main__':
code = open('mathlib4/.lake/packages/REPL/test/aime_1983_p9.code.in').read()
lean4_scheduler = Lean4ServerScheduler(max_concurrent_requests=1, timeout=300, memory_limit=10, name='verifier')
request_id_list = lean4_scheduler.submit_all_request([dict(code=code, ast=True, tactics=True)])
outputs_list = lean4_scheduler.get_all_request_outputs(request_id_list)
lean4_scheduler.close()
pprint(outputs_list)
\ No newline at end of file
import os
import argparse
import pandas as pd
from termcolor import colored
from prover.utils import get_datetime, load_config, load_jsonl_objects
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--log_dir", type=str)
args = parser.parse_args()
cfg = load_config(args.config)
dataset = load_jsonl_objects(cfg.data_path)
log_dir_dict = {
os.path.basename(args.log_dir): args.log_dir,
}
for data in dataset:
data['success'] = dict()
for runname, log_dir in log_dir_dict.items():
for prob_idx, data in enumerate(dataset):
res_dir = os.path.join(log_dir, f'{prob_idx}_{dataset[prob_idx]["name"]}')
_success_flag = False
if os.path.exists(res_dir):
for filename in os.listdir(res_dir):
if filename[:7] == 'success':
_success_flag = True
data['success'][runname] = _success_flag
def make_inner_list(info):
return {key: [val] for key, val in info.items()}
def add_color(info):
return {key: colored(val, 'cyan', attrs=['bold']) for key, val in info.items()} if info['prob_type'] == '<all>' else info
def aggregate(split, prob_type):
info = dict(split=split, prob_type=prob_type)
for runname in log_dir_dict:
success_count, total_count = 0, 0
for prob_idx, data in enumerate(dataset):
if data['split'] == split and (data['name'].startswith(prob_type) or prob_type == '<all>'):
total_count += 1
success_count += int(data['success'][runname])
info[runname] = '{:3d} / {:3d} = {:.3f}'.format(success_count, total_count, success_count / total_count)
return pd.DataFrame(make_inner_list(add_color(info)))
summary = pd.concat([
aggregate(split, '<all>')
for split in set([data['split'] for data in dataset])
])
print('DateTime:', get_datetime(readable=True))
print(summary.to_markdown(index=False, tablefmt="github", colalign=["left"] * 2 + ["right"] * len(log_dir_dict)))
\ No newline at end of file
import os
import json
import pytz
from pathlib import Path
from datetime import datetime
from collections import UserDict
from importlib.machinery import SourceFileLoader
from easydict import EasyDict as AttrDict
LEAN4_DEFAULT_HEADER = "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n"
def non_cot_prompt(data):
return "Complete the following Lean 4 code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}".format(
header=data.get('header', LEAN4_DEFAULT_HEADER),
informal_prefix=data.get('informal_prefix', str()),
formal_statement=data['formal_statement'],
)
def non_cot_few_shot_prompt(data):
return "Complete the following Lean 4 code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}{formal_proof}\n```\n\n\n".format(
header=data.get('header', LEAN4_DEFAULT_HEADER),
informal_prefix=data.get('informal_prefix', str()),
formal_statement=data['formal_statement'],
formal_proof=data['formal_proof'],
)
def cot_prompt(data):
return "Complete the following Lean 4 code with explanatory comments preceding each line of code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}".format(
header=data.get('header', LEAN4_DEFAULT_HEADER),
informal_prefix=data.get('informal_prefix', str()),
formal_statement=data['formal_statement'],
)
def cot_few_shot_prompt(data):
return "Complete the following Lean 4 code with explanatory comments preceding each line of code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}{formal_proof}\n```\n\n\n".format(
header=data.get('header', LEAN4_DEFAULT_HEADER),
informal_prefix=data.get('informal_prefix', str()),
formal_statement=data['formal_statement'],
formal_proof=data['formal_proof'],
)
def post_process_output(output):
_find_idx = output.find("```")
return output[:_find_idx] if _find_idx >= 0 else output
MODEL_FORMAT = dict(
non_cot=dict(prompt=non_cot_prompt, output=post_process_output, few_shot=non_cot_few_shot_prompt),
cot=dict(prompt=cot_prompt, output=post_process_output, few_shot=cot_few_shot_prompt),
)
def get_datetime(readable=False):
if readable:
return datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y/%m/%d %H:%M:%S")
return datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y%m%d_%H%M%S")
def load_config(fname):
name = Path(fname).stem
mod = SourceFileLoader(name, fname).load_module()
config = {}
for n in dir(mod):
if not n.startswith("__"):
config[n] = getattr(mod, n)
config = AttrDict(config)
return config
def load_jsonl_objects(input_path):
objects = []
with open(input_path, 'r', encoding='utf-8') as fr:
for line in fr:
objects.append(json.loads(line))
return objects
class ConcurrentJob(object):
def __init__(self, stage_list):
assert len(stage_list) > 1
self.stage_list = stage_list
self.reset()
def is_idle(self):
return self._stage_idx is None
def reset(self):
self._stage_idx = None
self._stage_cache = None
def start(self, **kwargs):
self._stage_idx = 1
self._stage_cache = self.stage_list[0](**kwargs)
def get_status(self):
assert not self.is_idle()
while True:
status = self.stage_list[self._stage_idx](**self._stage_cache)
if status is None:
return None
self._stage_idx += 1
if self._stage_idx == len(self.stage_list):
self.reset()
return status
self._stage_cache = status
\ No newline at end of file
from .data_loader import DataLoader
from .scheduler import Scheduler, ProcessScheduler
from .search import SearchProcess
from .generator import GeneratorProcess
import os
import copy
import torch
import torch.multiprocessing as mp
from prover.utils import load_jsonl_objects
class DataLoader(object):
def __init__(self, data_path, data_split, data_repeat, node_rank, world_size, log_dir):
self.manager = mp.Manager()
self.queue = self.manager.Queue()
self.lock = mp.Lock()
self.finished_flag_filename = 'finished_running.txt'
done_set = set()
for dirname in os.listdir(log_dir):
run_dir = os.path.join(log_dir, dirname)
if os.path.isdir(run_dir):
for subdirname in os.listdir(run_dir):
if subdirname.startswith('run') and os.path.exists(os.path.join(run_dir, subdirname, self.finished_flag_filename)):
done_set.add(os.path.join(dirname, subdirname))
todo_count = 0
if isinstance(data_split, str):
data_split = [data_split]
dataset = load_jsonl_objects(data_path)
for _repeat in range(data_repeat):
for prob_idx, prob in enumerate(dataset):
prob_runname = os.path.join(prob['name'], f'run{_repeat}')
if f'{prob_idx}_{prob_runname}' in done_set:
continue
if data_split is not None and prob['split'] not in data_split:
continue
if todo_count % world_size == node_rank:
self.queue.put((prob_idx, prob_runname, copy.deepcopy(prob)))
todo_count += 1
print('Number of TODO Problems: {}'.format(self.queue.qsize()))
def size(self):
return self.queue.qsize()
def get(self):
with self.lock:
if self.queue.qsize() > 0:
return self.queue.get()
return None, None, None
import os
import time
import torch
import torch.multiprocessing as mp
from vllm import LLM, SamplingParams
from prover.utils import AttrDict, MODEL_FORMAT
class GeneratorProcess(mp.Process):
def __init__(self, local_rank, node_rank, model_path, task_queue, request_statuses, lock, args):
super().__init__()
self.local_rank = local_rank
self.node_rank = node_rank
self.model_path = model_path
self.task_queue = task_queue
self.request_statuses = request_statuses
self.lock = lock
self.sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_tokens,
top_p=args.top_p,
n=1,
)
self.prompt_func = MODEL_FORMAT[args.mode]['prompt']
self.output_func = MODEL_FORMAT[args.mode]['output']
def run(self):
seed = int(time.time()) % 1000 + (self.node_rank * 8 + self.local_rank) * 1000
os.environ['LOCAL_RANK'] = str(self.local_rank)
llm = LLM(model=self.model_path, max_num_batched_tokens=8192, seed=seed, trust_remote_code=True)
while True:
inputs = self.task_queue.get()
if inputs is None: # Terminate when receiving None
break
model_inputs = [
''.join([
item.get('_extra_header', str()),
self.prompt_func(item),
item.get('_extra_prompt', str()),
]) for _, _, item in inputs
]
model_outputs = llm.generate(
model_inputs,
self.sampling_params,
use_tqdm=False,
)
outputs = [self.output_func(_output.outputs[0].text) for _output in model_outputs]
with self.lock:
for (_, request_id, _), output in zip(inputs, outputs):
self.request_statuses[request_id] = output
import os
import time
import ctypes
import subprocess
import threading
import multiprocessing as mp
import numpy as np
from prover.utils import AttrDict
class TaskQueue(object):
def __init__(self, batch_size=512, name='test'):
self.name = name
self.batch_size = batch_size
self.manager = mp.Manager()
self.waiting_list = self.manager.list()
self.all_tasks_done = mp.Event()
self.lock = mp.Lock()
self._monitor_log = self.manager.list()
self._monitor_thread = threading.Thread(target=self._monitor)
self._monitor_thread.start()
def _monitor(self):
last_log_time = time.time()
while not self.all_tasks_done.is_set():
if time.time() - last_log_time >= 60.0:
with self.lock:
if len(self._monitor_log) > 0:
print('TaskQueue-{}: {} requests popped with avg batch_size {:.1f} in last period {} waiting in queue'.format(
self.name, np.sum(self._monitor_log), np.mean(self._monitor_log), len(self.waiting_list),
))
self._monitor_log[:] = []
last_log_time = time.time()
time.sleep(1.0)
def __len__(self):
return len(self.waiting_list)
def put(self, item):
with self.lock:
self.waiting_list.append(item)
def get(self, no_wait=False):
while not self.all_tasks_done.is_set():
with self.lock:
if len(self.waiting_list) > 0:
tasks = self.waiting_list[:self.batch_size]
self.waiting_list[:self.batch_size] = []
self._monitor_log.append(len(tasks))
return tasks
if no_wait:
break
time.sleep(0.1)
return None
def close(self):
self.all_tasks_done.set()
self._monitor_thread.join()
class ProcessScheduler(object):
def __init__(self, batch_size=512, name='test'):
self.name = name
self.manager = mp.Manager()
self.batch_size = batch_size
self.task_queue = TaskQueue(batch_size=batch_size, name=name)
self.request_statuses = self.manager.dict()
self.request_counter = mp.Value(ctypes.c_int32, 0)
self.lock = mp.Lock()
def submit_request(self, data):
with self.lock:
self.request_counter.value += 1
request_id = self.request_counter.value
self.request_statuses[request_id] = None
self.task_queue.put((time.time(), request_id, data))
return request_id
def submit_all_request(self, data_list):
request_id_list = [self.submit_request(data) for data in data_list]
return request_id_list
def get_request_status(self, request_id):
with self.lock:
response = self.request_statuses.get(request_id, None)
if response is not None:
self.request_statuses.pop(request_id)
return response
def get_request_outputs(self, request_id):
while True:
outputs = self.get_request_status(request_id)
if outputs is not None:
return outputs
time.sleep(1.0)
def get_all_request_outputs(self, request_id_list):
outputs_list = []
for request_id in request_id_list:
outputs_list.append(self.get_request_outputs(request_id))
return outputs_list
def close(self):
self.task_queue.close()
class Scheduler(object):
def __init__(self, scheduler_dict):
self._scheduler_dict = scheduler_dict
for name, scheduler in scheduler_dict.items():
self.__setattr__(name, scheduler)
for key in dir(scheduler):
if not key.startswith('_'):
self.__setattr__(f'{name}_{key}', scheduler.__getattribute__(key))
def close(self):
for _, scheduler in self._scheduler_dict.items():
scheduler.close()
import os
import time
import copy
import json
import pickle
from pathlib import Path
import torch
import torch.multiprocessing as mp
import numpy as np
from prover.utils import AttrDict, get_datetime
class SearchProcess(mp.Process):
def __init__(self, idx, log_dir, tokenizer_path, scheduler, data_loader, cfg):
self.idx = idx
self.log_dir = Path(log_dir)
self.scheduler = scheduler
self.data_loader = data_loader
super().__init__()
self._current_prob_idx = None
sampler_cls = cfg.sampler['algorithm']
self.sampler = sampler_cls(
scheduler=self.scheduler,
tokenizer_path=tokenizer_path,
process_print=self.process_print,
cfg=AttrDict({
**cfg.sampler,
'mode': cfg.model_args.mode,
'max_tokens': cfg.model_args.max_tokens,
})
)
def _post_process(self, data: dict, proof_code: str):
header = data.get('header', str())
tailer = data.get('tailer', str())
formal_statement = data['formal_statement']
return dict(
statement_proposal=f'{header}{formal_statement}{proof_code}{tailer}',
proof_code=proof_code,
)
def process_print(self, logs, **kwargs):
print('Process ID: {:3d} Problem ID: {} {}'.format(self.idx, self._current_prob, logs), **kwargs)
def run(self):
while True:
prob_idx, prob_runname, data = self.data_loader.get()
if prob_idx is None: break
sample_start_time = time.time()
# build a yield-iterator object to generate samples
self._current_prob = f'{prob_idx}_{prob_runname}'
prob_log_dir = self.log_dir / self._current_prob
os.makedirs(prob_log_dir, exist_ok=True)
sample_generator = self.sampler.sample(
data=data,
prob_log_dir=prob_log_dir,
)
# submit requests to the verification server when receiving from the generator
candidate_list, info_list, request_id_list = [], [], []
for sample, info in sample_generator:
candidate = self._post_process(data, sample)
candidate_list.append(candidate)
info_list.append(copy.deepcopy(info))
request_id = self.scheduler.verifier_submit_request(candidate['statement_proposal'])
request_id_list.append(request_id)
sample_timecost = time.time() - sample_start_time
verification_start_wait_time = time.time()
result_list = self.scheduler.verifier_get_all_request_outputs(request_id_list)
verification_timecost = time.time() - verification_start_wait_time
success_count = sum([int(result['complete']) for result in result_list])
self.process_print('Success: {} / {} Generation: {:.2f} secs Verfication: {:.2f} secs'.format(
success_count, len(candidate_list), sample_timecost, verification_timecost,
))
summary_dict = dict(success=[], failure=[])
for _idx, (candidate, result, info) in enumerate(zip(candidate_list, result_list, info_list)):
success_flag = 'success' if result['complete'] else 'failure'
summary_dict[success_flag].append(dict(
problem_name=data['name'],
sample_info=info,
formal_statement=data['formal_statement'],
proof_code=candidate['proof_code'],
result=result,
))
prob_name, run_id = prob_runname.split('/')
prob_log_basedir = self.log_dir / f'{prob_idx}_{data["name"]}'
log_tag = f'{self.sampler.algorithm_name}-{run_id}'
# separately save success and failure results
for success_flag, summary_list in summary_dict.items():
if len(summary_list) > 0:
with open(prob_log_basedir / f'{success_flag}-{log_tag}-{get_datetime()}.pkl', 'wb') as pkl_f:
pickle.dump(summary_list, pkl_f)
# create a 'finished' placeholder
with open(prob_log_dir / self.data_loader.finished_flag_filename, 'w') as f:
print('finished', file=f)
import re
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from prover.lean.verifier import Lean4ServerScheduler
model_name = "deepseek-ai/DeepSeek-Prover-V1.5-RL"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LLM(model=model_name, max_num_batched_tokens=8192, seed=1, trust_remote_code=True)
lean4_scheduler = Lean4ServerScheduler(max_concurrent_requests=1, timeout=300, memory_limit=10, name='verifier')
prompt = r'''Complete the following Lean 4 code:
```lean4
'''
code_prefix = r'''import Mathlib
import Aesop
set_option maxHeartbeats 0
open BigOperators Real Nat Topology Rat
/-- The second and fourth terms of a geometric sequence are $2$ and $6$. Which of the following is a possible first term?
Show that it is $\frac{2\sqrt{3}}{3}$.-/
theorem amc12b_2003_p6 (a r : ℝ) (u : ℕ → ℝ) (h₀ : ∀ k, u k = a * r ^ k) (h₁ : u 1 = 2)
(h₂ : u 3 = 6) : u 0 = 2 / Real.sqrt 3 ∨ u 0 = -(2 / Real.sqrt 3) := by
'''
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=2048,
top_p=0.95,
n=1,
)
model_inputs = [prompt + code_prefix]
model_outputs = model.generate(
model_inputs,
sampling_params,
use_tqdm=False,
)
result = prompt + code_prefix + model_outputs[0].outputs[0].text
print(result)
# Expected output:
''' simp_all only [Nat.one_eq_succ_zero, Nat.zero_eq, zero_add, Nat.add_succ, Nat.add_zero,
Nat.succ_add]
have h₁' : a * r = 2 := by simpa [h₀] using h₁
have h₂' : a * r ^ 3 = 6 := by simpa [h₀] using h₂
have h₃ : r ^ 2 = 3 := by
nlinarith
have h₄ : a = 2 / Real.sqrt 3 ∨ a = -(2 / Real.sqrt 3) := by
apply eq_or_eq_neg_of_sq_eq_sq <;>
field_simp <;>
nlinarith
simpa [h₀] using h₄
```
'''
request_id_list = lean4_scheduler.submit_all_request([re.search(r'```lean4\n(.*?)\n```', result, re.DOTALL).group(1)])
outputs_list = lean4_scheduler.get_all_request_outputs(request_id_list)
print(outputs_list[0])
# Expected output (verify_time may vary):
'''{'sorries': [], 'tactics': [], 'errors': [], 'warnings': [{'severity': 'warning', 'pos': {'line': 14, 'column': 7}, 'endPos': {'line': 14, 'column': 10}, 'data': "unused variable `h₁'`\nnote: this linter can be disabled with `set_option linter.unusedVariables false`"}, {'severity': 'warning', 'pos': {'line': 15, 'column': 7}, 'endPos': {'line': 15, 'column': 10}, 'data': "unused variable `h₂'`\nnote: this linter can be disabled with `set_option linter.unusedVariables false`"}, {'severity': 'warning', 'pos': {'line': 19, 'column': 35}, 'endPos': {'line': 19, 'column': 38}, 'data': 'Used `tac1 <;> tac2` where `(tac1; tac2)` would suffice\nnote: this linter can be disabled with `set_option linter.unnecessarySeqFocus false`'}, {'severity': 'warning', 'pos': {'line': 20, 'column': 15}, 'endPos': {'line': 20, 'column': 18}, 'data': 'Used `tac1 <;> tac2` where `(tac1; tac2)` would suffice\nnote: this linter can be disabled with `set_option linter.unnecessarySeqFocus false`'}], 'infos': [], 'system_messages': '', 'system_errors': None, 'ast': {}, 'verified_code': "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- The second and fourth terms of a geometric sequence are $2$ and $6$. Which of the following is a possible first term?\nShow that it is $\x0crac{2\\sqrt{3}}{3}$.-/\ntheorem amc12b_2003_p6 (a r : ℝ) (u : ℕ → ℝ) (h₀ : ∀ k, u k = a * r ^ k) (h₁ : u 1 = 2)\n (h₂ : u 3 = 6) : u 0 = 2 / Real.sqrt 3 ∨ u 0 = -(2 / Real.sqrt 3) := by\n simp_all only [Nat.one_eq_succ_zero, Nat.zero_eq, zero_add, Nat.add_succ, Nat.add_zero,\n Nat.succ_add]\n have h₁' : a * r = 2 := by simpa [h₀] using h₁\n have h₂' : a * r ^ 3 = 6 := by simpa [h₀] using h₂\n have h₃ : r ^ 2 = 3 := by\n nlinarith\n have h₄ : a = 2 / Real.sqrt 3 ∨ a = -(2 / Real.sqrt 3) := by\n apply eq_or_eq_neg_of_sq_eq_sq <;>\n field_simp <;>\n nlinarith\n simpa [h₀] using h₄", 'pass': True, 'complete': True, 'verify_time': 23.28123140335083}'''
lean4_scheduler.close()
\ No newline at end of file
pytz==2022.1
easydict==1.13
torch==2.2.1
transformers==4.40.1
vllm==0.4.1
numpy==1.26.4
pandas==1.4.3
tabulate==0.9.0
termcolor==2.4.0
accelerate==0.33.0
flash_attn==2.6.3
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment