Unverified Commit 6d96fda3 by Xingyao Wang Committed by GitHub

[SFT] feat: Add LoRA support for SFT (#127)

This PR adds support for LoRA (Low-Rank Adaptation) for efficient model
fine-tuning.

### Changes

1. Added LoRA configuration support in trainer config
2. Modified FSDP wrapping policy to handle LoRA modules
3. Integrated with existing FSDP training infrastructure
4. Added peft dependency
5. Removed unused ring_attn_utils.py

### Features

- Configurable LoRA rank and alpha parameters
- Target module specification for selective adaptation
- Compatible with FSDP sharding strategy

### Testing

Tested with Qwen2.5-0.5B-Instruct model on GSM8K dataset using the
provided example script.

### Dependencies

- Added `peft` package to requirements.txt

This PR is based on commit 902ddbe6 and has been merged with the latest
upstream main branch.

---------

Co-authored-by: Jiayi Pan <i@jiayipan.me>
Co-authored-by: openhands <openhands@all-hands.dev>
parent 22e93114
name: e2e_lora
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/e2e_lora.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/e2e_lora.yml
- "tests/e2e/*.sh"
jobs:
e2e_lora:
runs-on: [self-hosted, l20-1]
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install hf_transfer peft
pip3 install -e .[test]
- name: Prepare gsm8k dataset
run: |
ray stop --force
python3 examples/data_preprocess/gsm8k.py
- name: Running gsm8k e2e training tests with LoRA
run: |
ray stop --force
bash examples/sft/gsm8k/run_qwen_05_peft.sh 8 $HOME/ckpts/
\ No newline at end of file
# Tested with 2 & 4 GPUs
set -x
if [ "$#" -lt 2 ]; then
echo "Usage: run_qwen_05_peft.sh <nproc_per_node> <save_path> [other_configs...]"
exit 1
fi
nproc_per_node=$1
save_path=$2
# Shift the arguments so $@ refers to the rest
shift 2
torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
-m verl.trainer.fsdp_sft_trainer \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.prompt_key=extra_info \
data.response_key=extra_info \
optim.lr=1e-4 \
+data.prompt_dict_keys=['question'] \
+data.response_dict_keys=['answer'] \
data.micro_batch_size=32 \
model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \
trainer.default_local_dir=$save_path \
trainer.project_name=gsm8k-sft \
trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \
trainer.logger=['console'] \
trainer.total_training_steps=1 \
trainer.default_hdfs_dir=null $@ \
model.lora_rank=32\
model.lora_alpha=16 \
model.target_modules=all-linear
# Or you can do this:
# model.target_modules=[q_proj,v_proj] \
......@@ -41,6 +41,7 @@ dependencies = [
"tensordict",
"transformers<4.48",
"vllm<=0.6.3",
"peft",
]
# Optional dependencies (extras_require in setup.py)
......
......@@ -19,6 +19,9 @@ model:
external_lib: null
enable_gradient_checkpointing: False
trust_remote_code: False
lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
lora_alpha: 16 # LoRA scaling factor
target_modules: [q_proj, v_proj] # Target modules for LoRA adaptation
optim:
lr: 1e-5
betas: [0.9, 0.95]
......
......@@ -43,6 +43,7 @@ from torch.distributed.device_mesh import DeviceMesh
import verl.utils.hdfs_io as hdfs_io
from verl.utils.debug import log_gpu_memory_usage
from peft import LoraConfig, TaskType, get_peft_model
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN'))
......@@ -55,6 +56,18 @@ def extract_step(path):
return None
def convert_to_regular_types(obj):
"""Convert Hydra configs and other special types to regular Python types."""
from omegaconf import ListConfig, DictConfig
if isinstance(obj, (ListConfig, DictConfig)):
return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
elif isinstance(obj, (list, tuple)):
return [convert_to_regular_types(x) for x in obj]
elif isinstance(obj, dict):
return {k: convert_to_regular_types(v) for k, v in obj.items()}
return obj
class FSDPSFTTrainer(object):
def __init__(self, config, device_mesh: DeviceMesh):
......@@ -163,6 +176,18 @@ class FSDPSFTTrainer(object):
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
if self.config.model.get('lora_rank', 0) > 0:
self.model.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
'task_type': TaskType.CAUSAL_LM,
'r': self.config.model.lora_rank,
'lora_alpha': self.config.model.lora_alpha,
'target_modules': convert_to_regular_types(self.config.model.target_modules),
'bias': "none"
}
self.model = get_peft_model(self.model, LoraConfig(**lora_config))
if self.config.model.enable_gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
......@@ -172,7 +197,9 @@ class FSDPSFTTrainer(object):
reduce_dtype=torch.float32,
buffer_dtype=torch.float32)
auto_wrap_policy = get_fsdp_wrap_policy(self.model, config=self.config.model.fsdp_config.wrap_policy)
auto_wrap_policy = get_fsdp_wrap_policy(self.model,
config=self.config.model.fsdp_config.wrap_policy,
is_lora=self.config.model.get('lora_rank', 0) > 0)
if self.device_mesh.get_rank() == 0:
print(auto_wrap_policy)
......
......@@ -45,7 +45,14 @@ def get_init_weight_context_manager(use_meta_tensor=True):
# Copyright 2020-present the HuggingFace Inc. team.
# Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py
def get_fsdp_wrap_policy(module, config=None):
def get_fsdp_wrap_policy(module, config=None, is_lora=False):
"""Get FSDP wrap policy for the module.
Args:
module: The module to get wrap policy for
config: Configuration for wrap policy
is_lora: Whether to enable lambda policy for LoRA modules
"""
if config is None:
config = {}
......@@ -57,8 +64,26 @@ def get_fsdp_wrap_policy(module, config=None):
default_transformer_cls_names_to_wrap)
min_num_params = config.get('min_num_params', 0)
auto_wrap_policy = None
policies = []
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
# Add lambda policy for LoRA modules if is_lora is True
if is_lora:
def lambda_policy_fn(module):
if (len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and
module.weight.requires_grad):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
policies.append(lambda_policy)
if min_num_params > 0:
auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
policies.append(size_policy)
elif fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
for layer_class in fsdp_transformer_layer_cls_to_wrap:
......@@ -68,11 +93,15 @@ def get_fsdp_wrap_policy(module, config=None):
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
policies.append(transformer_policy)
if len(policies) > 0:
auto_wrap_policy = functools.partial(_or_policy, policies=policies)
return auto_wrap_policy
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment