55# license information.
66# --------------------------------------------------------------------------
77# pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
8- # pyright: reportGeneralTypeIssues=false
98
109import calendar
10+ import decimal
1111import functools
1212import sys
1313import logging
1414import base64
1515import re
1616import copy
1717import typing
18- import email
18+ import enum
19+ import email .utils
1920from datetime import datetime , date , time , timedelta , timezone
2021from json import JSONEncoder
22+ from typing_extensions import Self
2123import isodate
2224from azure .core .exceptions import DeserializationError
2325from azure .core import CaseInsensitiveEnumMeta
3436__all__ = ["SdkJSONEncoder" , "Model" , "rest_field" , "rest_discriminator" ]
3537
3638TZ_UTC = timezone .utc
39+ _T = typing .TypeVar ("_T" )
3740
3841
3942def _timedelta_as_isostr (td : timedelta ) -> str :
@@ -144,6 +147,8 @@ def default(self, o): # pylint: disable=too-many-return-statements
144147 except TypeError :
145148 if isinstance (o , _Null ):
146149 return None
150+ if isinstance (o , decimal .Decimal ):
151+ return float (o )
147152 if isinstance (o , (bytes , bytearray )):
148153 return _serialize_bytes (o , self .format )
149154 try :
@@ -239,7 +244,7 @@ def _deserialize_date(attr: typing.Union[str, date]) -> date:
239244 # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
240245 if isinstance (attr , date ):
241246 return attr
242- return isodate .parse_date (attr , defaultmonth = None , defaultday = None )
247+ return isodate .parse_date (attr , defaultmonth = None , defaultday = None ) # type: ignore
243248
244249
245250def _deserialize_time (attr : typing .Union [str , time ]) -> time :
@@ -275,6 +280,12 @@ def _deserialize_duration(attr):
275280 return isodate .parse_duration (attr )
276281
277282
283+ def _deserialize_decimal (attr ):
284+ if isinstance (attr , decimal .Decimal ):
285+ return attr
286+ return decimal .Decimal (str (attr ))
287+
288+
278289_DESERIALIZE_MAPPING = {
279290 datetime : _deserialize_datetime ,
280291 date : _deserialize_date ,
@@ -283,6 +294,7 @@ def _deserialize_duration(attr):
283294 bytearray : _deserialize_bytes ,
284295 timedelta : _deserialize_duration ,
285296 typing .Any : lambda x : x ,
297+ decimal .Decimal : _deserialize_decimal ,
286298}
287299
288300_DESERIALIZE_MAPPING_WITHFORMAT = {
@@ -373,8 +385,12 @@ def get(self, key: str, default: typing.Any = None) -> typing.Any:
373385 except KeyError :
374386 return default
375387
376- @typing .overload # type: ignore
377- def pop (self , key : str ) -> typing .Any : # pylint: disable=no-member
388+ @typing .overload
389+ def pop (self , key : str ) -> typing .Any :
390+ ...
391+
392+ @typing .overload
393+ def pop (self , key : str , default : _T ) -> _T :
378394 ...
379395
380396 @typing .overload
@@ -395,8 +411,8 @@ def clear(self) -> None:
395411 def update (self , * args : typing .Any , ** kwargs : typing .Any ) -> None :
396412 self ._data .update (* args , ** kwargs )
397413
398- @typing .overload # type: ignore
399- def setdefault (self , key : str ) -> typing . Any :
414+ @typing .overload
415+ def setdefault (self , key : str , default : None = None ) -> None :
400416 ...
401417
402418 @typing .overload
@@ -434,6 +450,10 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m
434450 return tuple (_serialize (x , format ) for x in o )
435451 if isinstance (o , (bytes , bytearray )):
436452 return _serialize_bytes (o , format )
453+ if isinstance (o , decimal .Decimal ):
454+ return float (o )
455+ if isinstance (o , enum .Enum ):
456+ return o .value
437457 try :
438458 # First try datetime.datetime
439459 return _serialize_datetime (o , format )
@@ -458,7 +478,13 @@ def _get_rest_field(
458478
459479
460480def _create_value (rf : typing .Optional ["_RestField" ], value : typing .Any ) -> typing .Any :
461- return _deserialize (rf ._type , value ) if (rf and rf ._is_model ) else _serialize (value , rf ._format if rf else None )
481+ if not rf :
482+ return _serialize (value , None )
483+ if rf ._is_multipart_file_input :
484+ return value
485+ if rf ._is_model :
486+ return _deserialize (rf ._type , value )
487+ return _serialize (value , rf ._format )
462488
463489
464490class Model (_MyMutableMapping ):
@@ -494,7 +520,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
494520 def copy (self ) -> "Model" :
495521 return Model (self .__dict__ )
496522
497- def __new__ (cls , * args : typing .Any , ** kwargs : typing .Any ) -> "Model" : # pylint: disable=unused-argument
523+ def __new__ (cls , * args : typing .Any , ** kwargs : typing .Any ) -> Self : # pylint: disable=unused-argument
498524 # we know the last three classes in mro are going to be 'Model', 'dict', and 'object'
499525 mros = cls .__mro__ [:- 3 ][::- 1 ] # ignore model, dict, and object parents, and reverse the mro order
500526 attr_to_rest_field : typing .Dict [str , _RestField ] = { # map attribute name to rest_field property
@@ -536,7 +562,7 @@ def _deserialize(cls, data, exist_discriminators):
536562 exist_discriminators .append (discriminator )
537563 mapped_cls = cls .__mapping__ .get (
538564 data .get (discriminator ), cls
539- ) # pylint: disable=no-member
565+ ) # pyright: ignore # pylint: disable=no-member
540566 if mapped_cls == cls :
541567 return cls (data )
542568 return mapped_cls ._deserialize (data , exist_discriminators ) # pylint: disable=protected-access
@@ -553,20 +579,25 @@ def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.
553579 if exclude_readonly :
554580 readonly_props = [p ._rest_name for p in self ._attr_to_rest_field .values () if _is_readonly (p )]
555581 for k , v in self .items ():
556- if exclude_readonly and k in readonly_props : # pyright: ignore[reportUnboundVariable]
582+ if exclude_readonly and k in readonly_props : # pyright: ignore
557583 continue
558- result [k ] = Model ._as_dict_value (v , exclude_readonly = exclude_readonly )
584+ is_multipart_file_input = False
585+ try :
586+ is_multipart_file_input = next (rf for rf in self ._attr_to_rest_field .values () if rf ._rest_name == k )._is_multipart_file_input
587+ except StopIteration :
588+ pass
589+ result [k ] = v if is_multipart_file_input else Model ._as_dict_value (v , exclude_readonly = exclude_readonly )
559590 return result
560591
561592 @staticmethod
562593 def _as_dict_value (v : typing .Any , exclude_readonly : bool = False ) -> typing .Any :
563594 if v is None or isinstance (v , _Null ):
564595 return None
565596 if isinstance (v , (list , tuple , set )):
566- return [
597+ return type ( v )(
567598 Model ._as_dict_value (x , exclude_readonly = exclude_readonly )
568599 for x in v
569- ]
600+ )
570601 if isinstance (v , dict ):
571602 return {
572603 dk : Model ._as_dict_value (dv , exclude_readonly = exclude_readonly )
@@ -607,29 +638,22 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj
607638 return obj
608639 return _deserialize (model_deserializer , obj )
609640
610- return functools .partial (_deserialize_model , annotation )
641+ return functools .partial (_deserialize_model , annotation ) # pyright: ignore
611642 except Exception :
612643 pass
613644
614645 # is it a literal?
615646 try :
616- if sys .version_info >= (3 , 8 ):
617- from typing import (
618- Literal ,
619- ) # pylint: disable=no-name-in-module, ungrouped-imports
620- else :
621- from typing_extensions import Literal # type: ignore # pylint: disable=ungrouped-imports
622-
623- if annotation .__origin__ == Literal :
647+ if annotation .__origin__ is typing .Literal : # pyright: ignore
624648 return None
625649 except AttributeError :
626650 pass
627651
628652 # is it optional?
629653 try :
630- if any (a for a in annotation .__args__ if a == type (None )):
654+ if any (a for a in annotation .__args__ if a == type (None )): # pyright: ignore
631655 if_obj_deserializer = _get_deserialize_callable_from_annotation (
632- next (a for a in annotation .__args__ if a != type (None )), module , rf
656+ next (a for a in annotation .__args__ if a != type (None )), module , rf # pyright: ignore
633657 )
634658
635659 def _deserialize_with_optional (if_obj_deserializer : typing .Optional [typing .Callable ], obj ):
@@ -642,7 +666,13 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla
642666 pass
643667
644668 if getattr (annotation , "__origin__" , None ) is typing .Union :
645- deserializers = [_get_deserialize_callable_from_annotation (arg , module , rf ) for arg in annotation .__args__ ]
669+ # initial ordering is we make `string` the last deserialization option, because it is often them most generic
670+ deserializers = [
671+ _get_deserialize_callable_from_annotation (arg , module , rf )
672+ for arg in sorted (
673+ annotation .__args__ , key = lambda x : hasattr (x , "__name__" ) and x .__name__ == "str" # pyright: ignore
674+ )
675+ ]
646676
647677 def _deserialize_with_union (deserializers , obj ):
648678 for deserializer in deserializers :
@@ -655,32 +685,31 @@ def _deserialize_with_union(deserializers, obj):
655685 return functools .partial (_deserialize_with_union , deserializers )
656686
657687 try :
658- if annotation ._name == "Dict" :
659- key_deserializer = _get_deserialize_callable_from_annotation (annotation .__args__ [0 ], module , rf )
660- value_deserializer = _get_deserialize_callable_from_annotation (annotation .__args__ [1 ], module , rf )
688+ if annotation ._name == "Dict" : # pyright: ignore
689+ value_deserializer = _get_deserialize_callable_from_annotation (
690+ annotation .__args__ [1 ], module , rf # pyright: ignore
691+ )
661692
662693 def _deserialize_dict (
663- key_deserializer : typing .Optional [typing .Callable ],
664694 value_deserializer : typing .Optional [typing .Callable ],
665695 obj : typing .Dict [typing .Any , typing .Any ],
666696 ):
667697 if obj is None :
668698 return obj
669699 return {
670- _deserialize ( key_deserializer , k , module ) : _deserialize (value_deserializer , v , module )
700+ k : _deserialize (value_deserializer , v , module )
671701 for k , v in obj .items ()
672702 }
673703
674704 return functools .partial (
675705 _deserialize_dict ,
676- key_deserializer ,
677706 value_deserializer ,
678707 )
679708 except (AttributeError , IndexError ):
680709 pass
681710 try :
682- if annotation ._name in ["List" , "Set" , "Tuple" , "Sequence" ]:
683- if len (annotation .__args__ ) > 1 :
711+ if annotation ._name in ["List" , "Set" , "Tuple" , "Sequence" ]: # pyright: ignore
712+ if len (annotation .__args__ ) > 1 : # pyright: ignore
684713
685714 def _deserialize_multiple_sequence (
686715 entry_deserializers : typing .List [typing .Optional [typing .Callable ]],
@@ -694,10 +723,12 @@ def _deserialize_multiple_sequence(
694723 )
695724
696725 entry_deserializers = [
697- _get_deserialize_callable_from_annotation (dt , module , rf ) for dt in annotation .__args__
726+ _get_deserialize_callable_from_annotation (dt , module , rf ) for dt in annotation .__args__ # pyright: ignore
698727 ]
699728 return functools .partial (_deserialize_multiple_sequence , entry_deserializers )
700- deserializer = _get_deserialize_callable_from_annotation (annotation .__args__ [0 ], module , rf )
729+ deserializer = _get_deserialize_callable_from_annotation (
730+ annotation .__args__ [0 ], module , rf # pyright: ignore
731+ )
701732
702733 def _deserialize_sequence (
703734 deserializer : typing .Optional [typing .Callable ],
@@ -712,27 +743,29 @@ def _deserialize_sequence(
712743 pass
713744
714745 def _deserialize_default (
715- annotation ,
716- deserializer_from_mapping ,
746+ deserializer ,
717747 obj ,
718748 ):
719749 if obj is None :
720750 return obj
721751 try :
722- return _deserialize_with_callable (annotation , obj )
752+ return _deserialize_with_callable (deserializer , obj )
723753 except Exception :
724754 pass
725- return _deserialize_with_callable ( deserializer_from_mapping , obj )
755+ return obj
726756
727- return functools .partial (_deserialize_default , annotation , get_deserializer (annotation , rf ))
757+ if get_deserializer (annotation , rf ):
758+ return functools .partial (_deserialize_default , get_deserializer (annotation , rf ))
759+
760+ return functools .partial (_deserialize_default , annotation )
728761
729762
730763def _deserialize_with_callable (
731764 deserializer : typing .Optional [typing .Callable [[typing .Any ], typing .Any ]],
732765 value : typing .Any ,
733766):
734767 try :
735- if value is None :
768+ if value is None or isinstance ( value , _Null ) :
736769 return None
737770 if deserializer is None :
738771 return value
@@ -760,7 +793,8 @@ def _deserialize(
760793 value = value .http_response .json ()
761794 if rf is None and format :
762795 rf = _RestField (format = format )
763- deserializer = _get_deserialize_callable_from_annotation (deserializer , module , rf )
796+ if not isinstance (deserializer , functools .partial ):
797+ deserializer = _get_deserialize_callable_from_annotation (deserializer , module , rf )
764798 return _deserialize_with_callable (deserializer , value )
765799
766800
@@ -774,6 +808,7 @@ def __init__(
774808 visibility : typing .Optional [typing .List [str ]] = None ,
775809 default : typing .Any = _UNSET ,
776810 format : typing .Optional [str ] = None ,
811+ is_multipart_file_input : bool = False ,
777812 ):
778813 self ._type = type
779814 self ._rest_name_input = name
@@ -783,6 +818,11 @@ def __init__(
783818 self ._is_model = False
784819 self ._default = default
785820 self ._format = format
821+ self ._is_multipart_file_input = is_multipart_file_input
822+
823+ @property
824+ def _class_type (self ) -> typing .Any :
825+ return getattr (self ._type , "args" , [None ])[0 ]
786826
787827 @property
788828 def _rest_name (self ) -> str :
@@ -828,8 +868,9 @@ def rest_field(
828868 visibility : typing .Optional [typing .List [str ]] = None ,
829869 default : typing .Any = _UNSET ,
830870 format : typing .Optional [str ] = None ,
871+ is_multipart_file_input : bool = False ,
831872) -> typing .Any :
832- return _RestField (name = name , type = type , visibility = visibility , default = default , format = format )
873+ return _RestField (name = name , type = type , visibility = visibility , default = default , format = format , is_multipart_file_input = is_multipart_file_input )
833874
834875
835876def rest_discriminator (
0 commit comments