File tree Expand file tree Collapse file tree 5 files changed +50
-5
lines changed Expand file tree Collapse file tree 5 files changed +50
-5
lines changed Original file line number Diff line number Diff line change @@ -734,7 +734,16 @@ def add_noise(
734
734
schedule_timesteps = self .timesteps .to (original_samples .device )
735
735
timesteps = timesteps .to (original_samples .device )
736
736
737
- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
737
+ step_indices = []
738
+ for timestep in timesteps :
739
+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
740
+ if len (index_candidates ) == 0 :
741
+ step_index = len (schedule_timesteps ) - 1
742
+ elif len (index_candidates ) > 1 :
743
+ step_index = index_candidates [1 ].item ()
744
+ else :
745
+ step_index = index_candidates [0 ].item ()
746
+ step_indices .append (step_index )
738
747
739
748
sigma = sigmas [step_indices ].flatten ()
740
749
while len (sigma .shape ) < len (original_samples .shape ):
Original file line number Diff line number Diff line change @@ -896,7 +896,16 @@ def add_noise(
896
896
schedule_timesteps = self .timesteps .to (original_samples .device )
897
897
timesteps = timesteps .to (original_samples .device )
898
898
899
- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
899
+ step_indices = []
900
+ for timestep in timesteps :
901
+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
902
+ if len (index_candidates ) == 0 :
903
+ step_index = len (schedule_timesteps ) - 1
904
+ elif len (index_candidates ) > 1 :
905
+ step_index = index_candidates [1 ].item ()
906
+ else :
907
+ step_index = index_candidates [0 ].item ()
908
+ step_indices .append (step_index )
900
909
901
910
sigma = sigmas [step_indices ].flatten ()
902
911
while len (sigma .shape ) < len (original_samples .shape ):
Original file line number Diff line number Diff line change @@ -891,7 +891,16 @@ def add_noise(
891
891
schedule_timesteps = self .timesteps .to (original_samples .device )
892
892
timesteps = timesteps .to (original_samples .device )
893
893
894
- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
894
+ step_indices = []
895
+ for timestep in timesteps :
896
+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
897
+ if len (index_candidates ) == 0 :
898
+ step_index = len (schedule_timesteps ) - 1
899
+ elif len (index_candidates ) > 1 :
900
+ step_index = index_candidates [1 ].item ()
901
+ else :
902
+ step_index = index_candidates [0 ].item ()
903
+ step_indices .append (step_index )
895
904
896
905
sigma = sigmas [step_indices ].flatten ()
897
906
while len (sigma .shape ) < len (original_samples .shape ):
Original file line number Diff line number Diff line change @@ -897,7 +897,16 @@ def add_noise(
897
897
schedule_timesteps = self .timesteps .to (original_samples .device )
898
898
timesteps = timesteps .to (original_samples .device )
899
899
900
- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
900
+ step_indices = []
901
+ for timestep in timesteps :
902
+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
903
+ if len (index_candidates ) == 0 :
904
+ step_index = len (schedule_timesteps ) - 1
905
+ elif len (index_candidates ) > 1 :
906
+ step_index = index_candidates [1 ].item ()
907
+ else :
908
+ step_index = index_candidates [0 ].item ()
909
+ step_indices .append (step_index )
901
910
902
911
sigma = sigmas [step_indices ].flatten ()
903
912
while len (sigma .shape ) < len (original_samples .shape ):
Original file line number Diff line number Diff line change @@ -828,7 +828,16 @@ def add_noise(
828
828
schedule_timesteps = self .timesteps .to (original_samples .device )
829
829
timesteps = timesteps .to (original_samples .device )
830
830
831
- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
831
+ step_indices = []
832
+ for timestep in timesteps :
833
+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
834
+ if len (index_candidates ) == 0 :
835
+ step_index = len (schedule_timesteps ) - 1
836
+ elif len (index_candidates ) > 1 :
837
+ step_index = index_candidates [1 ].item ()
838
+ else :
839
+ step_index = index_candidates [0 ].item ()
840
+ step_indices .append (step_index )
832
841
833
842
sigma = sigmas [step_indices ].flatten ()
834
843
while len (sigma .shape ) < len (original_samples .shape ):
You can’t perform that action at this time.
0 commit comments