def apply_debiased_estimation(loss, timesteps, noise_scheduler): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 weight = 1 / (snr_t + 1) loss = weight * loss return loss