Unverified Commit 4a291fa7 by Yusheng (Ethan) Su Committed by GitHub

[Hardware] Support AMD (Rocm kernel) (#360)

parent 75dedb57
**/*.pt
**/checkpoints
**/wget-log
......
# Build the docker in the repo dir:
# docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 .
# docker images # you can find your built docker
FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4
# Set working directory
# WORKDIR $PWD/app
# Set environment variables
ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942"
# Install vllm
RUN pip uninstall -y vllm && \
rm -rf vllm && \
git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \
cd vllm && \
MAX_JOBS=$(nproc) python3 setup.py install && \
cd .. && \
rm -rf vllm
# Copy the entire project directory
COPY . .
# Install dependencies
RUN pip install "tensordict<0.6" --no-deps && \
pip install accelerate \
codetiming \
datasets \
dill \
hydra-core \
liger-kernel \
numpy \
pandas \
peft \
"pyarrow>=15.0.0" \
pylatexenc \
"ray[data,train,tune,serve]" \
torchdata \
transformers \
wandb \
orjson \
pybind11 && \
pip install -e . --no-deps
\ No newline at end of file
# Setup
## Dockerfile.rocm
```bash
# Build the docker in the repo dir:
# docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 .
# docker images # you can find your built docker
#
FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4
# Set working directory
# WORKDIR $PWD/app
# Set environment variables
ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942"
# Install vllm
RUN pip uninstall -y vllm && \
rm -rf vllm && \
git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \
cd vllm && \
MAX_JOBS=$(nproc) python3 setup.py install && \
cd .. && \
rm -rf vllm
# Copy the entire project directory
COPY . .
# Install dependencies
RUN pip install "tensordict<0.6" --no-deps && \
pip install accelerate \
codetiming \
datasets \
dill \
hydra-core \
liger-kernel \
numpy \
pandas \
peft \
"pyarrow>=15.0.0" \
pylatexenc \
"ray[data,train,tune,serve]" \
torchdata \
transformers \
wandb \
orjson \
pybind11 && \
pip install -e . --no-deps
```
## Build the image:
```bash
docker build -t verl-rocm .
```
## Run the container
```bash
docker run --rm -it \
--device /dev/dri \
--device /dev/kfd \
-p 8265:8265 \
--group-add video \
--cap-add SYS_PTRACE \
--security-opt seccomp=unconfined \
--privileged \
-v $HOME/.ssh:/root/.ssh \
-v $HOME:$HOME \
--shm-size 128G \
-w $PWD \
verl-rocm \
/bin/bash
```
# Example
## PPO
```bash
YOUR_PROJECT_NAME=r1-verl-ppo-upstream
YOUR_RUN_NAME=r1-training_ppo-upstream
# export HYDRA_FULL_ERROR=1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES
GPUS_PER_NODE=8
MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct
python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k
python3 -c "import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')"
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
data.train_files=data/gsm8k/train.parquet \
data.val_files=data/gsm8k/test.parquet \
data.train_batch_size=256 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=$MODEL_PATH \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
critic.optim.lr=1e-5 \
critic.model.path=$MODEL_PATH \
critic.ppo_micro_batch_size_per_gpu=4 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=['console'] \
trainer.project_name=$YOUR_PROJECT_NAME \
trainer.experiment_name=$YOUR_RUN_NAME \
+trainer.val_before_train=False \
trainer.default_hdfs_dir=null \
trainer.n_gpus_per_node=$GPUS_PER_NODE \
trainer.nnodes=1 \
trainer.save_freq=10 \
trainer.test_freq=10 \
trainer.total_epochs=15 #2>&1 | tee verl_demo.log
```
## GRPO
```bash
YOUR_PROJECT_NAME=r1-verl-grpo-upstream
YOUR_RUN_NAME=r1-training_grpo-upstream
# export HYDRA_FULL_ERROR=1
# export FSDP_VERBOSE=1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES
GPUS_PER_NODE=8
MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct
# MODEL_PATH=Qwen/Qwen2-7B-Instruct
python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k
python3 -c "import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')"
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=data/gsm8k/train.parquet \
data.val_files=data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=$MODEL_PATH \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=Flase \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=False \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name=$YOUR_PROJECT_NAME \
trainer.experiment_name=$YOUR_RUN_NAME \
trainer.n_gpus_per_node=$GPUS_PER_NODE \
+trainer.val_before_train=False \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=10 \
trainer.total_epochs=15
```
\ No newline at end of file
# Setup
## Docker:
Find the docker here: https://hub.docker.com/r/rocm/vllm/tags (rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4)
```bash
docker run --rm -it \
--device /dev/dri \
--device /dev/kfd \
--network host \
--ipc host \
--group-add video \
--cap-add SYS_PTRACE \
--security-opt seccomp=unconfined \
--privileged \
-v /home/yushensu:/home/yushensu \
-v $HOME/.ssh:/root/.ssh \
--shm-size 128G \
--name verl_vllm_upstream \
-w $PWD \
rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 \
/bin/bash
```
## Build ROCM vLLM:
```bash
pip uninstall -y vllm
git clone -b v0.6.3 https://github.com/vllm-project/vllm.git
cd vllm
export PYTORCH_ROCM_ARCH="gfx90a;gfx942"
export MAX_JOBS=$(nproc)
# python3 setup.py develop # will not create src need to keep the repo
python3 setup.py install # will add src into py. You can delete the repo
cd ..
rm -rf vllm
```
## Install the require packages:
```bash
pip install "tensordict<0.6" --no-deps
pip install accelerate \
codetiming \
datasets \
dill \
hydra-core \
liger-kernel \
numpy \
pandas \
peft \
"pyarrow>=15.0.0" \
pylatexenc \
"ray[data,train,tune,serve]" \
torchdata \
transformers \
wandb \
orjson \
pybind11
pip install -e . --no-deps
```
# Example
## PPO
```bash
YOUR_PROJECT_NAME=r1-verl-ppo-upstream
YOUR_RUN_NAME=r1-training_ppo-upstream
# export HYDRA_FULL_ERROR=1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES
GPUS_PER_NODE=8
MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct
python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k
python3 -c "import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')"
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
data.train_files=data/gsm8k/train.parquet \
data.val_files=data/gsm8k/test.parquet \
data.train_batch_size=256 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=$MODEL_PATH \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
critic.optim.lr=1e-5 \
critic.model.path=$MODEL_PATH \
critic.ppo_micro_batch_size_per_gpu=4 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=['console'] \
trainer.project_name=$YOUR_PROJECT_NAME \
trainer.experiment_name=$YOUR_RUN_NAME \
+trainer.val_before_train=False \
trainer.default_hdfs_dir=null \
trainer.n_gpus_per_node=$GPUS_PER_NODE \
trainer.nnodes=1 \
trainer.save_freq=10 \
trainer.test_freq=10 \
trainer.total_epochs=15 #2>&1 | tee verl_demo.log
```
## GRPO
```bash
YOUR_PROJECT_NAME=r1-verl-grpo-upstream
YOUR_RUN_NAME=r1-training_grpo-upstream
# export HYDRA_FULL_ERROR=1
# export FSDP_VERBOSE=1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES
GPUS_PER_NODE=8
MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct
# MODEL_PATH=Qwen/Qwen2-7B-Instruct
python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k
python3 -c "import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')"
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=data/gsm8k/train.parquet \
data.val_files=data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=$MODEL_PATH \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=Flase \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=False \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name=$YOUR_PROJECT_NAME \
trainer.experiment_name=$YOUR_RUN_NAME \
trainer.n_gpus_per_node=$GPUS_PER_NODE \
+trainer.val_before_train=False \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=10 \
trainer.total_epochs=15
```
\ No newline at end of file
......@@ -118,6 +118,19 @@ class Worker(WorkerHelper):
def __init__(self, cuda_visible_devices=None) -> None:
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
import os
###
# [SUPPORT AMD: torch]
import torch
###
###
# [SUPPORT AMD: torch]
if "AMD" in torch.cuda.get_device_name():
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('ROCR_VISIBLE_DEVICES')
os.environ['LOCAL_RANK'] = os.environ.get('RAY_LOCAL_RANK')
###
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ['RANK'])
self._rank = rank
......@@ -129,6 +142,18 @@ class Worker(WorkerHelper):
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
###
# [SUPPORT AMD: torch]
if "AMD" in torch.cuda.get_device_name():
self.local_rank = int(os.environ['LOCAL_RANK'])
###
###
# [SUPPORT AMD: torch]
if "AMD" in torch.cuda.get_device_name():
cuda_visible_devices = str(local_rank)
###
store = {
'_world_size': world_size,
'_rank': rank,
......@@ -143,6 +168,13 @@ class Worker(WorkerHelper):
meta = WorkerMeta(store=store)
self._configure_with_meta(meta=meta)
###
# [SUPPORT AMD: torch]
# torch.cuda.set_device(local_rank)
if "AMD" in torch.cuda.get_device_name():
torch.cuda.set_device(int(cuda_visible_devices))
###
def _configure_with_meta(self, meta: WorkerMeta):
"""
This function should only be called inside by WorkerGroup
......
......@@ -47,6 +47,11 @@ elif package_version == '0.6.3':
from .vllm_v_0_6_3.llm import LLM
from .vllm_v_0_6_3.llm import LLMEngine
from .vllm_v_0_6_3 import parallel_state
elif package_version == '0.6.3+rocm624':
vllm_version = '0.6.3'
from .vllm_v_0_6_3.llm import LLM
from .vllm_v_0_6_3.llm import LLMEngine
from .vllm_v_0_6_3 import parallel_state
elif vs.parse(package_version) >= vs.parse('0.6.6.post2.dev252+g8027a724'):
# From 0.6.6.post2 on, vllm supports SPMD inference
# See https://github.com/vllm-project/vllm/pull/12071
......
......@@ -336,6 +336,19 @@ def compute_timing_metrics(batch, timing_raw):
}
def compute_throughout_metrics(batch, timing_raw, n_gpus):
total_num_tokens = sum(batch.meta_info['global_token_num'])
time = timing_raw["step"]
# estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time)
# f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus),
# f'Theoretical TFLOPs/s/GPU​': promised_flops,
return {
f'total_num_tokens': total_num_tokens,
f'time_per_step': time,
f'Tokens/Sec/GPU': total_num_tokens / (time * n_gpus),
}
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
......@@ -997,6 +1010,11 @@ class RayPPOTrainer(object):
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
config = self.config
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
# Implement actual tflpo and theoretical tflpo
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
......
......@@ -32,7 +32,10 @@ def get_device_flops(unit="T"):
device_name = torch.cuda.get_device_name()
flops = float("inf") # INF flops for unkown gpu type
if "H100" in device_name or "H800" in device_name:
if "MI300X" in device_name:
flops = 1336e12
elif "H100" in device_name or "H800" in device_name:
flops = 989e12
elif "A100" in device_name or "A800" in device_name:
flops = 312e12
......
......@@ -264,11 +264,11 @@ class DataParallelPPOActor(BasePPOActor):
self.actor_optimizer.zero_grad()
for data in micro_batches:
# Support all hardwares
if isinstance(data, DataProto):
data = {**data.batch.cuda(), **data.non_tensor_batch}
data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch}
else:
data = data.cuda() # actor device is cpu when using offload
data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload
responses = data['responses']
response_length = responses.size(1)
attention_mask = data['attention_mask']
......
......@@ -198,11 +198,11 @@ class DataParallelPPOCritic(BasePPOCritic):
self.critic_optimizer.zero_grad()
for data in micro_batches:
#Support all devices
if isinstance(data, DataProto):
data = {**data.batch.cuda(), **data.non_tensor_batch}
data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch}
else:
data = data.cuda() # critic device is cpu when using offload
data = data.to(torch.cuda.current_device()) # critic device is cpu when using offload
input_ids = data['input_ids']
responses = data['responses']
attention_mask = data['attention_mask']
......
......@@ -416,7 +416,8 @@ class ActorRolloutRefWorker(Worker):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
data = data.to('cuda')
# Support all hardwares
data = data.to(torch.cuda.current_device())
assert self._is_actor
if self._is_offload_param:
......@@ -424,7 +425,8 @@ class ActorRolloutRefWorker(Worker):
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device())
data.batch = data.batch.cuda()
# Support all hardwares
data.batch = data.batch.to(torch.cuda.current_device())
log_gpu_memory_usage('Before update policy', logger=logger)
......@@ -459,13 +461,15 @@ class ActorRolloutRefWorker(Worker):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
prompts = prompts.to('cuda')
# Support all hardwares
prompts = prompts.to(torch.cuda.current_device())
assert self._is_rollout
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
prompts.batch = prompts.batch.cuda()
# Support all hardwares
prompts.batch = prompts.batch.to(torch.cuda.current_device())
meta_info = {
'eos_token_id':
self.generation_config.eos_token_id
......@@ -504,7 +508,9 @@ class ActorRolloutRefWorker(Worker):
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
data = data.to('cuda')
# Support all hardwares
data = data.to(torch.cuda.current_device())
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu
......@@ -537,7 +543,8 @@ class ActorRolloutRefWorker(Worker):
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
data = data.to('cuda')
# Support all hardwares
data = data.to(torch.cuda.current_device())
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
data.meta_info['micro_batch_size'] = micro_batch_size
......@@ -782,7 +789,9 @@ class CriticWorker(Worker):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
data = data.to('cuda')
# Support all hardwares
data = data.to(torch.cuda.current_device())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
......@@ -804,7 +813,8 @@ class CriticWorker(Worker):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
data = data.to('cuda')
# Support all hardwares
data = data.to(torch.cuda.current_device())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
if self._is_offload_optimizer:
......@@ -1103,11 +1113,13 @@ class RewardModelWorker(Worker):
def compute_rm_score(self, data: DataProto):
import itertools
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
data = data.to('cuda')
# Support all hardwares
data = data.to(torch.cuda.current_device())
if self._do_switch_chat_template:
rm_data = self._switch_chat_template(data)
rm_data.batch = rm_data.batch.cuda()
# Support all hardwares
rm_data.batch = rm_data.batch.to(torch.cuda.current_device())
# perform forward computation
with self.ulysses_sharding_manager:
......
......@@ -14,6 +14,11 @@
from importlib.metadata import version, PackageNotFoundError
###
# [SUPPORT AMD:]
import torch
###
def get_version(pkg):
try:
......@@ -25,6 +30,17 @@ def get_version(pkg):
package_name = 'vllm'
package_version = get_version(package_name)
###
# package_version = get_version(package_name)
# [SUPPORT AMD:]
if "AMD" in torch.cuda.get_device_name():
import re
package_version = version(package_name)
package_version = re.match(r'(\d+\.\d+\.?\d*)', package_version).group(1)
else:
package_version = get_version(package_name)
###
if package_version <= '0.6.3':
vllm_mode = 'customized'
from .vllm_rollout import vLLMRollout
......
......@@ -74,6 +74,7 @@ class FSDPVLLMShardingManager(BaseShardingManager):
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
# Copy, not share memory
load_format = 'hf' if self.full_params else 'dtensor'
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
self.inference_engine.sync_model_weights(params, load_format=load_format)
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment