1. 07 Mar, 2025 3 commits
  2. 06 Mar, 2025 6 commits
  3. 05 Mar, 2025 5 commits
  4. 04 Mar, 2025 5 commits
  5. 03 Mar, 2025 5 commits
    • [feat] Initial support for VLMs, add Qwen2.5VL GRPO example (#386) · b46f55ec
      ## What does this PR do?
      
      This PR migrates the feature of RL on VLMs in our implementation in
      [EasyR1](https://github.com/hiyouga/EasyR1) fork back to veRL. We have
      validated this feature using Qwen2.5-VL 7B model on 8*H100 GPUs. The
      configuration and data processing script are provided along this PR for
      easy reproducing.
      
      ## How to reproduce?
      
      1. Download and preprocess the dataset
      
      ```bash
      python3 examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k
      ```
      
      2. Start GRPO training
      
      ```bash
      bash examples/grpo_trainer/run_qwen2_5_vl-7b.sh
      ```
      
      ## Dependencies
      
      - vllm>=0.7.3
      - transformers>=4.49.0
      - [qwen-vl-utils](https://pypi.org/project/qwen-vl-utils/)
      - [mathruler](https://pypi.org/project/mathruler/)
      
      ## Major Changes
      
      ### New dataflow for multimodal RL
      
      In this PR, we introduce two new concepts in the dataflow,
      `multi_modal_data` and `multi_modal_inputs`. The former means the
      multi-modal features required by the **rollout** worker (such as vLLM),
      while the latter means the multi-modal features required by the
      **actor/critic** worker (such as an HF model). They are different
      because the rollout and actor workers have their own data format
      requirements.
      
      Taking Qwen2-VL + huggingface + vLLM as an example, the data structure
      should be:
      
      - **multi_modal_data**: {"image": [PIL.Image, PIL.Image, ...]}
      - **multi_modal_inputs**: {"pixel_values": torch.Tensor,
      "image_grid_thw": torch.Tensor}
      
      Both of them are converted to numpy objects and placed in the non-tensor
      batch in DataProto.
      
      This design can be extended to other modalities/VLMs easily due to the
      agnostic of models.
      
      ### Other changes
      
      - Data
      - Support pre-processing the
      [Geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k)
      dataset.
      - Support `config.data.image_key`, which should be **a list of Pillow
      images**.
      
      - Actor/Ref/Critic
        - Support `multi_modal_inputs`.
        - Process position ids to adapt to the m-rope .
      
      - Rollout
      - Update dtensor weight loader to adapt to the Qwen2-VL architecture in
      vLLM 0.7+.
        - Support `multi_modal_data`.
      - Use `raw_prompt_ids` as the vLLM inputs to **avoid unpadding** the
      input ids.
      
      - Reward Manager
      - Add **mathruler** for more accurate math scores on the Geometry 3k
      dataset
      
      - Models
        - Support calculating the position ids for the m-rope in Qwen2-VL.
      - Support removing padding in flash attention2 for m-rope (transformers
      itself **does not support it**).
      
      - Sharding Manager
        - Support all-gathering the non-tensor batch.
      
      - FSDP Workers / Checkpoint Merger
        - Support `AutoModelForVision2Seq` at model initialization.
      
      Note: The Ulysses parallelism is not completed yet. We will support it
      in the next update.
      
      ## Performance
      
      We provide the estimated MFU of the language model part for H100 GPUs.
      These values are lower than the actual ones because **we did not compute
      the FLOPs of the vision tower part**.
      
      - `remove_padding=False`: MFU ~7%
      - `remove_padding=True`: MFU ~20%
      
      The training and test reward score curves are presented as follows.
      
      
      ![image](https://github.com/user-attachments/assets/ecb9fc27-8591-4c5b-ae4b-4ba77c6e30f9)
      
      ## Who can review?
      
      @vermouth1992 @PeterSH6
      hoshi-hiyouga committed
    • [fix] update yaml file for generation (#445) · a0a4d5fa
      forget to update params in generation.yaml #259
      BearBiscuit committed
    • megatron:Update megatron-lm to `core_r0.11.0` (#392) · 0cfd548c
      # Support Megatron mcore 0.11
      
      ## Description
      This PR introduces official support for Megatron mcore 0.11 with the
      following updates:
      - Upgraded Megatron to version `core_r0.11.0`
      - Applied compatibility patch `patches/mcore_r0.11.patch`
      - Removed legacy version support for cleaner implementation
      
      Special thanks to @chendong-1998 for:
      - Original Megatron upgrade from 0.4 to 0.6 (#93f6a7e)
      
      ## Compatibility Notes
      Current implementation requires careful handling due to dependency
      conflicts:
      - `megatron-core==0.11.0` requires torch>=2.6
      - `vllm==0.6.3` requires torch==2.4
      
      Installation constraints:
      1. Must use vllm's torch dependency (2.4) as baseline
      2. Do NOT run `pip install -e .` in mcore directory (will upgrade torch
      to 2.6)
      3. Apply compatibility patch manually after installation
      
      ## Testing
      ### test with `verl/examples/ppo_trainer/run_deepseek_megatron.sh`
      
      ![image](https://github.com/user-attachments/assets/e053c9b8-fdd7-47fc-aaeb-42cf85070056)
      
      ---------
      
      Signed-off-by: chendong-1998 <chendong136@huawei.com>
      Co-authored-by: chendong-1998 <chendong136@huawei.com>
      Co-authored-by: gaoziyuan <gaoziyuan.955@bytedance.com>
      Co-authored-by: Sion Gao <gaoziyuan19@mails.ucas.ac.cn>
      Yan Bai committed
  6. 02 Mar, 2025 6 commits
  7. 01 Mar, 2025 2 commits
    • fix: 2 typos (#435) · 99fb2dde
      Lumeng Wu committed
    • Update vLLM>=0.7 doc (#432) · cef4c2de
      Because of the ongoing updates in vLLM, I noticed that veRL currently
      cannot integrate with the nightly build of vLLM directly. The new DP
      feature in the nightly version can no longer be bypassed by simply
      adjusting the `data_parallel_size` parameter, and resolving this
      requires further investigation.
      
      As a temporary workaround, I recommend a customized installation of vLLM
      if the V1 engine is required. I have updated the relevant documentation
      accordingly to reflect this guidance.
      ZSL98 committed
  8. 28 Feb, 2025 3 commits
  9. 27 Feb, 2025 5 commits