verl 解读 - Hybrid controller、WorkerGroup colocate 设计及源码分析 (part2)

RLHF 算法流程回顾

在介绍 verl 设计和理解源码前,先回顾一下 RLHF 算法。抑或直接跳至源码阅读部分。

PPO 算法

PPO 使用梯度上升优化的目标函数为:

上式中:

  • policy model 即对应 actor model
  • :question,即输入的 prompt;
  • :即 经过 actor model 后的 response 输出(采样的一条轨迹的 response 部分);
  • :即第 t 时间步的输出 token;
  • :一条轨迹的总时间步数,即
  • :在时间步 输出 token 的概率;
  • :优势函数。

关于优势函数。现考虑从时间步 开始往后 步的优势估计:

  • 开始的 1 步,即采用 TD-Error

  • 开始的 2 步,

  • 开始的 n 步,以此类推:

代表的步数比较小时,会导致高偏差, 比较大时,会导致高方差。GAE 引入了 参数进行偏差-方差的 trade-off。即:

则当 取值分别为 0,1 时:

GAE 原理与推导可参考 GAE 论文

整体流程图解如下:

下面从代码角度(只关注主体部分,或有疏漏),对照公式和流程图解梳理 PPO。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# ---------------------
# ------ 初始状态 ------
# ---------------------
# 根据上图中流程,初始化 4 个 model
actor_model = ...
ref_model = ...
critic_model = ...
reward_model = ...

# 输入的一个 batch 的 prompts
prompts = ... # (bsz, prompt_len)

# ----------------------
# ------ generate ------
# ----------------------
# generate 部分;rollout

# 不同长度的 prompt 和生成的 response 在一个 batch 内会被 padding 对齐

# prompts_completions 进行了 padding 操作,形如:
# prompts | responses
# x x x x x x o o o o|o o o o x x x x
# x x x o o o o o o o|o o o o o x x x
# x x x x x x o o o o|o o o x x x x x
# 对应的 attention_mask:
# 0 0 0 0 0 0 1 1 1 1|1 1 1 1 0 0 0 0
# 0 0 0 1 1 1 1 1 1 1|1 1 1 1 1 0 0 0
# 0 0 0 0 0 0 1 1 1 1|1 1 1 0 0 0 0 0

seq = actor_model.generate(prompts) # (bsz, prompt_len+completion_len+1)

# ----------------------
# ----- inference ------
# ----------------------
logits = actor_model(seq).logits # (bsz, prompt_len+completion_len+1, vocab_size)
ref_logits = ref_model(seq).logits # (bsz, prompt_len+completion_len+1, vocab_size)
# (bsz, prompt_len+completion_len)
log_probs = gather_log_probs(logits[:, :-1, :], seq[:, 1:])
# (bsz, prompt_len+completion_len)
ref_log_probs = gather_log_probs(ref_logits[:, :-1, :], seq[:, 1:])

# 当一条 prompt 生成完成后,可以计算这条轨迹的 reward。
reward_scores = reward_model(seq) # (bsz, )
# (bsz, prompt_len+completion_len)
values = critic_model(seq)[:, :-1]

# ---------------------------
# per token 计算 rewards,具体实现参见下方函数定义。
# (bsz, prompt_len+completion_len)
# ---------------------------
old_rewards = compute_rewards(
log_probs, # (bsz, prompt_len+completion_len)
ref_log_probs, # (bsz, prompt_len+completion_len)
reward_scores, # (bsz, )
)

# ---------------------
# 计算优势,具体实现参见下方函数定义。
# ---------------------
# (bsz, completion_len), (bsz, completion_len)
advantages, returns = get_advantages_and_returns(
values, old_rewards
)

# ----------------------
# -------- train -------
# ----------------------
# actor loss
actor_probs = actor_model(seq).logits
# (bsz, prompt_len+completion_len)
actor_log_probs = gather_log_probs(actor_probs[:, :-1, :], seq[:, 1:])
# loss 实现见下方定义
actor_loss = actor_loss_fn(
actor_log_probs[:, -completion_len:], # 作为 \pi_{\theta}
log_probs[:, -completion_len:], # 作为 \pi_{\theta_{old}}
advantages
)

# critic loss
critic_values = critic_model(seq)[:, :-1]
# loss 实现见下方定义
critic_loss = critic_loss_fn(
critic_values, # 作为 values
values, # 作为 old_values
returns
)

rewards 的计算函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 参考 DeepSpeed-Chat,为了便于理解作了调整
def compute_rewards(
log_probs,
ref_log_probs,
reward_scores
):
# kl_ctl default is 0.1
# (bsz, prompt_len+completion_len)
kl_divergence_estimate = -kl_ctl * (log_probs - ref_log_probs)
row_idx = torch.arange(batch_size).unsqueeze(1)

