diff --git a/Makefile.am b/Makefile.am index 17a6c9395..fc174ec0c 100644 --- a/Makefile.am +++ b/Makefile.am @@ -20,5 +20,17 @@ include pyext/Makefile.am include sonic-db-cli/Makefile.am include tests/Makefile.am +# This will set redis_acl to an empty value if not already set +redis_acl ?= + +# If redis_acl is set, append -D$(redis_acl) to DBGFLAGS +ifneq ($(strip $(redis_acl)),) +ifeq ($(redis_acl),y) +REDIS_ACL_G=1 +else ifeq ($(redis_acl),n) +REDIS_ACL_G=0 +endif +AM_CPPFLAGS += -DREDIS_ACL_G +endif ACLOCAL_AMFLAGS = -I m4 diff --git a/common/Makefile.am b/common/Makefile.am index 4c9a08976..84f6657f2 100644 --- a/common/Makefile.am +++ b/common/Makefile.am @@ -85,7 +85,7 @@ common_libswsscommon_la_SOURCES = \ common_libswsscommon_la_CXXFLAGS = $(DBGFLAGS) $(AM_CFLAGS) $(CFLAGS_COMMON) $(LIBNL_CFLAGS) $(CODE_COVERAGE_CXXFLAGS) common_libswsscommon_la_CPPFLAGS = $(DBGFLAGS) $(AM_CFLAGS) $(CFLAGS_COMMON) $(LIBNL_CPPFLAGS) $(CODE_COVERAGE_CPPFLAGS) -common_libswsscommon_la_LIBADD = -lpthread $(LIBNL_LIBS) $(CODE_COVERAGE_LIBS) -lzmq -lboost_serialization -luuid +common_libswsscommon_la_LIBADD = -lpthread $(LIBNL_LIBS) $(CODE_COVERAGE_LIBS) -lzmq -lboost_serialization -luuid -ldl -lssl -lcrypto -lhiredis_ssl common_libswsscommon_la_LDFLAGS = -Wl,-z,now $(LDFLAGS) if YANGMODS diff --git a/common/dbconnector.cpp b/common/dbconnector.cpp index 47fe80d3b..211a98e3d 100755 --- a/common/dbconnector.cpp +++ b/common/dbconnector.cpp @@ -14,10 +14,113 @@ #include "common/redispipeline.h" #include "common/pubsub.h" +#include +#include +#include +#include + using json = nlohmann::json; using namespace std; using namespace swss; +#define SUDO_GID 27 + +// This is a macro to check if the REDIS_ACL_G is defined, which meaning that the feature Redis ACL is enabled. +#ifdef REDIS_ACL_G + +bool isUserInSudoGroup() { + int ngroups = getgroups(0, nullptr); // Get the number of supplementary groups + bool group_sudo_exists = false; + gid_t groupToFind = SUDO_GID; + + if (ngroups < 0) { + SWSS_LOG_ERROR("no ngroups exits"); + return false; + } + + std::vector groups(ngroups); + if (getgroups(ngroups, groups.data()) < 0) { + SWSS_LOG_ERROR("no ngroups exits"); + return false; + } + + for (const auto& group : groups) { + if (group == groupToFind) { + group_sudo_exists = true; + break; + + } + } + return group_sudo_exists; +} + +bool isRootUser() { + uid_t uid = getuid(); // Get the user ID of the current user + + return (uid == 0); +} + +bool is_admin_user() { + bool is_admin_user = false; + + if (isRootUser()) { + is_admin_user = true; + } else if (isUserInSudoGroup()) { + is_admin_user = true; + } + + return is_admin_user; +} + + +// Function to write content to a file +void writeFile(const std::string& filename, const std::string& content) { + std::ofstream outputFile(filename, std::ios_base::app); + + if (!outputFile.is_open()) { + std::cerr << "Failed to open the file for writing." << std::endl; + return; + } + + outputFile << content; + outputFile.close(); +} + +// Function to read the entire content of a file and return as a string +std::string readFileContent_pw(const std::string& filename) { + std::ifstream inputFile(filename); + std::string content; + + if (!inputFile.is_open()) { + std::cerr << "Failed to open the file." << std::endl; + return content; + } + + content.assign((std::istreambuf_iterator(inputFile)), std::istreambuf_iterator()); + + inputFile.close(); + + // Check if the string ends with '\n' and remove it. + if (!content.empty() && content.back() == '\n') { + content.pop_back(); + } + + return content; +} + +std::string get_auth_cmd() { + std::string command = ""; + if (is_admin_user()){ + std::string shadow_redis_admin = readFileContent_pw("/etc/shadow_redis_dir/shadow_redis_admin"); + command = std::string("auth admin ") + shadow_redis_admin; + }else{ + std::string shadow_redis_monitor = readFileContent_pw("/etc/shadow_redis_dir/shadow_redis_monitor"); + command = std::string("auth monitor ") + shadow_redis_monitor; + } + return command; +} + +#endif void SonicDBConfig::parseDatabaseConfig(const string &file, std::map &inst_entry, @@ -541,11 +644,11 @@ RedisContext::RedisContext(const RedisContext &other) const char *unixPath = octx->unix_sock.path; if (unixPath) { - initContext(unixPath, octx->timeout); + initContext(unixPath, octx->connect_timeout); } else { - initContext(octx->tcp.host, octx->tcp.port, octx->timeout); + initContext(octx->tcp.host, octx->tcp.port, octx->connect_timeout); } } @@ -563,6 +666,43 @@ void RedisContext::initContext(const char *host, int port, const timeval *tv) if (m_conn->err) throw system_error(make_error_code(errc::address_not_available), "Unable to connect to redis - " + std::string(m_conn->errstr) + "(" + std::to_string(m_conn->err) + ")"); + +#ifdef REDIS_ACL_G + // Redis SSL configuration + redisSSLContext *ssl; + redisSSLContextError ssl_error = REDIS_SSL_CTX_NONE; + const char *ca = "/etc/shadow_redis_dir/certs_redis/ca.crt"; + redisInitOpenSSL(); + + redisSSLOptions options = { + .cacert_filename = ca, + .capath = NULL, + .cert_filename = NULL, + .private_key_filename = NULL, + .server_name = NULL, + .verify_mode = REDIS_SSL_VERIFY_NONE, + }; + + + ssl = redisCreateSSLContextWithOptions(&options, &ssl_error); + if (!ssl || ssl_error != REDIS_SSL_CTX_NONE) { + SWSS_LOG_ERROR("SSL Context error: %s\n", redisSSLContextGetError(ssl_error)); + exit(1); + } + + // start ssl connection + if (redisInitiateSSLWithContext(m_conn, ssl) != REDIS_OK) { + SWSS_LOG_ERROR("Couldn't initialize SSL!\n"); + redisFree(m_conn); + exit(1); + } + + + // Redis Authentication + std::string command = ""; + command = get_auth_cmd(); + RedisReply r1(this, command, REDIS_REPLY_STATUS); +#endif } void RedisContext::initContext(const char *path, const timeval *tv) @@ -579,6 +719,13 @@ void RedisContext::initContext(const char *path, const timeval *tv) if (m_conn->err) throw system_error(make_error_code(errc::address_not_available), "Unable to connect to redis (unix-socket) - " + std::string(m_conn->errstr) + "(" + std::to_string(m_conn->err) + ")"); + +#ifdef REDIS_ACL_G + // Redis Authentication + std::string command = ""; + command = get_auth_cmd(); + RedisReply r1(this, command, REDIS_REPLY_STATUS); +#endif } redisContext *RedisContext::getContext() const diff --git a/common/dbconnector.h b/common/dbconnector.h index 832983ed9..5153e7b20 100644 --- a/common/dbconnector.h +++ b/common/dbconnector.h @@ -12,6 +12,7 @@ #include #include +#include #include "rediscommand.h" #include "redisreply.h" #define EMPTY_NAMESPACE std::string() diff --git a/common/rediscommand.cpp b/common/rediscommand.cpp index 5cc7422b9..f1e50883a 100644 --- a/common/rediscommand.cpp +++ b/common/rediscommand.cpp @@ -48,8 +48,8 @@ void RedisCommand::formatArgv(int argc, const char **argv, const size_t *argvlen } len = 0; - int ret = redisFormatCommandArgv(&temp, argc, argv, argvlen); - if (ret == -1) { + long long ret = redisFormatCommandArgv(&temp, argc, argv, argvlen); + if (len == -1) { throw std::bad_alloc(); } len = ret; @@ -148,6 +148,7 @@ size_t RedisCommand::length() const { if (len <= 0) return 0; + // TODO review this casting return static_cast(len); } diff --git a/common/rediscommand.h b/common/rediscommand.h index ed6cd846b..33ccf8f3c 100644 --- a/common/rediscommand.h +++ b/common/rediscommand.h @@ -77,7 +77,7 @@ class RedisCommand { private: char *temp; - int len; + long long len; }; template