Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
set form_type on CustomFormSubmodel correctly
  • Loading branch information
kristapratico committed Sep 30, 2020
commit 58cbecb8207bc8fc2a89c9e0b26620d9b4a498b5
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ def begin_training(self, training_files_url, use_training_labels, **kwargs):

def callback_v2_0(raw_response):
model = self._deserialize(self._generated_models.Model, raw_response)
return CustomFormModel._from_generated(model)
return CustomFormModel._from_generated(model, api_version=self.api_version)

def callback_v2_1(raw_response, _, headers): # pylint: disable=unused-argument
model = self._deserialize(self._generated_models.Model, raw_response)
return CustomFormModel._from_generated(model)
return CustomFormModel._from_generated(model, api_version=self.api_version)

cls = kwargs.pop("cls", None)
display_name = kwargs.pop("display_name", None)
Expand Down Expand Up @@ -309,7 +309,7 @@ def get_custom_model(self, model_id, **kwargs):
raise ValueError("model_id cannot be None or empty.")

response = self._client.get_custom_model(model_id=model_id, include_keys=True, error_map=error_map, **kwargs)
return CustomFormModel._from_generated(response)
return CustomFormModel._from_generated(response, api_version=self.api_version)

@distributed_trace
def get_copy_authorization(self, resource_id, resource_region, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,14 +617,16 @@ def __init__(self, **kwargs):
self.properties = kwargs.get("properties", None)

@classmethod
def _from_generated(cls, model):
def _from_generated(cls, model, api_version):
return cls(
model_id=model.model_info.model_id,
status=model.model_info.status,
training_started_on=model.model_info.created_date_time,
training_completed_on=model.model_info.last_updated_date_time,
submodels=CustomFormSubmodel._from_generated_unlabeled(model)
if model.keys else CustomFormSubmodel._from_generated_labeled(model),
if model.keys else CustomFormSubmodel._from_generated_labeled(
model, api_version, display_name=model.model_info.model_name
),
errors=FormRecognizerError._from_generated(model.train_result.errors)
if model.train_result else None,
training_documents=TrainingDocumentInfo._from_generated(model.train_result)
Expand Down Expand Up @@ -693,14 +695,21 @@ def _from_generated_unlabeled(cls, model):
]

@classmethod
def _from_generated_labeled(cls, model):
def _from_generated_labeled(cls, model, api_version, display_name):
if api_version == "2.0":
form_type = "form-" + model.model_info.model_id
elif display_name:
form_type = "custom:" + display_name
else:
form_type = "custom:" + model.model_info.model_id

return [
cls(
model_id=model.model_info.model_id,
accuracy=model.train_result.average_model_accuracy,
fields={field.field_name: CustomFormModelField._from_generated_labeled(field)
for field in model.train_result.fields} if model.train_result.fields else None,
form_type="form-" + model.model_info.model_id # FIXME: should match form_type in RecognizedForm
form_type=form_type
)
] if model.train_result else None

Expand All @@ -711,7 +720,7 @@ def _from_generated_composed(cls, model):
accuracy=train_result.average_model_accuracy,
fields={field.field_name: CustomFormModelField._from_generated_labeled(field)
for field in train_result.fields} if train_result.fields else None,
form_type="form-" + train_result.model_id, # FIXME: should match form_type in RecognizedForm
form_type="custom:" + train_result.model_id,
model_id=train_result.model_id
) for train_result in model.composed_train_results
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ async def begin_training(

def callback_v2_0(raw_response):
model = self._deserialize(self._generated_models.Model, raw_response)
return CustomFormModel._from_generated(model)
return CustomFormModel._from_generated(model, api_version=self.api_version)

def callback_v2_1(raw_response, _, headers): # pylint: disable=unused-argument
model = self._deserialize(self._generated_models.Model, raw_response)
return CustomFormModel._from_generated(model)
return CustomFormModel._from_generated(model, api_version=self.api_version)

cls = kwargs.pop("cls", None)
display_name = kwargs.pop("display_name", None)
Expand Down Expand Up @@ -323,7 +323,7 @@ async def get_custom_model(self, model_id: str, **kwargs: Any) -> CustomFormMode
error_map=error_map,
**kwargs
)
return CustomFormModel._from_generated(response)
return CustomFormModel._from_generated(response, api_version=self.api_version)

@distributed_trace_async
async def get_copy_authorization(
Expand Down