@@ -312,7 +312,7 @@ def build_weighted_embedding_tensor(self,
312312 token_ids : torch .Tensor ,
313313 per_token_weights : torch .Tensor ,
314314 attention_mask : Optional [torch .Tensor ] = None ,
315- return_pooled : bool = False ,
315+ should_return_pooled : bool = False ,
316316 device : Optional [str ] = None ) -> torch .Tensor :
317317 """
318318 :param token_ids: A tensor of shape `n*[self.max_length]` containing token IDs (ints) where n is some arbitrary
@@ -373,7 +373,7 @@ def build_weighted_embedding_tensor(self,
373373
374374 chunk_start_index += chunk_size
375375
376- if self . requires_pooled :
376+ if should_return_pooled :
377377 return weighted_z , pooled
378378
379379 return weighted_z
@@ -473,15 +473,15 @@ def __init__(self,
473473 textual_inversion_manager : BaseTextualInversionManager = None ,
474474 dtype_for_device_getter : Callable [[torch .device ], torch .dtype ] = lambda device : torch .float32 ,
475475 hidden_states_types : Union [str , List [str ]] = "final" ,
476- return_pooled : Union [str , List [bool ]] = False ,
476+ requires_pooled : Union [str , List [bool ]] = False ,
477477 ):
478478
479479 hidden_states_types = len (text_encoders ) * [hidden_states_types ] if not isinstance (hidden_states_types , (list , tuple )) else hidden_states_types
480- return_pooled = len (text_encoders ) * [return_pooled ] if not isinstance (return_pooled , (list , tuple )) else return_pooled
480+ requires_pooled = len (text_encoders ) * [requires_pooled ] if not isinstance (requires_pooled , (list , tuple )) else requires_pooled
481481
482482 self .embedding_providers = [
483483 EmbeddingsProvider (tokenizer , text_encoder , textual_inversion_manager , dtype_for_device_getter , hidden_states_type , pooled )
484- for tokenizer , text_encoder , hidden_states_type , pooled in zip (tokenizers , text_encoders , hidden_states_types , return_pooled )
484+ for tokenizer , text_encoder , hidden_states_type , pooled in zip (tokenizers , text_encoders , hidden_states_types , requires_pooled )
485485 ]
486486
487487 @property
0 commit comments