diff --git a/setup.py b/setup.py index 62b1d798..3f55499e 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,9 @@ from setuptools import setup, find_packages setup( - name='target-postgres', + name='target-sql', version="0.0.1", - description='Singer.io target for loading data into postgres', + description='Singer.io targets for loading data into SQL databases', classifiers=['Programming Language :: Python :: 3 :: Only'], py_modules=['target_postgres'], install_requires=[ @@ -16,7 +16,8 @@ ], entry_points=''' [console_scripts] - target-postgres=target_postgres:main + target-postgres=target_sql:target_postgres_main + target-redshift=target_sql:target_redshift_main ''', packages=find_packages() ) diff --git a/target_postgres/__init__.py b/target_sql/__init__.py similarity index 70% rename from target_postgres/__init__.py rename to target_sql/__init__.py index 27ab218b..cf87d914 100644 --- a/target_postgres/__init__.py +++ b/target_sql/__init__.py @@ -9,13 +9,14 @@ from singer import utils, metadata, metrics import psycopg2 -from target_postgres.postgres import PostgresTarget -from target_postgres.singer_stream import BufferedSingerStream +from target_sql.target_postgres import TargetPostgres +from target_sql.target_redshift import TargetRedshift +from target_sql.singer_stream import BufferedSingerStream LOGGER = singer.get_logger() REQUIRED_CONFIG_KEYS = [ - 'postgres_database' + 'target_connection' ] def flush_stream(target, stream_buffer): @@ -87,47 +88,40 @@ def line_handler(streams, target, max_batch_rows, max_batch_size, line): line_data['type'], line)) -def main(config, input_stream=None): +def target_sql(target_class, config, input_stream=None): try: - connection = psycopg2.connect( - host=config.get('postgres_host', 'localhost'), - port=config.get('postgres_port', 5432), - dbname=config.get('postgres_database'), - user=config.get('postgres_username'), - password=config.get('postgres_password')) - - streams = {} - postgres_target = PostgresTarget( - connection, - LOGGER, - postgres_schema=config.get('postgres_schema', 'public')) - - max_batch_rows = config.get('max_batch_rows') - max_batch_size = config.get('max_batch_size') - batch_detection_threshold = config.get('batch_detection_threshold', 5000) - - if not input_stream: - input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') - - line_count = 0 - for line in input_stream: - line_handler(streams, postgres_target, max_batch_rows, max_batch_size, line) - if line_count > 0 and line_count % batch_detection_threshold == 0: - flush_streams(streams, postgres_target) - line_count += 1 - - flush_streams(streams, postgres_target, force=True) - - connection.close() + with target_class(config, LOGGER) as target: + max_batch_rows = config.get('max_batch_rows') + max_batch_size = config.get('max_batch_size') + batch_detection_threshold = config.get('batch_detection_threshold', 5000) + + if not input_stream: + input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') + + line_count = 0 + streams = {} + for line in input_stream: + line_handler(streams, target, max_batch_rows, max_batch_size, line) + if line_count > 0 and line_count % batch_detection_threshold == 0: + flush_streams(streams, target) + line_count += 1 + + flush_streams(streams, target, force=True) except Exception as e: LOGGER.critical(e) raise e -if __name__ == "__main__": +def main(target_class): try: args = utils.parse_args(REQUIRED_CONFIG_KEYS) - main(args.config) + target_sql(target_class, args.config) except Exception as e: LOGGER.critical(e) raise e + +def target_postgres_main(): + main(TargetPostgres) + +def target_redshift_main(): + main(TargetRedshift) diff --git a/target_postgres/pysize.py b/target_sql/pysize.py similarity index 100% rename from target_postgres/pysize.py rename to target_sql/pysize.py diff --git a/target_postgres/singer_stream.py b/target_sql/singer_stream.py similarity index 97% rename from target_postgres/singer_stream.py rename to target_sql/singer_stream.py index c5d156c3..bb74cb87 100644 --- a/target_postgres/singer_stream.py +++ b/target_sql/singer_stream.py @@ -1,7 +1,7 @@ from jsonschema import Draft4Validator, FormatChecker import arrow -from target_postgres.pysize import get_size +from target_sql.pysize import get_size SINGER_RECEIVED_AT = '_sdc_received_at' SINGER_BATCHED_AT = '_sdc_batched_at' diff --git a/target_sql/target_postgres.py b/target_sql/target_postgres.py new file mode 100644 index 00000000..c91c875d --- /dev/null +++ b/target_sql/target_postgres.py @@ -0,0 +1,86 @@ +import psycopg2 +from psycopg2 import sql + +from target_sql.target_sql import TargetSQL, TransformStream + +class TargetPostgres(TargetSQL): + def create_connection(self, config): + connection = config.get('target_connection') + + self.conn = psycopg2.connect( + host=connection.get('host', 'localhost'), + port=connection.get('port', 5432), + dbname=connection.get('database'), + user=connection.get('username'), + password=connection.get('password')) + + def destroy_connection(self): + self.conn.close() + + def sql_to_json_schema(self, sql_type, nullable): + _format = None + if sql_type == 'timestamp with time zone': + json_type = 'string' + _format = 'date-time' + elif sql_type == 'bigint': + json_type = 'integer' + elif sql_type == 'double precision': + json_type = 'number' + elif sql_type == 'boolean': + json_type = 'boolean' + elif sql_type == 'text': + json_type = 'string' + else: + raise Exception('Unsupported type `{}` in existing target table'.format(sql_type)) + + if nullable: + json_type = ['null', json_type] + + json_schema = {'type': json_type} + if _format: + json_schema['format'] = _format + + return json_schema + + def json_schema_to_sql(self, json_schema): + _type = json_schema['type'] + not_null = True + if isinstance(_type, list): + ln = len(_type) + if ln == 1: + _type = _type[0] + if ln == 2 and 'null' in _type: + not_null = False + if _type.index('null') == 0: + _type = _type[1] + else: + _type = _type[0] + elif ln > 2: + raise Exception('Multiple types per column not supported') + + sql_type = 'text' + + if 'format' in json_schema and \ + json_schema['format'] == 'date-time' and \ + _type == 'string': + sql_type = 'timestamp with time zone' + elif _type == 'boolean': + sql_type = 'boolean' + elif _type == 'integer': + sql_type = 'bigint' + elif _type == 'number': + sql_type = 'double precision' + + if not_null: + sql_type += ' NOT NULL' + + return sql_type + + def copy_rows(self, cur, table_name, headers, row_fn): + rows = TransformStream(row_fn) + + copy = sql.SQL('COPY {}.{} ({}) FROM STDIN CSV').format( + sql.Identifier(self.catalog), + sql.Identifier(table_name), + sql.SQL(', ').join(map(sql.Identifier, headers))) + cur.copy_expert(copy, rows) diff --git a/target_sql/target_redshift.py b/target_sql/target_redshift.py new file mode 100644 index 00000000..d2e1acbe --- /dev/null +++ b/target_sql/target_redshift.py @@ -0,0 +1,113 @@ +import uuid + +import boto3 +from psycopg2 import sql + +from target_sql.target_postgres import TargetPostgres, TransformStream + +class TargetRedshift(TargetPostgres): + MAX_VARCHAR = 65535 + + def __init__(self, config, *args, **kwargs): + s3_config = config.get('target_s3') + if not s3_config: + raise Exception('`target_s3` required') + self.s3_config = s3_config + + super(TargetRedshift, self).__init__(config, *args, **kwargs) + + def sql_to_json_schema(self, sql_type, nullable): + _format = None + if sql_type == 'timestamp with time zone': + json_type = 'string' + _format = 'date-time' + elif sql_type == 'bigint': + json_type = 'integer' + elif sql_type == 'double precision': + json_type = 'number' + elif sql_type == 'boolean': + json_type = 'boolean' + elif sql_type == 'character varying': + json_type = 'string' + else: + raise Exception('Unsupported type `{}` in existing target table'.format(sql_type)) + + if nullable: + json_type = ['null', json_type] + + json_schema = {'type': json_type} + if _format: + json_schema['format'] = _format + + return json_schema + + def json_schema_to_sql(self, json_schema): + _type = json_schema['type'] + not_null = True + if isinstance(_type, list): + ln = len(_type) + if ln == 1: + _type = _type[0] + if ln == 2 and 'null' in _type: + not_null = False + if _type.index('null') == 0: + _type = _type[1] + else: + _type = _type[0] + elif ln > 2: + raise Exception('Multiple types per column not supported') + + max_length = json_schema.get('maxLength', self.MAX_VARCHAR) + if max_length > self.MAX_VARCHAR: + max_length = self.MAX_VARCHAR + + sql_type = 'varchar({})'.format(max_length) + + if 'format' in json_schema and \ + json_schema['format'] == 'date-time' and \ + _type == 'string': + sql_type = 'timestamp with time zone' + elif _type == 'boolean': + sql_type = 'boolean' + elif _type == 'integer': + sql_type = 'bigint' + elif _type == 'number': + sql_type = 'double precision' + + if not_null: + sql_type += ' NOT NULL' + + return sql_type + + def copy_rows(self, cur, table_name, headers, row_fn): + s3_client = boto3.client( + 's3', + aws_access_key_id=self.s3_config.get('aws_access_key_id'), + aws_secret_access_key=self.s3_config.get('aws_secret_access_key')) + + bucket = self.s3_config.get('bucket') + if not bucket: + raise Exception('`target_s3.bucket` required') + prefix = self.s3_config.get('key_prefix', '') + key = prefix + table_name + self.NESTED_SEPARATOR + str(uuid.uuid4()).replace('-', '') + + rows = TransformStream(row_fn, binary=True) + + s3_client.upload_fileobj( + rows, + bucket, + key) + + source = 's3://{}/{}'.format(bucket, key) + credentials = 'aws_access_key_id={};aws_secret_access_key={}'.format( + self.s3_config.get('aws_access_key_id'), + self.s3_config.get('aws_secret_access_key')) + + copy_sql = sql.SQL('COPY {}.{} ({}) FROM {} CREDENTIALS {} FORMAT AS CSV').format( + sql.Identifier(self.catalog), + sql.Identifier(table_name), + sql.SQL(', ').join(map(sql.Identifier, headers)), + sql.Literal(source), + sql.Literal(credentials)) + + cur.execute(copy_sql) diff --git a/target_postgres/postgres.py b/target_sql/target_sql.py similarity index 84% rename from target_postgres/postgres.py rename to target_sql/target_sql.py index 8a42f76a..dfe8377a 100644 --- a/target_postgres/postgres.py +++ b/target_sql/target_sql.py @@ -9,7 +9,7 @@ import arrow from psycopg2 import sql -from target_postgres.singer_stream import ( +from target_sql.singer_stream import ( SINGER_RECEIVED_AT, SINGER_BATCHED_AT, SINGER_SEQUENCE, @@ -20,31 +20,62 @@ ) class TransformStream(object): - def __init__(self, fun): + def __init__(self, fun, binary=False): self.fun = fun + self.binary = binary def read(self, *args, **kwargs): - return self.fun() + if self.binary: + if len(args) > 0: + max_bytes = args[0] + else: + max_bytes = None + output = b'' + while (max_bytes is not None and len(output) < max_bytes) or True: ## TODO: overflow? + line = self.fun() + if line == '': + return output + output += line.encode('utf-8') + return output + else: + return self.fun() -class PostgresTarget(object): +class TargetSQL(object): NESTED_SEPARATOR = '__' - def __init__(self, connection, logger, *args, postgres_schema='public', **kwargs): - self.conn = connection + def __init__(self, config, logger, *args, **kwargs): self.logger = logger - self.postgres_schema = postgres_schema + self.catalog = config.get('target_catalog', 'public') + + self.create_connection(config) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.destroy_connection() + + def create_connection(self, config): + raise NotImplementedError() + + def destroy_connection(self): + raise NotImplementedError() def write_batch(self, stream_buffer): if stream_buffer.count == 0: return + self.logger.info('{} - Writing batch ({})'.format( + stream_buffer.stream, + stream_buffer.count)) + with self.conn.cursor() as cur: try: cur.execute('BEGIN;') processed_records = map(partial(self.process_record_message, stream_buffer.use_uuid_pk, - self.get_postgres_datetime()), + self.get_sql_datetime()), stream_buffer.peek_buffer()) versions = set() max_version = None @@ -58,7 +89,7 @@ def write_batch(self, stream_buffer): records_all_versions.append(record) table_metadata = self.get_table_metadata(cur, - self.postgres_schema, + self.catalog, stream_buffer.stream) ## TODO: check if PK has changed. Fail on PK change? Just update and log on PK change? @@ -75,7 +106,7 @@ def write_batch(self, stream_buffer): if current_table_version is not None and \ min(versions) < current_table_version: - self.logger.warn('{} - Records from an earlier table vesion detected.' + self.logger.warn('{} - Records from an earlier table version detected.' .format(stream_buffer.stream)) if len(versions) > 1: self.logger.warn('{} - Multiple table versions in stream, only using the latest.' @@ -83,7 +114,15 @@ def write_batch(self, stream_buffer): if current_table_version is not None and \ target_table_version > current_table_version: + self.logger.info('{} - New table version: {}'.format( + stream_buffer.stream, + target_table_version)) root_table_name = stream_buffer.stream + self.NESTED_SEPARATOR + str(target_table_version) + elif current_table_version is not None and \ + target_table_version < current_table_version: + message = '{} - Previous table version encountered in stream'.format(stream_buffer.stream) + self.logger.error(message) + raise Exception(message) else: root_table_name = stream_buffer.stream @@ -156,7 +195,7 @@ def activate_version(self, stream_buffer, version): cur.execute('BEGIN;') table_metadata = self.get_table_metadata(cur, - self.postgres_schema, + self.catalog, stream_buffer.stream) if not table_metadata: @@ -174,7 +213,7 @@ def activate_version(self, stream_buffer, version): SELECT tablename FROM pg_tables WHERE schemaname = {} AND tablename like {}; ''').format( - sql.Literal(self.postgres_schema), + sql.Literal(self.catalog), sql.Literal(versioned_root_table + '%'))) for versioned_table_name in map(lambda x: x[0], cur.fetchall()): @@ -185,7 +224,7 @@ def activate_version(self, stream_buffer, version): ALTER TABLE {table_schema}.{version_table} RENAME TO {stream_table}; DROP TABLE {table_schema}.{stream_table_old}; COMMIT;''').format( - table_schema=sql.Identifier(self.postgres_schema), + table_schema=sql.Identifier(self.catalog), stream_table_old=sql.Identifier(table_name + self.NESTED_SEPARATOR + 'old'), @@ -409,11 +448,11 @@ def denest_records(self, table_name, records, records_map, key_properties, pk_fk self.denest_record(table_name, None, record, records_map, key_properties, record_pk_fks, level) def upsert_table_schema(self, cur, table_name, schema, key_properties, table_version): - existing_table_schema = self.get_schema(cur, self.postgres_schema, table_name) + existing_table_schema = self.get_schema(cur, self.catalog, table_name) if existing_table_schema: schema = self.merge_put_schemas(cur, - self.postgres_schema, + self.catalog, table_name, existing_table_schema, schema) @@ -421,15 +460,15 @@ def upsert_table_schema(self, cur, table_name, schema, key_properties, table_ver else: schema = schema self.create_table(cur, - self.postgres_schema, - table_name, - schema, - key_properties, - table_version) + self.catalog, + table_name, + schema, + key_properties, + table_version) target_table_name = self.get_temp_table_name(table_name) self.create_table(cur, - self.postgres_schema, + self.catalog, target_table_name, schema, key_properties, @@ -437,12 +476,12 @@ def upsert_table_schema(self, cur, table_name, schema, key_properties, table_ver return target_table_name - def get_update_sql(self, target_table_name, temp_table_name, key_properties, subkeys): + def get_update_sql(self, target_table_name, temp_table_name, key_properties, columns, subkeys): full_table_name = sql.SQL('{}.{}').format( - sql.Identifier(self.postgres_schema), + sql.Identifier(self.catalog), sql.Identifier(target_table_name)) full_temp_table_name = sql.SQL('{}.{}').format( - sql.Identifier(self.postgres_schema), + sql.Identifier(self.catalog), sql.Identifier(temp_table_name)) pk_temp_select_list = [] @@ -455,7 +494,7 @@ def get_update_sql(self, target_table_name, temp_table_name, key_properties, sub pk_identifier)) pk_where_list.append( - sql.SQL('{table}.{pk} = {temp_table}.{pk}').format( + sql.SQL('{table}.{pk} = "dedupped".{pk}').format( table=full_table_name, temp_table=full_temp_table_name, pk=pk_identifier)) @@ -474,8 +513,7 @@ def get_update_sql(self, target_table_name, temp_table_name, key_properties, sub pk_null = sql.SQL(' AND ').join(pk_null_list) cxt_where = sql.SQL(' AND ').join(cxt_where_list) - sequence_join = sql.SQL(' AND {}.{} >= {}.{}').format( - full_temp_table_name, + sequence_join = sql.SQL(' AND "dedupped".{} >= {}.{}').format( sql.Identifier(SINGER_SEQUENCE), full_table_name, sql.Identifier(SINGER_SEQUENCE)) @@ -500,19 +538,34 @@ def get_update_sql(self, target_table_name, temp_table_name, key_properties, sub insert_distinct_on = pk_temp_select insert_distinct_order_by = distinct_order_by + insert_columns_list = [] + for column in columns: + insert_columns_list.append(sql.SQL('{}.{}').format(sql.Identifier('dedupped'), + sql.Identifier(column))) + insert_columns = sql.SQL(', ').join(insert_columns_list) + return sql.SQL(''' - WITH "pks" AS ( - SELECT DISTINCT ON ({pk_temp_select}) {pk_temp_select} - FROM {temp_table} - JOIN {table} ON {pk_where}{sequence_join}{distinct_order_by} - ) - DELETE FROM {table} USING "pks" WHERE {cxt_where}; + DELETE FROM {table} USING ( + SELECT "dedupped".* + FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY {pk_temp_select} + {distinct_order_by}) AS "pk_ranked" + FROM {temp_table} + {distinct_order_by}) AS "dedupped" + JOIN {table} ON {pk_where}{sequence_join} + WHERE pk_ranked = 1 + ) AS "pks" WHERE {cxt_where}; INSERT INTO {table} ( - SELECT DISTINCT ON ({insert_distinct_on}) {temp_table}.* - FROM {temp_table} + SELECT {insert_columns} + FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY {insert_distinct_on} + {insert_distinct_order_by}) AS "pk_ranked" + FROM {temp_table} + {insert_distinct_order_by}) AS "dedupped" LEFT JOIN {table} ON {pk_where} - WHERE {pk_null} - {insert_distinct_order_by} + WHERE pk_ranked = 1 AND {pk_null} ); DROP TABLE {temp_table}; ''').format(table=full_table_name, @@ -524,7 +577,11 @@ def get_update_sql(self, target_table_name, temp_table_name, key_properties, sub distinct_order_by=distinct_order_by, pk_null=pk_null, insert_distinct_on=insert_distinct_on, - insert_distinct_order_by=insert_distinct_order_by) + insert_distinct_order_by=insert_distinct_order_by, + insert_columns=insert_columns) + + def copy_rows(self, cur, table_name, headers, rows): + raise NotImplementedError() def persist_rows(self, cur, @@ -544,23 +601,17 @@ def transform(): try: row = next(rows) with io.StringIO() as out: - ## Serialize datetime to postgres compatible format + ## Serialize datetime to sql compatible format for prop in datetime_fields: if prop in row: - row[prop] = self.get_postgres_datetime(row[prop]) + row[prop] = self.get_sql_datetime(row[prop]) writer = csv.DictWriter(out, headers) writer.writerow(row) return out.getvalue() except StopIteration: return '' - csv_rows = TransformStream(transform) - - copy = sql.SQL('COPY {}.{} ({}) FROM STDIN CSV').format( - sql.Identifier(self.postgres_schema), - sql.Identifier(temp_table_name), - sql.SQL(', ').join(map(sql.Identifier, headers))) - cur.copy_expert(copy, csv_rows) + self.copy_rows(cur, temp_table_name, headers, transform) pattern = re.compile(SINGER_LEVEL.format('[0-9]+')) subkeys = list(filter(lambda header: re.match(pattern, header) is not None, headers)) @@ -568,10 +619,14 @@ def transform(): update_sql = self.get_update_sql(target_table_name, temp_table_name, key_properties, + headers, subkeys) + cur.execute(update_sql) - def get_postgres_datetime(self, *args): + ## TODO: delete s3 / cleamup? + + def get_sql_datetime(self, *args): if len(args) > 0: parsed_datetime = arrow.get(args[0]) else: @@ -605,68 +660,28 @@ def get_temp_table_name(self, stream_name): return stream_name + self.NESTED_SEPARATOR + str(uuid.uuid4()).replace('-', '') def sql_to_json_schema(self, sql_type, nullable): - _format = None - if sql_type == 'timestamp with time zone': - json_type = 'string' - _format = 'date-time' - elif sql_type == 'bigint': - json_type = 'integer' - elif sql_type == 'double precision': - json_type = 'number' - elif sql_type == 'boolean': - json_type = 'boolean' - elif sql_type == 'text': - json_type = 'string' - else: - raise Exception('Unsupported type `{}` in existing target table'.format(sql_type)) - - if nullable: - json_type = ['null', json_type] - - json_schema = {'type': json_type} - if _format: - json_schema['format'] = _format - - return json_schema + raise NotImplementedError() def json_schema_to_sql(self, json_schema): - _type = json_schema['type'] - not_null = True - if isinstance(_type, list): - ln = len(_type) - if ln == 1: - _type = _type[0] - if ln == 2 and 'null' in _type: - not_null = False - if _type.index('null') == 0: - _type = _type[1] - else: - _type = _type[0] - elif ln > 2: - raise Exception('Multiple types per column not supported') - - sql_type = 'text' - - if 'format' in json_schema and \ - json_schema['format'] == 'date-time' and \ - _type == 'string': - sql_type = 'timestamp with time zone' - elif _type == 'boolean': - sql_type = 'boolean' - elif _type == 'integer': - sql_type = 'bigint' - elif _type == 'number': - sql_type = 'double precision' - - if not_null: - sql_type += ' NOT NULL' - - return sql_type + raise NotImplementedError() def get_table_metadata(self, cur, table_schema, table_name): + cur.execute(sql.SQL(''' + SELECT EXISTS ( + SELECT 1 FROM pg_tables + WHERE schemaname = {} AND + tablename = {});''').format( + sql.Literal(table_schema), + sql.Literal(table_name))) + table_exists = cur.fetchone()[0] + + if not table_exists: + return None + cur.execute( - sql.SQL('SELECT obj_description(to_regclass({}));').format( - sql.Literal('{}.{}'.format(table_schema, table_name)))) + sql.SQL('SELECT description FROM pg_description WHERE objoid = {}::regclass;').format( + sql.Literal( + '"{}"."{}"'.format(table_schema, table_name)))) comment = cur.fetchone()[0] if comment: diff --git a/tests/fixtures.py b/tests/fixtures.py index e6649c76..de0b77fc 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -8,16 +8,18 @@ from faker import Faker from chance import chance -from target_postgres.singer_stream import SINGER_SEQUENCE +from target_sql.singer_stream import SINGER_SEQUENCE CONFIG = { - 'postgres_database': 'target_postgres_test' + 'target_connection': { + 'database': 'target_postgres_test' + } } TEST_DB = { 'host': 'localhost', 'port': 5432, - 'dbname': CONFIG['postgres_database'], + 'dbname': CONFIG['target_connection']['database'], 'user': None, 'password': None } diff --git a/tests/test_postgres.py b/tests/test_postgres.py index aabebb14..f725bdf4 100644 --- a/tests/test_postgres.py +++ b/tests/test_postgres.py @@ -5,9 +5,9 @@ import psycopg2 import psycopg2.extras -from target_postgres import main -from target_postgres import singer_stream -from target_postgres import postgres +from target_sql import target_sql +from target_sql import TargetPostgres +from target_sql import singer_stream from fixtures import CatStream, CONFIG, TEST_DB, db_cleanup ## TODO: create and test more fake streams @@ -130,7 +130,7 @@ def assert_records(conn, records, table_name, pks, match_pks=False): def test_loading_simple(db_cleanup): stream = CatStream(100) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -168,7 +168,7 @@ def test_loading_simple(db_cleanup): def test_upsert(db_cleanup): stream = CatStream(100) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -177,7 +177,7 @@ def test_upsert(db_cleanup): assert_records(conn, stream.records, 'cats', 'id') stream = CatStream(100) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -186,7 +186,7 @@ def test_upsert(db_cleanup): assert_records(conn, stream.records, 'cats', 'id') stream = CatStream(200) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -196,7 +196,7 @@ def test_upsert(db_cleanup): def test_nested_delete_on_parent(db_cleanup): stream = CatStream(100, nested_count=3) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -205,7 +205,7 @@ def test_nested_delete_on_parent(db_cleanup): assert_records(conn, stream.records, 'cats', 'id') stream = CatStream(100, nested_count=2) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -217,7 +217,7 @@ def test_nested_delete_on_parent(db_cleanup): def test_full_table_replication(db_cleanup): stream = CatStream(110, version=0, nested_count=3) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -231,7 +231,7 @@ def test_full_table_replication(db_cleanup): assert version_0_sub_count == 330 stream = CatStream(100, version=1, nested_count=3) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -245,7 +245,7 @@ def test_full_table_replication(db_cleanup): assert version_1_sub_count == 300 stream = CatStream(120, version=2, nested_count=2) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -260,7 +260,7 @@ def test_full_table_replication(db_cleanup): def test_deduplication_newer_rows(db_cleanup): stream = CatStream(100, nested_count=3, duplicates=2) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -282,7 +282,7 @@ def test_deduplication_newer_rows(db_cleanup): def test_deduplication_older_rows(db_cleanup): stream = CatStream(100, nested_count=2, duplicates=2, duplicate_sequence_delta=-100) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -304,14 +304,14 @@ def test_deduplication_older_rows(db_cleanup): def test_deduplication_existing_new_rows(db_cleanup): stream = CatStream(100, nested_count=2) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) original_sequence = stream.sequence stream = CatStream(100, nested_count=2, sequence=original_sequence - 20) - main(CONFIG, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -333,7 +333,7 @@ def mocked_mock_write_batch(stream_buffer): records = stream_buffer.flush_buffer() def test_multiple_batches_by_rows(db_cleanup): - with patch.object(postgres.PostgresTarget, + with patch.object(TargetPostgres, 'write_batch', side_effect=mocked_mock_write_batch) as mock_write_batch: config = CONFIG.copy() @@ -341,12 +341,12 @@ def test_multiple_batches_by_rows(db_cleanup): config['batch_detection_threshold'] = 5 stream = CatStream(100) - main(config, input_stream=stream) + target_sql(TargetPostgres, config, input_stream=stream) assert mock_write_batch.call_count == 6 def test_multiple_batches_by_memory(db_cleanup): - with patch.object(postgres.PostgresTarget, + with patch.object(TargetPostgres, 'write_batch', side_effect=mocked_mock_write_batch) as mock_write_batch: config = CONFIG.copy() @@ -354,7 +354,7 @@ def test_multiple_batches_by_memory(db_cleanup): config['batch_detection_threshold'] = 5 stream = CatStream(100) - main(config, input_stream=stream) + target_sql(TargetPostgres, config, input_stream=stream) assert mock_write_batch.call_count == 21 @@ -364,7 +364,7 @@ def test_multiple_batches_upsert(db_cleanup): config['batch_detection_threshold'] = 5 stream = CatStream(100, nested_count=2) - main(config, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -375,7 +375,7 @@ def test_multiple_batches_upsert(db_cleanup): assert_records(conn, stream.records, 'cats', 'id') stream = CatStream(100, nested_count=3) - main(config, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -391,7 +391,7 @@ def test_multiple_batches_by_memory_upsert(db_cleanup): config['batch_detection_threshold'] = 5 stream = CatStream(100, nested_count=2) - main(config, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: @@ -402,7 +402,7 @@ def test_multiple_batches_by_memory_upsert(db_cleanup): assert_records(conn, stream.records, 'cats', 'id') stream = CatStream(100, nested_count=3) - main(config, input_stream=stream) + target_sql(TargetPostgres, CONFIG, input_stream=stream) with psycopg2.connect(**TEST_DB) as conn: with conn.cursor() as cur: