Skip to content
Merged
Prev Previous commit
Next Next commit
add test for multiple data files
  • Loading branch information
kevinjqliu committed Mar 9, 2024
commit ef64c9264cbd3b0ab269d5bcbea63e4ebf8395e4
27 changes: 27 additions & 0 deletions tests/integration/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,33 @@ def get_current_snapshot_id(identifier: str) -> int:
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_write_multiple_data_files(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
identifier = "default.write_multiple_arrow_data_files"
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [])

def get_data_files_count(identifier: str) -> int:
return spark.sql(
f"""
SELECT *
FROM {identifier}.all_data_files
"""
).count()

# writes to 1 data file since the table is small
tbl.overwrite(arrow_table_with_null)
assert get_data_files_count(identifier) == 1

# writes to 1 data file as long as table is smaller than default target file size
bigger_arrow_tbl = pa.concat_tables([arrow_table_with_null] * 10)
tbl.overwrite(bigger_arrow_tbl)
assert bigger_arrow_tbl.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
assert get_data_files_count(identifier) == 1


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
@pytest.mark.parametrize(
Expand Down