diff --git a/.gitignore b/.gitignore index ad89c36..b735f8c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ test/log.html test/my_db_test.db test/output.xml test/report.html +.vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..c4fc3da --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python.linting.pylintEnabled": true, + "python.pythonPath": "/usr/bin/python3" +} \ No newline at end of file diff --git a/src/DatabaseLibrary/__init__.py b/src/DatabaseLibrary/__init__.py index 9ccc368..1bb38c1 100644 --- a/src/DatabaseLibrary/__init__.py +++ b/src/DatabaseLibrary/__init__.py @@ -21,6 +21,7 @@ __version_file_path__ = os.path.join(os.path.dirname(__file__), 'VERSION') __version__ = open(__version_file_path__, 'r').read().strip() + class DatabaseLibrary(ConnectionManager, Query, Assertion): """ Database Library contains utilities meant for Robot Framework's usage. diff --git a/src/DatabaseLibrary/assertion.py b/src/DatabaseLibrary/assertion.py index 382f52b..3104858 100644 --- a/src/DatabaseLibrary/assertion.py +++ b/src/DatabaseLibrary/assertion.py @@ -20,7 +20,7 @@ class Assertion(object): Assertion handles all the assertions of Database Library. """ - def check_if_exists_in_database(self, selectStatement, sansTran=False): + def check_if_exists_in_database(self, selectStatement, sansTran=False, alias=None): """ Check if any row would be returned by given the input `selectStatement`. If there are no results, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction @@ -41,12 +41,13 @@ def check_if_exists_in_database(self, selectStatement, sansTran=False): Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Check If Exists In Database | SELECT id FROM person WHERE first_name = 'John' | True | """ - logger.info ('Executing : Check If Exists In Database | %s ' % selectStatement) - if not self.query(selectStatement, sansTran): + logger.info('Executing : Check If Exists In Database | %s | %s ' % (selectStatement,alias)) + + if not self.query(selectStatement=selectStatement, sansTran=sansTran, alias=alias): raise AssertionError("Expected to have have at least one row from '%s' " "but got 0 rows." % selectStatement) - def check_if_not_exists_in_database(self, selectStatement, sansTran=False): + def check_if_not_exists_in_database(self, selectStatement, sansTran=False, alias=None): """ This is the negation of `check_if_exists_in_database`. @@ -69,13 +70,15 @@ def check_if_not_exists_in_database(self, selectStatement, sansTran=False): Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Check If Not Exists In Database | SELECT id FROM person WHERE first_name = 'John' | True | """ - logger.info('Executing : Check If Not Exists In Database | %s ' % selectStatement) - queryResults = self.query(selectStatement, sansTran) + logger.info( + 'Executing : Check If Not Exists In Database | %s | %s ' % (selectStatement,alias)) + queryResults = self.query( + selectStatement=selectStatement, sansTran=sansTran, alias=alias) if queryResults: raise AssertionError("Expected to have have no rows from '%s' " "but got some rows : %s." % (selectStatement, queryResults)) - def row_count_is_0(self, selectStatement, sansTran=False): + def row_count_is_0(self, selectStatement, sansTran=False, alias=None): """ Check if any rows are returned from the submitted `selectStatement`. If there are, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction commit or @@ -96,13 +99,15 @@ def row_count_is_0(self, selectStatement, sansTran=False): Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Row Count is 0 | SELECT id FROM person WHERE first_name = 'John' | True | """ - logger.info('Executing : Row Count Is 0 | %s ' % selectStatement) - num_rows = self.row_count(selectStatement, sansTran) + logger.info('Executing : Row Count Is 0 | %s | %s ' % (selectStatement,alias)) + logger.info( + 'Connection: Row Count Is 0 | %s' % alias) + num_rows = self.row_count(selectStatement, sansTran, alias) if num_rows > 0: raise AssertionError("Expected zero rows to be returned from '%s' " "but got rows back. Number of rows returned was %s" % (selectStatement, num_rows)) - def row_count_is_equal_to_x(self, selectStatement, numRows, sansTran=False): + def row_count_is_equal_to_x(self, selectStatement, numRows, sansTran=False, alias=None): """ Check if the number of rows returned from `selectStatement` is equal to the value submitted. If not, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit @@ -124,13 +129,15 @@ def row_count_is_equal_to_x(self, selectStatement, numRows, sansTran=False): Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Row Count Is Equal To X | SELECT id FROM person WHERE first_name = 'John' | 0 | True | """ - logger.info('Executing : Row Count Is Equal To X | %s | %s ' % (selectStatement, numRows)) - num_rows = self.row_count(selectStatement, sansTran) + logger.info('Executing : Row Count Is Equal To X | %s | %s | %s ' % + (selectStatement, numRows,alias)) + + num_rows = self.row_count(selectStatement, sansTran, alias) if num_rows != int(numRows.encode('ascii')): raise AssertionError("Expected same number of rows to be returned from '%s' " "than the returned rows of %s" % (selectStatement, num_rows)) - def row_count_is_greater_than_x(self, selectStatement, numRows, sansTran=False): + def row_count_is_greater_than_x(self, selectStatement, numRows, sansTran=False, alias=None): """ Check if the number of rows returned from `selectStatement` is greater than the value submitted. If not, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit @@ -152,13 +159,15 @@ def row_count_is_greater_than_x(self, selectStatement, numRows, sansTran=False): Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Row Count Is Greater Than X | SELECT id FROM person | 1 | True | """ - logger.info('Executing : Row Count Is Greater Than X | %s | %s ' % (selectStatement, numRows)) - num_rows = self.row_count(selectStatement, sansTran) + logger.info('Executing : Row Count Is Greater Than X | %s | %s | %s ' % ( + selectStatement, numRows,alias)) + + num_rows = self.row_count(selectStatement, sansTran, alias) if num_rows <= int(numRows.encode('ascii')): raise AssertionError("Expected more rows to be returned from '%s' " "than the returned rows of %s" % (selectStatement, num_rows)) - def row_count_is_less_than_x(self, selectStatement, numRows, sansTran=False): + def row_count_is_less_than_x(self, selectStatement, numRows, sansTran=False, alias=None): """ Check if the number of rows returned from `selectStatement` is less than the value submitted. If not, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit @@ -180,13 +189,15 @@ def row_count_is_less_than_x(self, selectStatement, numRows, sansTran=False): Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Row Count Is Less Than X | SELECT id FROM person | 3 | True | """ - logger.info('Executing : Row Count Is Less Than X | %s | %s ' % (selectStatement, numRows)) - num_rows = self.row_count(selectStatement, sansTran) + logger.info('Executing : Row Count Is Less Than X | %s | %s | %s ' % ( + selectStatement, numRows,alias)) + num_rows = self.row_count(selectStatement, sansTran, alias) + logger.info('Row Num: %s ' % str(num_rows)) if num_rows >= int(numRows.encode('ascii')): raise AssertionError("Expected less rows to be returned from '%s' " "than the returned rows of %s" % (selectStatement, num_rows)) - def table_must_exist(self, tableName, sansTran=False): + def table_must_exist(self, tableName, sansTran=False, alias=None): """ Check if the table given exists in the database. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -203,15 +214,24 @@ def table_must_exist(self, tableName, sansTran=False): Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Table Must Exist | person | True | """ - logger.info('Executing : Table Must Exist | %s ' % tableName) - if self.db_api_module_name in ["cx_Oracle"]: - selectStatement = ("SELECT * FROM all_objects WHERE object_type IN ('TABLE','VIEW') AND owner = SYS_CONTEXT('USERENV', 'SESSION_USER') AND object_name = UPPER('%s')" % tableName) - elif self.db_api_module_name in ["sqlite3"]: - selectStatement = ("SELECT name FROM sqlite_master WHERE type='table' AND name='%s' COLLATE NOCASE" % tableName) - elif self.db_api_module_name in ["ibm_db", "ibm_db_dbi"]: - selectStatement = ("SELECT name FROM SYSIBM.SYSTABLES WHERE type='T' AND name=UPPER('%s')" % tableName) + logger.info('Executing : Table Must Exist | %s | %s ' % (tableName,alias)) + + connection, module_api = self._get_cache(alias) + + if module_api in ["cx_Oracle"]: + selectStatement = ("SELECT * FROM all_objects WHERE object_type IN ('TABLE','VIEW') AND owner = SYS_CONTEXT('USERENV', 'SESSION_USER') \ + AND object_name = UPPER('%s')" % tableName) + elif module_api in ["sqlite3"]: + selectStatement = ( + "SELECT name FROM sqlite_master WHERE type='table' AND name='%s' COLLATE NOCASE" % tableName) + elif module_api in ["ibm_db", "ibm_db_dbi"]: + selectStatement = ( + "SELECT name FROM SYSIBM.SYSTABLES WHERE type='T' AND name=UPPER('%s')" % tableName) else: - selectStatement = ("SELECT * FROM information_schema.tables WHERE table_name='%s'" % tableName) - num_rows = self.row_count(selectStatement, sansTran) + selectStatement = ( + "SELECT * FROM information_schema.tables WHERE table_name='%s'" % tableName) + num_rows = self.row_count(selectStatement, sansTran, alias) + logger.info('Row Num: %s ' % str(num_rows)) if num_rows == 0: - raise AssertionError("Table '%s' does not exist in the db" % tableName) + raise AssertionError( + "Table '%s' does not exist in the db" % tableName) diff --git a/src/DatabaseLibrary/connection_manager.py b/src/DatabaseLibrary/connection_manager.py index dc949c3..0963565 100644 --- a/src/DatabaseLibrary/connection_manager.py +++ b/src/DatabaseLibrary/connection_manager.py @@ -13,6 +13,10 @@ # limitations under the License. import importlib +import robot +from robot.libraries.BuiltIn import BuiltIn +from robot.utils.asserts import fail +from urllib.parse import urlparse try: import ConfigParser @@ -30,11 +34,37 @@ class ConnectionManager(object): def __init__(self): """ Initializes _dbconnection to None. + Added cache mode for multi connection use. + Added to all method a new field, alias (Name of connection) """ - self._dbconnection = None - self.db_api_module_name = None + self._cache = robot.utils.ConnectionCache('No connection created') + self.builtin = BuiltIn() - def connect_to_database(self, dbapiModuleName=None, dbName=None, dbUsername=None, dbPassword=None, dbHost=None, dbPort=None, dbCharset=None, dbConfigFile="./resources/db.cfg"): + def _push_cache(self, alias=None, connection=None, db_api_module_name=None): + """ + Overlay _cache.register using dictionary + Create a dictionary that contains the dbconnection and the api_module used + and push it into the cache + """ + logger.info('Connection Name: %s | Db Module: %s ' % + (alias, db_api_module_name)) + obj_dict = {'connection': connection, 'module': db_api_module_name} + self._cache.register(obj_dict, alias=alias) + + def _get_cache(self, alias=None): + """ + Overlay _cache.switch using dictionary + Get from cache the dictionary contain dbconnection and api_module + and return them + """ + obj_dict = self._cache.switch(alias) + dbconnection = obj_dict['connection'] + db_api_module_name = obj_dict['module'] + + return dbconnection, db_api_module_name + + def connect_to_database(self, dbapiModuleName=None, dbName=None, dbUsername=None, dbPassword=None, dbHost=None, + dbPort=None, dbCharset=None, dbConfigFile="./resources/db.cfg", url=None, alias=None): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to connect to the database using `dbName`, `dbUsername`, and `dbPassword`. @@ -49,8 +79,12 @@ def connect_to_database(self, dbapiModuleName=None, dbName=None, dbUsername=None The `dbConfigFile` is useful if you don't want to check into your SCM your database credentials. + Added new field alias + Added _cache.register to register given connection with alias + Example db.cfg file | [default] + | alias=aliasuwant | dbapiModuleName=pymysqlforexample | dbName=yourdbname | dbUsername=yourusername @@ -60,7 +94,7 @@ def connect_to_database(self, dbapiModuleName=None, dbName=None, dbUsername=None Example usage: | # explicitly specifies all db property values | - | Connect To Database | psycopg2 | my_db | postgres | s3cr3t | tiger.foobar.com | 5432 | + | Connect To Database | alias | psycopg2 | my_db | postgres | s3cr3t | tiger.foobar.com | 5432 | | # loads all property values from default.cfg | | Connect To Database | dbConfigFile=default.cfg | @@ -69,112 +103,180 @@ def connect_to_database(self, dbapiModuleName=None, dbName=None, dbUsername=None | Connect To Database | | # uses explicit `dbapiModuleName` and `dbName` but uses the `dbUsername` and `dbPassword` in 'default.cfg' | - | Connect To Database | psycopg2 | my_db_test | dbConfigFile=default.cfg | + | Connect To Database | alias | psycopg2 | my_db_test | dbConfigFile=default.cfg | | # uses explicit `dbapiModuleName` and `dbName` but uses the `dbUsername` and `dbPassword` in './resources/db.cfg' | - | Connect To Database | psycopg2 | my_db_test | + | Connect To Database | alias | psycopg2 | my_db_test | """ + logger.info('Creating Db Connection using : alias=%s,url=%s dbapiModuleName=%s, dbName=%s, \ + dbUsername=%s, dbPassword=%s, dbHost=%s, dbPort=%s, dbCharset=%s, \ + dbConfigFile=%s ' % (alias, url, dbapiModuleName, dbName, dbUsername, dbPassword, dbHost, dbPort, + dbCharset, dbConfigFile)) + config = ConfigParser.ConfigParser() config.read([dbConfigFile]) - dbapiModuleName = dbapiModuleName or config.get('default', 'dbapiModuleName') - dbName = dbName or config.get('default', 'dbName') - dbUsername = dbUsername or config.get('default', 'dbUsername') - dbPassword = dbPassword if dbPassword is not None else config.get('default', 'dbPassword') - dbHost = dbHost or config.get('default', 'dbHost') or 'localhost' - dbPort = int(dbPort or config.get('default', 'dbPort')) - - if dbapiModuleName == "excel" or dbapiModuleName == "excelrw": - self.db_api_module_name = "pyodbc" - db_api_2 = importlib.import_module("pyodbc") - else: - self.db_api_module_name = dbapiModuleName - db_api_2 = importlib.import_module(dbapiModuleName) - if dbapiModuleName in ["MySQLdb", "pymysql"]: - dbPort = dbPort or 3306 - logger.info('Connecting using : %s.connect(db=%s, user=%s, passwd=%s, host=%s, port=%s, charset=%s) ' % (dbapiModuleName, dbName, dbUsername, dbPassword, dbHost, dbPort, dbCharset)) - self._dbconnection = db_api_2.connect(db=dbName, user=dbUsername, passwd=dbPassword, host=dbHost, port=dbPort, charset=dbCharset) - elif dbapiModuleName in ["psycopg2"]: - dbPort = dbPort or 5432 - logger.info('Connecting using : %s.connect(database=%s, user=%s, password=%s, host=%s, port=%s) ' % (dbapiModuleName, dbName, dbUsername, dbPassword, dbHost, dbPort)) - self._dbconnection = db_api_2.connect(database=dbName, user=dbUsername, password=dbPassword, host=dbHost, port=dbPort) - elif dbapiModuleName in ["pyodbc", "pypyodbc"]: - dbPort = dbPort or 1433 - logger.info('Connecting using : %s.connect(DRIVER={SQL Server};SERVER=%s,%s;DATABASE=%s;UID=%s;PWD=%s)' % (dbapiModuleName, dbHost, dbPort, dbName, dbUsername, dbPassword)) - self._dbconnection = db_api_2.connect('DRIVER={SQL Server};SERVER=%s,%s;DATABASE=%s;UID=%s;PWD=%s' % (dbHost, dbPort, dbName, dbUsername, dbPassword)) - elif dbapiModuleName in ["excel"]: - logger.info( - 'Connecting using : %s.connect(DRIVER={Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)};DBQ=%s;ReadOnly=1;Extended Properties="Excel 8.0;HDR=YES";)' % ( - dbapiModuleName, dbName)) - self._dbconnection = db_api_2.connect( - 'DRIVER={Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)};DBQ=%s;ReadOnly=1;Extended Properties="Excel 8.0;HDR=YES";)' % ( - dbName), autocommit=True) - elif dbapiModuleName in ["excelrw"]: - logger.info( - 'Connecting using : %s.connect(DRIVER={Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)};DBQ=%s;ReadOnly=0;Extended Properties="Excel 8.0;HDR=YES";)' % ( - dbapiModuleName, dbName)) - self._dbconnection = db_api_2.connect( - 'DRIVER={Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)};DBQ=%s;ReadOnly=0;Extended Properties="Excel 8.0;HDR=YES";)' % ( - dbName), autocommit=True) - elif dbapiModuleName in ["ibm_db", "ibm_db_dbi"]: - dbPort = dbPort or 50000 - logger.info('Connecting using : %s.connect(DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=%s;) ' % (dbapiModuleName, dbName, dbHost, dbPort, dbUsername, dbPassword)) - self._dbconnection = db_api_2.connect('DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=%s;' % (dbName, dbHost, dbPort, dbUsername, dbPassword), '', '') - elif dbapiModuleName in ["cx_Oracle"]: - dbPort = dbPort or 1521 - oracle_dsn = db_api_2.makedsn(host=dbHost, port=dbPort, service_name=dbName) - logger.info('Connecting using: %s.connect(user=%s, password=%s, dsn=%s) ' % (dbapiModuleName, dbUsername, dbPassword, oracle_dsn)) - self._dbconnection = db_api_2.connect(user=dbUsername, password=dbPassword, dsn=oracle_dsn) - else: - logger.info('Connecting using : %s.connect(database=%s, user=%s, password=%s, host=%s, port=%s) ' % (dbapiModuleName, dbName, dbUsername, dbPassword, dbHost, dbPort)) - self._dbconnection = db_api_2.connect(database=dbName, user=dbUsername, password=dbPassword, host=dbHost, port=dbPort) - - def connect_to_database_using_custom_params(self, dbapiModuleName=None, db_connect_string=''): + if not (url is None): + dataConnection = urlparse(url) + + dbapiModuleName = dbapiModuleName or config.get( + 'default', 'dbapiModuleName') + dbName = dbName or dataConnection.path[1:] or config.get( + 'default', 'dbName') + dbUsername = dbUsername or dataConnection.username or config.get( + 'default', 'dbUsername') + dbPassword = dbPassword if dbPassword is not None else \ + dataConnection.password if dataConnection.password is not None else \ + config.get('default', 'dbPassword') + dbHost = dbHost or dataConnection.hostname or config.get( + 'default', 'dbHost') or 'localhost' + dbPort = int(dbPort or dataConnection.port or config.get( + 'default', 'dbPort')) + + return self._connect_to_database( + alias, + dbapiModuleName, + dbName, + dbUsername, + dbPassword, + dbHost, + dbPort, + dbCharset, + dbConfigFile) + + def _connect_to_database(self, alias, dbapiModuleName, dbName, dbUsername, dbPassword, dbHost, dbPort, dbCharset, dbConfigFile="./resources/db.cfg"): + + try: + + if dbapiModuleName == "excel" or dbapiModuleName == "excelrw": + db_api_module_name = "pyodbc" + db_api_2 = importlib.import_module("pyodbc") + else: + db_api_module_name = dbapiModuleName + db_api_2 = importlib.import_module(dbapiModuleName) + + if dbapiModuleName in ["MySQLdb", "pymysql"]: + dbPort = dbPort or 3306 + logger.info('Connecting using : %s.connect(db=%s, user=%s, passwd=%s, host=%s, port=%s, charset=%s) ' % + (dbapiModuleName, dbName, dbUsername, dbPassword, dbHost, dbPort, dbCharset)) + dbconnection = db_api_2.connect( + db=dbName, user=dbUsername, passwd=dbPassword, host=dbHost, port=dbPort, charset=dbCharset) + elif dbapiModuleName in ["psycopg2"]: + dbPort = dbPort or 5432 + logger.info('Connecting using : %s.connect(database=%s, user=%s, password=%s, host=%s, port=%s) ' % + (dbapiModuleName, dbName, dbUsername, dbPassword, dbHost, dbPort)) + dbconnection = db_api_2.connect( + database=dbName, user=dbUsername, password=dbPassword, host=dbHost, port=dbPort) + elif dbapiModuleName in ["pyodbc", "pypyodbc"]: + dbPort = dbPort or 1433 + logger.info('Connecting using : %s.connect(DRIVER={SQL Server};SERVER=%s,%s;DATABASE=%s;UID=%s;PWD=%s)' % + (dbapiModuleName, dbHost, dbPort, dbName, dbUsername, dbPassword)) + dbconnection = db_api_2.connect('DRIVER={SQL Server};SERVER=%s,%s;DATABASE=%s;UID=%s;PWD=%s' % + (dbHost, dbPort, dbName, dbUsername, dbPassword)) + elif dbapiModuleName in ["excel"]: + logger.info( + 'Connecting using : %s.connect(DRIVER={Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)};DBQ=%s;ReadOnly=1;' + 'Extended Properties="Excel 8.0;HDR=YES";)' % (dbapiModuleName, dbName)) + dbconnection = db_api_2.connect( + 'DRIVER={Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)};DBQ=%s;ReadOnly=1;Extended Properties="Excel 8.0;HDR=YES";)' % ( + dbName), autocommit=True) + elif dbapiModuleName in ["excelrw"]: + logger.info( + 'Connecting using : %s.connect(DRIVER={Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)};DBQ=%s;ReadOnly=0;' + 'Extended Properties="Excel 8.0;HDR=YES";)' % (dbapiModuleName, dbName)) + dbconnection = db_api_2.connect( + 'DRIVER={Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)};DBQ=%s;ReadOnly=0;Extended Properties="Excel 8.0;HDR=YES";)' % ( + dbName), autocommit=True) + elif dbapiModuleName in ["ibm_db", "ibm_db_dbi"]: + dbPort = dbPort or 50000 + logger.info('Connecting using : %s.connect(DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=%s;) ' % + (dbapiModuleName, dbName, dbHost, dbPort, dbUsername, dbPassword)) + dbconnection = db_api_2.connect('DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=%s;' % + (dbName, dbHost, dbPort, dbUsername, dbPassword), '', '') + elif dbapiModuleName in ["cx_Oracle"]: + dbPort = dbPort or 1521 + oracle_dsn = db_api_2.makedsn( + host=dbHost, port=dbPort, service_name=dbName) + logger.info('Connecting using: %s.connect(user=%s, password=%s, dsn=%s) ' % ( + dbapiModuleName, dbUsername, dbPassword, oracle_dsn)) + dbconnection = db_api_2.connect( + user=dbUsername, password=dbPassword, dsn=oracle_dsn) + else: + logger.info('Connecting using : %s.connect(database=%s, user=%s, password=%s, host=%s, port=%s) ' % + (dbapiModuleName, dbName, dbUsername, dbPassword, dbHost, dbPort)) + dbconnection = db_api_2.connect( + database=dbName, user=dbUsername, password=dbPassword, host=dbHost, port=dbPort) + + self._push_cache(alias, dbconnection, db_api_module_name) + + except Exception as Err: + err_msg = ('DbConnection : %s : %s' % (alias, Err)) + raise AssertionError(err_msg) + + def connect_to_database_using_custom_params(self, dbapiModuleName=None, db_connect_string='', alias=None): + + logger.info('Creating Db Connection using : alias=%s, dbapiModuleName=%s, db_connect_string=%s' % + (alias, dbapiModuleName, db_connect_string)) + + return self._connect_to_database_using_custom_params(alias, dbapiModuleName, db_connect_string) + + def _connect_to_database_using_custom_params(self, dbapiModuleName=None, db_connect_string='', alias=None): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to connect to the database using the map string `db_custom_param_string`. + Added field alias + Example usage: | # for psycopg2 | - | Connect To Database Using Custom Params | psycopg2 | database='my_db_test', user='postgres', password='s3cr3t', host='tiger.foobar.com', port=5432 | + | Connect To Database Using Custom Params | alias | psycopg2 | database='my_db_test', user='postgres', password='s3cr3t', host='tiger.foobar.com', port=5432 | | # for JayDeBeApi | - | Connect To Database Using Custom Params | JayDeBeApi | 'oracle.jdbc.driver.OracleDriver', 'my_db_test', 'system', 's3cr3t' | + | Connect To Database Using Custom Params | alias | JayDeBeApi | 'oracle.jdbc.driver.OracleDriver', 'my_db_test', 'system', 's3cr3t' | """ + db_api_2 = importlib.import_module(dbapiModuleName) db_connect_string = 'db_api_2.connect(%s)' % db_connect_string - self.db_api_module_name = dbapiModuleName - logger.info('Executing : Connect To Database Using Custom Params : %s.connect(%s) ' % (dbapiModuleName, db_connect_string)) - self._dbconnection = eval(db_connect_string) + logger.info('Executing : Connect To Database Using Custom Params : %s.connect(%s) ' % ( + dbapiModuleName, db_connect_string)) + dbconnection = eval(db_connect_string) + + self._push_cache(alias, dbconnection, dbapiModuleName) - def disconnect_from_database(self): + def disconnect_from_database(self, alias=None): """ Disconnects from the database. + Added field alias For example: - | Disconnect From Database | # disconnects from current connection to the database | + | Disconnect From Database | alias | # disconnects from current connection to the database | """ logger.info('Executing : Disconnect From Database') - self._dbconnection.close() + connection, module_api = self._get_cache(alias) + connection.close() - def set_auto_commit(self, autoCommit=True): + def set_auto_commit(self, autoCommit=True, alias=None): """ Turn the autocommit on the database connection ON or OFF. - + The default behaviour on a newly created database connection is to automatically start a transaction, which means that database actions that won't work if there is an active transaction will fail. Common examples of these actions are creating or deleting a database or database snapshot. By turning on auto commit on the database connection these actions can be performed. - + + Added field alias + Example: | # Default behaviour, sets auto commit to true - | Set Auto Commit + | Set Auto Commit | alias | # Explicitly set the desired state - | Set Auto Commit | False + | Set Auto Commit | alias | False """ logger.info('Executing : Set Auto Commit') - self._dbconnection.autocommit = autoCommit + connection, module_api = self._get_cache(alias) + connection.autocommit = autoCommit + self._push_cache(alias, connection, module_api) diff --git a/src/DatabaseLibrary/query.py b/src/DatabaseLibrary/query.py index 13c269c..7889e03 100644 --- a/src/DatabaseLibrary/query.py +++ b/src/DatabaseLibrary/query.py @@ -14,6 +14,8 @@ import sys from robot.api import logger +import robot +from robot.libraries.BuiltIn import BuiltIn class Query(object): @@ -21,7 +23,7 @@ class Query(object): Query handles all the querying done by the Database Library. """ - def query(self, selectStatement, sansTran=False, returnAsDict=False): + def query(self, selectStatement, sansTran=False, returnAsDict=False, alias=None): """ Uses the input `selectStatement` to query for the values that will be returned as a list of tuples. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -56,8 +58,12 @@ def query(self, selectStatement, sansTran=False, returnAsDict=False): | @{queryResults} | Query | SELECT * FROM person | True | """ cur = None + logger.info('Connection: Query | %s' % alias) + try: - cur = self._dbconnection.cursor() + connection, module_api = self._get_cache(alias) + logger.info('Module: Query | %s' % module_api) + cur = connection.cursor() logger.info('Executing : Query | %s ' % selectStatement) self.__execute_sql(cur, selectStatement) allRows = cur.fetchall() @@ -65,7 +71,6 @@ def query(self, selectStatement, sansTran=False, returnAsDict=False): if returnAsDict: mappedRows = [] col_names = [c[0] for c in cur.description] - for rowIdx in range(len(allRows)): d = {} for colIdx in range(len(allRows[rowIdx])): @@ -77,12 +82,12 @@ def query(self, selectStatement, sansTran=False, returnAsDict=False): finally: if cur: if not sansTran: - self._dbconnection.rollback() + connection.rollback() - def row_count(self, selectStatement, sansTran=False): + def row_count(self, selectStatement, sansTran=False, alias=None): """ - Uses the input `selectStatement` to query the database and returns the number of rows from the query. Set - optional input `sansTran` to True to run command without an explicit transaction commit or rollback. + Uses the input `selectStatement` to query the database and returns the number of rows from the query. + Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. For example, given we have a table `person` with the following data: | id | first_name | last_name | @@ -108,11 +113,13 @@ def row_count(self, selectStatement, sansTran=False): """ cur = None try: - cur = self._dbconnection.cursor() + connection, module_api = self._get_cache(alias) + cur = connection.cursor() logger.info('Executing : Row Count | %s ' % selectStatement) + self.__execute_sql(cur, selectStatement) data = cur.fetchall() - if self.db_api_module_name in ["sqlite3", "ibm_db", "ibm_db_dbi", "pyodbc"]: + if module_api in ["sqlite3", "ibm_db", "ibm_db_dbi", "pyodbc"]: rowCount = len(data) else: rowCount = cur.rowcount @@ -120,9 +127,9 @@ def row_count(self, selectStatement, sansTran=False): finally: if cur: if not sansTran: - self._dbconnection.rollback() + connection.rollback() - def description(self, selectStatement, sansTran=False): + def description(self, selectStatement, sansTran=False, alias=None): """ Uses the input `selectStatement` to query a table in the db which will be used to determine the description. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -145,20 +152,22 @@ def description(self, selectStatement, sansTran=False): """ cur = None try: - cur = self._dbconnection.cursor() + connection, module_api = self._get_cache(alias) + cur = connection.cursor() logger.info('Executing : Description | %s ' % selectStatement) self.__execute_sql(cur, selectStatement) description = list(cur.description) if sys.version_info[0] < 3: for row in range(0, len(description)): - description[row] = (description[row][0].encode('utf-8'),) + description[row][1:] + description[row] = (description[row][0].encode( + 'utf-8'),) + description[row][1:] return description finally: if cur: if not sansTran: - self._dbconnection.rollback() + connection.rollback() - def delete_all_rows_from_table(self, tableName, sansTran=False): + def delete_all_rows_from_table(self, tableName, sansTran=False, alias=None): """ Delete all the rows within a given table. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -180,21 +189,23 @@ def delete_all_rows_from_table(self, tableName, sansTran=False): cur = None selectStatement = ("DELETE FROM %s;" % tableName) try: - cur = self._dbconnection.cursor() - logger.info('Executing : Delete All Rows From Table | %s ' % selectStatement) + connection, module_api = self._get_cache(alias) + cur = connection.cursor() + logger.info( + 'Executing : Delete All Rows From Table | %s ' % selectStatement) result = self.__execute_sql(cur, selectStatement) if result is not None: if not sansTran: - self._dbconnection.commit() + connection.commit() return result if not sansTran: - self._dbconnection.commit() + connection.commit() finally: if cur: if not sansTran: - self._dbconnection.rollback() + connection.rollback() - def execute_sql_script(self, sqlScriptFileName, sansTran=False): + def execute_sql_script(self, sqlScriptFileName, sansTran=False, alias=None): """ Executes the content of the `sqlScriptFileName` as SQL commands and returns number of rows affected. Useful for setting the database to @@ -255,8 +266,10 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False): cur = None result = 0 try: - cur = self._dbconnection.cursor() - logger.info('Executing : Execute SQL Script | %s ' % sqlScriptFileName) + connection, module_api = self._get_cache(alias) + cur = connection.cursor() + logger.info('Executing : Execute SQL Script | %s ' % + sqlScriptFileName) sqlStatement = '' for line in sqlScriptFile: PY3K = sys.version_info >= (3, 0) @@ -267,7 +280,6 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False): continue elif line.startswith('--'): continue - sqlFragments = line.split(';') if len(sqlFragments) == 1: sqlStatement += line + ' ' @@ -276,9 +288,7 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False): sqlFragment = sqlFragment.strip() if len(sqlFragment) == 0: continue - sqlStatement += sqlFragment + ' ' - result = result + self.__execute_sql(cur, sqlStatement) sqlStatement = '' @@ -287,14 +297,14 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False): result = self.__execute_sql(cur, sqlStatement) if not sansTran: - self._dbconnection.commit() + connection.commit() finally: if cur: if not sansTran: - self._dbconnection.rollback() + connection.rollback() return result - def execute_sql_string(self, sqlString, sansTran=False): + def execute_sql_string(self, sqlString, sansTran=False, alias=None): """ Executes the sqlString as SQL commands and returns number of rows affected. Useful to pass arguments to your sql. Set optional input @@ -315,18 +325,19 @@ def execute_sql_string(self, sqlString, sansTran=False): cur = None result = 0 try: - cur = self._dbconnection.cursor() + connection, module_api = self._get_cache(alias) + cur = connection.cursor() logger.info('Executing : Execute SQL String | %s ' % sqlString) result = self.__execute_sql(cur, sqlString) if not sansTran: - self._dbconnection.commit() + connection.commit() finally: if cur: if not sansTran: - self._dbconnection.rollback() + connection.rollback() return result - def call_stored_procedure(self, spName, spParams=None, sansTran=False): + def call_stored_procedure(self, spName, spParams=None, sansTran=False, alias=None): """ Uses the inputs of `spName` and 'spParams' to call a stored procedure. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -353,27 +364,28 @@ def call_stored_procedure(self, spName, spParams=None, sansTran=False): spParams = [] cur = None try: - if self.db_api_module_name in ["cx_Oracle"]: - cur = self._dbconnection.cursor() + connection, module_api = self._get_cache(alias) + if module_api in ["cx_Oracle"]: + cur = connection.cursor() else: - cur = self._dbconnection.cursor(as_dict=False) + cur = connection.cursor(as_dict=False) PY3K = sys.version_info >= (3, 0) if not PY3K: spName = spName.encode('ascii', 'ignore') - logger.info('Executing : Call Stored Procedure | %s | %s ' % (spName, spParams)) + logger.info( + 'Executing : Call Stored Procedure | %s | %s ' % (spName, spParams)) cur.callproc(spName, spParams) cur.nextset() - retVal=list() + retVal = list() for row in cur: - #logger.info ( ' %s ' % (row)) retVal.append(row) if not sansTran: - self._dbconnection.commit() + connection.commit() return retVal finally: if cur: if not sansTran: - self._dbconnection.rollback() + connection.rollback() def __execute_sql(self, cur, sqlStatement): return cur.execute(sqlStatement)