# 每条 seq 最后一个 token 的 index
end_idx = ... # (bsz, 1)
# --------------------
# 返回的 reward 是 per token 计算的,分别计算每个 response 中的每个 token 的
# KL 值,另外,**最后一个 token** 除了 KL 值外还叠加了奖励分数(reward_score)。
# --------------------
kl_divergence_estimate[row_idx, end_idx] += reward_scores.unsqueeze(1)
return kl_divergence_estimate

依照上述 GAE 公式,

位置的 token 对应的优势估计为:

可得: 。由此递推关系,可以采用动态规划计算生成的每个 token 对应的优势估计。结合下面的代码理解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 参考 DeepSpeed-Chat
def get_advantages_and_returns(values, rewards):
"""
Args:
values (batch_size, seq_length):
critic model 的输出。
rewards (batch_size, seq_length):
上述 compute_rewards 的返回值。
"""
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
# ----------------------
# 每个时间步,即 per token
# 从上述 A_t 和 A_{t+1} 的递推关系,从最后一个 token 的优势往前计算
# ----------------------
for t in reversed(range(start, length)):
# (batch_size, )
nextvalues = values[:, t + 1] if t < length - 1 else 0.0
delta = rewards[:, t] + gamma * nextvalues - values[:, t]
# (batch_size, ) 上述递推关系的公式
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
# -----------------------------
# (batch_size, completion_length)
# 因为上述 for-loop 从最后 token 往前计算,所以 advantages_reversed
# 需要再 reversed 一次。
# -----------------------------
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values[:, start:]
return advantages.detach(), returns

