verl 解读 - ray 相关前置知识 (part1)

概述

verl 强化学习框架依赖于分布式计算框架 Ray。Ray 相关的基础知识是理解 verl 代码的基础。

本文是关于 Ray Actors 的一些基础操作说明。覆盖的内容:

  • 定义一个 Actors,其初始化一个 torch model
  • Actors 初始化一个分布式进程组 (torch.distributed.init_process_group);
  • 实现一个简易的 DP (Data Parallel) forward 和 backward 计算,backward 后进行不同 GPU 间的 grad 同步。

示例代码阐释

完整程序如下,详细说明见代码中的注释。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import traceback
import os
import socket
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import ray
from ray.util.placement_group import placement_group, remove_placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy


def find_free_port():
"""Find a free port for master communication"""
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
return port


def set_random_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

set_random_seed(42)
ray.init()


class DummyAttn(nn.Module):
"""
定义一个简单的 model,完成 forward,backward,梯度同步等操作。
"""
def __init__(
self,
num_heads: int = 16,
head_dim: int = 64,
hidden_size: int = 1024,
):
super().__init__()

self.num_heads = num_heads
self.head_dim = head_dim
self.q_proj = nn.Linear(hidden_size, head_dim * num_heads, bias=False)
self.k_proj = nn.Linear(hidden_size, head_dim * num_heads, bias=False)
self.v_proj = nn.Linear(hidden_size, head_dim * num_heads, bias=False)
self.out_proj = nn.Linear(head_dim * num_heads, hidden_size, bias=False)

self._init_weights()
# ---------------------
# 为 `requires_grad=True` 的 tensor 注册梯度同步的 hook
# backward 完成梯度计算后,自动调用
# ---------------------
self.register_backward_hook(self._allreduce_grads)

def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)

def register_backward_hook(self, hook):
for p in self.parameters():
if p.requires_grad is True:
p.register_hook(hook)

def _allreduce_grads(self, grad):
"""
在当前 group (default group) 中执行梯度 all_reduce 同步操作。
"""
dist.all_reduce(grad, op=dist.ReduceOp.SUM)
# grad /= world_size
return grad

def forward(self, x):
"""
Args:
x: (batch_size, seq_len, hidden_size)
"""
bs, seq_len, hidden_size = x.size()
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

q = q.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(bs, seq_len, self.num_heads, self.head_dim).permute(0, 2, 3, 1)
v = v.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

attn = q @ k / math.sqrt(self.head_dim) # equal `torch.matmul`
attn_scores = F.softmax(attn, dim=-1)
y = attn_scores @ v
y = y.transpose(1, 2).contiguous().view(bs, seq_len, -1)
y = self.out_proj(y)
return y


class WorkerBase:
""" 用作 Actors 的类
可以直接使用装饰器定义 @ray.remote 或者 ray.remote(WorkerBase)
"""

def __init__(self, temp_init: bool = False):
self._node_id = ray.get_runtime_context().get_node_id()
self._actor_id = ray.get_runtime_context().get_actor_id()
self._task_id = ray.get_runtime_context().get_task_id()
self._job_id = ray.get_runtime_context().get_job_id()
self._hostname = socket.gethostname()
self._ip_address = socket.gethostbyname(socket.gethostname())
if temp_init:
return
# ---------------------
# 注意:
# 因为每个 actor 会启用自己的进程空间,如果需要随机种子,需要在 actor 内部设置。
# ---------------------
self._set_seed(42)

# ---------------------
# 和 pytorch 中定义分布式训练流程差不多
# ---------------------
if not dist.is_initialized():
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
world_size=int(os.getenv("WORLD_SIZE", "1")),
rank=int(os.getenv("RANK", "0")),
)

self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# print(f"=> Rank {self._rank}/{self._world_size}")

self.model = DummyAttn()
# ---------------------
# 注意:
# 因为 Ray 的资源调度,cuda index 的设置需要注意,
# 此处,如果当前 actor 使用 1 CPU,1 GPU,则在当前 actor 只能看到 1 个 GPU,
# 假设启动当前这个脚本时,指定了 CUDA_VISIBLE_DEVICES=0,2,3。
# 则,在启动的 3 个 actor 中分别获取 CUDA_VISIBLE_DEVICES,
# Rank:0 -> CUDA_VISIBLE_DEVICES=0
# Rank:1 -> CUDA_VISIBLE_DEVICES=2
# Rank:2 -> CUDA_VISIBLE_DEVICES=3
# ---------------------
self.model.to("cuda")
print(f"=> Rank {self._rank} init model")

def get_actor_info(self):
return {
"node_id": self._node_id,
"actor_id": self._actor_id,
"task_id": self._task_id,
"job_id": self._job_id,
"hostname": self._hostname,
"ip_address": self._ip_address,
}

def shutdown(self):
dist.destroy_process_group()

def _set_seed(self, seed: int = 42):
set_random_seed(seed)

def train_step(self, data):
self.model.train()

x = data.to("cuda")
y = self.model(x)
loss = y.sum()
loss.backward()
return loss.cpu()

def sample_grads(self):
for name, p in self.model.named_parameters():
if p.requires_grad is True:
# print(f"=> Rank {self._rank} grad: {p.grad}")
return name, p.grad.cpu()


