@@ -648,16 +648,13 @@ def custom_forward(*inputs):
648648 return custom_forward
649649
650650 ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
651- hidden_states = torch .utils .checkpoint .checkpoint (
652- create_custom_forward (attn , return_dict = False ),
651+ hidden_states = attn (
653652 hidden_states ,
654- encoder_hidden_states ,
655- None , # timestep
656- None , # class_labels
657- cross_attention_kwargs ,
658- attention_mask ,
659- encoder_attention_mask ,
660- ** ckpt_kwargs ,
653+ encoder_hidden_states = encoder_hidden_states ,
654+ cross_attention_kwargs = cross_attention_kwargs ,
655+ attention_mask = attention_mask ,
656+ encoder_attention_mask = encoder_attention_mask ,
657+ return_dict = False ,
661658 )[0 ]
662659 hidden_states = torch .utils .checkpoint .checkpoint (
663660 create_custom_forward (resnet ),
@@ -1035,16 +1032,13 @@ def custom_forward(*inputs):
10351032 temb ,
10361033 ** ckpt_kwargs ,
10371034 )
1038- hidden_states = torch .utils .checkpoint .checkpoint (
1039- create_custom_forward (attn , return_dict = False ),
1035+ hidden_states = attn (
10401036 hidden_states ,
1041- encoder_hidden_states ,
1042- None , # timestep
1043- None , # class_labels
1044- cross_attention_kwargs ,
1045- attention_mask ,
1046- encoder_attention_mask ,
1047- ** ckpt_kwargs ,
1037+ encoder_hidden_states = encoder_hidden_states ,
1038+ cross_attention_kwargs = cross_attention_kwargs ,
1039+ attention_mask = attention_mask ,
1040+ encoder_attention_mask = encoder_attention_mask ,
1041+ return_dict = False ,
10481042 )[0 ]
10491043 else :
10501044 hidden_states = resnet (hidden_states , temb )
@@ -1711,13 +1705,12 @@ def custom_forward(*inputs):
17111705 return custom_forward
17121706
17131707 hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (resnet ), hidden_states , temb )
1714- hidden_states = torch .utils .checkpoint .checkpoint (
1715- create_custom_forward (attn , return_dict = False ),
1708+ hidden_states = attn (
17161709 hidden_states ,
1717- encoder_hidden_states ,
1718- mask ,
1719- cross_attention_kwargs ,
1720- )[ 0 ]
1710+ encoder_hidden_states = encoder_hidden_states ,
1711+ attention_mask = mask ,
1712+ ** cross_attention_kwargs ,
1713+ )
17211714 else :
17221715 hidden_states = resnet (hidden_states , temb )
17231716
@@ -1912,15 +1905,13 @@ def custom_forward(*inputs):
19121905 temb ,
19131906 ** ckpt_kwargs ,
19141907 )
1915- hidden_states = torch .utils .checkpoint .checkpoint (
1916- create_custom_forward (attn , return_dict = False ),
1908+ hidden_states = attn (
19171909 hidden_states ,
1918- encoder_hidden_states ,
1919- temb ,
1920- attention_mask ,
1921- cross_attention_kwargs ,
1922- encoder_attention_mask ,
1923- ** ckpt_kwargs ,
1910+ encoder_hidden_states = encoder_hidden_states ,
1911+ emb = temb ,
1912+ attention_mask = attention_mask ,
1913+ cross_attention_kwargs = cross_attention_kwargs ,
1914+ encoder_attention_mask = encoder_attention_mask ,
19241915 )
19251916 else :
19261917 hidden_states = resnet (hidden_states , temb )
@@ -2173,16 +2164,13 @@ def custom_forward(*inputs):
21732164 temb ,
21742165 ** ckpt_kwargs ,
21752166 )
2176- hidden_states = torch .utils .checkpoint .checkpoint (
2177- create_custom_forward (attn , return_dict = False ),
2167+ hidden_states = attn (
21782168 hidden_states ,
2179- encoder_hidden_states ,
2180- None , # timestep
2181- None , # class_labels
2182- cross_attention_kwargs ,
2183- attention_mask ,
2184- encoder_attention_mask ,
2185- ** ckpt_kwargs ,
2169+ encoder_hidden_states = encoder_hidden_states ,
2170+ cross_attention_kwargs = cross_attention_kwargs ,
2171+ attention_mask = attention_mask ,
2172+ encoder_attention_mask = encoder_attention_mask ,
2173+ return_dict = False ,
21862174 )[0 ]
21872175 else :
21882176 hidden_states = resnet (hidden_states , temb )
@@ -2872,13 +2860,12 @@ def custom_forward(*inputs):
28722860 return custom_forward
28732861
28742862 hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (resnet ), hidden_states , temb )
2875- hidden_states = torch .utils .checkpoint .checkpoint (
2876- create_custom_forward (attn , return_dict = False ),
2863+ hidden_states = attn (
28772864 hidden_states ,
2878- encoder_hidden_states ,
2879- mask ,
2880- cross_attention_kwargs ,
2881- )[ 0 ]
2865+ encoder_hidden_states = encoder_hidden_states ,
2866+ attention_mask = mask ,
2867+ ** cross_attention_kwargs ,
2868+ )
28822869 else :
28832870 hidden_states = resnet (hidden_states , temb )
28842871
@@ -3094,16 +3081,14 @@ def custom_forward(*inputs):
30943081 temb ,
30953082 ** ckpt_kwargs ,
30963083 )
3097- hidden_states = torch .utils .checkpoint .checkpoint (
3098- create_custom_forward (attn , return_dict = False ),
3084+ hidden_states = attn (
30993085 hidden_states ,
3100- encoder_hidden_states ,
3101- temb ,
3102- attention_mask ,
3103- cross_attention_kwargs ,
3104- encoder_attention_mask ,
3105- ** ckpt_kwargs ,
3106- )[0 ]
3086+ encoder_hidden_states = encoder_hidden_states ,
3087+ emb = temb ,
3088+ attention_mask = attention_mask ,
3089+ cross_attention_kwargs = cross_attention_kwargs ,
3090+ encoder_attention_mask = encoder_attention_mask ,
3091+ )
31073092 else :
31083093 hidden_states = resnet (hidden_states , temb )
31093094 hidden_states = attn (
0 commit comments