Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix tests for composed model and repr
  • Loading branch information
kristapratico committed Sep 25, 2020
commit dc7fe098e3774c0f8360a9585a8960645042aa1a
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def __init__(self, **kwargs):
self.form_type_confidence = kwargs.get('form_type_confidence', None)

def __repr__(self):
return "RecognizedForm(form_type={}, fields={}, page_range={}, pages={}, form_type_confidence={}," \
return "RecognizedForm(form_type={}, fields={}, page_range={}, pages={}, form_type_confidence={}, " \
"model_id={})" \
.format(
self.form_type,
Expand Down Expand Up @@ -717,12 +717,12 @@ def _from_generated_composed(cls, model):
]

def __repr__(self):
return "CustomFormSubmodel(accuracy={}, fields={}, form_type={}, model_id={})" \
return "CustomFormSubmodel(accuracy={}, model_id={}, fields={}, form_type={})" \
.format(
self.accuracy,
self.model_id,
repr(self.fields),
self.form_type,
self.model_id
)[:1024]


Expand Down Expand Up @@ -879,13 +879,19 @@ def __init__(self, **kwargs):
def _from_generated(cls, model, model_id=None):
if model.status == "succeeded": # map copy status to model status
model.status = "ready"

properties, display_name = None, None
if hasattr(model, "attributes"):
properties = CustomFormModelProperties._from_generated(model)
if hasattr(model, "model_name"):
display_name = model.model_name
return cls(
model_id=model_id if model_id else model.model_id,
status=model.status,
training_started_on=model.created_date_time,
training_completed_on=model.last_updated_date_time,
properties=CustomFormModelProperties._from_generated(model),
display_name=model.model_name
properties=properties,
display_name=display_name
)

def __repr__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_custom_form_labeled(self, client, container_sas_url):
poller = fr_client.begin_recognize_custom_forms(model.model_id, myfile, content_type=FormContentType.IMAGE_JPEG)
form = poller.result()

self.assertEqual(form[0].form_type, "form-"+model.model_id)
self.assertEqual(form[0].form_type, "custom:"+model.model_id)
self.assertFormPagesHasValues(form[0].pages)
for label, field in form[0].fields.items():
self.assertIsNotNone(field.confidence)
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_custom_form_multipage_labeled(self, client, container_sas_url):
forms = poller.result()

for form in forms:
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertFormPagesHasValues(form.pages)
for label, field in form.fields.items():
self.assertIsNotNone(field.confidence)
Expand Down Expand Up @@ -377,7 +377,7 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)

@GlobalFormRecognizerAccountPreparer()
Expand Down Expand Up @@ -441,7 +441,7 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)

@GlobalFormRecognizerAccountPreparer()
Expand Down Expand Up @@ -480,5 +480,5 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ async def test_custom_form_labeled(self, client, container_sas_url):
poller = await fr_client.begin_recognize_custom_forms(model.model_id, myfile, content_type=FormContentType.IMAGE_JPEG)
form = await poller.result()

self.assertEqual(form[0].form_type, "form-"+model.model_id)
self.assertEqual(form[0].form_type, "custom:"+model.model_id)
self.assertFormPagesHasValues(form[0].pages)
for label, field in form[0].fields.items():
self.assertIsNotNone(field.confidence)
Expand Down Expand Up @@ -242,7 +242,7 @@ async def test_custom_form_multipage_labeled(self, client, container_sas_url):
forms = await poller.result()

for form in forms:
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertFormPagesHasValues(form.pages)
for label, field in form.fields.items():
self.assertIsNotNone(field.confidence)
Expand Down Expand Up @@ -411,7 +411,7 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)

@GlobalFormRecognizerAccountPreparer()
Expand Down Expand Up @@ -480,7 +480,7 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)

@GlobalFormRecognizerAccountPreparer()
Expand Down Expand Up @@ -521,5 +521,5 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_custom_form_labeled(self, client, container_sas_url):
poller = fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg)
form = poller.result()

