Skip to content
Prev Previous commit
Next Next commit
Fixes spacing issues.
  • Loading branch information
musically-ut committed May 2, 2019
commit dc79fdab7e20aff873f13425c0467ea52cc3068e
103 changes: 56 additions & 47 deletions load_into_pg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python

import sys
import time
import argparse
Expand All @@ -16,17 +17,19 @@
# part of the file already downloaded
file_part = None


def show_progress(block_num, block_size, total_size):
"""Display the total size of the file to download and the progess in percent"""
"""Display the total size of the file to download and the progress in percent"""
global file_part
if file_part is None:
suffixes=['B','KB','MB','GB','TB']
suffixes = ['B', 'KB', 'MB', 'GB', 'TB']
suffixIndex = 0
pp_size = total_size
while pp_size > 1024:
suffixIndex += 1 #increment the index of the suffix
pp_size = pp_size/1024.0 #apply the division
six.print_('Total file size is: {0:.1f} {1}'.format(pp_size,suffixes[suffixIndex]))
suffixIndex += 1 # Increment the index of the suffix
pp_size = pp_size / 1024.0 # Apply the division
six.print_('Total file size is: {0:.1f} {1}'
.format(pp_size, suffixes[suffixIndex]))
six.print_("0 % of the file downloaded ...\r", end="", flush=True)
file_part = 0

Expand All @@ -40,6 +43,7 @@ def show_progress(block_num, block_size, total_size):
file_part = None
six.print_("")


def buildConnectionString(dbname, mbHost, mbPort, mbUsername, mbPassword):
dbConnectionParam = "dbname={}".format(dbname)

Expand All @@ -58,20 +62,23 @@ def buildConnectionString(dbname, mbHost, mbPort, mbUsername, mbPassword):
dbConnectionParam += ' password={}'.format(mbPassword)
return dbConnectionParam


def _makeDefValues(keys):
"""Returns a dictionary containing None for all keys."""
return dict(( (k, None) for k in keys ))
return dict(((k, None) for k in keys))


def _createMogrificationTemplate(table, keys, insertJson):
"""Return the template string for mogrification for the given keys."""
table_keys = ', '.join( [ '%(' + k + ')s' if (table, k) not in specialRules
else specialRules[table, k]
for k in keys ])
table_keys = ', '.join(['%(' + k + ')s' if (table, k) not in specialRules
else specialRules[table, k]
for k in keys])
if insertJson:
return ('(' + table_keys + ', %(jsonfield)s' + ')')
else:
return ('(' + table_keys + ')')


