1. 16 Feb, 2025 3 commits
  2. 15 Feb, 2025 7 commits
  3. 14 Feb, 2025 5 commits
    • [testing][rollout] feat: support integration of vllm>=0.7.0 (spmd-version) (#209) · f8b4d085
      This PR aims to integrate vllm>=0.7.0 and preserve:
      **Backward compatibility**: 0.3.1, 0.4.2, 0.5.4, 0.6.3 are still
      supported
      **Forward compatibility**: Future versions of vllm (>= 0.7.0) will be
      supported without requiring manual maintenance for each new release.
      
      The readme of this Beta version is located at docs/README_vllm0.7.md,
      where users can find the installation method and related features. This
      readme is copied as below.
      
      ---
      # Readme for verl(vllm>=0.7) version
      ## Installation
      
      Note: This version of veRL supports **FSDP** for training and **vLLM**
      for rollout. (Megatron-LM is not supported yet.)
      
      ```
      # Create the conda environment
      conda create -n verl python==3.10
      conda activate verl
      
      # Install verl
      git clone https://github.com/volcengine/verl.git
      cd verl
      pip3 install -e .
      # Install vLLM>=0.7
      pip3 install vllm==0.7.0
      # Install flash-attn
      pip3 install flash-attn --no-build-isolation
      
      ```
      
      For existing stable vllm versions (<=0.7.2), you also need to make some
      tiny patches manually on vllm (/path/to/site-packages/vllm after
      installation) after the above steps:
      
      - vllm/distributed/parallel_state.py: Remove the assertion below:
      
      ```
      if (world_size
              != tensor_model_parallel_size * pipeline_model_parallel_size):
          raise RuntimeError(
              f"world_size ({world_size}) is not equal to "
              f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
              f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
      
      ```
      
      - vllm/executor/uniproc_executor.py: change `local_rank = rank` to
      `local_rank = int(os.environ["LOCAL_RANK"])`
      - vllm/model_executor/model_loader/weight_utils.py: remove the
      `torch.cuda.empty_cache()` in `pt_weights_iterator`
      
      These modifications have already been merged into the main branch of
      vLLM. To avoid modifying these files manually, you can directly build
      vLLM from source.
      
      ## Features
      
      ### Use cuda graph
      
      After installation, examples using FSDP as training backends can be
      used. By default, the `enforce_eager` is set to True, which disables the
      cuda graph. To enjoy cuda graphs and the sleep mode of vLLM>=0.7, add
      the following lines to the bash script:
      
      ```
      actor_rollout_ref.rollout.enforce_eager=False \
      actor_rollout_ref.rollout.free_cache_engine=False \
      
      ```
      
      For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh,
      the rollout generation time is 115 seconds with vLLM0.6.3, while it is
      85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation
      duration is further reduced to 62 seconds.
      
      **Note:** Currently, if the `n` is greater than 1 in `SamplingParams` in
      vLLM>=0.7, there is a potential performance issue on the stability of
      rollout generation time (Some iterations would see generation time
      bursts). We are working with the vLLM team to check this issue.
      
      ### Other features in vLLM
      
      1. **num_scheduler_step>1:** not supported yet (weight loading has not
      been aligned with `MultiStepModelRunner`)
      2. **Prefix caching:** not supported yet (vLLM sleep mode does not
      support prefix caching)
      3. **Chunked prefill:** supported
      
      ---------
      
      Co-authored-by: zhangshulai <zhangshulai@bytedance.com>
      ZSL98 committed
    • fix the file lock issue (#255) · 63f75138
      Previous FileLock in 
      
      https://github.com/volcengine/verl/blob/c46f403479db5d7afca6388800503a3bfe393bf5/verl/utils/checkpoint/checkpoint_manager.py#L75
      may cause some errors when the given path is too long. To fix this
      issue, use the hash value to replace the original path to avoid the
      conflict.
      
      For instance, FileExistsEror: lErmno 17] File exists or BlockingIOError:
      [Errno 11] Resource temporarily unavailable.
      
      After modifying this part, the issue could be avoided.
      
      ```
      @staticmethod
          def local_mkdir(path):
              if not os.path.isabs(path):
                  working_dir = os.getcwd()
                  path = os.path.join(working_dir, path)
      
              # Using hash value of path as lock file name to avoid long file name
              lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock"
              lock_path = os.path.join(tempfile.gettempdir(), lock_filename)
              
              try:
                  with FileLock(lock_path, timeout=60):  # Add timeout
                      # make a new dir
                      os.makedirs(path, exist_ok=True)
              except Exception as e:
                  print(f"Warning: Failed to acquire lock for {path}: {e}")
                  # Even if the lock is not acquired, try to create the directory
                  os.makedirs(path, exist_ok=True)
      
              return path
      ```
      Wei Liu committed
    • [misc] Compatibility Issue with Python 3.9 in FSDP Worker for LLaMA Model (#268) · 7346ecf8
      **Fix: Compatibility Issue with Python 3.9 in FSDP Worker for LLaMA
      Model**
      
      When running the LLaMA model in the FSDP worker, an ImportError occurs
      due to the use of the Unpack type from the typing module. This type is
      only available in Python 3.11 and later, but the current environment
      uses Python 3.9, which does not support it.
      
      **Error Details:**
      ```
      File "/project/Logic-RL-main/verl/models/transformers/llama.py", line 17, in <module>
      from typing import Optional, List, Union, Tuple, Unpack, Callable
      ImportError: cannot import name 'Unpack' from 'typing' (/opt/miniconda3/envs/verl/lib/python3.9/typing.py)
      ```
      **Solution:**
      To resolve this issue, I added conditional imports to handle different
      Python versions. For Python versions lower than 3.11, the code now uses
      a fallback or alternative approach to avoid relying on Unpack.
      
      Co-authored-by: Yu Feng <fengyufengyu@didiglobal.com>
      Yu Feng committed
  4. 13 Feb, 2025 1 commit
  5. 12 Feb, 2025 4 commits
  6. 11 Feb, 2025 2 commits
  7. 10 Feb, 2025 5 commits
  8. 09 Feb, 2025 7 commits
  9. 08 Feb, 2025 3 commits
    • [ckpt] feat: integrate checkpoint resume in RL ray trainer (#222) · 5a400bf2
      **Features:**
      - Save actor and critic checkpoint:
        - Model
        - Optimizer
        - lr_scheduler
        - rng_state
        - dataloader
      - A complete checkpoint represents that dataloader, actor and critic (if
      any) state are properly saved
      - By default, we will not save the dataset but only store the dataloader
      (with sampler) state
      
      **Usage:**
      - Support resume mode: auto, disable and resume_from_path
      - auto: veRL will automatically check the latest checkpoint from
      `trainer.default_local_dir`
         - disable: veRL will always train from scratch
      - resume_from_path: When setting `resume_from_path`=True, then user only
      need to set the resume_mode to the checkpoint path that you want to
      load.
      
      **TODO:**
      - Support SFT resume in the next PR
      - Support uploader
      
      **Relevant issue:**
      - https://github.com/volcengine/verl/issues/76
      - https://github.com/volcengine/verl/issues/143
      Guangming Sheng committed
    • Fix typo tips in bash sft. (#226) · 62a065b9
      Fix typo tips in bash sft.
      
      Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
      湛露先生 committed
    • Memory efficiency improvement to logprobs_from_logits_v2 (#220) · 4b516249
      Existing `logprobs_from_logits_v2` doesnt achieve the memory savings it
      claims. This is because `logsumexp` still allocates a `bs*seqlen*vocab`
      tensor internally to hold the element-wise application of `exp`.
      However, by applying a loop over `logsumexp`, we can iteratively compute
      logsumexp outputs.
      
      Benchmarks show this uses significantly less memory to compute logprobs.
      
      Fix provided, as well as a separate memory-efficient approach for
      bfloat16 case.
      Tyler Romero committed
  10. 07 Feb, 2025 3 commits
    • [TRACKING] feat: Integrate SwanLab for experiment tracking with online/offline… · 958a3267
      [TRACKING] feat: Integrate SwanLab for experiment tracking with online/offline mode and local dashboard support (#218)
      
      ---
      
      ### Pull Request Description  
      
      This PR introduces **SwanLab**, a lightweight open-source experiment
      tracking tool, as a new logging option for the training framework. The
      integration provides both online and offline tracking capabilities,
      along with a local dashboard for visualizing results. Below is a
      detailed overview of the changes and usage instructions:
      
      ---
      
      #### **Key Features of SwanLab Integration**
      
      1. **Online and Offline Tracking**:
      - **Online Mode**: Track experiments remotely and store data on
      SwanLab's cloud platform.
      - **Offline Mode**: Use a local dashboard to visualize training logs
      without an internet connection.
      
      2. **Hardware Monitoring**:
      - Automatically tracks GPU usage, power consumption, temperature, and
      other hardware metrics.
         - Supports NVIDIA GPUs and Huawei Ascend NPUs.
      
      3. **Remote Access**:
      - View training progress remotely via the SwanLab web interface or
      mobile app.
      
      4. **Local Dashboard**:
      - Includes an open-source local dashboard for offline visualization of
      training logs.
      
      ---
      
      #### **Usage Instructions**
      
      ##### **Step 1: Set Up Online Tracking (Optional)**
      
      To use SwanLab's online tracking, log in to the [SwanLab
      website](https://swanlab.cn) and obtain your API key from the [Settings
      page](https://swanlab.cn/space/~/settings). Then, authenticate using the
      following command:
      
      ```bash
      swanlab login
      ```
      
      If you prefer offline mode, skip this step.
      
      ---
      
      ##### **Step 2: Configure SwanLab as the Logger**
      
      To enable SwanLab as the experiment tracker, add
      `trainer.logger=['swanlab']` to your training command. For example,
      using the [Post-train a LLM using PPO with GSM8K
      dataset](https://verl.readthedocs.io/en/latest/start/quickstart.html)
      workflow:
      
      ```bash
      PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
       data.train_files=$HOME/data/gsm8k/train.parquet \
       data.val_files=$HOME/data/gsm8k/test.parquet \
       data.train_batch_size=256 \
       data.val_batch_size=1312 \
       data.max_prompt_length=512 \
       data.max_response_length=256 \
       actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
       actor_rollout_ref.actor.optim.lr=1e-6 \
       actor_rollout_ref.actor.ppo_mini_batch_size=64 \
       actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
       actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
       actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
       actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
       actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
       critic.optim.lr=1e-5 \
       critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
       critic.ppo_micro_batch_size_per_gpu=4 \
       algorithm.kl_ctrl.kl_coef=0.001 \
       trainer.logger=['console','swanlab'] \
       +trainer.val_before_train=False \
       trainer.default_hdfs_dir=null \
       trainer.n_gpus_per_node=1 \
       trainer.nnodes=1 \
       trainer.save_freq=10 \
       trainer.test_freq=10 \
       trainer.total_epochs=15 2>&1 | tee verl_demo.log
      ```
      
      If you are not logged in, you will be prompted to choose a tracking
      mode:
      
      1. **Cloud Mode**: Upload logs to SwanLab's cloud platform.
      2. **Cloud-Only Mode**: Upload logs to the cloud but do not save them
      locally.
      3. **Local Mode**: Save logs locally for offline tracking.
      
      <img width="1325" alt="select"
      src="https://github.com/user-attachments/assets/5c55fc45-79a9-4673-ae4e-ea9d0623dd29"
      />
      
      Alternatively, you can configure SwanLab using environment variables:
      
      ```bash
      export SWANLAB_API_KEY=<your_api_key>          # Set API key for online tracking
      export SWANLAB_LOG_DIR=<local_log_path>        # Set local log directory
      export SWANLAB_MODE=<mode>                    # Set tracking mode: cloud (default), cloud-only, local, or disabled
      ```
      
      ---
      
      ##### **Step 3: View Training Logs**
      
      After logging in, you will see a confirmation message:
      
      <img width="1415" alt="track"
      src="https://github.com/user-attachments/assets/87c4ff2f-c8c4-4e7a-a41e-21afa935cb56"
      />
      
      - **Online Tracking**: View logs on the [SwanLab
      website](https://swanlab.cn).
      
      <img width="1900" alt="remote"
      src="https://github.com/user-attachments/assets/5b44b9f3-948f-4f93-9873-572bce56daf7"
      />
      
      For more details, refer to the [SwanLab Cloud
      Documentation](https://docs.swanlab.cn/guide_cloud/experiment_track/view-result.html).
      
      - **Offline Tracking**: Use the local dashboard to visualize logs:
      
        ```bash
        swanlab watch
        ```
      
      For advanced configurations, such as setting a custom port, refer to the
      [Offline Dashboard
      Documentation](https://docs.swanlab.cn/guide_cloud/self_host/offline-board.html)
      and [CLI
      Documentation](https://docs.swanlab.cn/api/cli-swanlab-watch.html#%E8%AE%BE%E7%BD%AEip%E5%92%8C%E7%AB%AF%E5%8F%A3%E5%8F%B7).
      
      ---
      
      #### **Impact**
      
      - Provides a lightweight, flexible, and user-friendly experiment
      tracking solution.
      - Supports both online and offline use cases, making it suitable for
      environments with restricted internet access.
      - Enhances hardware monitoring capabilities for better resource
      utilization.
      
      ---
      
      This PR is ready for review. Feedback and suggestions are welcome!
      Shaohon Chen committed
    • [rollout]: fix incorrect response_attention_mask in vLLM rollout (#213) · 3140cc2f
      This PR addresses issue https://github.com/volcengine/verl/issues/212.
      
      The changes include:
      - read eos_token_id from generation_config to ensure alignment with vLLM
      - modified the get_eos_mask function to accept both int and list types
      for the eos_token parameter.
      Kinman Lei committed
    • [misc] feat: add ckpt manager in utils (#216) · 27484a7b
      - Support FSDPCheckpointManager
      - Support hdfs_io import if installed
      - Add CI for FSDPCheckpointManager
      
      TODO:
      - Will integrate in the next PR
      Guangming Sheng committed