Unverified Commit 558fae54 by Guangming Sheng Committed by GitHub

[misc] fix: disable chunked-prefill by default (#259)

Thanks: @HillZhang1999

- Related issue: https://github.com/volcengine/verl/issues/189

`[36m(main_task pid=3523385)[0m ValueError: max_num_batched_tokens
(8192) is smaller than max_model_len (9216). This effectively limits the
maximum sequence length to max_num_batched_tokens and makes vLLM reject
longer sequences. Please increase max_num_batched_tokens or decrease
max_model_len.`

When enable_chunked_prefill is activated, the aforementioned issue will
be concealed. Please increase `max_num_batched_tokens` or `decrease
max_model_len`.
parent 59643585
......@@ -69,6 +69,7 @@ actor_rollout_ref:
load_format: dummy_megatron
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
......
......@@ -75,13 +75,14 @@ actor_rollout_ref:
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: True # could get higher throughput
enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
......
......@@ -74,7 +74,7 @@ class vLLMRollout(BaseRollout):
tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
assert tensor_parallel_size <= torch.distributed.get_world_size(), \
"tensor parallel size should be less than or equal to the world size"
max_num_batched_tokens = self.config.get('max_num_batched_tokens', 8192)
max_num_batched_tokens = int(self.config.get('max_num_batched_tokens', 8192))
if kwargs.get('train_tp', None) is not None:
# deployed with megatron
......@@ -89,6 +89,15 @@ class vLLMRollout(BaseRollout):
assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
"model context length should be greater than total sequence length"
max_model_len = self.config.max_model_len if self.config.max_model_len \
else config.prompt_length + config.response_length
max_model_len = int(max_model_len)
if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill:
raise ValueError('Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \
please increase max_num_batched_tokens or disable chunked prefill')
self.inference_engine = LLM(
actor_module,
tokenizer=tokenizer,
......@@ -98,7 +107,7 @@ class vLLMRollout(BaseRollout):
enforce_eager=config.enforce_eager,
gpu_memory_utilization=config.gpu_memory_utilization,
skip_tokenizer_init=False,
max_model_len=config.prompt_length + config.response_length,
max_model_len=max_model_len,
load_format=config.load_format,
disable_log_stats=config.disable_log_stats,
max_num_batched_tokens=max_num_batched_tokens,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment