with torch.no_grad(): # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's # computation here, and use per_token_logps.detach() instead. if self.num_iterations > 1: old_per_token_logps = self._get_per_token_logps( self.model, prompt_completion_ids, attention_mask, logits_to_keep ) else: old_per_token_logps = None
# Normalize the rewards to compute the advantages # (global_batch_size,) like (8,) mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # (global_batch_size,) like (8,) std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = rewards - mean_grouped_rewards if self.scale_rewards: advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) # (B,) like (4,) advantages = advantages[process_slice]