self.assertEqual(form[0].form_type, "form-"+model.model_id)
self.assertEqual(form[0].form_type, "custom:"+model.model_id)
self.assertFormPagesHasValues(form[0].pages)
for label, field in form[0].fields.items():
self.assertIsNotNone(field.confidence)
Expand All @@ -173,7 +173,7 @@ def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_url):
forms = poller.result()

for form in forms:
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertFormPagesHasValues(form.pages)
for label, field in form.fields.items():
self.assertIsNotNone(field.confidence)
Expand Down Expand Up @@ -319,7 +319,7 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)

@GlobalFormRecognizerAccountPreparer()
Expand Down Expand Up @@ -379,7 +379,7 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)

@GlobalFormRecognizerAccountPreparer()
Expand Down Expand Up @@ -415,5 +415,5 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async def test_form_labeled(self, client, container_sas_url):
poller = await fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg)
form = await poller.result()

self.assertEqual(form[0].form_type, "form-"+model.model_id)
self.assertEqual(form[0].form_type, "custom:"+model.model_id)
self.assertFormPagesHasValues(form[0].pages)
for label, field in form[0].fields.items():
self.assertIsNotNone(field.confidence)
Expand Down Expand Up @@ -196,7 +196,7 @@ async def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_
forms = await poller.result()

for form in forms:
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertFormPagesHasValues(form.pages)
for label, field in form.fields.items():
self.assertIsNotNone(field.confidence)
Expand Down Expand Up @@ -351,7 +351,7 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)

@GlobalFormRecognizerAccountPreparer()
Expand Down Expand Up @@ -414,7 +414,7 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)

@GlobalFormRecognizerAccountPreparer()
Expand Down Expand Up @@ -452,5 +452,5 @@ def callback(raw_response, _, headers):
for form, actual in zip(recognized_form, document_results):
self.assertEqual(form.page_range.first_page_number, actual.page_range[0])
self.assertEqual(form.page_range.last_page_number, actual.page_range[1])
self.assertEqual(form.form_type, "form-"+model.model_id)
self.assertEqual(form.form_type, "custom:"+model.model_id)
self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results)
38 changes: 25 additions & 13 deletions sdk/formrecognizer/azure-ai-formrecognizer/tests/test_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def custom_form_model_field():

@pytest.fixture
def custom_form_sub_model(custom_form_model_field):
model = _models.CustomFormSubmodel(accuracy=0.99, fields={"name": custom_form_model_field[0]}, form_type="Itemized")
model_repr = "CustomFormSubmodel(accuracy=0.99, fields={{'name': {}}}, form_type=Itemized)".format(custom_form_model_field[1])[:1024]
model = _models.CustomFormSubmodel(accuracy=0.99, model_id=1, fields={"name": custom_form_model_field[0]}, form_type="Itemized")
model_repr = "CustomFormSubmodel(accuracy=0.99, model_id=1, fields={{'name': {}}}, form_type=Itemized)".format(custom_form_model_field[1])[:1024]
assert repr(model) == model_repr
return model, model_repr

Expand All @@ -118,8 +118,15 @@ def form_recognizer_error():

@pytest.fixture
def training_document_info(form_recognizer_error):
model = _models.TrainingDocumentInfo(name="name", status=_models.TrainingStatus.PARTIALLY_SUCCEEDED, page_count=5, errors=[form_recognizer_error[0]])
model_repr = "TrainingDocumentInfo(name=name, status=partiallySucceeded, page_count=5, errors=[{}])".format(form_recognizer_error[1])[:1024]
model = _models.TrainingDocumentInfo(name="name", status=_models.TrainingStatus.PARTIALLY_SUCCEEDED, page_count=5, errors=[form_recognizer_error[0]], model_id=1)
model_repr = "TrainingDocumentInfo(name=name, status=partiallySucceeded, page_count=5, errors=[{}], model_id=1)".format(form_recognizer_error[1])[:1024]
assert repr(model) == model_repr
return model, model_repr

