Skip to content

Commit 9fb07b5

Browse files
committed
redshift S3 working
1 parent ddd77e1 commit 9fb07b5

File tree

9 files changed

+306
-155
lines changed

9 files changed

+306
-155
lines changed

setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from setuptools import setup, find_packages
44

55
setup(
6-
name='target-postgres',
6+
name='target-sql',
77
version="0.0.1",
8-
description='Singer.io target for loading data into postgres',
8+
description='Singer.io targets for loading data into SQL databases',
99
classifiers=['Programming Language :: Python :: 3 :: Only'],
1010
py_modules=['target_postgres'],
1111
install_requires=[
@@ -16,7 +16,8 @@
1616
],
1717
entry_points='''
1818
[console_scripts]
19-
target-postgres=target_postgres:main
19+
target-postgres=target_sql:target_postgres_main
20+
target-redshift=target_sql:target_redshift_main
2021
''',
2122
packages=find_packages()
2223
)

target_postgres/__init__.py renamed to target_sql/__init__.py

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from singer import utils, metadata, metrics
1010
import psycopg2
1111

12-
from target_postgres.postgres import PostgresTarget
13-
from target_postgres.singer_stream import BufferedSingerStream
12+
from target_sql.target_postgres import TargetPostgres
13+
from target_sql.target_redshift import TargetRedshift
14+
from target_sql.singer_stream import BufferedSingerStream
1415

1516
LOGGER = singer.get_logger()
1617

1718
REQUIRED_CONFIG_KEYS = [
18-
'postgres_database'
19+
'target_connection'
1920
]
2021

2122
def flush_stream(target, stream_buffer):
@@ -87,47 +88,40 @@ def line_handler(streams, target, max_batch_rows, max_batch_size, line):
8788
line_data['type'],
8889
line))
8990

90-
def main(config, input_stream=None):
91+
def target_sql(target_class, config, input_stream=None):
9192
try:
92-
connection = psycopg2.connect(
93-
host=config.get('postgres_host', 'localhost'),
94-
port=config.get('postgres_port', 5432),
95-
dbname=config.get('postgres_database'),
96-
user=config.get('postgres_username'),
97-
password=config.get('postgres_password'))
98-
99-
streams = {}
100-
postgres_target = PostgresTarget(
101-
connection,
102-
LOGGER,
103-
postgres_schema=config.get('postgres_schema', 'public'))
104-
105-
max_batch_rows = config.get('max_batch_rows')
106-
max_batch_size = config.get('max_batch_size')
107-
batch_detection_threshold = config.get('batch_detection_threshold', 5000)
108-
109-
if not input_stream:
110-
input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
111-
112-
line_count = 0
113-
for line in input_stream:
114-
line_handler(streams, postgres_target, max_batch_rows, max_batch_size, line)
115-
if line_count > 0 and line_count % batch_detection_threshold == 0:
116-
flush_streams(streams, postgres_target)
117-
line_count += 1
118-
119-
flush_streams(streams, postgres_target, force=True)
120-
121-
connection.close()
93+
with target_class(config, LOGGER) as target:
94+
max_batch_rows = config.get('max_batch_rows')
95+
max_batch_size = config.get('max_batch_size')
96+
batch_detection_threshold = config.get('batch_detection_threshold', 5000)
97+
98+
if not input_stream:
99+
input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
100+
101+
line_count = 0
102+
streams = {}
103+
for line in input_stream:
104+
line_handler(streams, target, max_batch_rows, max_batch_size, line)
105+
if line_count > 0 and line_count % batch_detection_threshold == 0:
106+
flush_streams(streams, target)
107+
line_count += 1
108+
109+
flush_streams(streams, target, force=True)
122110
except Exception as e:
123111
LOGGER.critical(e)
124112
raise e
125113

126-
if __name__ == "__main__":
114+
def main(target_class):
127115
try:
128116
args = utils.parse_args(REQUIRED_CONFIG_KEYS)
129117

130-
main(args.config)
118+
target_sql(target_class, args.config)
131119
except Exception as e:
132120
LOGGER.critical(e)
133121
raise e
122+
123+
def target_postgres_main():
124+
main(TargetPostgres)
125+
126+
def target_redshift_main():
127+
main(TargetRedshift)
File renamed without changes.

target_postgres/singer_stream.py renamed to target_sql/singer_stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from jsonschema import Draft4Validator, FormatChecker
22
import arrow
33

4-
from target_postgres.pysize import get_size
4+
from target_sql.pysize import get_size
55

66
SINGER_RECEIVED_AT = '_sdc_received_at'
77
SINGER_BATCHED_AT = '_sdc_batched_at'