actor model loss 的计算,对照上述 的公式,比较清晰。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 参考 DeepSpeed-Chat
def actor_loss_fn(logprobs, old_logprobs, advantages, mask):
"""
Args:
mask: 只会计算生成部分的 token,prompt 部分被 mask 掉。
"""
## policy gradient loss
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(
ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss

critic model loss 计算如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def critic_loss_fn(values, old_values, returns, mask):
"""
Args:
mask: 只会计算生成部分的 token,prompt 部分被 mask 掉。
"""
## value loss
values_clipped = torch.clamp(
values,
old_values - cliprange_value,
old_values + cliprange_value,
)

vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(
torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
return vf_loss

至此,PPO 主体部分梳理完成。

verl 架构设计

回顾

根据 RLHF 算法流程的回顾,其基本可以归纳为三个阶段:

  • stage 1 (generation): actor model 从一个 batch 的 prompts 生成 responses。
  • stage 2 (inference): critic model 从 prompts+responses 推理计算 values;ref model 从 prompts+responses 推理计算 tokens 的概率;reward model 从 prompts+responses 推理计算奖励值。
  • stage 3 (training): 计算 actor modelcritic model loss,计算梯度,更新模型参数。

不同阶段的 bound 不同,比如,actor model 的 training 阶段是 computation-bound,在 generation 阶段是 memory-bound。在 training 阶段为了高效利用 GPU 需要更大的 model-parallel size (PP, TP),而在 generation 阶段需要更大的 DP size 来提高 GPU 利用率,这就需要灵活地对 actor model 的 weights 进行 resharding,同时不因此引入更多的内存和通信开销。

既往的框架,采用 Single-Controller 或者 Multi-Controller 方式都各有优劣。

Single-Controller:

  • 由一个中心控制器驱动所有 worker;
  • 控制器向所有 worker 发送数据、计算指令;
  • worker 可以执行不同的程序(MPMD);
  • 优点是能在 inter-node (节点之间) 进行灵活的数据、指令传输;但是随之而来的缺点就是 intra-node (节点内部) 之间的执行控制流程开销很大。

Multi-Controller: (pytorch 等训练框架属于此类)

  • 无中心控制器,每个 worker 自驱;
  • worker 各自执行计算,由通信原语同步;
  • worker 运行相同程序(SPMD);
  • 优点是能够高效执行 intra-node 的分布式计算指令;缺点是 inter-node 的统一协调,数据分发开销很大。而且 multi-controller 从编程实现的扩展性上也比 Single-Controller 繁琐。

有鉴于此,verl 实际是采用 hybrid 的方式,结合了 Single-ControllerMulti-Controller 的优点。

verl 设计

verl 核心设计思想:

  • 通过 single-controller 实现 RL 数据流;中心控制器在 inter-component 以 RPC 交换传输数据。
  • 通过 multi-controller 实现各个 LLM component (Worker);对 component 的一组 SPMD 进程抽象,使得 single-controller 可以像单进程一样调用。intra-component 以 SPMD 工作,以通信原语通过 NCCL 交换传输数据。

verl 不同层级的 API 如图:

如上图所示,3DParallelWorker 类负责 model 的 intra-node3D parallel (DP,PP,TP) 计算。通过 WorkerGroup 对一组 Worker 进行 SPMD 的进程抽象。当需要收集和分发数据时,Driver 进程(Single-Controller)通过 Ray object store,从 WorkerGroup 收集数据,并分发给其他 WorkerGroup,这个过程可以通过 @register(transfer_mode=3D_PROTO) 实现。

而关于资源的调度和分配,WorkerGroup colocate 使用 Ray PlacementGroup 调度,通过设置 Worker 的 placement_group,placement_group_bundle_index,num_gpus 灵活控制 WorkerGroup 是否共享同一个 placement_group。

有了 hybrid 的架构设计,自然就可以很方便地在 single-controller 的层级进行 RLHF 的流程定义:

verl 不同阶段的 backend:

  • generation:
    • actor/rollout: vLLM / TensorRT-LLM / SGLang
  • training:
    • actor/critic: Megatron-LM / Deepspeed / FSDP
  • inference:
    • reward/reference: Megatron-LM / Deepspeed / FSDP

verl 源码解读

注:基于 verl==0.4.1.dev0 版本。

关于源码解读部分:

  • 核心逻辑的梳理;
  • GRPO 为主;
  • FSDP 为主。

引入

先通过一个例子引入来看核心类。借鉴了 verl 官方示例

示例代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import warnings
import os

import torch
from tensordict import TensorDict
import ray
from verl import DataProto
from verl.single_controller import Worker
from verl.single_controller.base.decorator import Dispatch, Execute, register
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils.ray_utils import parallel_put


warnings.filterwarnings("ignore")

ray.init()


@ray.remote
class GPUWorker(Worker):

def __init__(self):
super().__init__()

def env_info(self):
return (
f"[{self.get_name()}] "
f"rank: {self.rank}, "
f"world_size: {self.world_size}, "
f"local_world_size: {os.environ.get('LOCAL_WORLD_SIZE')}, "
f"local_rank: {os.environ.get('LOCAL_RANK')}, "
f"master_addr: {os.environ.get('MASTER_ADDR')}, "
f"master_port: {os.environ.get('MASTER_PORT')}, "
f"cuda_visible_devices: {os.environ.get('CUDA_VISIBLE_DEVICES')}")

def get_name(self):
ctx = ray.runtime_context.get_runtime_context()
return ctx.get_actor_name()

@register(dispatch_mode=Dispatch.ONE_TO_ALL, execute_mode=Execute.ALL)
def add(self, x, y):
return f"[{self.get_name()}] {x + y}"

@register(dispatch_mode=Dispatch.DP_COMPUTE, execute_mode=Execute.ALL, blocking=False)
def dummy_compute(self, data):
for key in data.batch.keys():
data.batch[key] += self.rank
return data


if __name__ == "__main__":
resource_pool = RayResourcePool([2], use_gpu=True, max_colocate_count=1)
class_with_args = RayClassWithInitArgs(cls=GPUWorker)
worker_group = RayWorkerGroup(resource_pool, class_with_args)
worker_names = worker_group.worker_names
workers = worker_group.workers

print(ray.get([worker.env_info.remote() for worker in workers]))

print(worker_group.add(x=1, y=2))
ray.shutdown()

运行上述程序:CUDA_VISIBLE_DEVICES=0,2 python <this file>。基于上述代码,重点关注 RayResourcePoolRayClassWithInitArgsRayWorkerGroupregister 等类别和方法的功能与实现。

核心类解析

DataProto

类的 doc 说明已经很清晰了,不赘述。

1
2
3
4
5
6
7
8
9
10
11
12
13
# verl/protocol.py
@dataclass
class DataProto:
"""
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
same batch size should be put inside batch.
"""

batch: TensorDict = None
non_tensor_batch: Dict = field(default_factory=dict)
meta_info: Dict = field(default_factory=dict)

RayResourcePool

  • 核心是提供一个灵活构造 Ray 的 placement_group 的功能接口。
1
2
3
4
5
6
# verl/single_controller/ray/base.py
class RayResourcePool(ResourcePool):

# 核心方法,构造 placement_group
def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"):
...

假设有 2 个 node,每个 node 上有 4 个 GPU,则:

1
2
resource_pool = RayResourcePool([4, 4], use_gpu=True, max_colocate_count=1)
pgs = resource_pool.get_placement_groups()

最终 pgs 会是包含了各自 CPU、GPU 资源的 bundles 组成的 placement_group。形如:

1
2
3
4
[
placement_group(bundles=[{"CPU": 1, "GPU": 1}, {"CPU": 1, "GPU": 1}, {"CPU": 1, "GPU": 1}, {"CPU": 1, "GPU": 1}], strategy="STRICT_PACK", name=..., ),
placement_group(bundles=[{"CPU": 1, "GPU": 1}, {"CPU": 1, "GPU": 1}, {"CPU": 1, "GPU": 1}, {"CPU": 1, "GPU": 1}], strategy="STRICT_PACK", name=..., ),
]

RayClassWithInitArgs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# verl/single_controller/ray/base.py
class RayClassWithInitArgs(ClassWithInitArgs):
...
def __call__(
self,
placement_group,
placement_group_bundle_idx,
use_gpu: bool = True,
num_gpus=1,
sharing_with=None,
device_name="cuda",
) -> Any:
"""Create and return a Ray actor with the configured options.

Args:
placement_group: Ray placement group for scheduling
placement_group_bundle_idx: Index of the bundle in the placement group
use_gpu: Whether to use GPU resources
num_gpus: Number of GPUs to allocate
sharing_with: Actor to share resources with
device_name: Device for training

Returns:
A Ray actor handle with the configured options
"""
...
# Ray Actor 的资源分配和初始化
self.cls.options(**options).remote(*self.args, **self.kwargs)

示例代码中的调用:

1
class_with_args = RayClassWithInitArgs(cls=GPUWorker)

GPUWorker 是一个 Ray Actor 类,上面 RayClassWithInitArgs 的初始化操作,并不会立即初始化 Ray Actor 即 GPUWorkerGPUWorker 的初始化发生在 RayClassWithInitArgs 的实例发生调用行为时。此时可以根据调用时的传入参数,将 GPUWorker (Ray Actor) 绑定到特定的 CPU、GPU 资源上。代码示例如下:

1
2
3
4
5
6
7
8
9
# -------------------------
# 调用 RayClassWithInitArgs::__call__ 函数
# 借用上述 2 node,每个 node 4 GPU 的例子,给 GPUWorker 分配在第 1 个 node 的第 1 个 bundle 的资源上,
# 同时初始化 `GPUWorker`。
# -------------------------
class_with_args(
placement_group=pgs[0],
placement_group_bundle_idx=0 # {"CPU": 1, "GPU": 1}
)

RayClassWithInitArgsRay Actor 的类提供了统一的资源分配接口和封装。

RayWorkerGroup

现在看看 RayWorkerGroup 的初始化过程中,涉及的核心逻辑。重点看:

  • _init_with_resource_pool
  • _bind_worker_method

RayWorkerGroup 初始化过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# verl/single_controller/ray/base.py
class RayWorkerGroup(WorkerGroup):

def __init__(
self,
resource_pool: RayResourcePool = None,
ray_cls_with_init: RayClassWithInitArgs = None,
bin_pack: bool = True,
name_prefix: str = None,
detached=False,
worker_names=None,
worker_handles: List[ray.actor.ActorHandle] = None,
ray_wait_register_center_timeout: int = 300,
device_name="cuda",
**kwargs,
) -> None:
"""
Args:
resource_pool: 上述 RayResourcePool 的 instance
ray_cls_with_init: 上述 RayClassWithInitArgs 的 instance
"""
super().__init__(resource_pool=resource_pool, **kwargs)
self.ray_cls_with_init = ray_cls_with_init
self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix
...
if self._is_init_with_detached_workers:
# resource_pool 如果为 None
self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles)
else:
# 如果提供了 resource_pool
self._init_with_resource_pool(
resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached
)

if ray_cls_with_init is not None:
self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)

