Skip to content

Commit 36a587a

Browse files
yiyixuxuyiyixuxu
authored andcommitted
Fix a bug in add_noise function (huggingface#6085)
* fix * copies --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent 8c558ac commit 36a587a

File tree

5 files changed

+50
-5
lines changed

5 files changed

+50
-5
lines changed

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,16 @@ def add_noise(
734734
schedule_timesteps = self.timesteps.to(original_samples.device)
735735
timesteps = timesteps.to(original_samples.device)
736736

737-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
737+
step_indices = []
738+
for timestep in timesteps:
739+
index_candidates = (schedule_timesteps == timestep).nonzero()
740+
if len(index_candidates) == 0:
741+
step_index = len(schedule_timesteps) - 1
742+
elif len(index_candidates) > 1:
743+
step_index = index_candidates[1].item()
744+
else:
745+
step_index = index_candidates[0].item()
746+
step_indices.append(step_index)
738747

739748
sigma = sigmas[step_indices].flatten()
740749
while len(sigma.shape) < len(original_samples.shape):

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,16 @@ def add_noise(
896896
schedule_timesteps = self.timesteps.to(original_samples.device)
897897
timesteps = timesteps.to(original_samples.device)
898898

899-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
899+
step_indices = []
900+
for timestep in timesteps:
901+
index_candidates = (schedule_timesteps == timestep).nonzero()
902+
if len(index_candidates) == 0:
903+
step_index = len(schedule_timesteps) - 1
904+
elif len(index_candidates) > 1:
905+
step_index = index_candidates[1].item()
906+
else:
907+
step_index = index_candidates[0].item()
908+
step_indices.append(step_index)
900909

901910
sigma = sigmas[step_indices].flatten()
902911
while len(sigma.shape) < len(original_samples.shape):

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,16 @@ def add_noise(
891891
schedule_timesteps = self.timesteps.to(original_samples.device)
892892
timesteps = timesteps.to(original_samples.device)
893893

894-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
894+
step_indices = []
895+
for timestep in timesteps:
896+
index_candidates = (schedule_timesteps == timestep).nonzero()
897+
if len(index_candidates) == 0:
898+
step_index = len(schedule_timesteps) - 1
899+
elif len(index_candidates) > 1:
900+
step_index = index_candidates[1].item()
901+
else:
902+
step_index = index_candidates[0].item()
903+
step_indices.append(step_index)
895904

896905
sigma = sigmas[step_indices].flatten()
897906
while len(sigma.shape) < len(original_samples.shape):

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,16 @@ def add_noise(
897897
schedule_timesteps = self.timesteps.to(original_samples.device)
898898
timesteps = timesteps.to(original_samples.device)
899899

900-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
900+
step_indices = []
901+
for timestep in timesteps:
902+
index_candidates = (schedule_timesteps == timestep).nonzero()
903+
if len(index_candidates) == 0:
904+
step_index = len(schedule_timesteps) - 1
905+
elif len(index_candidates) > 1:
906+
step_index = index_candidates[1].item()
907+
else:
908+
step_index = index_candidates[0].item()
909+
step_indices.append(step_index)
901910

902911
sigma = sigmas[step_indices].flatten()
903912
while len(sigma.shape) < len(original_samples.shape):

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,16 @@ def add_noise(
828828
schedule_timesteps = self.timesteps.to(original_samples.device)
829829
timesteps = timesteps.to(original_samples.device)
830830

831-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
831+
step_indices = []
832+
for timestep in timesteps:
833+
index_candidates = (schedule_timesteps == timestep).nonzero()
834+
if len(index_candidates) == 0:
835+
step_index = len(schedule_timesteps) - 1
836+
elif len(index_candidates) > 1:
837+
step_index = index_candidates[1].item()
838+
else:
839+
step_index = index_candidates[0].item()
840+
step_indices.append(step_index)
832841

833842
sigma = sigmas[step_indices].flatten()
834843
while len(sigma.shape) < len(original_samples.shape):

0 commit comments

Comments
 (0)