target_sql/target_postgres.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import psycopg2
2+
from psycopg2 import sql
3+
4+
from target_sql.target_sql import TargetSQL, TransformStream
5+
6+
class TargetPostgres(TargetSQL):
7+
def create_connection(self, config):
8+
connection = config.get('target_connection')
9+
10+
self.conn = psycopg2.connect(
11+
host=connection.get('host', 'localhost'),
12+
port=connection.get('port', 5432),
13+
dbname=connection.get('database'),
14+
user=connection.get('username'),
15+
password=connection.get('password'))
16+
17+
def destroy_connection(self):
18+
self.conn.close()
19+
20+
def sql_to_json_schema(self, sql_type, nullable):
21+
_format = None
22+
if sql_type == 'timestamp with time zone':
23+
json_type = 'string'
24+
_format = 'date-time'
25+
elif sql_type == 'bigint':
26+
json_type = 'integer'
27+
elif sql_type == 'double precision':
28+
json_type = 'number'
29+
elif sql_type == 'boolean':
30+
json_type = 'boolean'
31+
elif sql_type == 'text':
32+
json_type = 'string'
33+
else:
34+
raise Exception('Unsupported type `{}` in existing target table'.format(sql_type))
35+
36+
if nullable:
37+
json_type = ['null', json_type]
38+
39+
json_schema = {'type': json_type}
40+
if _format:
41+
json_schema['format'] = _format
42+
43+
return json_schema
44+
45+
def json_schema_to_sql(self, json_schema):
46+
_type = json_schema['type']
47+
not_null = True
48+
if isinstance(_type, list):
49+
ln = len(_type)
50+
if ln == 1:
51+
_type = _type[0]
52+
if ln == 2 and 'null' in _type:
53+
not_null = False
54+
if _type.index('null') == 0:
55+
_type = _type[1]
56+
else:
57+
_type = _type[0]
58+
elif ln > 2:
59+
raise Exception('Multiple types per column not supported')
60+
61+
sql_type = 'text'
62+
63+
if 'format' in json_schema and \
64+
json_schema['format'] == 'date-time' and \
65+
_type == 'string':
66+
sql_type = 'timestamp with time zone'
67+
elif _type == 'boolean':
68+
sql_type = 'boolean'
69+
elif _type == 'integer':
70+
sql_type = 'bigint'
71+
elif _type == 'number':
72+
sql_type = 'double precision'
73+
74+
if not_null:
75+
sql_type += ' NOT NULL'
76+
77+
return sql_type
78+
79+
def copy_rows(self, cur, table_name, headers, row_fn):
80+
rows = TransformStream(row_fn)
81+
82+
copy = sql.SQL('COPY {}.{} ({}) FROM STDIN CSV').format(
83+
sql.Identifier(self.catalog),
84+
sql.Identifier(table_name),
85+
sql.SQL(', ').join(map(sql.Identifier, headers)))
86+
cur.copy_expert(copy, rows)

target_sql/target_redshift.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import uuid
2+
3+
import boto3
4+
from psycopg2 import sql
5+
6+
from target_sql.target_postgres import TargetPostgres, TransformStream
7+
8+
class TargetRedshift(TargetPostgres):
9+
def __init__(self, config, *args, **kwargs):
10+
s3_config = config.get('target_s3')
11+
if not s3_config:
12+
raise Exception('`target_s3` required')
13+
self.s3_config = s3_config
14+
15+
super(TargetRedshift, self).__init__(config, *args, **kwargs)
16+
17+
def sql_to_json_schema(self, sql_type, nullable):
18+
_format = None
19+
if sql_type == 'timestamp with time zone':
20+
json_type = 'string'
21+
_format = 'date-time'
22+
elif sql_type == 'bigint':
23+
json_type = 'integer'
24+
elif sql_type == 'double precision':
25+
json_type = 'number'
26+
elif sql_type == 'boolean':
27+
json_type = 'boolean'
28+
elif sql_type[:7] == 'varchar':
29+
json_type = 'string'
30+
else:
31+
raise Exception('Unsupported type `{}` in existing target table'.format(sql_type))
32+
33+
if nullable:
34+
json_type = ['null', json_type]
35+
36+
json_schema = {'type': json_type}
37+
if _format:
38+
json_schema['format'] = _format
39+
40+
return json_schema
41+
42+
def json_schema_to_sql(self, json_schema):
43+
_type = json_schema['type']
44+
not_null = True
45+
if isinstance(_type, list):
46+
ln = len(_type)
47+
if ln == 1:
48+
_type = _type[0]
49+
if ln == 2 and 'null' in _type:
50+
not_null = False
51+
if _type.index('null') == 0:
52+
_type = _type[1]
53+
else:
54+
_type = _type[0]
55+
elif ln > 2:
56+
raise Exception('Multiple types per column not supported')
57+
58+
sql_type = 'varchar(65535)'
59+
60+
if 'format' in json_schema and \
61+
json_schema['format'] == 'date-time' and \
62+
_type == 'string':
63+
sql_type = 'timestamp with time zone'
64+
elif _type == 'boolean':
65+
sql_type = 'boolean'
66+
elif _type == 'integer':
67+
sql_type = 'bigint'
68+
elif _type == 'number':
69+
sql_type = 'double precision'
70+
71+
if not_null:
72+
sql_type += ' NOT NULL'
73+
74+
return sql_type
75+
76+
def copy_rows(self, cur, table_name, headers, row_fn):
77+
s3_client = boto3.client(
78+
's3',
79+
aws_access_key_id=self.s3_config.get('aws_access_key_id'),
80+
aws_secret_access_key=self.s3_config.get('aws_secret_access_key'))
81+
82+
bucket = self.s3_config.get('bucket')
83+
if not bucket:
84+
raise Exception('`target_s3.bucket` required')
85+
prefix = self.s3_config.get('key_prefix', '')
86+
key = prefix + table_name + self.NESTED_SEPARATOR + str(uuid.uuid4()).replace('-', '')
87+
88+
rows = TransformStream(row_fn, binary=True)
89+
90+
s3_client.upload_fileobj(
91+
rows,
92+
bucket,
93+
key)
94+
95+
source = 's3://{}/{}'.format(bucket, key)
96+
credentials = 'aws_access_key_id={};aws_secret_access_key={};'.format(
97+
self.s3_config.get('aws_access_key_id'),
98+
self.s3_config.get('aws_secret_access_key'))
99+
100+
copy_sql = sql.SQL('COPY {}.{} ({}) FROM {} CREDENTIALS {} FORMAT AS CSV').format(
101+
sql.Identifier(self.catalog),
102+
sql.Identifier(table_name),
103+
sql.SQL(', ').join(map(sql.Identifier, headers)),
104+
sql.Literal(source),
105+
sql.Literal(credentials))
106+
107+
cur.execute(copy_sql)

0 commit comments

Comments
 (0)