self.wg_dict = None
self.method_names = []

1. _init_with_resource_pool

接着重点看 _init_with_resource_pool 方法的核心部分:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# verl/single_controller/ray/base.py
class RayWorkerGroup(WorkerGroup):
...
def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached):
...
# 参考上述 RayResourcePool 类的解析
pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name)
...
rank = -1
# 第一个 node 上 GPU 数量
local_world_size = resource_pool.store[0]

# per node,node 层面循环
for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)):
assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the "
# per device (GPU) 层面循环
for local_rank in range(local_world_size):
rank += 1

# we pass in environment variable at option so that Worker can use environment variable to set
# ---------------------------
# 作为 Ray actor 的环境变量
# 这些环境变量有些会用作分布式进程组的初始化参数(torch.distributed.init_process_group)
# ---------------------------
env_vars = {
"WORLD_SIZE": str(world_size),
"RANK": str(rank),
"WG_PREFIX": self.name_prefix,
"WG_BACKEND": "ray",
"RAY_LOCAL_WORLD_SIZE": str(local_world_size),
"RAY_LOCAL_RANK": str(local_rank),
}
if rank != 0:
env_vars["MASTER_ADDR"] = self._master_addr
env_vars["MASTER_PORT"] = self._master_port

import re

cia_name = type(ray_cls_with_init.cls).__name__
match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5