def _createCmdTuple(cursor, keys, templ, attribs, insertJson):
"""Use the cursor to mogrify a tuple of data.
The passed data in `attribs` is augmented with default data (NULLs) and the
Expand All @@ -83,14 +90,14 @@ def _createCmdTuple(cursor, keys, templ, attribs, insertJson):
defs.update(attribs)

if insertJson:
dict_attribs = { }
dict_attribs = {}
for name, value in attribs.items():
dict_attribs[name] = value
defs['jsonfield'] = json.dumps(dict_attribs)

values_to_insert = cursor.mogrify(templ, defs)
return cursor.mogrify(templ, defs)


def _getTableKeys(table):
"""Return an array of the keys for a given table"""
keys = None
Expand Down Expand Up @@ -177,25 +184,26 @@ def _getTableKeys(table):
]
elif table == 'PostHistory':
keys = [
'Id',
'PostHistoryTypeId',
'PostId',
'RevisionGUID',
'CreationDate',
'UserId',
'Text'
'Id'
, 'PostHistoryTypeId'
, 'PostId'
, 'RevisionGUID'
, 'CreationDate'
, 'UserId'
, 'Text'
]
elif table == 'Comments':
keys = [
'Id',
'PostId',
'Score',
'Text',
'CreationDate',
'UserId',
'Id'
, 'PostId'
, 'Score'
, 'Text'
, 'CreationDate'
, 'UserId'
]
return keys


def handleTable(table, insertJson, createFk, mbDbFile, dbConnectionParam):
"""Handle the table including the post/pre processing."""
keys = _getTableKeys(table)
Expand Down Expand Up @@ -228,10 +236,10 @@ def handleTable(table, insertJson, createFk, mbDbFile, dbConnectionParam):
six.print_('Processing data ...')
for rows in Processor.batch(Processor.parse(xml), 500):
valuesStr = ',\n'.join(
[ _createCmdTuple(cur, keys, tmpl, row_attribs, insertJson).decode('utf-8')
for row_attribs in rows
]
)
[_createCmdTuple(cur, keys, tmpl, row_attribs, insertJson).decode('utf-8')
for row_attribs in rows
]
)
if len(valuesStr) > 0:
cmd = 'INSERT INTO ' + table + \
' VALUES\n' + valuesStr + ';'
Expand All @@ -249,11 +257,11 @@ def handleTable(table, insertJson, createFk, mbDbFile, dbConnectionParam):
if createFk:
# fk-processing (creation of foreign keys)
start_time = time.time()
six.print_('fk processing ...')
six.print_('Foreign Key processing ...')
if post != '':
cur.execute(fk)
conn.commit()
six.print_('fk processing took {0:.1f} seconds'.format(time.time() - start_time))
six.print_('Foreign Key processing took {0:.1f} seconds'.format(time.time() - start_time))

except IOError as e:
six.print_("Could not read from file {}.".format(dbFile), file=sys.stderr)
Expand All @@ -275,7 +283,7 @@ def moveTableToSchema(table, schemaName, dbConnectionParam):
cur.execute('CREATE SCHEMA IF NOT EXISTS ' + schemaName + ';')
conn.commit()
# move the table to the right schema
cur.execute('ALTER TABLE '+table+' SET SCHEMA ' + schemaName + ';')
cur.execute('ALTER TABLE ' + table + ' SET SCHEMA ' + schemaName + ';')
conn.commit()
except pg.Error as e:
six.print_("Error in dealing with the database.", file=sys.stderr)
Expand All @@ -288,76 +296,76 @@ def moveTableToSchema(table, schemaName, dbConnectionParam):
#############################################################

parser = argparse.ArgumentParser()
parser.add_argument( '-t', '--table'
parser.add_argument('-t', '--table'
, help = 'The table to work on.'
, choices = ['Users', 'Badges', 'Posts', 'Tags', 'Votes', 'PostLinks', 'PostHistory', 'Comments']
, default = None
)

parser.add_argument( '-d', '--dbname'
parser.add_argument('-d', '--dbname'
, help = 'Name of database to create the table in. The database must exist.'
, default = 'stackoverflow'
)

parser.add_argument( '-f', '--file'
parser.add_argument('-f', '--file'
, help = 'Name of the file to extract data from.'
, default = None
)

parser.add_argument( '-s', '--so-project'
parser.add_argument('-s', '--so-project'
, help = 'StackExchange project to load.'
, default = None
)

parser.add_argument( '--archive-url'
parser.add_argument('--archive-url'
, help = 'URL of the archive directory to retrieve.'
, default = 'https://ia800107.us.archive.org/27/items/stackexchange'
)

parser.add_argument( '-k', '--keep-archive'
parser.add_argument('-k', '--keep-archive'
, help = 'Will preserve the downloaded archive instead of deleting it.'
, action = 'store_true'
, default = False
)

parser.add_argument( '-u', '--username'
parser.add_argument('-u', '--username'
, help = 'Username for the database.'
, default = None
)

parser.add_argument( '-p', '--password'
parser.add_argument('-p', '--password'
, help = 'Password for the database.'
, default = None
)

parser.add_argument( '-P', '--port'
parser.add_argument('-P', '--port'
, help = 'Port to connect with the database on.'
, default = None
)

parser.add_argument( '-H', '--host'
parser.add_argument('-H', '--host'
, help = 'Hostname for the database.'
, default = None
)

parser.add_argument( '--with-post-body'
parser.add_argument('--with-post-body'
, help = 'Import the posts with the post body. Only used if importing Posts.xml'
, action = 'store_true'
, default = False
)

parser.add_argument( '-j', '--insert-json'
parser.add_argument('-j', '--insert-json'
, help = 'Insert raw data as JSON.'
, action = 'store_true'
, default = False
)

parser.add_argument( '-n', '--schema-name'
parser.add_argument('-n', '--schema-name'
, help = 'Use specific schema.'
, default = 'public'
)

parser.add_argument( '--foreign-keys'
parser.add_argument('--foreign-keys'
, help = 'Create foreign keys.'
, action = 'store_true'
, default = False
Expand Down Expand Up @@ -421,13 +429,14 @@ def moveTableToSchema(table, schemaName, dbConnectionParam):
six.print_('Error: impossible to extract the {0} archive ({1})'.format(url, e))
exit(1)

tables = [ 'Tags', 'Users', 'Badges', 'Posts', 'Comments', 'Votes', 'PostLinks', 'PostHistory' ]
tables = ['Tags', 'Users', 'Badges', 'Posts', 'Comments',
'Votes', 'PostLinks', 'PostHistory']

for table in tables:
six.print_('Load {0}.xml file'.format(table))
handleTable(table, args.insert_json, args.foreign_keys, None, dbConnectionParam)
# remove file
os.remove(table+'.xml')
os.remove(table + '.xml')

if not args.keep_archive:
os.remove(filepath)
Expand Down