def main():
num_devices = get_num_devices()
# ---------------------
# placement group 的 Scheduling,
# 参考:https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
# ---------------------
pg = placement_group([
{"CPU": 1, "GPU": 1} for _ in range(num_devices)
], strategy="STRICT_PACK", name="ray_actor_communication")

ray.get(pg.ready())
print(f"=> Placement group {pg.id} is ready, num_devices: {num_devices}")
worker_cls = WorkerBase
# ---------------------
# WorkerBase 普通类 -> Ray Actors
# ---------------------
Worker = ray.remote(worker_cls)

# ---------------------
# 以下部分主要是为了获取 actor 的 ip,并设置 "MASTER_ADDR","MASTER_PORT",pytorch 分布式进程组初始化时需要。
# 实现上比较丑陋,不用过多关注。
# ---------------------
# Create a temporary rank 0 worker just to get network info
print("========= Get network info ==========")
temp_worker = Worker.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=0,
),
runtime_env={"env_vars": {"WORLD_SIZE": "1", "RANK": "0", "MASTER_ADDR": "127.0.0.1", "MASTER_PORT": str(find_free_port())}},
num_gpus=1,
num_cpus=1,
).remote(temp_init=True)

# Get master address
network_info = ray.get(temp_worker.get_actor_info.remote())
master_addr = network_info["ip_address"]
master_port = str(find_free_port())

print(f"Using master: {master_addr}:{master_port}")
# Terminate temporary worker
ray.kill(temp_worker)
print("========= Terminate temporary worker ==========\n")

workers = []
for i in range(num_devices):
env_vars = {
"WORLD_SIZE": str(num_devices),
"RANK": str(i),
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
}

workers.append(
# ------------------
# WorkerBase 类的初始化,并分配和绑定调度资源
# ------------------
Worker.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=i,
),
runtime_env={
"env_vars": env_vars,
},
num_gpus=1,
num_cpus=1,
).remote()
)

# ------ test ------
# for worker in workers:
# print(f"Worker {ray.get(worker.get_actor_info.remote())}")

# ------ train ------
datas = torch.randn(4, 128, 1024)
for i, worker in enumerate(workers):
# ----------------
# 此处模拟 DP 操作
# ----------------
data = datas.chunk(num_devices)[i]
worker.train_step.remote(data)

# ------ sample grads ------
grads = ray.get([worker.sample_grads.remote() for worker in workers])
for i, (name, grad) in enumerate(grads):
print(f"=> Rank {i} grad: {name}\n{grad}")


if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
pass
except Exception as e:
print(e)
traceback.print_exc()
finally:
ray.shutdown()

使用 2 GPUs 执行,即数据并行度为 2:PYTHONUNBUFFERED=1 CUDA_VISIBLE_DEVICES=0,1 python <this file>

结果为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
=> Rank 0 grad: q_proj.weight
tensor([[-3.8231, 0.4550, 5.1973, ..., -1.1057, 0.7960, 3.6364],
[-6.1473, -6.2373, -0.2407, ..., 2.4793, -4.6591, -3.9527],
[-1.2464, -1.6429, -0.6477, ..., 1.6034, 0.2802, 3.2735],
...,
[-1.1934, 0.6633, -4.8593, ..., -2.4207, 0.7207, -3.9471],
[ 4.1377, 10.6296, -2.8300, ..., -1.6472, 7.8439, -2.1861],
[ 0.5365, -1.7629, 4.4939, ..., -0.6042, -7.0833, -1.4912]])
=> Rank 1 grad: q_proj.weight
tensor([[-3.8231, 0.4550, 5.1973, ..., -1.1057, 0.7960, 3.6364],
[-6.1473, -6.2373, -0.2407, ..., 2.4793, -4.6591, -3.9527],
[-1.2464, -1.6429, -0.6477, ..., 1.6034, 0.2802, 3.2735],
...,
[-1.1934, 0.6633, -4.8593, ..., -2.4207, 0.7207, -3.9471],
[ 4.1377, 10.6296, -2.8300, ..., -1.6472, 7.8439, -2.1861],
[ 0.5365, -1.7629, 4.4939, ..., -0.6042, -7.0833, -1.4912]])

可以看到两张卡上的 model 梯度已经同步了。

然后使用 1 GPUs 执行,即数据并行度为 1:PYTHONUNBUFFERED=1 CUDA_VISIBLE_DEVICES=0 python <this file>
结果为:

1
2
3
4
5
6
7
8
=> Rank 0 grad: q_proj.weight
tensor([[-3.8231, 0.4550, 5.1973, ..., -1.1057, 0.7960, 3.6364],
[-6.1473, -6.2373, -0.2407, ..., 2.4793, -4.6591, -3.9527],
[-1.2464, -1.6429, -0.6477, ..., 1.6034, 0.2802, 3.2735],
...,
[-1.1934, 0.6633, -4.8593, ..., -2.4207, 0.7207, -3.9471],
[ 4.1377, 10.6296, -2.8300, ..., -1.6472, 7.8439, -2.1861],
[ 0.5365, -1.7629, 4.4939, ..., -0.6042, -7.0833, -1.4912]])

对比 GPUx1 和 GPUx2 的结果,是对齐的,验证了 DP 计算和梯度同步无误。

参考

  • Ray docs. docs
  • HybridFlow: A Flexible and Efficient RLHF Framework. 2409. arxiv