if self.profile_steps:
ray_cls_with_init.update_options(
{
"runtime_env": {
"env_vars": env_vars,
"nsight": self.worker_nsight_options,
},
"name": name,
}
)
else:
ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name})

if detached:
ray_cls_with_init.update_options({"lifetime": "detached"})

# create a worker
# --------------------------
# 参考上述 RayClassWithInitArgs 类的解析
# 初始化 Ray actor 并为其分配资源
# --------------------------
worker = ray_cls_with_init(
placement_group=pg,
placement_group_bundle_idx=local_rank,
use_gpu=use_gpu,
num_gpus=num_gpus,
device_name=self.device_name,
)
self._workers.append(worker)
self._worker_names.append(name)

结合示例代码:

1
2
3
4
5
6
# 1 node, 2 GPU
resource_pool = RayResourcePool([2], use_gpu=True, max_colocate_count=1)
class_with_args = RayClassWithInitArgs(cls=GPUWorker)
worker_group = RayWorkerGroup(resource_pool, class_with_args)
workers = worker_group.workers
print(ray.get([worker.env_info.remote() for worker in workers]))

经过 _init_with_resource_pool 后,GPUWorker 会分配给 2 个 GPU,并完成初始化。此时,Worker Group 的实例中有 2 个 GPUWorker 的实例。分别调用 2 个实例的 env_info 函数,输出为:

1
2
3
4
5
# 注意:运行时,指定的 CUDA_VISIBLE_DEVICES 为 "0,2"
[
'[zmZJ9cGPUWorker_0:0] rank: 0, world_size: 2, local_world_size: 1, local_rank: 0, master_addr: 172.17.0.3, master_port: 55185, cuda_visible_devices: 0',
'[zmZJ9cGPUWorker_0:1] rank: 1, world_size: 2, local_world_size: 1, local_rank: 0, master_addr: 172.17.0.3, master_port: 55185, cuda_visible_devices: 2'
]

如此,则 GPUWorker 在 2 个 GPU 上分别初始化了一个实例,每个实例分配了 {"CPU": 1, "GPU": 1} 的资源。

RayWorkerGroup 通过属性 _workers 列表管理着所有的 worker。但是如果需要调用 worker 列表中某个或者某些 worker 的方法,有没有更加统一简洁的实现?比如像上述 env_info 方法的调用,是通过显式地循环 _workers 列表调用的,且传输给不同 worker 的数据也需要挨个处理,这样既繁琐也缺乏灵活性。这就涉及到 _bind_worker_method 方法。

2. _bind_worker_method

在介绍 _bind_worker_method 方法之前先看看 register 装饰器函数的作用。

1
2
3
4
5
6
@ray.remote
class GPUWorker(Worker):
...
@register(dispatch_mode=Dispatch.ONE_TO_ALL, execute_mode=Execute.ALL)
def add(self, x, y):
return f"[{self.get_name()}] {x + y}"

register 会对被装饰的函数设置一个名为 {MAGIC_ATTR} 的属性,属性值为:{"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking, "materialize_futures": materialize_futures},具体到上述示例,则 GPUWorkeradd 函数会被设置一个属性 {MAGIC_ATTR},对应的值为:{"dispatch_mode": Dispatch.ONE_TO_ALL, "execute_mode": Execute.ALL, "blocking": True, "materialize_futures": True}

_bind_worker_method 方法,会对 Worker 中经过 @register 装饰的方法作特殊处理。具体看下面代码中的注释。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# verl/single_controller/base/worker_group.py
class WorkerGroup:
...
def _bind_worker_method(self, user_defined_cls, func_generator):
method_names = []
for method_name in dir(user_defined_cls):
try:
method = getattr(user_defined_cls, method_name)
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
except Exception:
# if it is a property, it will fail because Class doesn't have instance property
continue

# --------------------
# 经过 @register 装饰过的函数会被设置 {MAGIC_ATTR} 属性
# --------------------
if hasattr(method, MAGIC_ATTR):
# this method is decorated by register
# --------------------
# {"dispatch_mode": ..., "execute_mode": ..., "blocking": ..., "materialize_futures": ...}
# --------------------
attribute = getattr(method, MAGIC_ATTR)
assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}"
assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key"