@pytest.fixture
def custom_form_model_properties():
model = _models.CustomFormModelProperties(is_composed_model=True)
model_repr = "CustomFormModelProperties(is_composed_model=True)"
assert repr(model) == model_repr
return model, model_repr

Expand All @@ -128,35 +135,40 @@ class TestRepr():
# Not inheriting form FormRecognizerTest because that doesn't allow me to define pytest fixtures in the same file
# Not worth moving pytest fixture definitions to conftest since all I would use is assertEqual and I can just use assert
def test_recognized_form(self, form_field_one, page_range, form_page):
model = _models.RecognizedForm(form_type="receipt", fields={"one": form_field_one[0]}, page_range=page_range[0], pages=[form_page[0]])
model_repr = "RecognizedForm(form_type=receipt, fields={{'one': {}}}, page_range={}, pages=[{}])".format(
model = _models.RecognizedForm(form_type="receipt", form_type_confidence=1.0, model_id=1, fields={"one": form_field_one[0]}, page_range=page_range[0], pages=[form_page[0]])
model_repr = "RecognizedForm(form_type=receipt, fields={{'one': {}}}, page_range={}, pages=[{}], form_type_confidence=1.0, model_id=1)".format(
form_field_one[1], page_range[1], form_page[1]
)[:1024]
assert repr(model) == model_repr

def test_custom_form_model(self, custom_form_sub_model, form_recognizer_error, training_document_info):
def test_custom_form_model(self, custom_form_sub_model, custom_form_model_properties, form_recognizer_error, training_document_info):
model = _models.CustomFormModel(
model_id=1,
status=_models.CustomFormModelStatus.CREATING,
training_started_on=datetime.datetime(1, 1, 1),
training_completed_on=datetime.datetime(1, 1, 1),
submodels=[custom_form_sub_model[0], custom_form_sub_model[0]],
errors=[form_recognizer_error[0]],
training_documents=[training_document_info[0], training_document_info[0]]
training_documents=[training_document_info[0], training_document_info[0]],
properties=custom_form_model_properties[0],
display_name="my model"
)

model_repr = "CustomFormModel(model_id=1, status=creating, training_started_on=0001-01-01 00:00:00, " \
"training_completed_on=0001-01-01 00:00:00, submodels=[{}, {}], errors=[{}], training_documents=[{}, {}])".format(
custom_form_sub_model[1], custom_form_sub_model[1], form_recognizer_error[1], training_document_info[1], training_document_info[1]
"training_completed_on=0001-01-01 00:00:00, submodels=[{}, {}], errors=[{}], training_documents=[{}, {}], " \
"display_name=my model, properties={})".format(
custom_form_sub_model[1], custom_form_sub_model[1], form_recognizer_error[1], training_document_info[1], training_document_info[1],
custom_form_model_properties[1]
)[:1024]

assert repr(model) == model_repr

def test_custom_form_model_info(self):
def test_custom_form_model_info(self, custom_form_model_properties):
model = _models.CustomFormModelInfo(
model_id=1, status=_models.CustomFormModelStatus.READY, training_started_on=datetime.datetime(1, 1, 1), training_completed_on=datetime.datetime(1, 1, 1)
model_id=1, status=_models.CustomFormModelStatus.READY, training_started_on=datetime.datetime(1, 1, 1), training_completed_on=datetime.datetime(1, 1, 1),
properties=custom_form_model_properties[0], display_name="my model"
)
model_repr = "CustomFormModelInfo(model_id=1, status=ready, training_started_on=0001-01-01 00:00:00, training_completed_on=0001-01-01 00:00:00)"[:1024]
model_repr = "CustomFormModelInfo(model_id=1, status=ready, training_started_on=0001-01-01 00:00:00, training_completed_on=0001-01-01 00:00:00, properties={}, display_name=my model)".format(custom_form_model_properties[1])[:1024]
assert repr(model) == model_repr

def test_account_properties(self):
Expand Down