The Overfitting Test
Example of a naive task where DPO fails:
The Overfitting Test
DPO loss:
-F.logsigmoid(self.beta * (policy_chosen_logp - policy_rejected_logp - ref_chosen_logp + ref_rejected_logp))
SFT loss:
losses = -policy_chosen_logps
(I will talk about KTO later. For now, focus on DPO and IPO)
The Overfitting Test
Overfits quickly:
losses = -F.logsigmoid(policy_chosen_logp)
Overfits slowly with a ceiling because of the reference model (as expected)
losses = -F.logsigmoid(self.beta * (policy_chosen_logp - ref_chosen_logp))
Fails to overfit
losses = -policy_chosen_logps + policy_reject_logps
The Overfitting Test
The Overfitting Test
DPO and IPO
logits = policy_chosen_logps - policy_rejected_logps - reference_chosen_logps + reference_rejected_logps
KTO
term1 = sigmoid(policy_chosen_logps - ref_chosen_logps - (policy_rejected_logps - ref_rejected_logps).clamp(min=0))
term2 = sigmoid((policy_chosen_logps - ref_chosen_logps).clamp(min=0) - policy_rejected_logps - ref_rejected_logps)
The Overfitting Test
The Overfitting Test
- Possible that all the “overfitting” in DPO is a manifestation of the degenerate solution caused by log(prob(rejected_token)
- And all the DPO derivatives and different “constraining” methods are bandaids to prevent it
- when in fact, the solution should be to compute rejected log probs differently
The Overfitting Test
What’s wrong with DPO math that led to this?
A: The Bradley-Terry:
The Overfitting Test
Observations and Open questions