dispatch_mode = attribute["dispatch_mode"]
execute_mode = attribute["execute_mode"]
blocking = attribute["blocking"]

# -----------------------
# 根据 "dispatch_mode" 获取对应的
# 输入参数或者数据的分派函数
# 返回数据的收集函数。
#
# 相当于,多个 Worker 分布在不同的资源 bundle 上(如:GPU),
# 通过 dispatch_fn / collect_fn 可以在外层使用统一的接口自动进行
# 数据在不同 Worker 上的分发和回收聚合。
# -----------------------
# get dispatch fn
if isinstance(dispatch_mode, Dispatch):
# get default dispatch fn
fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
dispatch_fn = fn["dispatch_fn"]
collect_fn = fn["collect_fn"]
else:
assert isinstance(dispatch_mode, dict)
assert "dispatch_fn" in dispatch_mode
assert "collect_fn" in dispatch_mode
dispatch_fn = dispatch_mode["dispatch_fn"]
collect_fn = dispatch_mode["collect_fn"]

# -----------------------------
# 根据 "execute_mode" 获取具体的执行函数,
# 这些函数定义在 RayWorkerGroup 中
# -----------------------------
# get execute_fn_name
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
wg_execute_fn_name = execute_mode["execute_fn_name"]

# get execute_fn from string
try:
execute_fn = getattr(self, wg_execute_fn_name)
assert callable(execute_fn), "execute_fn must be callable"
except Exception:
print(f"execute_fn {wg_execute_fn_name} is invalid")
raise

# bind a new method to the RayWorkerGroup
# ----------------------
# 对 @register 装饰的函数注入 dispatch_fn collect_fn execute_fn 逻辑。
# 参考下方介绍。
# ----------------------
func = func_generator(
self,
method_name,
dispatch_fn=dispatch_fn,
collect_fn=collect_fn,
execute_fn=execute_fn,
blocking=blocking,
)

try:
# -------------------------
# 将 Worker 中使用 @register 装饰的函数
# 绑定到 WorkerGroup 中。
# 以上述示例代码为例,即将 `GPUWorker` 中的 `add` 和 `dummy_compute`
# 绑定到 `RayWorkerGroup` 上。
# -------------------------
setattr(self, method_name, func)
method_names.append(method_name)
except Exception as e:
raise ValueError(f"Fail to set method_name {method_name}") from e

return method_names

func_generator 定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# verl/single_controller/ray/base.py
def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):
class Functor:
def __call__(this, *args, **kwargs):
args, kwargs = dispatch_fn(self, *args, **kwargs)
padding_count = kwargs.pop(_padding_size_key, 0)
output = execute_fn(method_name, *args, **kwargs)
if blocking:
output = ray.get(output)
output = collect_fn(self, output)
if padding_count > 0:
if isinstance(output, DataProto):
indices = [i for i in range(len(output))][:-padding_count]
output = output.select_idxs(indices)
elif isinstance(output, list):
output = output[:-padding_count]
return output

# use class type to pass the method_name to get a better observability
return type(method_name, (Functor,), {})()
  • 类似于装饰器的作用,在 method_name 对应的函数(即被 @register 装饰的函数)基础上,注入 dispatch_fn,execute_fn,collect_fn 的逻辑;
  • execute_fn(method_name, *args, **kwargs) 部分实际调用的 RayWorkerGroupexecute_* 相关函数;
    • "execute_mode": Execute.ALL 为例,调用 execute_all -> execute_all_async,循环 _workers 列表中的 Worker 实例,并调用实例函数名 method_name 对应的函数。

综上,RayWorkerGroup

  • 根据 resource_pool 分配资源给 Worker 并初始化之;
  • Workerregister 装饰的函数注入参数和数据的分派(dispatch_fn)、数据收集聚合(collect_fn)及不同的执行模式(execute_fn)。并将其同名函数绑定到 RayWorkerGroup 上。

如此,在 RayWorkerGroup 层就可以直接调用 Worker 中被 register 装饰的函数,进而执行不同资源上的 Worker 对应的函数。

可以参考示例代码中的 add 函数:

1
worker_group.add(x=1, y=2)

下图示意了这个流程:

Single-Controller & Multi-Controller

经由上述示例,可以看到,RayWorkerGroup(WorkerGroup) 其实可以对应为 Single-Controller,而 GPUWorker(Worker) 对应 Multi-Controller 中的 Worker。

WorkerGroup 以 single-controller 方式,分配资源 bundles 给 Worker,并以注册的方式实现在 WorkerGroup 层面灵活分派数据给 Worker,从 Worker 聚合结果数据等功能。分派数据可以通过 Ray object store 高效实现。

