Basics of Reinforcement Learning - GRPO 及代码实现理解 (part 3)

GRPO

原理

对每个问题 prompt ,GRPO 采样一组 的回答输出

关于 使用如下的无偏估计方式,并且保证了 > 0。

优势函数的定义:

code

根据 open-r1 代码库的 GRPO 实现进行说明。

下面以如下超参数作为初始背景:

1
2
3
4
5
6
7
8
9
10
11
12
13
gradient_accumulation_steps: 4
max_steps: -1
num_generations: 4
per_device_train_batch_size: 4
reward_funcs:
- accuracy
- format
- tag_count
reward_weights:
- 1.0
- 1.0
- 1.0
gpus: 2 # 2 张卡

根据上述超参数,global_batch_size = gpus * per_device_train_batch_size 为 8,每个 prompt 生成 4 个回答 (completion)。下述注释中,B 代表 micro_batch,即每个 GPU 上会分配的数据的 batch_size。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
with torch.no_grad():
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
old_per_token_logps = None

# -----------------------------
# prompt_completion_ids: (B, P+C)
# attention_mask: (B, P+C)
# logits_to_keep: completion_ids_length
# -----------------------------
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)

_get_per_token_logps 调用如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# -----------------------------
# input_ids: (B, P+C), like: (4, 1215)
# attention_mask: (B, P+C), like: (4, 1215)
# logits_to_keep: completion_ids_length, like: 1024
# logits: (B, completion_ids_length+1, vocab_size), like: (4, 1025, 151936)
# -----------------------------
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature

per_token_logps = []
for row_logits, row_labels in zip(logits, input_ids):
# -----------------------------
# (completion_ids_length, vocab_size), like: (1024, 151936)
# -----------------------------
row_logps = F.log_softmax(row_logits, dim=-1)
# -----------------------------
# (completion_ids_length), like: (1024)
# -----------------------------
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
per_token_logps.append(row_per_token_logps)
# -----------------------------
# (B, completion_ids_length), like: (4, 1024)
# -----------------------------
per_token_logps = torch.stack(per_token_logps)

advantages 优势函数的计算:

代码实现中,限定了 global_batch_size % num_generations == 0.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# -----------------------------
# rewards_per_func: (B, num_reward_func), like (4, 3) 当前 rank 有 4 条 seq,有 3 个奖励函数。
# gather 之后:rewards_per_func like (8, 3),当前 2 GPUs
# -----------------------------
rewards_per_func = gather(rewards_per_func)
# Apply weights to each reward function's output and sum
# -----------------------------
# (global_batch_size,) like (8,)
# -----------------------------
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)

# Compute grouped-wise rewards
# (global_batch_size / num_generations,) like (2,)
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
# (global_batch_size / num_generations,) like (2,)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

# Normalize the rewards to compute the advantages
# (global_batch_size,) like (8,)
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# (global_batch_size,) like (8,)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = rewards - mean_grouped_rewards
if self.scale_rewards:
advantages = advantages / (std_grouped_rewards + 1e-4)

# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
# (B,) like (4,)
advantages = advantages[process_slice]

loss 计算:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# (B, completion_ids_length), like: (4, 1024)
ref_per_token_logps = self._get_per_token_logps(ref_model, input_ids, attention_mask, logits_to_keep)
# (B, completion_ids_length), like: (4, 1024)
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

# self.beta default is 0.04
if self.beta != 0.0:
# -----------------------------
# 对照上述 GRPO 目标函数的 KL 计算:$D_{KL}[\pi_\theta || \pi_{ref}]$
# 计算 policy model 和 reference model 之间的 KL
# -----------------------------
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
)

# Compute the loss
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
# _generate_and_score_completions) and use per_token_logps.detach() instead.
# -----------------------------
# 对应 GRPO 目标函数中的 $\pi_{\theta_{old}}, \pi_{\theta}, advantages$ 联合计算的部分。
# -----------------------------
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

# -----------------------------
# 对照 GRPO 的总的目标函数表达式
# -----------------------------
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl

if self.loss_type == "grpo":
# completion_mask 生成的内容的 mask,以 eos 为界限
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()