作为实际进行计算密集操作(forward,backward)的 Worker,通过 NCCL 通信(梯度同步等)。即大家熟悉的分布式训练。

WorkerGroup colocate

RLHF 中涉及多种 Worker,比如,actor modelcritic modelreference modelreward model。如果对每个每个类型的 model 都各自使用一个 WorkerGroup 管理显然不灵活,扩展性差。verl 以 colocated worker 的方式将不同的 model 进行统一的管理。下面直接切入 verl 中的实现。

RayPPOTrainer::init_workers() 中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# verl/trainer/ppo/ray_trainer.py
class RayPPOTrainer:
...
def init_workers(self):
...
# ------------------------------------
# resource_pool: ResourcePool;参考上述核心类的解析
# class_dict 形如:{
# "actor_rollout": RayClassWithInitArgs(ray.remote(<actor_rollout Worker>)),
# "critic": RayClassWithInitArgs(ray.remote(<critic Worker>)),
# "ref": RayClassWithInitArgs(ray.remote(<ref Worker>)),
# "rm": RayClassWithInitArgs(ray.remote(<rm Worker>))
# }
# RayClassWithInitArgs 参考上述核心类的解析
# ------------------------------------
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
# --------------------------------------
# 这个参考上述对 `RayWorkerGroup` 的解析,相对清晰。
# worker_dict_cls 同时包含了多个 role (actor_rollout,ref,critic,rm) 的 Worker。
# --------------------------------------
wg_dict = self.ray_worker_group_cls(
resource_pool=resource_pool,
ray_cls_with_init=worker_dict_cls,
device_name=self.device_name,
**wg_kwargs,
)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)

核心在 create_colocated_worker_cls 方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# verl/single_controller/ray/base.py
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):
...
# TODO: create a class with customizable name
class WorkerDict(worker_cls):
def __init__(self):
super().__init__()
self.worker_dict = {}
for key, user_defined_cls in cls_dict.items():
# ---------------------
# key 为不同 model 的 role name,即 `actor_rollout` `critic` `ref` `rm`。
#
# 注意此时 user_defined_cls 是正常的 class,而非 ray actor class,
# 比如对于 actor_rollout,是 ActorRolloutRefWorker 类。
# ---------------------
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
# directly instantiate the class without remote
# in worker class, e.g. <verl.single_controller.base.worker.Worker>
# when DISABLE_WORKER_INIT == 1 it will return immediately
with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}):
self.worker_dict[key] = user_defined_cls(
*init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {})
)

# now monkey-patch the methods from inner class to WorkerDict
for key, user_defined_cls in cls_dict.items():
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
_bind_workers_method_to_parent(WorkerDict, key, user_defined_cls)

remote_cls = ray.remote(WorkerDict)
# 把 `WorkerDict` 包进 RayClassWithInitArgs 中,用作后续 `WorkerGroup` 的创建。
remote_cls = RayClassWithInitArgs(cls=remote_cls)
return remote_cls

_bind_workers_method_to_parent 函数做的事情是:

  • actor_rolloutcriticrefrm model 的 Worker 定义中,被 @register 装饰的函数绑定到 WorkerDict 类中。

比如,对于 actor_rolloutref Worker 对应的类:

1
2
3
4
5
6
# verl/workers/fsdp_workers.py
class ActorRolloutRefWorker(Worker, DistProfilerExtension):
...
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
...

经过 _bind_workers_method_to_parent 后,WorkerDict 会新 set 几个与 actor_rolloutref 相关的函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
class WorkerDict(worker_cls):
...
# 对于 `actor_rollout` role
# 注意同时会将 {MAGIC_ATTR} set 给这个函数
# 实际就是 `actor_rollout` 的 init_model
def actor_rollout_init_model(self):
...

# 对于 `ref` role
# 注意同时会将 {MAGIC_ATTR} set 给这个函数
# 实际就是 `ref` 的 init_model
def ref_init_model(self):
...

最后,RayPPOTrainer::init_workers 中还有 spawn 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class RayWorkerGroup(WorkerGroup):
...
def spawn(self, prefix_set):
"""Spawn to a dictionary of worker groups, each with a subset of method with prefix.

Args:
prefix_set: Set of prefixes to create worker groups for

Returns:
Dictionary of worker groups keyed by prefix
"""
if self.fused_worker_used:
return self.spawn_fused(prefix_set)

def _rebind_actor_methods(worker_group, actor_name):
prefix: str = actor_name + "_"
for method_name in dir(worker_group):
if method_name.startswith(prefix):
# only valid when Python >= 3.9
original_method_name = method_name.removeprefix(prefix)
method = getattr(worker_group, method_name)
setattr(worker_group, original_method_name, method)

new_worker_group_dict = {}
for prefix in prefix_set:
new_worker_group = self.from_detached(
name_prefix=self.name_prefix,
worker_names=self._worker_names,
worker_handles=self._workers,
ray_cls_with_init=self.ray_cls_with_init,
profile_steps=self.profile_steps,
worker_nsight_options=self.worker_nsight_options,
)

_rebind_actor_methods(new_worker_group, prefix)
new_worker_group_dict[prefix] = new_worker_group
return new_worker_group_dict

这个函数的作用是将不同 role 进行分组,相当于逻辑上(实际上共享一份,即 spawn 的含义)相同的 role 放在在一个 WorkerGroup,然后将上面绑定的在 WorkerDict 上的函数,去掉 role name 前缀后,再绑定到 WorkerGroup 上。

这样,就可以在 WorkerGroup 层面直接调用不同 role 的 @register 装饰的函数了,且函数名和对应 role Worker 中原始的函数名相同。

下面以图来简单示意这个过程最终的结果。

上图展示了各个 PPO 中 各个 role 的 Worker 的 init_model 的函数。其他的 @register 装饰函数同理。

整体流程概览

经过上述核心类和方法的源码阅读,回到 PPO Trainer 的流程,其实就很清晰了。

另外需要注意的是,各个 PPO role 的 Worker (ActorRolloutRefWorker、CriticWorker、RewardModelWorker) 在初始化时,不会马上 build model,只是根据 ray worker group 分配的 resource pool 进行分布式进程组的初始化、device mesh (FSDP) 的创建。而 model 的真正 build (模型初始化、vLLM / pytorch backend 选择等等) 发生在调用 init_model 时。

PPO 各 role Worker 及涉及 @register 的方法(以 FSDP 为例)如下。

  • actor_rollout & ref
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# verl/workers/fsdp_workers.py
class ActorRolloutRefWorker(Worker, DistProfilerExtension):

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
...

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@DistProfiler.annotate(color="red")
def update_actor(self, data: DataProto):
...

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@DistProfiler.annotate(color="red")
def generate_sequences(self, prompts: DataProto):
...

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@DistProfiler.annotate(color="blue")
def compute_log_prob(self, data: DataProto):
...

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@DistProfiler.annotate(color="olive")
def compute_ref_log_prob(self, data: DataProto):
...

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
...

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):
...
  • critic

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    # verl/workers/fsdp_workers.py
    class CriticWorker(Worker, DistProfilerExtension):

    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    def init_model(self):
    ...

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    @DistProfiler.annotate(color="cyan")
    def compute_values(self, data: DataProto):
    ...

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    @DistProfiler.annotate(color="pink")
    def update_critic(self, data: DataProto):
    ...

    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
    ...

    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
    ...
  • rm

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    # verl/workers/fsdp_workers.py
    class RewardModelWorker(Worker, DistProfilerExtension):

    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    def init_model(self):
    ...

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    @DistProfiler.annotate(color="brown")
    def compute_rm_score(self, data: DataProto):
    ...

Trainer 进行 fit() 时,即 PPO 训练流程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
actor_rollout_wg = ...
rm_wg = ...
ref_policy_wg = ...
critic_wg = ...
for prompts, pretrain_batch in dataloader:
# stage 1: generate responses
batch = actor_rollout_wg.generate_sequences(prompts)
# stage 2: prepare experience
batch = rm_wg.compute_rm_score(batch)
batch = actor_rollout_wg.compute_log_prob(batch)
batch = ref_policy_wg.compute_ref_log_prob(batch)
batch = critic_wg.compute_values(batch)
batch = compute_advantage(batch)
# stage 3: actor & critic training
critic_wg.update_critic(batch)
actor_rollout_wg.update_actor(batch)

参考

  • Nathan Lambert. Reinforcement Learning from Human Feedback: A short introduction to RLHF and post-training focused on language models. arxiv
  • Proximal Policy Optimization Algorithms. arxiv
  • Spinning Up in Deep RL. OpenAI. link
  • HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION. arxiv
  • DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models. arxiv
  • OpenRLHF: An Easy-to-use, Scalable and High-performance RLHF Framework. arxiv
  • Secrets of RLHF in Large Language Models Part I: PPO. arxiv
  • HybridFlow: A Flexible and Efficient RLHF Framework. arxiv
  • Slide 1, Slide 2, Slide 3
  • DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales. github
  • HybridFlow veRL 原文浅析. link
  • 基于 Ray 的分离式架构:veRL、OpenRLHF 工程设计. 知乎