diff --git a/.appveyor.yml b/.appveyor.yml deleted file mode 100644 index 37826fb4..00000000 --- a/.appveyor.yml +++ /dev/null @@ -1,26 +0,0 @@ -version: 1.1.1.{build} -image: Visual Studio 2019 -configuration: -- Release -- Debug - -environment: - matrix: - - GENERATOR: Visual Studio 16 2019 - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - ARCH: Win32 - - - GENERATOR: Visual Studio 16 2019 - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - ARCH: x64 - -init: [] - -before_build: -- md build -- cd build -- cmake --config "%CONFIGURATION%" -G "%GENERATOR%" .. -build_script: -- cmake --build . --config "%CONFIGURATION%" -build: - verbosity: minimal \ No newline at end of file diff --git a/.cmake-format b/.cmake-format new file mode 100644 index 00000000..bec68f36 --- /dev/null +++ b/.cmake-format @@ -0,0 +1,53 @@ +# notes for additional commands +# +# nargs: '*' to allow multiple arguments +# kwargs: &fookwargs to definite keyword arguments +# kwargs: *fookwargs to use the same keyword arguments as fookwargs +# NAME: 1 to allow single keyword arguments +# NAME: + to allow multiple keyword arguments +# NAME: * to allow multiple keyword arguments +# spelling: FOO to use foo to FOO spelling + +parse: + additional_commands: + FetchContent_Declare: + pargs: + nargs: '*' + flags: [] + kwargs: + GIT_TAG: 1 + GITHUB_REPOSITORY: 1 + GITLAB_REPOSITORY: 1 + GIT_REPOSITORY: 1 + SVN_REPOSITORY: 1 + SVN_REVISION: 1 + URL: 1 + URL_HASH: 1 + URL_MD5: 1 + FIND_PACKAGE_ARGS: + + FetchContent_MakeAvailable: + pargs: + nargs: '*' + flags: [] + execute_process: + pargs: + nargs: '*' + flags: [] + kwargs: + COMMAND: + + WORKING_DIRECTORY: 1 + set_target_properties: + pargs: + nargs: '*' + flags: [] + kwargs: + PROPERTIES: + + IMPORTED_LOCATION: 1 + INTERFACE_INCLUDE_DIRECTORIES: 1 +format: + tab_size: 2 + line_width: 120 + autosort: true + dangle_parens: true + max_subgroups_hwrap: 2 + max_pargs_hwrap: 3 \ No newline at end of file diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml deleted file mode 100644 index f847d2ba..00000000 --- a/.github/workflows/cmake.yml +++ /dev/null @@ -1,141 +0,0 @@ -name: Build Trantor - -on: - push: - branches: [master] - pull_request: - workflow_dispatch: - -env: - # Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.) - BUILD_TYPE: Release - -jobs: - windows: - name: 'windows/msvc - ${{matrix.link}}' - runs-on: windows-latest - strategy: - fail-fast: false - matrix: - link: [ 'STATIC', 'SHARED' ] - steps: - - name: Checkout Trantor source code - uses: actions/checkout@v2 - with: - submodules: false - - - name: Install dependencies - run: | - pip install conan - - - name: Create build directory - run: | - mkdir build - - - name: Install conan packages - shell: pwsh - working-directory: ./build - run: | - conan install .. -s compiler="Visual Studio" -s compiler.version=16 -sbuild_type=Debug -g cmake_paths - - - name: Create Build Environment & Configure Cmake - shell: bash - working-directory: ./build - run: | - [[ ${{ matrix.link }} == "SHARED" ]] && shared="ON" || shared="OFF" - cmake .. -DCMAKE_BUILD_TYPE=Debug -DBUILD_TESTING=on -DBUILD_SHARED_LIBS=$shared -DCMAKE_TOOLCHAIN_FILE="conan_paths.cmake" -DCMAKE_INSTALL_PREFIX=../install - - - name: Build - working-directory: ${{env.GITHUB_WORKSPACE}} - shell: bash - run: | - cd build - cmake --build . --target install --parallel - unix: - name: ${{matrix.buildname}} - runs-on: ${{matrix.os}} - strategy: - fail-fast: false - matrix: - include: - - os: ubuntu-20.04 - buildname: 'ubuntu-20.04/gcc' - triplet: x64-linux - compiler: gcc_64 - - os: ubuntu-20.04 - buildname: 'ubuntu-20.04/gcc without openssl' - triplet: x64-linux - compiler: gcc_64 - - os: macos-latest - buildname: 'macos/clang' - triplet: x64-osx - compiler: clang_64 - - steps: - - name: Checkout Trantor source code - uses: actions/checkout@v2 - with: - submodules: true - fetch-depth: 0 - - - name: (macOS) Install dependencies - if: runner.os == 'macOS' - run: | - brew install c-ares openssl - - - name: (Linux) Install dependencies - if: matrix.buildname == 'ubuntu-20.04/gcc' - run: | - # Installing packages might fail as the github image becomes outdated - sudo apt update - # These aren't available or don't work well in vcpkg - sudo apt install openssl libssl-dev - sudo apt install dos2unix - - name: (Linux) Install dependencies - if: matrix.buildname == 'ubuntu-20.04/gcc without openssl' - run: | - # Installing packages might fail as the github image becomes outdated - sudo apt update - # These aren't available or don't work well in vcpkg - sudo apt install dos2unix - - name: install gtest - run: | - wget https://github.com/google/googletest/archive/release-1.10.0.tar.gz - tar xf release-1.10.0.tar.gz - cd googletest-release-1.10.0 - cmake . - make - sudo make install - - - name: Create Build Environment & Configure Cmake - # Some projects don't allow in-source building, so create a separate build directory - # We'll use this as our working directory for all subsequent commands - shell: bash - working-directory: ${{env.GITHUB_WORKSPACE}} - run: | - mkdir build - cd build - cmake .. -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DBUILD_TESTING=on - - - name: Build - working-directory: ${{env.GITHUB_WORKSPACE}} - shell: bash - # Execute the build. You can specify a specific target with "--target " - run: | - cd build - sudo make && sudo make install - - - name: Test - working-directory: ${{env.GITHUB_WORKSPACE}} - shell: bash - # Execute tests defined by the CMake configuration. - # See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail - run: | - cd build - make test - - - name: Lint - if: matrix.os == 'ubuntu-20.04' - working-directory: ${{env.GITHUB_WORKSPACE}} - shell: bash - run: ./format.sh && git diff --exit-code diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..c17f2550 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,31 @@ +name: Lint source code + +on: + push: + branches: [master] + pull_request: + workflow_dispatch: + +jobs: + unix: + name: Lint + runs-on: ubuntu-latest + strategy: + fail-fast: false + + steps: + - name: Checkout Trantor source code + uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + + - name: (Linux) Install dependencies + run: | + # Installing packages might fail as the github image becomes outdated + sudo apt update + sudo apt install dos2unix clang-format + pip install cmake-format + + - name: Lint + run: ./format.sh && git diff --exit-code diff --git a/.github/workflows/macos-clang.yml b/.github/workflows/macos-clang.yml new file mode 100644 index 00000000..6baa1688 --- /dev/null +++ b/.github/workflows/macos-clang.yml @@ -0,0 +1,73 @@ +name: Build macos-clang + +on: + push: + branches: [master] + pull_request: + workflow_dispatch: + +jobs: + build: + name: '${{matrix.link}}-${{matrix.build-type}}-${{matrix.tls-provider}}' + runs-on: macos-latest + strategy: + fail-fast: false + matrix: + link: [ 'STATIC', 'SHARED' ] + # Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.) + build-type: ['Debug', 'Release'] + # Botan needs std::ranges but clang on macOS doesn't support it yet + #tls-provider: ['', 'openssl', 'botan'] + tls-provider: ['', 'openssl'] + + steps: + - name: Install dependencies + # botan v3 + run: | + brew install botan spdlog + + - name: Install gtest + run: | + wget https://github.com/google/googletest/archive/refs/tags/v1.13.0.tar.gz + tar xf v1.13.0.tar.gz + cd googletest-1.13.0 + cmake . + make && sudo make install + + - name: Checkout Trantor source code + uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + + - name: Create build directory + run: | + mkdir build + + - name: Create Build Environment & Configure Cmake + shell: bash + working-directory: ./build + run: | + [[ ${{ matrix.link }} == "SHARED" ]] && shared="ON" || shared="OFF" + cmake .. \ + -DTRANTOR_USE_TLS=${{matrix.tls-provider}} \ + -DCMAKE_BUILD_TYPE=${{matrix.build-type}} \ + -DBUILD_SHARED_LIBS=$shared \ + -DCMAKE_INSTALL_PREFIX=../install \ + -DUSE_SPDLOG=ON \ + -DBUILD_TESTING=ON \ + + - name: Build + shell: bash + working-directory: ./build + # Execute the build. You can specify a specific target with "--target " + run: | + sudo make && sudo make install + + - name: Test + working-directory: ./build + shell: bash + # Execute tests defined by the CMake configuration. + # See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail + run: | + make test diff --git a/.github/workflows/rockylinux-gcc.yml b/.github/workflows/rockylinux-gcc.yml new file mode 100644 index 00000000..85de8c26 --- /dev/null +++ b/.github/workflows/rockylinux-gcc.yml @@ -0,0 +1,88 @@ +name: Build rockylinux-gcc + +on: + push: + branches: [master] + pull_request: + workflow_dispatch: + +jobs: + build: + name: '${{matrix.link}}-${{matrix.build-type}}-${{matrix.tls-provider}}' + runs-on: ubuntu-latest + container: + image: rockylinux:9.3 + options: --user root + strategy: + fail-fast: false + matrix: + link: [ 'STATIC', 'SHARED' ] + # Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.) + build-type: ['Debug', 'Release'] + # TODO: ubuntu botan is v2, v2 support is removed + # tls-provider: ['', 'openssl', 'botan'] + tls-provider: ['', 'openssl'] + + steps: + - name: Install dependencies + run: | + dnf install gcc-c++ cmake git wget -y + + - name: Install dependencies - spdlog + run: | + git clone https://github.com/gabime/spdlog.git + cd spdlog && mkdir build && cd build + cmake .. && make -j + + - name: Install dependencies - OpenSSL + if: matrix.tls-provider == 'openssl' + run: | + dnf install openssl-devel -y + + - name: Install gtest + run: | + wget https://github.com/google/googletest/archive/refs/tags/v1.13.0.tar.gz + tar xf v1.13.0.tar.gz + cd googletest-1.13.0 + cmake . + make -j && make install + + - name: Checkout Trantor source code + uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + + - name: Create build directory + run: | + mkdir build + + - name: Create Build Environment & Configure Cmake + shell: bash + working-directory: ./build + if: ${{matrix.link}} == "SHARED" + run: | + [[ ${{ matrix.link }} == "SHARED" ]] && shared="ON" || shared="OFF" + cmake .. \ + -DTRANTOR_USE_TLS=${{matrix.tls-provider}} \ + -DCMAKE_BUILD_TYPE=${{matrix.build-type}} \ + -DBUILD_SHARED_LIBS=$shared \ + -DCMAKE_INSTALL_PREFIX=../install \ + -DUSE_SPDLOG=ON \ + -DBUILD_TESTING=ON + + - name: Build + shell: bash + working-directory: ./build + # Execute the build. You can specify a specific target with "--target " + run: | + make && make install + + - name: Test + working-directory: ./build + shell: bash + # Execute tests defined by the CMake configuration. + # See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail + run: | + make test + diff --git a/.github/workflows/ubuntu-gcc.yml b/.github/workflows/ubuntu-gcc.yml new file mode 100644 index 00000000..04b47ff7 --- /dev/null +++ b/.github/workflows/ubuntu-gcc.yml @@ -0,0 +1,81 @@ +name: Build ubuntu-gcc + +on: + push: + branches: [master] + pull_request: + workflow_dispatch: + +jobs: + build: + name: '${{matrix.link}}-${{matrix.build-type}}-${{matrix.tls-provider}}' + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + link: [ 'STATIC', 'SHARED' ] + # Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.) + build-type: ['Debug', 'Release'] + # TODO: ubuntu botan is v2, v2 support is removed + # tls-provider: ['', 'openssl', 'botan'] + tls-provider: ['', 'openssl'] + + steps: + - name: Install dependencies + run: | + # Installing packages might fail as the github image becomes outdated + sudo apt update + sudo apt install libspdlog-dev libfmt-dev + + - name: Install dependencies - OpenSSL + if: matrix.tls-provider == 'openssl' + run: | + sudo apt install openssl libssl-dev + + - name: Install gtest + run: | + wget https://github.com/google/googletest/archive/refs/tags/v1.13.0.tar.gz + tar xf v1.13.0.tar.gz + cd googletest-1.13.0 + cmake . + make -j && sudo make install + + - name: Checkout Trantor source code + uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + + - name: Create build directory + run: | + mkdir build + + - name: Create Build Environment & Configure Cmake + shell: bash + working-directory: ./build + if: ${{matrix.link}} == "SHARED" + run: | + [[ ${{ matrix.link }} == "SHARED" ]] && shared="ON" || shared="OFF" + cmake .. \ + -DTRANTOR_USE_TLS=${{matrix.tls-provider}} \ + -DCMAKE_BUILD_TYPE=${{matrix.build-type}} \ + -DBUILD_SHARED_LIBS=$shared \ + -DCMAKE_INSTALL_PREFIX=../install \ + -DUSE_SPDLOG=ON \ + -DBUILD_TESTING=ON + + - name: Build + shell: bash + working-directory: ./build + # Execute the build. You can specify a specific target with "--target " + run: | + sudo make && sudo make install + + - name: Test + working-directory: ./build + shell: bash + # Execute tests defined by the CMake configuration. + # See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail + run: | + make test + diff --git a/.github/workflows/windows-msvc.yml b/.github/workflows/windows-msvc.yml new file mode 100644 index 00000000..1cfc7341 --- /dev/null +++ b/.github/workflows/windows-msvc.yml @@ -0,0 +1,62 @@ +name: Build windows-msvc + +on: + push: + branches: [master] + pull_request: + workflow_dispatch: + +jobs: + build: + name: '${{matrix.link}}-${{matrix.build-type}}-${{matrix.tls-provider}}' + runs-on: windows-latest + strategy: + fail-fast: false + matrix: + link: [ 'STATIC', 'SHARED' ] + # Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.) + build-type: ['Debug', 'Release'] + # TODO: conan botan is v2, v2 support is removed + # tls-provider: ['', 'openssl', 'botan'] + tls-provider: ['', 'openssl'] + + steps: + - name: Checkout Trantor source code + uses: actions/checkout@v4 + with: + submodules: false + + - name: Create build directory + working-directory: ${{env.GITHUB_WORKSPACE}} + run: | + mkdir build + + - name: Install conan packages + shell: bash + working-directory: ./build + run: | + pip install conan + conan profile detect --force + conan install .. --output-folder=. --build=missing --settings=build_type=${{matrix.build-type}} --settings=compiler="msvc" + + - name: Create Build Environment & Configure Cmake + shell: bash + working-directory: ./build + # -DBUILD_TESTING=ON Removed, + # Due to unittest by GTest in windows runner will comes out 'error MSB3073' + run: | + [[ ${{ matrix.link }} == "SHARED" ]] && shared="ON" || shared="OFF" + cmake .. -G "Visual Studio 17 2022" -T host=x64 -A x64 \ + -DTRANTOR_USE_TLS=${{matrix.tls-provider}} \ + -DCMAKE_BUILD_TYPE=${{matrix.build-type}} \ + -DBUILD_SHARED_LIBS=$shared \ + -DCMAKE_INSTALL_PREFIX=../install \ + -DUSE_SPDLOG=ON \ + -DCMAKE_POLICY_DEFAULT_CMP0091=NEW + + - name: Build + working-directory: ./build + shell: bash + # multi config build using --config to switch Release|Debug + run: | + cmake --build . --config ${{matrix.build-type}} --parallel diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 542bc697..00000000 --- a/.travis.yml +++ /dev/null @@ -1,32 +0,0 @@ -matrix: - include: - - os: linux - dist: xenial - - os: osx - osx_image: xcode11 - -sudo: required - -language: cpp - -addons: - apt: - sources: - - xenial - - sourceline: 'deb http://archive.ubuntu.com/ubuntu xenial main' - packages: - - openssl - - libssl-dev - - build-essential - - cmake - - libgtest-dev - - libc-ares-dev - homebrew: - packages: - - openssl - - cmake - - libtool - - gtest - -script: - - ./build.sh -t diff --git a/CMakeLists.txt b/CMakeLists.txt index fe393cbe..42a362da 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,53 +3,84 @@ project(trantor) option(BUILD_DOC "Build Doxygen documentation" OFF) option(BUILD_C-ARES "Build C-ARES" ON) +option(BUILD_TESTING "Build tests" OFF) option(BUILD_SHARED_LIBS "Build trantor as a shared lib" OFF) +option(TRANTOR_USE_TLS + "TLS provider for trantor. Valid options are 'openssl', 'botan' or '' (let the build scripr decide)" "" +) +option(USE_SPDLOG "Allow using the spdlog logging library" OFF) list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake_modules/) set(TRANTOR_MAJOR_VERSION 1) set(TRANTOR_MINOR_VERSION 5) -set(TRANTOR_PATCH_VERSION 7) -set(TRANTOR_VERSION - ${TRANTOR_MAJOR_VERSION}.${TRANTOR_MINOR_VERSION}.${TRANTOR_PATCH_VERSION}) +set(TRANTOR_PATCH_VERSION 24) +set(TRANTOR_VERSION ${TRANTOR_MAJOR_VERSION}.${TRANTOR_MINOR_VERSION}.${TRANTOR_PATCH_VERSION}) include(GNUInstallDirs) # Offer the user the choice of overriding the installation directories -set(INSTALL_BIN_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Installation directory for binaries") -set(INSTALL_LIB_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Installation directory for libraries") +set(INSTALL_BIN_DIR + ${CMAKE_INSTALL_BINDIR} + CACHE PATH "Installation directory for binaries" +) +set(INSTALL_LIB_DIR + ${CMAKE_INSTALL_LIBDIR} + CACHE PATH "Installation directory for libraries" +) set(INSTALL_INCLUDE_DIR ${CMAKE_INSTALL_INCLUDEDIR} - CACHE PATH "Installation directory for header files") + CACHE PATH "Installation directory for header files" +) set(DEF_INSTALL_TRANTOR_CMAKE_DIR ${CMAKE_INSTALL_LIBDIR}/cmake/Trantor) set(INSTALL_TRANTOR_CMAKE_DIR ${DEF_INSTALL_TRANTOR_CMAKE_DIR} - CACHE PATH "Installation directory for cmake files") + CACHE PATH "Installation directory for cmake files" +) add_library(${PROJECT_NAME}) if(BUILD_SHARED_LIBS) - list(FIND CMAKE_PLATFORM_IMPLICIT_LINK_DIRECTORIES - "${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}" isSystemDir) + list( + FIND + CMAKE_PLATFORM_IMPLICIT_LINK_DIRECTORIES + "${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}" + isSystemDir + ) if("${isSystemDir}" STREQUAL "-1") set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}") endif("${isSystemDir}" STREQUAL "-1") - set_target_properties(${PROJECT_NAME} PROPERTIES - VERSION ${TRANTOR_VERSION} - SOVERSION ${TRANTOR_MAJOR_VERSION}) + set_target_properties( + ${PROJECT_NAME} + PROPERTIES VERSION + ${TRANTOR_VERSION} + SOVERSION + ${TRANTOR_MAJOR_VERSION} + ) if(CMAKE_CXX_COMPILER_ID MATCHES MSVC) - # Ignore MSVC C4251 and C4275 warning of exporting std objects with no dll export - # We export class to facilitate maintenance, thus if you compile - # drogon on windows as a shared library, you will need to use - # exact same compiler for drogon and your app. + # Ignore MSVC C4251 and C4275 warning of exporting std objects with no dll export We export class to facilitate + # maintenance, thus if you compile drogon on windows as a shared library, you will need to use exact same compiler + # for drogon and your app. target_compile_options(${PROJECT_NAME} PUBLIC /wd4251 /wd4275) endif() endif(BUILD_SHARED_LIBS) -if (NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Windows" AND CMAKE_CXX_COMPILER_ID MATCHES Clang|GNU) - target_compile_options(${PROJECT_NAME} PRIVATE -Wall -Wextra -Werror) +# Tells Visual Studio 2017 (15.7+) and newer to correctly set the value of the standard __cplusplus macro, instead of +# leaving it to 199711L and settings the effective c++ version in _MSVC_LANG Dropping support for older versions of VS +# would allow to only rely on __cplusplus +if(MSVC AND MSVC_VERSION GREATER_EQUAL 1914) + add_compile_options(/Zc:__cplusplus) +endif(MSVC AND MSVC_VERSION GREATER_EQUAL 1914) + +if(NOT + ${CMAKE_SYSTEM_NAME} + STREQUAL + "Windows" + AND CMAKE_CXX_COMPILER_ID MATCHES Clang|GNU +) + target_compile_options(${PROJECT_NAME} PRIVATE -Wall -Wextra -Werror) endif() if(${CMAKE_SYSTEM_NAME} STREQUAL "Haiku") - target_link_libraries(${PROJECT_NAME} PRIVATE network) + target_link_libraries(${PROJECT_NAME} PRIVATE network) endif() include(GenerateExportHeader) @@ -58,19 +89,17 @@ generate_export_header(${PROJECT_NAME} EXPORT_FILE_NAME ${CMAKE_CURRENT_BINARY_D # include directories target_include_directories( ${PROJECT_NAME} - PUBLIC $ - $ + PUBLIC $ $ $ PRIVATE ${PROJECT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}/trantor/utils ${PROJECT_SOURCE_DIR}/trantor/net ${PROJECT_SOURCE_DIR}/trantor/net/inner - $) + $ +) if(MINGW) - target_compile_definitions( - ${PROJECT_NAME} - PUBLIC -D_WIN32_WINNT=0x0601) + target_compile_definitions(${PROJECT_NAME} PUBLIC -D_WIN32_WINNT=0x0601) endif(MINGW) set(TRANTOR_SOURCES @@ -94,12 +123,16 @@ set(TRANTOR_SOURCES trantor/net/inner/Connector.cc trantor/net/inner/Poller.cc trantor/net/inner/Socket.cc + trantor/net/inner/MemBufferNode.cc + trantor/net/inner/StreamBufferNode.cc + trantor/net/inner/AsyncStreamBufferNode.cc trantor/net/inner/TcpConnectionImpl.cc trantor/net/inner/Timer.cc trantor/net/inner/TimerQueue.cc trantor/net/inner/poller/EpollPoller.cc trantor/net/inner/poller/KQueue.cc - trantor/net/inner/poller/PollPoller.cc) + trantor/net/inner/poller/PollPoller.cc +) set(private_headers trantor/net/inner/Acceptor.h trantor/net/inner/Connector.h @@ -110,76 +143,154 @@ set(private_headers trantor/net/inner/TimerQueue.h trantor/net/inner/poller/EpollPoller.h trantor/net/inner/poller/KQueue.h - trantor/net/inner/poller/PollPoller.h) + trantor/net/inner/poller/PollPoller.h +) if(WIN32) set(TRANTOR_SOURCES ${TRANTOR_SOURCES} third_party/wepoll/Wepoll.c - trantor/utils/WindowsSupport.cc) - set(private_headers - ${private_headers} - third_party/wepoll/Wepoll.h - trantor/utils/WindowsSupport.h) + trantor/utils/WindowsSupport.cc + trantor/net/inner/FileBufferNodeWin.cc + ) + set(private_headers ${private_headers} third_party/wepoll/Wepoll.h trantor/utils/WindowsSupport.h) +else(WIN32) + set(TRANTOR_SOURCES ${TRANTOR_SOURCES} trantor/net/inner/FileBufferNodeUnix.cc) endif(WIN32) -find_package(OpenSSL) -if(OpenSSL_FOUND) - target_link_libraries(${PROJECT_NAME} PRIVATE OpenSSL::SSL OpenSSL::Crypto) - - # To enable the acceptance of local issued certificates - # Add the flag: ALLOW_SELF_SIGNED_CERTS - # ONLY FOR TESTING PURPOSES - target_compile_definitions(${PROJECT_NAME} PRIVATE USE_OPENSSL) +# Somehow the default value of TRANTOR_USE_TLS is OFF +if(TRANTOR_USE_TLS STREQUAL OFF) + set(TRANTOR_USE_TLS "") +endif() +set(VALID_TLS_PROVIDERS "openssl" "botan" "none") +list( + FIND + VALID_TLS_PROVIDERS + "${TRANTOR_USE_TLS}" + PREFERED_TLS_IDX +) +if(PREFERED_TLS_IDX EQUAL -1 + AND NOT + TRANTOR_USE_TLS + STREQUAL + "" +) + message(FATAL_ERROR "Invalid TLS provider: ${TRANTOR_USE_TLS}\n" "Valid TLS providers are: ${VALID_TLS_PROVIDERS}") endif() -set(HAVE_C-ARES NO) -if (BUILD_C-ARES) - find_package(c-ares) - if(c-ares_FOUND) - message(STATUS "c-ares found!") - set(HAVE_C-ARES TRUE) +set(TRANTOR_TLS_PROVIDER "None") +if(TRANTOR_USE_TLS STREQUAL "openssl" OR TRANTOR_USE_TLS STREQUAL "") + find_package(OpenSSL) + if(OpenSSL_FOUND) + target_link_libraries(${PROJECT_NAME} PRIVATE OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(${PROJECT_NAME} PRIVATE USE_OPENSSL) + set(TRANTOR_TLS_PROVIDER "OpenSSL") + + set(TRANTOR_SOURCES ${TRANTOR_SOURCES} trantor/net/inner/tlsprovider/OpenSSLProvider.cc + trantor/utils/crypto/openssl.cc + ) + elseif(TRANTOR_USE_TLS STREQUAL "openssl") + message(FATAL_ERROR "Requested OpenSSL TLS provider but OpenSSL was not found") + endif() +endif() + +if(TRANTOR_TLS_PROVIDER STREQUAL "None" AND (TRANTOR_USE_TLS STREQUAL "botan" OR TRANTOR_USE_TLS STREQUAL "")) + find_package(Botan) + if(Botan_FOUND) + target_compile_definitions(${PROJECT_NAME} PRIVATE USE_BOTAN) + target_link_libraries(${PROJECT_NAME} PRIVATE Botan::Botan) + if(CMAKE_CXX_COMPILER_ID MATCHES Clang|GNU) + # Trantor uses some features that are deprecated in C++20 but Botan3 needs C++20 + target_compile_options(${PROJECT_NAME} PRIVATE -Wno-deprecated) endif() -endif () + set(TRANTOR_TLS_PROVIDER "Botan") -if(HAVE_C-ARES) - target_link_libraries(${PROJECT_NAME} PRIVATE c-ares_lib) + set(TRANTOR_SOURCES ${TRANTOR_SOURCES} trantor/net/inner/tlsprovider/BotanTLSProvider.cc + trantor/utils/crypto/botan.cc + ) + elseif(TRANTOR_USE_TLS STREQUAL "botan") + message(FATAL_ERROR "Requested Botan TLS provider but Botan was not found") + endif() +endif() + +if(TRANTOR_TLS_PROVIDER STREQUAL "None") set(TRANTOR_SOURCES ${TRANTOR_SOURCES} - trantor/net/inner/AresResolver.cc) + trantor/utils/crypto/sha3.cc + trantor/utils/crypto/md5.cc + trantor/utils/crypto/sha1.cc + trantor/utils/crypto/sha256.cc + trantor/utils/crypto/blake2.cc + ) set(private_headers ${private_headers} - trantor/net/inner/AresResolver.h) + trantor/utils/crypto/sha3.h + trantor/utils/crypto/md5.h + trantor/utils/crypto/sha1.h + trantor/utils/crypto/sha256.h + ) +endif() + +message(STATUS "Trantor using SSL library: ${TRANTOR_TLS_PROVIDER}") +target_compile_definitions(${PROJECT_NAME} PRIVATE TRANTOR_TLS_PROVIDER=${TRANTOR_TLS_PROVIDER}) + +set(HAVE_SPDLOG NO) +if(USE_SPDLOG) + find_package(spdlog CONFIG) + if(spdlog_FOUND) + message(STATUS "spdlog found!") + set(HAVE_SPDLOG TRUE) + endif(spdlog_FOUND) +endif(USE_SPDLOG) +if(HAVE_SPDLOG) + target_link_libraries(${PROJECT_NAME} PUBLIC spdlog::spdlog_header_only) + target_compile_definitions(${PROJECT_NAME} PUBLIC TRANTOR_SPDLOG_SUPPORT SPDLOG_FMT_EXTERNAL FMT_HEADER_ONLY) +endif(HAVE_SPDLOG) + +set(HAVE_C-ARES NO) +if(BUILD_C-ARES) + find_package(c-ares) + if(c-ares_FOUND) + message(STATUS "c-ares found!") + set(HAVE_C-ARES TRUE) + endif() +endif() + +if(HAVE_C-ARES) + if(NOT BUILD_SHARED_LIBS) + target_compile_definitions(${PROJECT_NAME} PRIVATE CARES_STATICLIB) + endif() + target_link_libraries(${PROJECT_NAME} PRIVATE c-ares_lib) + set(TRANTOR_SOURCES ${TRANTOR_SOURCES} trantor/net/inner/AresResolver.cc) + set(private_headers ${private_headers} trantor/net/inner/AresResolver.h) if(APPLE) - target_link_libraries(${PROJECT_NAME} PRIVATE resolv) + target_link_libraries(${PROJECT_NAME} PRIVATE resolv) + elseif(WIN32) + target_link_libraries(${PROJECT_NAME} PRIVATE iphlpapi) endif() else() - set(TRANTOR_SOURCES - ${TRANTOR_SOURCES} - trantor/net/inner/NormalResolver.cc) - set(private_headers - ${private_headers} - trantor/net/inner/NormalResolver.h) + set(TRANTOR_SOURCES ${TRANTOR_SOURCES} trantor/net/inner/NormalResolver.cc) + set(private_headers ${private_headers} trantor/net/inner/NormalResolver.h) endif() find_package(Threads) target_link_libraries(${PROJECT_NAME} PUBLIC Threads::Threads) if(WIN32) - target_link_libraries(${PROJECT_NAME} PRIVATE ws2_32 Rpcrt4) + target_link_libraries(${PROJECT_NAME} PRIVATE ws2_32 rpcrt4) if(OpenSSL_FOUND) - target_link_libraries(${PROJECT_NAME} PRIVATE Crypt32 Secur32) + target_link_libraries(${PROJECT_NAME} PRIVATE crypt32 secur32) endif(OpenSSL_FOUND) -else(WIN32) +elseif(NOT ANDROID) target_link_libraries(${PROJECT_NAME} PRIVATE pthread $<$:socket>) endif(WIN32) -file(WRITE ${CMAKE_BINARY_DIR}/test_atomic.cpp - "#include \n" - "int main() { std::atomic i(0); i++; return 0; }\n") +file(WRITE ${CMAKE_BINARY_DIR}/test_atomic.cpp "#include \n" + "int main() { std::atomic i(0); i++; return 0; }\n" +) try_compile(ATOMIC_WITHOUT_LINKING ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/test_atomic.cpp) -if (NOT ATOMIC_WITHOUT_LINKING) - target_link_libraries(${PROJECT_NAME} PUBLIC atomic) -endif () +if(NOT ATOMIC_WITHOUT_LINKING) + target_link_libraries(${PROJECT_NAME} PUBLIC atomic) +endif() file(REMOVE ${CMAKE_BINARY_DIR}/test_atomic.cpp) set_target_properties(${PROJECT_NAME} PROPERTIES CXX_STANDARD 14) @@ -204,9 +315,13 @@ set(public_net_headers trantor/net/TcpClient.h trantor/net/TcpConnection.h trantor/net/TcpServer.h + trantor/net/AsyncStream.h trantor/net/callbacks.h trantor/net/Resolver.h - trantor/net/Channel.h) + trantor/net/Channel.h + trantor/net/Certificate.h + trantor/net/TLSPolicy.h +) set(public_utils_headers trantor/utils/AsyncFileLogger.h @@ -222,66 +337,68 @@ set(public_utils_headers trantor/utils/SerialTaskQueue.h trantor/utils/TaskQueue.h trantor/utils/TimingWheel.h - trantor/utils/Utilities.h) - -target_sources(${PROJECT_NAME} PRIVATE - ${TRANTOR_SOURCES} - ${CMAKE_CURRENT_BINARY_DIR}/exports/trantor/exports.h - ${public_net_headers} - ${public_utils_headers} - ${private_headers}) - -source_group("Public API" - FILES - ${CMAKE_CURRENT_BINARY_DIR}/exports/trantor/exports.h - ${public_net_headers} - ${public_utils_headers}) - -source_group("Private Headers" - FILES - ${private_headers}) - -install(TARGETS trantor - # IMPORTANT: Add the trantor library to the "export-set" - EXPORT TrantorTargets - RUNTIME DESTINATION "${INSTALL_BIN_DIR}" COMPONENT bin - ARCHIVE DESTINATION "${INSTALL_LIB_DIR}" COMPONENT lib - LIBRARY DESTINATION "${INSTALL_LIB_DIR}" COMPONENT lib) -install(FILES ${CMAKE_CURRENT_BINARY_DIR}/exports/trantor/exports.h - DESTINATION ${INSTALL_INCLUDE_DIR}/trantor) -install(FILES ${public_net_headers} - DESTINATION ${INSTALL_INCLUDE_DIR}/trantor/net) -install(FILES ${public_utils_headers} - DESTINATION ${INSTALL_INCLUDE_DIR}/trantor/utils) + trantor/utils/Utilities.h +) + +target_sources( + ${PROJECT_NAME} + PRIVATE ${TRANTOR_SOURCES} + ${CMAKE_CURRENT_BINARY_DIR}/exports/trantor/exports.h + ${public_net_headers} + ${public_utils_headers} + ${private_headers} +) + +source_group( + "Public API" FILES ${CMAKE_CURRENT_BINARY_DIR}/exports/trantor/exports.h ${public_net_headers} + ${public_utils_headers} +) + +source_group("Private Headers" FILES ${private_headers}) + +install( + TARGETS trantor + # IMPORTANT: Add the trantor library to the "export-set" + EXPORT TrantorTargets + RUNTIME DESTINATION "${INSTALL_BIN_DIR}" COMPONENT bin + ARCHIVE DESTINATION "${INSTALL_LIB_DIR}" COMPONENT lib + LIBRARY DESTINATION "${INSTALL_LIB_DIR}" COMPONENT lib +) +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/exports/trantor/exports.h DESTINATION ${INSTALL_INCLUDE_DIR}/trantor) +install(FILES ${public_net_headers} DESTINATION ${INSTALL_INCLUDE_DIR}/trantor/net) +install(FILES ${public_utils_headers} DESTINATION ${INSTALL_INCLUDE_DIR}/trantor/utils) include(CMakePackageConfigHelpers) # ... for the install tree configure_package_config_file( - cmake/templates/TrantorConfig.cmake.in - ${CMAKE_CURRENT_BINARY_DIR}${CMAKE_FILES_DIRECTORY}/TrantorConfig.cmake - INSTALL_DESTINATION - ${INSTALL_TRANTOR_CMAKE_DIR}) + cmake/templates/TrantorConfig.cmake.in ${CMAKE_CURRENT_BINARY_DIR}${CMAKE_FILES_DIRECTORY}/TrantorConfig.cmake + INSTALL_DESTINATION ${INSTALL_TRANTOR_CMAKE_DIR} +) # version write_basic_package_version_file( ${CMAKE_CURRENT_BINARY_DIR}/TrantorConfigVersion.cmake VERSION ${TRANTOR_VERSION} - COMPATIBILITY SameMajorVersion) + COMPATIBILITY SameMajorVersion +) # Install the TrantorConfig.cmake and TrantorConfigVersion.cmake install( - FILES - "${CMAKE_CURRENT_BINARY_DIR}${CMAKE_FILES_DIRECTORY}/TrantorConfig.cmake" - "${CMAKE_CURRENT_BINARY_DIR}/TrantorConfigVersion.cmake" - "${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules/Findc-ares.cmake" + FILES "${CMAKE_CURRENT_BINARY_DIR}${CMAKE_FILES_DIRECTORY}/TrantorConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/TrantorConfigVersion.cmake" + "${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules/Findc-ares.cmake" + "${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules/FindBotan.cmake" DESTINATION "${INSTALL_TRANTOR_CMAKE_DIR}" - COMPONENT dev) + COMPONENT dev +) # Install the export set for use with the install-tree -install(EXPORT TrantorTargets - DESTINATION "${INSTALL_TRANTOR_CMAKE_DIR}" - NAMESPACE Trantor:: - COMPONENT dev) +install( + EXPORT TrantorTargets + DESTINATION "${INSTALL_TRANTOR_CMAKE_DIR}" + NAMESPACE Trantor:: + COMPONENT dev +) # Doxygen documentation find_package(Doxygen OPTIONAL_COMPONENTS dot dia) @@ -291,27 +408,29 @@ if(DOXYGEN_FOUND) set(DOXYGEN_GENERATE_LATEX NO) set(DOXYGEN_BUILTIN_STL_SUPPORT YES) set(DOXYGEN_USE_MDFILE_AS_MAINPAGE README.md) - set(DOXYGEN_STRIP_FROM_INC_PATH ${PROJECT_SOURCE_DIR} - ${CMAKE_CURRENT_BINARY_DIR}/exports) - if (WIN32) + set(DOXYGEN_STRIP_FROM_INC_PATH ${PROJECT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/exports) + if(WIN32) set(DOXYGEN_PREDEFINED _WIN32) endif(WIN32) - doxygen_add_docs(doc_${PROJECT_NAME} - README.md - ChangeLog.md - ${public_net_headers} - ${public_utils_headers} - COMMENT "Generate documentation") + doxygen_add_docs( + doc_${PROJECT_NAME} + README.md + ChangeLog.md + ${public_net_headers} + ${public_utils_headers} + COMMENT "Generate documentation" + ) if(NOT TARGET doc) add_custom_target(doc) endif() add_dependencies(doc doc_${PROJECT_NAME}) - if (BUILD_DOC) + if(BUILD_DOC) add_dependencies(${PROJECT_NAME} doc_${PROJECT_NAME}) # Don't install twice, so limit to Debug (assume developer) - install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/docs/${PROJECT_NAME} - TYPE DOC - CONFIGURATIONS Debug) + install( + DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/docs/${PROJECT_NAME} + TYPE DOC + CONFIGURATIONS Debug + ) endif(BUILD_DOC) endif(DOXYGEN_FOUND) - diff --git a/ChangeLog.md b/ChangeLog.md index 54139364..a1d4ec6a 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -1,8 +1,244 @@ # Changelog + All notable changes to this project will be documented in this file. ## [Unreleased] +## [1.5.24] - 2025-06-20 + +### Changed + +- refactor: replace atomic counter with plain int in RunInLoopTest2. + +### Fixed +- Fix compile errors in some case. + +- Fix a bug when sending streams. + +## [1.5.23] - 2025-02-20 + +### Changed + +- Replace ipv4 inet_ntop with a handrolled function. + +### Fixed + +- Fix some typos. + +## [1.5.22] - 2024-10-27 + +### Fixed + +- Fix a bug in the dtor of EventLoop. + +- Free leaked memory in ares resolver. + +## [1.5.21] - 2024-09-10 + +### API changes list + +- Add a method to reload the SSL certificate and private key on the fly. + +### Changed + +- Keep log level consistency. + +## [1.5.20] - 2024-07-20 + +### Changed + +- Add byte order detection for internal SHA1 implementation for OSX, POWER, RISC-V and s390. + +### Fixed + +- Fix Windows CI build fail by using the latest MSVC. + +- Fix the Botan TLS provider build on Linux. + +- Fix "pthread not found" build error when using Android NDK. + +## [1.5.19] - 2024-06-08 + +### changed + +- show forked repository build status. + +- Add cmake-format. + +- Some spelling corrections. + +## [1.5.18] - 2024-05-04 + +### Fixed + +- Fix data type conflict. + +- Fix build on latest c-ares. + +## [1.5.17] - 2024-02-09 + +### Changed + +- Make FileBufferNodeWin aware of UWP Win32 API. + +- Use ssize_t declared by toolchain when available. + +## [1.5.16] - 2024-01-18 + +### Changed + +- Add build badge for individual OS. + +- deinit libressl. + +- Remove mutex. + +### Fixed + +- Pile of fix for h2. + +- Fix a bug when sending data. + +- Fix c-ares CARES_EXTERN for static builds. + +- Fix header file name issue when cross-compiling on Windows. + +- Fix name issue when cross-compiling. + +## [1.5.15] - 2023-11-27 + +### Changed + +- Feature: Integrate spdlog as logging backend for Trantor Logger. + +### Fixed + +- Fix the botan backend always validating certificate and OpenSSL allowing empty ALPN. + +- Fix build error on OpenBSD. + +- Fix Botan leaking memory if connection force closed. + +- Fix a cmake warning. + +- Workaround botan backend init failure on MacOS. + +- Fix failing wstr conversion if locale is set to C. + +## [1.5.14] - 2023-09-19 + +### [Fixed] + +- Fix OpenSSL: read can be incomplete. + +- Fix botan provider. + +- Fix botan3 not triggering handshake finish event. + +- Fix an compilation error when no STL lib is found. + +## [1.5.13] - 2023-08-23 + +### Fixed + +- Fix an error when sending files. + +- Include <memory> header in TcpConnectionImpl.cc. + +## [1.5.12] - 2023-08-20 + +### API changes list + +- Add NetEndian versions of toIp and toIpPort. + +- Add setsockopt to TcpClient and TcpServer. + +- Support setting max files in AsyncFileLogger. + +- Support returning multiple results for dns parsing. + +### Changed + +- Refactor SSL handling. + +- Add ability to use one log file until the size-limit. + +- Make the std::string_view work on windows. + +- Drop Botan 2 support and support Botan 3. + +- Make the getNextLoop method multi-thread safe. + +- Add fallback when OpenSSL not providing BLAKE2b. + +### Fixed + +- Fix override mark. + +- Add missing <cstdint> header with GCC 13. + +- Fix AresResolver. + +- Fix building built-in hashes on Windows. + +- Fix MSYS2/Cygwin compatibility issues. + +- Fix more build errors on win32/mingw. + +- Fix off_t(on windows off_t defined with long, not longlong). + +- Fix bug with Trantor::Date timeZoneOffset calculation. + +- Fix wrong usage of shared pointer in TcpClient ctor. + +## [1.5.11] - 2023-03-17 + +### API Changes list + +- Add a method to the Logger class to enable local time displaying. + +- TRNANTOR_LOG_COMPACT - compact logs without source code details. + +### Changed + +- Refactor TcpServer I/O loop logic. + +### Fixed + +- Fix a conan issue. + +## [1.5.10] - 2023-01-23 + +### API Changes list + +### Changed + +- Use gtest 1.13 in github actions + +### Fixed + +## [1.5.9] - 2023-01-23 + +### API Changes list + +### Changed + +- Search for \ if under msvc + +### Fixed + +## [1.5.8] - 2022-11-11 + +### API Changes list + +### Changed + +### Fixed + +- Fix Date::timezoneOffset(). + +- Fix socket fd leak if Connector destruct before connection callback is made. + ## [1.5.7] - 2022-09-25 ### API changes list @@ -401,7 +637,6 @@ All notable changes to this project will be documented in this file. - Use the std::chrono::steady_clock for timers - ## [1.0.0-rc7] - 2019-11-21 ### Changed @@ -430,7 +665,7 @@ All notable changes to this project will be documented in this file. - Add the Resolver class that provides high-performance DNS functionality(with c-ares library) - Add some unit tests. - + ## [1.0.0-rc4] - 2019-08-08 ### API changes list @@ -472,7 +707,41 @@ All notable changes to this project will be documented in this file. ## [1.0.0-rc1] - 2019-06-11 -[Unreleased]: https://github.com/an-tao/trantor/compare/v1.5.7...HEAD +[Unreleased]: https://github.com/an-tao/trantor/compare/v1.5.24...HEAD + +[1.5.24]: https://github.com/an-tao/trantor/compare/v1.5.23...v1.5.24 + +[1.5.23]: https://github.com/an-tao/trantor/compare/v1.5.22...v1.5.23 + +[1.5.22]: https://github.com/an-tao/trantor/compare/v1.5.21...v1.5.22 + +[1.5.21]: https://github.com/an-tao/trantor/compare/v1.5.20...v1.5.21 + +[1.5.20]: https://github.com/an-tao/trantor/compare/v1.5.19...v1.5.20 + +[1.5.19]: https://github.com/an-tao/trantor/compare/v1.5.18...v1.5.19 + +[1.5.18]: https://github.com/an-tao/trantor/compare/v1.5.17...v1.5.18 + +[1.5.17]: https://github.com/an-tao/trantor/compare/v1.5.16...v1.5.17 + +[1.5.16]: https://github.com/an-tao/trantor/compare/v1.5.15...v1.5.16 + +[1.5.15]: https://github.com/an-tao/trantor/compare/v1.5.14...v1.5.15 + +[1.5.14]: https://github.com/an-tao/trantor/compare/v1.5.13...v1.5.14 + +[1.5.13]: https://github.com/an-tao/trantor/compare/v1.5.12...v1.5.13 + +[1.5.12]: https://github.com/an-tao/trantor/compare/v1.5.11...v1.5.12 + +[1.5.11]: https://github.com/an-tao/trantor/compare/v1.5.10...v1.5.11 + +[1.5.10]: https://github.com/an-tao/trantor/compare/v1.5.9...v1.5.10 + +[1.5.9]: https://github.com/an-tao/trantor/compare/v1.5.8...v1.5.9 + +[1.5.8]: https://github.com/an-tao/trantor/compare/v1.5.7...v1.5.8 [1.5.7]: https://github.com/an-tao/trantor/compare/v1.5.6...v1.5.7 diff --git a/README.md b/README.md index a76aa4e0..33bb7c2d 100755 --- a/README.md +++ b/README.md @@ -1,22 +1,21 @@ # TRANTOR - -[![Build Status](https://travis-ci.org/an-tao/trantor.svg?branch=master)](https://travis-ci.org/an-tao/trantor) -[![Build status](https://ci.appveyor.com/api/projects/status/yn8xunsubn37pi1u/branch/master?svg=true)](https://ci.appveyor.com/project/an-tao/trantor/branch/master) -[![Language grade: C/C++](https://img.shields.io/lgtm/grade/cpp/g/an-tao/trantor.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/an-tao/trantor/context:cpp) - +[![Build Ubuntu gcc](../../actions/workflows/ubuntu-gcc.yml/badge.svg)](../../actions/workflows/ubuntu-gcc.yml/badge.svg) +[![Build Macos clang](../../actions/workflows/macos-clang.yml/badge.svg)](../../actions/workflows/macos-clang.yml/badge.svg) +[![Build RockyLinux gcc](../../actions/workflows/rockylinux-gcc.yml/badge.svg)](../../actions/workflows/rockylinux-gcc.yml/badge.svg) +[![Build Windows msvc](../../actions/workflows/windows-msvc.yml/badge.svg)](../../actions/workflows/windows-msvc.yml/badge.svg) ## Overview A non-blocking I/O cross-platform TCP network library, using C++14. Drawing on the design of Muduo Library -## suported platforms +## Supported platforms - Linux - MacOS - UNIX(BSD) - Windows ## Feature highlights -- non-blocking I/O +- Non-blocking I/O - cross-platform - Thread pool - Lock free design @@ -28,7 +27,7 @@ Drawing on the design of Muduo Library ```shell git clone https://github.com/an-tao/trantor.git cd trantor -cmake -Bbuild -H. +cmake -B build -H. cd build make -j ``` diff --git a/cmake/templates/TrantorConfig.cmake.in b/cmake/templates/TrantorConfig.cmake.in index e18652de..e9422ed9 100644 --- a/cmake/templates/TrantorConfig.cmake.in +++ b/cmake/templates/TrantorConfig.cmake.in @@ -1,3 +1,4 @@ +#[[ # - Config file for the Trantor package # It defines the following variables # TRANTOR_INCLUDE_DIRS - include directories for Trantor @@ -5,6 +6,7 @@ # Trantor_FOUND # This module defines the following IMPORTED target: # Trantor::Trantor +#]] @PACKAGE_INIT@ @@ -13,10 +15,16 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}) if(@OpenSSL_FOUND@) find_dependency(OpenSSL) endif() +if(@Botan_FOUND@) + find_dependency(Botan) +endif() if(@c-ares_FOUND@) find_dependency(c-ares) endif() find_dependency(Threads) +if(@spdlog_FOUND@) + find_dependency(spdlog) +endif() # Compute paths # Our library dependencies (contains definitions for IMPORTED targets) diff --git a/cmake_modules/FindBotan.cmake b/cmake_modules/FindBotan.cmake new file mode 100644 index 00000000..e09ce039 --- /dev/null +++ b/cmake_modules/FindBotan.cmake @@ -0,0 +1,68 @@ +function(find_botan_pkgconfig package_name botan_ver) + if(TARGET Botan::Botan) + return() + endif() + + pkg_check_modules( + Botan + QUIET + IMPORTED_TARGET + ${package_name} + ) + if(TARGET PkgConfig::Botan) + add_library(Botan::Botan ALIAS PkgConfig::Botan) + + if(botan_ver EQUAL 3) + target_compile_features(PkgConfig::Botan INTERFACE cxx_std_20) + endif() + endif() +endfunction() + +function(find_botan_search package_name botan_ver) + if(TARGET Botan::Botan) + return() + endif() + find_path( + Botan_INCLUDE_DIRS + NAMES botan/botan.h + PATH_SUFFIXES ${package_name} + DOC "The Botan include directory" + ) + + find_library( + Botan_LIBRARIES + NAMES botan ${package_name} + DOC "The Botan library" + ) + + mark_as_advanced(Botan_INCLUDE_DIRS Botan_LIBRARIES) + + add_library(Botan::Botan IMPORTED UNKNOWN) + set_target_properties( + Botan::Botan + PROPERTIES + IMPORTED_LOCATION "${Botan_LIBRARIES}" + INTERFACE_INCLUDE_DIRECTORIES "${Botan_INCLUDE_DIRS}" + ) + if(botan_ver EQUAL 3) + target_compile_features(Botan::Botan INTERFACE cxx_std_20) + endif() + + if(WIN32) + target_compile_definitions(Botan::Botan INTERFACE -DNOMINMAX=1) + endif() +endfunction() + +find_package(PkgConfig) +if(NOT WIN32 AND PKG_CONFIG_FOUND) + # find_botan_pkgconfig(botan-2 2) + find_botan_pkgconfig(botan-3 3) +endif() + +if(NOT TARGET Botan::Botan) + # find_botan_search(botan-2 2) + find_botan_search(botan-3 3) +endif() + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Botan REQUIRED_VARS Botan_LIBRARIES Botan_INCLUDE_DIRS) diff --git a/cmake_modules/Findc-ares.cmake b/cmake_modules/Findc-ares.cmake index 98821437..73334b72 100644 --- a/cmake_modules/Findc-ares.cmake +++ b/cmake_modules/Findc-ares.cmake @@ -1,27 +1,30 @@ +#[[ # Try to find c-ares library Once done this will define # -# c-ares_FOUND - system has c-ares -# C-ARES_INCLUDE_DIRS - The c-ares include directory +# c-ares_FOUND - system has c-ares +# C-ARES_INCLUDE_DIRS - The c-ares include directory # C-ARES_LIBRARIES - Link these to use c-ares # c-ares_lib - Imported Targets # # Copyright (c) 2020 antao -# +#]] find_path(C-ARES_INCLUDE_DIRS ares.h) find_library(C-ARES_LIBRARIES NAMES cares) if(C-ARES_INCLUDE_DIRS AND C-ARES_LIBRARIES) add_library(c-ares_lib INTERFACE IMPORTED) - set_target_properties(c-ares_lib - PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - "${C-ARES_INCLUDE_DIRS}" - INTERFACE_LINK_LIBRARIES - "${C-ARES_LIBRARIES}") + set_target_properties( + c-ares_lib + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${C-ARES_INCLUDE_DIRS}" INTERFACE_LINK_LIBRARIES "${C-ARES_LIBRARIES}" + ) endif() include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(c-ares - DEFAULT_MSG - C-ARES_INCLUDE_DIRS - C-ARES_LIBRARIES) +find_package_handle_standard_args( + c-ares + DEFAULT_MSG + C-ARES_INCLUDE_DIRS + C-ARES_LIBRARIES +) mark_as_advanced(C-ARES_INCLUDE_DIRS C-ARES_LIBRARIES) diff --git a/conanfile.txt b/conanfile.txt index 59906b3a..2e6dc442 100644 --- a/conanfile.txt +++ b/conanfile.txt @@ -1,10 +1,11 @@ [requires] gtest/1.10.0 -openssl/1.1.1j +openssl/1.1.1t #c-ares/1.17.1 +spdlog/1.12.0 [generators] -cmake_paths +CMakeToolchain [options] diff --git a/format.sh b/format.sh index a1bde595..9510d04b 100755 --- a/format.sh +++ b/format.sh @@ -4,3 +4,9 @@ clang-format --version find trantor -name *.h -o -name *.cc -exec dos2unix {} \; find trantor -name *.h -o -name *.cc|xargs clang-format -i -style=file + +cmake-format --version +find . -maxdepth 1 -name CMakeLists.txt|xargs cmake-format -i +find trantor -name CMakeLists.txt|xargs cmake-format -i +find cmake -name *.cmake -o -name *.cmake.in|xargs cmake-format -i +find cmake_modules -name *.cmake -o -name *.cmake.in|xargs cmake-format -i \ No newline at end of file diff --git a/third_party/wepoll/Wepoll.c b/third_party/wepoll/Wepoll.c index 5b8ee517..ed37d7f7 100644 --- a/third_party/wepoll/Wepoll.c +++ b/third_party/wepoll/Wepoll.c @@ -902,7 +902,7 @@ int init(void) /* `InitOnceExecuteOnce()` itself is infallible, and it doesn't set any * error code when the once-callback returns FALSE. We return -1 here to * indicate that global initialization failed; the failing init function - * is resposible for setting `errno` and calling `SetLastError()`. */ + * is responsible for setting `errno` and calling `SetLastError()`. */ return -1; return 0; diff --git a/trantor/net/AsyncStream.h b/trantor/net/AsyncStream.h new file mode 100644 index 00000000..c459f9ea --- /dev/null +++ b/trantor/net/AsyncStream.h @@ -0,0 +1,51 @@ +/** + * + * @file AsyncStream.h + * @author An Tao + * + * Public header file in trantor lib. + * + * Copyright 2023, An Tao. All rights reserved. + * Use of this source code is governed by a BSD-style license + * that can be found in the License file. + * + * + */ + +#pragma once + +#include +#include + +namespace trantor +{ +/** + * @brief This class represents a data stream that can be sent asynchronously. + * The data is sent in chunks, and the chunks are sent in order, and all the + * chunks are sent continuously. + */ +class TRANTOR_EXPORT AsyncStream : public NonCopyable +{ + public: + virtual ~AsyncStream() = default; + /** + * @brief Send data asynchronously. + * + * @param data The data to be sent + * @param len The length of the data + * @return true if the data is sent successfully or at least is put in the + * send buffer. + * @return false if the connection is closed. + */ + virtual bool send(const char *data, size_t len) = 0; + bool send(const std::string &data) + { + return send(data.data(), data.length()); + } + /** + * @brief Terminate the stream. + */ + virtual void close() = 0; +}; +using AsyncStreamPtr = std::unique_ptr; +} // namespace trantor \ No newline at end of file diff --git a/trantor/net/Certificate.h b/trantor/net/Certificate.h new file mode 100644 index 00000000..925c6d96 --- /dev/null +++ b/trantor/net/Certificate.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include + +namespace trantor +{ +struct Certificate +{ + virtual ~Certificate() = default; + virtual std::string sha1Fingerprint() const = 0; + virtual std::string sha256Fingerprint() const = 0; + virtual std::string pem() const = 0; +}; +using CertificatePtr = std::shared_ptr; + +} // namespace trantor \ No newline at end of file diff --git a/trantor/net/EventLoop.cc b/trantor/net/EventLoop.cc index a2231192..17de454c 100644 --- a/trantor/net/EventLoop.cc +++ b/trantor/net/EventLoop.cc @@ -28,7 +28,9 @@ #include #include #include +#ifndef _SSIZE_T_DEFINED using ssize_t = long long; +#endif #else #include #endif @@ -43,6 +45,7 @@ using ssize_t = long long; #include #include #include +#include namespace trantor { @@ -136,7 +139,6 @@ EventLoop::~EventLoop() #endif } - t_loopInThisThread = nullptr; #ifdef __linux__ close(wakeupFd_); #elif defined _WIN32 @@ -241,7 +243,8 @@ void EventLoop::loop() catch (std::exception &e) { LOG_WARN << "Exception thrown from event loop, rethrowing after " - "running functions on quit"; + "running functions on quit: " + << e.what(); loopException = std::current_exception(); } @@ -253,7 +256,7 @@ void EventLoop::loop() { f(); } - + t_loopInThisThread = nullptr; // Throw the exception from the end if (loopException) { diff --git a/trantor/net/EventLoop.h b/trantor/net/EventLoop.h index fd0253d2..00b7ebf7 100644 --- a/trantor/net/EventLoop.h +++ b/trantor/net/EventLoop.h @@ -45,7 +45,7 @@ enum /** * @brief As the name implies, this class represents an event loop that runs in - * a perticular thread. The event loop can handle network I/O events and timers + * a particular thread. The event loop can handle network I/O events and timers * in asynchronous mode. * @note An event loop object always belongs to a separate thread, and there is * one event loop object at most in a thread. We can call an event loop object diff --git a/trantor/net/EventLoopThreadPool.cc b/trantor/net/EventLoopThreadPool.cc index c96927b0..33264d43 100644 --- a/trantor/net/EventLoopThreadPool.cc +++ b/trantor/net/EventLoopThreadPool.cc @@ -47,10 +47,9 @@ EventLoop *EventLoopThreadPool::getNextLoop() { if (loopThreadVector_.size() > 0) { - EventLoop *loop = loopThreadVector_[loopIndex_]->getLoop(); - ++loopIndex_; - if (loopIndex_ >= loopThreadVector_.size()) - loopIndex_ = 0; + size_t index = loopIndex_.fetch_add(1, std::memory_order_relaxed); + EventLoop *loop = + loopThreadVector_[index % loopThreadVector_.size()]->getLoop(); return loop; } return nullptr; diff --git a/trantor/net/EventLoopThreadPool.h b/trantor/net/EventLoopThreadPool.h index 6332169b..8ca9776b 100644 --- a/trantor/net/EventLoopThreadPool.h +++ b/trantor/net/EventLoopThreadPool.h @@ -18,6 +18,7 @@ #include #include #include +#include namespace trantor { @@ -87,6 +88,6 @@ class TRANTOR_EXPORT EventLoopThreadPool : NonCopyable private: std::vector> loopThreadVector_; - size_t loopIndex_; + std::atomic loopIndex_{0}; }; } // namespace trantor diff --git a/trantor/net/InetAddress.cc b/trantor/net/InetAddress.cc index dc1aa2fb..08a0aa42 100644 --- a/trantor/net/InetAddress.cc +++ b/trantor/net/InetAddress.cc @@ -9,7 +9,8 @@ #include #include -//#include +#include +// #include #ifdef _WIN32 struct in6_addr_uint @@ -117,6 +118,18 @@ std::string InetAddress::toIpPort() const snprintf(buf, sizeof(buf), ":%u", port); return toIp() + std::string(buf); } +std::string InetAddress::toIpPortNetEndian() const +{ + std::string buf; + static constexpr auto bytes = sizeof(addr_.sin_port); + buf.resize(bytes); +#if defined _WIN32 + std::memcpy((PVOID)&buf[0], (PVOID)&addr_.sin_port, bytes); +#else + std::memcpy(&buf[0], &addr_.sin_port, bytes); +#endif + return toIpNetEndian() + buf; +} bool InetAddress::isIntranetIp() const { if (addr_.sin_family == AF_INET) @@ -183,20 +196,44 @@ bool InetAddress::isLoopbackIp() const return false; } +static void byteToChars(std::string::iterator &dst, unsigned char byte) +{ + *dst = byte / 100 + '0'; + dst += byte >= 100; + *dst = byte % 100 / 10 + '0'; + dst += byte >= 10; + *dst = byte % 10 + '0'; + ++dst; +} + +static std::string iptos(unsigned inet_addr) +{ + // Initialize with a static buffer to force the constructor of string to get + // fully inlined + constexpr char stringInitBuffer[15]{}; + std::string out(stringInitBuffer, 15); + std::string::iterator dst = out.begin(); + byteToChars(dst, inet_addr >> 0 & 0xff); + *(dst++) = '.'; + byteToChars(dst, inet_addr >> 8 & 0xff); + *(dst++) = '.'; + byteToChars(dst, inet_addr >> 16 & 0xff); + *(dst++) = '.'; + byteToChars(dst, inet_addr >> 24 & 0xff); + out.erase(dst, out.end()); + return out; +} + std::string InetAddress::toIp() const { - char buf[64]; + char buf[INET6_ADDRSTRLEN]{}; if (addr_.sin_family == AF_INET) { -#if defined _MSC_VER && _MSC_VER >= 1900 - ::inet_ntop(AF_INET, (PVOID)&addr_.sin_addr, buf, sizeof(buf)); -#else - ::inet_ntop(AF_INET, &addr_.sin_addr, buf, sizeof(buf)); -#endif + return iptos(addr_.sin_addr.s_addr); } else if (addr_.sin_family == AF_INET6) { -#if defined _MSC_VER && _MSC_VER >= 1900 +#if defined _WIN32 ::inet_ntop(AF_INET6, (PVOID)&addr6_.sin6_addr, buf, sizeof(buf)); #else ::inet_ntop(AF_INET6, &addr6_.sin6_addr, buf, sizeof(buf)); @@ -206,6 +243,33 @@ std::string InetAddress::toIp() const return buf; } +std::string InetAddress::toIpNetEndian() const +{ + std::string buf; + if (addr_.sin_family == AF_INET) + { + static constexpr auto bytes = sizeof(addr_.sin_addr.s_addr); + buf.resize(bytes); +#if defined _WIN32 + std::memcpy((PVOID)&buf[0], (PVOID)&addr_.sin_addr.s_addr, bytes); +#else + std::memcpy(&buf[0], &addr_.sin_addr.s_addr, bytes); +#endif + } + else if (addr_.sin_family == AF_INET6) + { + static constexpr auto bytes = sizeof(addr6_.sin6_addr); + buf.resize(bytes); +#if defined _WIN32 + std::memcpy((PVOID)&buf[0], (PVOID)ip6NetEndian(), bytes); +#else + std::memcpy(&buf[0], ip6NetEndian(), bytes); +#endif + } + + return buf; +} + uint32_t InetAddress::ipNetEndian() const { // assert(family() == AF_INET); diff --git a/trantor/net/InetAddress.h b/trantor/net/InetAddress.h index 1750bb5f..a8f7eb66 100644 --- a/trantor/net/InetAddress.h +++ b/trantor/net/InetAddress.h @@ -112,6 +112,21 @@ class TRANTOR_EXPORT InetAddress */ std::string toIpPort() const; + /** + * @brief Return the IP bytes of the endpoint in net endian byte order + * + * @return std::string + */ + std::string toIpNetEndian() const; + + /** + * @brief Return the IP and port bytes of the endpoint in net endian byte + * order + * + * @return std::string + */ + std::string toIpPortNetEndian() const; + /** * @brief Return the port number of the endpoint. * @@ -204,7 +219,7 @@ class TRANTOR_EXPORT InetAddress } /** - * @brief Return true if the address is not initalized. + * @brief Return true if the address is not initialized. */ inline bool isUnspecified() const { diff --git a/trantor/net/Resolver.h b/trantor/net/Resolver.h index 478887bf..2da40f8b 100644 --- a/trantor/net/Resolver.h +++ b/trantor/net/Resolver.h @@ -22,6 +22,8 @@ class TRANTOR_EXPORT Resolver { public: using Callback = std::function; + using ResolverResultsCallback = + std::function&)>; /** * @brief Create a new DNS resolver. @@ -42,6 +44,15 @@ class TRANTOR_EXPORT Resolver virtual void resolve(const std::string& hostname, const Callback& callback) = 0; + /** + * @brief Resolve an address array asynchronously. + * + * @param hostname + * @param callback + */ + virtual void resolve(const std::string& hostname, + const ResolverResultsCallback& callback) = 0; + virtual ~Resolver() { } diff --git a/trantor/net/TLSPolicy.h b/trantor/net/TLSPolicy.h new file mode 100644 index 00000000..e59aca03 --- /dev/null +++ b/trantor/net/TLSPolicy.h @@ -0,0 +1,221 @@ +#pragma once +#include + +#include +#include +#include +#include + +namespace trantor +{ +struct TRANTOR_EXPORT TLSPolicy final +{ + /** + * @brief set the ssl configuration commands. The commands will be passed + * to the ssl library. The commands are in the form of {{key, value}}. + * for example, {"SSL_OP_NO_SSLv2", "1"}. Not all TLS providers support + * this feature AND the meaning of the commands may vary between TLS + * providers. + * + * As of 2023-03 Only OpenSSL supports this feature. LibreSSL does not + * nor Botan. + */ + TLSPolicy &setConfCmds( + const std::vector> &sslConfCmds) + { + sslConfCmds_ = sslConfCmds; + return *this; + } + /** + * @brief set the hostname to be used for SNI and certificate validation. + */ + TLSPolicy &setHostname(const std::string &hostname) + { + hostname_ = hostname; + return *this; + } + + /** + * @brief set the path to the certificate file. The file must be in PEM + * format. + */ + TLSPolicy &setCertPath(const std::string &certPath) + { + certPath_ = certPath; + return *this; + } + + /** + * @brief set the path to the private key file. The file must be in PEM + * format. + */ + TLSPolicy &setKeyPath(const std::string &keyPath) + { + keyPath_ = keyPath; + return *this; + } + + /** + * @brief set the path to the CA file or directory. The file must be in + * PEM format. + */ + TLSPolicy &setCaPath(const std::string &caPath) + { + caPath_ = caPath; + return *this; + } + + /** + * @brief enables the use of the old TLS protocol (old meaning < TLS 1.2). + * TLS providers may not support old protocols even if this option is set + */ + TLSPolicy &setUseOldTLS(bool useOldTLS) + { + useOldTLS_ = useOldTLS; + return *this; + } + + /** + * @brief set the list of protocols to be used for ALPN. + * + * @note for servers, it selects matching protocol against the client's + * list. And the first matching protocol supplied in the parameter will be + * selected. If no matching protocol is found, the connection will be + * closed. + * + * @note for clients, it sends the list of protocols to the server. + */ + TLSPolicy &setAlpnProtocols(const std::vector &alpnProtocols) + { + alpnProtocols_ = alpnProtocols; + return *this; + } + TLSPolicy &setAlpnProtocols(std::vector &&alpnProtocols) + { + alpnProtocols_ = std::move(alpnProtocols); + return *this; + } + + /** + * @brief Weather to use the system's certificate store. + * + * @note setting both not to use the system's certificate store and to + * supply a CA path WILL LEAD TO NO CERTIFICATE VALIDATION AT ALL. + */ + TLSPolicy &setUseSystemCertStore(bool useSystemCertStore) + { + useSystemCertStore_ = useSystemCertStore; + return *this; + } + + /** + * @brief Enable certificate validation. + */ + TLSPolicy &setValidate(bool enable) + { + validate_ = enable; + return *this; + } + + /** + * @brief Allow broken chain (self-signed certificate, root CA not in + * allowed list, etc..) but still validate the domain name and date. This + * option has no effect if validate is false. + * + * @note IMPORTANT: This option makes more then self signed certificates + * valid. It also allows certificates that are not signed by a trusted CA, + * the CA gets revoked. But the underlying implementation may still check + * for the type of certificate, date and hostname, etc.. To disable all + * certificate validation, use setValidate(false). + */ + TLSPolicy &setAllowBrokenChain(bool allow) + { + allowBrokenChain_ = allow; + return *this; + } + + // The getters + const std::vector> &getConfCmds() const + { + return sslConfCmds_; + } + const std::string &getHostname() const + { + return hostname_; + } + const std::string &getCertPath() const + { + return certPath_; + } + const std::string &getKeyPath() const + { + return keyPath_; + } + const std::string &getCaPath() const + { + return caPath_; + } + bool getUseOldTLS() const + { + return useOldTLS_; + } + bool getValidate() const + { + return validate_; + } + bool getAllowBrokenChain() const + { + return allowBrokenChain_; + } + const std::vector &getAlpnProtocols() const + { + return alpnProtocols_; + } + const std::vector &getAlpnProtocols() + { + return alpnProtocols_; + } + + bool getUseSystemCertStore() const + { + return useSystemCertStore_; + } + + static std::shared_ptr defaultServerPolicy( + const std::string &certPath, + const std::string &keyPath) + { + auto policy = std::make_shared(); + policy->setValidate(false) + .setUseOldTLS(false) + .setUseSystemCertStore(false) + .setCertPath(certPath) + .setKeyPath(keyPath); + return policy; + } + + static std::shared_ptr defaultClientPolicy( + const std::string &hostname = "") + { + auto policy = std::make_shared(); + policy->setValidate(true) + .setUseOldTLS(false) + .setUseSystemCertStore(true) + .setHostname(hostname); + return policy; + } + + protected: + std::vector> sslConfCmds_ = {}; + std::string hostname_ = ""; + std::string certPath_ = ""; + std::string keyPath_ = ""; + std::string caPath_ = ""; + std::vector alpnProtocols_ = {}; + bool useOldTLS_ = false; // turn into specific version + bool validate_ = true; + bool allowBrokenChain_ = false; + bool useSystemCertStore_ = true; +}; +using TLSPolicyPtr = std::shared_ptr; +} // namespace trantor \ No newline at end of file diff --git a/trantor/net/TcpClient.cc b/trantor/net/TcpClient.cc index 7b28fb03..1acd73bb 100644 --- a/trantor/net/TcpClient.cc +++ b/trantor/net/TcpClient.cc @@ -10,6 +10,7 @@ // Taken from muduo and modified by an tao #include +#include #include #include "Connector.h" @@ -18,6 +19,8 @@ #include #include +#include +#include #include "Socket.h" @@ -64,44 +67,28 @@ TcpClient::TcpClient(EventLoop *loop, connect_(true) { (void)validateCert_; - connector_->setNewConnectionCallback( - std::bind(&TcpClient::newConnection, this, _1)); - connector_->setErrorCallback([this]() { - if (connectionErrorCallback_) - { - connectionErrorCallback_(); - } - }); LOG_TRACE << "TcpClient::TcpClient[" << name_ << "] - connector "; } TcpClient::~TcpClient() { LOG_TRACE << "TcpClient::~TcpClient[" << name_ << "] - connector "; - TcpConnectionImplPtr conn; + std::lock_guard lock(mutex_); + if (connection_ == nullptr) { - std::lock_guard lock(mutex_); - conn = std::dynamic_pointer_cast(connection_); - if (conn) - { - assert(loop_ == conn->getLoop()); - // TODO: not 100% safe, if we are in different thread - auto loop = loop_; - loop_->runInLoop([conn, loop]() { - conn->setCloseCallback([loop](const TcpConnectionPtr &connPtr) { - loop->queueInLoop([connPtr]() { - static_cast(connPtr.get()) - ->connectDestroyed(); - }); - }); - }); - conn->forceClose(); - } - else - { - connector_->stop(); - } + connector_->stop(); + return; } + assert(loop_ == connection_->getLoop()); + auto conn = + std::atomic_load_explicit(&connection_, std::memory_order_relaxed); + loop_->runInLoop([conn = std::move(conn)]() { + conn->setCloseCallback([](const TcpConnectionPtr &connPtr) mutable { + connPtr->getLoop()->queueInLoop( + [connPtr] { connPtr->connectDestroyed(); }); + }); + }); + connection_->forceClose(); } void TcpClient::connect() @@ -109,6 +96,23 @@ void TcpClient::connect() // TODO: check state LOG_TRACE << "TcpClient::connect[" << name_ << "] - connecting to " << connector_->serverAddress().toIpPort(); + + auto weakPtr = std::weak_ptr(shared_from_this()); + connector_->setNewConnectionCallback([weakPtr](int sockfd) { + auto ptr = weakPtr.lock(); + if (ptr) + { + ptr->newConnection(sockfd); + } + }); + // WORKAROUND: somehow we got use-after-free error + connector_->setErrorCallback([weakPtr]() { + auto ptr = weakPtr.lock(); + if (ptr && ptr->connectionErrorCallback_) + { + ptr->connectionErrorCallback_(); + } + }); connect_ = true; connector_->start(); } @@ -132,6 +136,16 @@ void TcpClient::stop() connector_->stop(); } +void TcpClient::setSockOptCallback(SockOptCallback &&cb) +{ + connector_->setSockOptCallback(std::move(cb)); +} + +void TcpClient::setSockOptCallback(const SockOptCallback &cb) +{ + connector_->setSockOptCallback(cb); +} + void TcpClient::newConnection(int sockfd) { loop_->assertInLoopThread(); @@ -139,22 +153,13 @@ void TcpClient::newConnection(int sockfd) InetAddress localAddr(Socket::getLocalAddr(sockfd)); // TODO poll with zero timeout to double confirm the new connection // TODO use make_shared if necessary - std::shared_ptr conn; - if (sslCtxPtr_) + TcpConnectionPtr conn; + LOG_TRACE << "SSL enabled: " << (tlsPolicyPtr_ ? "true" : "false"); + if (tlsPolicyPtr_) { -#ifdef USE_OPENSSL - conn = std::make_shared(loop_, - sockfd, - localAddr, - peerAddr, - sslCtxPtr_, - false, - validateCert_, - SSLHostName_); -#else - LOG_FATAL << "OpenSSL is not found in your system!"; - throw std::runtime_error("OpenSSL is not found in your system!"); -#endif + assert(sslContextPtr_); + conn = std::make_shared( + loop_, sockfd, localAddr, peerAddr, tlsPolicyPtr_, sslContextPtr_); } else { @@ -179,9 +184,7 @@ void TcpClient::newConnection(int sockfd) { LOG_TRACE << "TcpClient::removeConnection was skipped because " "TcpClient instanced already freed"; - c->getLoop()->queueInLoop( - std::bind(&TcpConnectionImpl::connectDestroyed, - std::dynamic_pointer_cast(c))); + c->getLoop()->queueInLoop([c] { c->connectDestroyed(); }); } }); conn->setCloseCallback(std::move(closeCb)); @@ -189,11 +192,10 @@ void TcpClient::newConnection(int sockfd) std::lock_guard lock(mutex_); connection_ = conn; } - conn->setSSLErrorCallback([this](SSLError err) { - if (sslErrorCallback_) - { - sslErrorCallback_(err); - } + conn->setSSLErrorCallback([weakSelf = std::move(weakSelf)](SSLError err) { + auto self = weakSelf.lock(); + if (self && self->sslErrorCallback_) + self->sslErrorCallback_(err); }); conn->connectEstablished(); } @@ -209,9 +211,7 @@ void TcpClient::removeConnection(const TcpConnectionPtr &conn) connection_.reset(); } - loop_->queueInLoop( - std::bind(&TcpConnectionImpl::connectDestroyed, - std::dynamic_pointer_cast(conn))); + loop_->queueInLoop([conn]() { conn->connectDestroyed(); }); if (retry_ && connect_) { LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to " @@ -229,32 +229,21 @@ void TcpClient::enableSSL( const std::string &keyPath, const std::string &caPath) { -#ifdef USE_OPENSSL - /* Create a new OpenSSL context */ - sslCtxPtr_ = newSSLClientContext( - useOldTLS, validateCert, certPath, keyPath, sslConfCmds, caPath); - validateCert_ = validateCert; if (!hostname.empty()) { std::transform(hostname.begin(), hostname.end(), hostname.begin(), [](unsigned char c) { return tolower(c); }); - SSLHostName_ = std::move(hostname); } -#else - // When not using OpenSSL, using `void` here will - // work around the unused parameter warnings without overhead. - (void)useOldTLS; - (void)validateCert; - (void)hostname; - (void)sslConfCmds; - (void)certPath; - (void)keyPath; - (void)caPath; - - LOG_FATAL << "OpenSSL is not found in your system!"; - throw std::runtime_error("OpenSSL is not found in your system!"); -#endif + tlsPolicyPtr_ = TLSPolicy::defaultClientPolicy(); + tlsPolicyPtr_->setValidate(validateCert) + .setUseOldTLS(useOldTLS) + .setConfCmds(sslConfCmds) + .setCertPath(certPath) + .setKeyPath(keyPath) + .setHostname(hostname) + .setCaPath(caPath); + sslContextPtr_ = newSSLContext(*tlsPolicyPtr_, false); } diff --git a/trantor/net/TcpClient.h b/trantor/net/TcpClient.h index 9349063c..8d885da9 100644 --- a/trantor/net/TcpClient.h +++ b/trantor/net/TcpClient.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -29,7 +30,6 @@ namespace trantor { class Connector; using ConnectorPtr = std::shared_ptr; -class SSLContext; /** * @brief This class represents a TCP client. * @@ -190,6 +190,13 @@ class TRANTOR_EXPORT TcpClient : NonCopyable, sslErrorCallback_ = std::move(cb); } + /** + * @brief Set the callback for set socket option + * @param cb The callback is called, before connect + */ + void setSockOptCallback(const SockOptCallback &cb); + void setSockOptCallback(SockOptCallback &&cb); + /** * @brief Enable SSL encryption. * @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the @@ -202,17 +209,27 @@ class TRANTOR_EXPORT TcpClient : NonCopyable, * OpenSSL. * @param certPath The path of the certificate file. * @param keyPath The path of the private key file. + * @param caPath The path of the certificate authority file. * @note It's well known that TLS 1.0 and 1.1 are not considered secure in * 2020. And it's a good practice to only use TLS 1.2 and above. */ - void enableSSL(bool useOldTLS = false, - bool validateCert = true, - std::string hostname = "", - const std::vector> - &sslConfCmds = {}, - const std::string &certPath = "", - const std::string &keyPath = "", - const std::string &caPath = ""); + [[deprecated("Use enableSSL(TLSPolicyPtr policy) instead")]] void enableSSL( + bool useOldTLS = false, + bool validateCert = true, + std::string hostname = "", + const std::vector> &sslConfCmds = + {}, + const std::string &certPath = "", + const std::string &keyPath = "", + const std::string &caPath = ""); + /** + * @brief Enable SSL encryption. + */ + void enableSSL(TLSPolicyPtr policy) + { + tlsPolicyPtr_ = std::move(policy); + sslContextPtr_ = newSSLContext(*tlsPolicyPtr_, false); + } private: /// Not thread safe, but in loop @@ -233,8 +250,8 @@ class TRANTOR_EXPORT TcpClient : NonCopyable, // always in loop thread mutable std::mutex mutex_; TcpConnectionPtr connection_; // @GuardedBy mutex_ - std::shared_ptr sslCtxPtr_; - std::string SSLHostName_; + TLSPolicyPtr tlsPolicyPtr_; + SSLContextPtr sslContextPtr_; bool validateCert_{false}; #ifndef _WIN32 diff --git a/trantor/net/TcpConnection.h b/trantor/net/TcpConnection.h index 479874c5..41626bcd 100644 --- a/trantor/net/TcpConnection.h +++ b/trantor/net/TcpConnection.h @@ -19,19 +19,20 @@ #include #include #include +#include +#include +#include #include #include #include namespace trantor { -class SSLContext; -TRANTOR_EXPORT std::shared_ptr newSSLServerContext( - const std::string &certPath, - const std::string &keyPath, - bool useOldTLS = false, - const std::vector> &sslConfCmds = {}, - const std::string &caPath = ""); +class TimingWheel; + +struct SSLContext; +using SSLContextPtr = std::shared_ptr; + /** * @brief This class represents a TCP connection. * @@ -39,6 +40,10 @@ TRANTOR_EXPORT std::shared_ptr newSSLServerContext( class TRANTOR_EXPORT TcpConnection { public: + friend class TcpServer; + friend class TcpConnectionImpl; + friend class TcpClient; + TcpConnection() = default; virtual ~TcpConnection(){}; @@ -65,8 +70,8 @@ class TRANTOR_EXPORT TcpConnection * @param length */ virtual void sendFile(const char *fileName, - size_t offset = 0, - size_t length = 0) = 0; + long long offset = 0, + long long length = 0) = 0; /** * @brief Send a file to the peer. * @@ -75,8 +80,8 @@ class TRANTOR_EXPORT TcpConnection * @param length */ virtual void sendFile(const wchar_t *fileName, - size_t offset = 0, - size_t length = 0) = 0; + long long offset = 0, + long long length = 0) = 0; /** * @brief Send a stream to the peer. * @@ -91,11 +96,20 @@ class TRANTOR_EXPORT TcpConnection callback) = 0; // (buffer, buffer size) -> size // of data put in buffer + /** + * @brief Send a stream to the peer asynchronously. + * @param disableKickoff Disable the kickoff mechanism. If this parameter is + * enabled, the connection will not be closed after the inactive timeout. + * @note The subsequent data sent after the async stream will be sent after + * the stream is closed. + */ + virtual AsyncStreamPtr sendAsyncStream(bool disableKickoff = false) = 0; /** * @brief Get the local address of the connection. * * @return const InetAddress& */ + virtual const InetAddress &localAddr() const = 0; /** @@ -121,12 +135,12 @@ class TRANTOR_EXPORT TcpConnection */ virtual bool disconnected() const = 0; - /** + /* * * @brief Get the buffer in which the received data stored. * * @return MsgBuffer* */ - virtual MsgBuffer *getRecvBuffer() = 0; + // virtual MsgBuffer *getRecvBuffer() = 0; /** * @brief Set the high water mark callback @@ -177,6 +191,7 @@ class TRANTOR_EXPORT TcpConnection { contextPtr_ = std::move(context); } + virtual std::string applicationProtocol() const = 0; /** * @brief Get the custom data from the connection. @@ -248,38 +263,124 @@ class TRANTOR_EXPORT TcpConnection virtual bool isSSLConnection() const = 0; /** - * @brief Start the SSL encryption on the connection (as a client). + * @brief Get buffer of unprompted data. + */ + virtual MsgBuffer *getRecvBuffer() = 0; + + /** + * @brief Get peer certificate (if any). * - * @param callback The callback is called when the SSL connection is - * established. - * @param hostname The server hostname for SNI. If it is empty, the SNI is - * not used. - * @param sslConfCmds The commands used to call the SSL_CONF_cmd function in - * OpenSSL. + * @return pointer to Certificate object or nullptr if no certificate was + * provided */ - virtual void startClientEncryption( - std::function callback, - bool useOldTLS = false, - bool validateCert = true, - std::string hostname = "", - const std::vector> &sslConfCmds = - {}) = 0; + virtual CertificatePtr peerCertificate() const = 0; /** - * @brief Start the SSL encryption on the connection (as a server). + * @brief Get the SNI name (for server connections only) * - * @param ctx The SSL context. - * @param callback The callback is called when the SSL connection is - * established. + * @return Empty string if no SNI name was provided (not an SSL connection + * or peer did not provide SNI) + */ + virtual std::string sniName() const = 0; + + /** + * @brief Start TLS. If the connection is specified as a server, the + * connection will be upgraded to a TLS server connection. If the connection + * is specified as a client, the connection will be upgraded to a TLS client + * @note This method is only available for non-SSL connections. + */ + virtual void startEncryption(TLSPolicyPtr policy, + bool isServer, + std::function + upgradeCallback = nullptr) = 0; + /** + * @brief Start TLS as a client. + * @note This method is only available for non-SSL connections. */ - virtual void startServerEncryption(const std::shared_ptr &ctx, - std::function callback) = 0; + [[deprecated("Use startEncryption(TLSPolicyPtr) instead")]] void + startClientEncryption( + std::function &&callback, + bool useOldTLS = false, + bool validateCert = true, + const std::string &hostname = "", + const std::vector> &sslConfCmds = + {}) + { + auto policy = TLSPolicy::defaultClientPolicy(); + policy->setUseOldTLS(useOldTLS) + .setValidate(validateCert) + .setHostname(hostname) + .setConfCmds(sslConfCmds); + startEncryption(std::move(policy), false, std::move(callback)); + } + + void setValidationPolicy(TLSPolicy &&policy) + { + tlsPolicy_ = std::move(policy); + } + + void setRecvMsgCallback(const RecvMessageCallback &cb) + { + recvMsgCallback_ = cb; + } + void setRecvMsgCallback(RecvMessageCallback &&cb) + { + recvMsgCallback_ = std::move(cb); + } + void setConnectionCallback(const ConnectionCallback &cb) + { + connectionCallback_ = cb; + } + void setConnectionCallback(ConnectionCallback &&cb) + { + connectionCallback_ = std::move(cb); + } + void setWriteCompleteCallback(const WriteCompleteCallback &cb) + { + writeCompleteCallback_ = cb; + } + void setWriteCompleteCallback(WriteCompleteCallback &&cb) + { + writeCompleteCallback_ = std::move(cb); + } + void setCloseCallback(const CloseCallback &cb) + { + closeCallback_ = cb; + } + void setCloseCallback(CloseCallback &&cb) + { + closeCallback_ = std::move(cb); + } + void setSSLErrorCallback(const SSLErrorCallback &cb) + { + sslErrorCallback_ = cb; + } + void setSSLErrorCallback(SSLErrorCallback &&cb) + { + sslErrorCallback_ = std::move(cb); + } + + // TODO: These should be internal APIs + virtual void connectEstablished() = 0; + virtual void connectDestroyed() = 0; + virtual void enableKickingOff( + size_t timeout, + const std::shared_ptr &timingWheel) = 0; protected: - bool validateCert_ = false; + // callbacks + RecvMessageCallback recvMsgCallback_; + ConnectionCallback connectionCallback_; + CloseCallback closeCallback_; + WriteCompleteCallback writeCompleteCallback_; + HighWaterMarkCallback highWaterMarkCallback_; + SSLErrorCallback sslErrorCallback_; + TLSPolicy tlsPolicy_; private: std::shared_ptr contextPtr_; }; +TRANTOR_EXPORT SSLContextPtr newSSLContext(const TLSPolicy &policy, + bool server); } // namespace trantor diff --git a/trantor/net/TcpServer.cc b/trantor/net/TcpServer.cc index 23b57d48..d256b9c1 100644 --- a/trantor/net/TcpServer.cc +++ b/trantor/net/TcpServer.cc @@ -12,31 +12,33 @@ * */ -#include "Acceptor.h" -#include "inner/TcpConnectionImpl.h" #include #include #include #include +#include "Acceptor.h" +#include "inner/TcpConnectionImpl.h" using namespace trantor; using namespace std::placeholders; TcpServer::TcpServer(EventLoop *loop, const InetAddress &address, - const std::string &name, + std::string name, bool reUseAddr, bool reUsePort) : loop_(loop), acceptorPtr_(new Acceptor(loop, address, reUseAddr, reUsePort)), - serverName_(name), + serverName_(std::move(name)), recvMessageCallback_([](const TcpConnectionPtr &, MsgBuffer *buffer) { LOG_ERROR << "unhandled recv message [" << buffer->readableBytes() << " bytes]"; buffer->retrieveAll(); - }) + }), + ioLoops_({loop}), + numIoLoops_(1) { acceptorPtr_->setNewConnectionCallback( - std::bind(&TcpServer::newConnection, this, _1, _2)); + [this](int fd, const InetAddress &peer) { newConnection(fd, peer); }); } TcpServer::~TcpServer() @@ -45,39 +47,37 @@ TcpServer::~TcpServer() LOG_TRACE << "TcpServer::~TcpServer [" << serverName_ << "] destructing"; } +void TcpServer::setBeforeListenSockOptCallback(SockOptCallback cb) +{ + acceptorPtr_->setBeforeListenSockOptCallback(std::move(cb)); +} + +void TcpServer::setAfterAcceptSockOptCallback(SockOptCallback cb) +{ + acceptorPtr_->setAfterAcceptSockOptCallback(std::move(cb)); +} + void TcpServer::newConnection(int sockfd, const InetAddress &peer) { LOG_TRACE << "new connection:fd=" << sockfd << " address=" << peer.toIpPort(); - // test code for blocking or nonblocking - // std::vector str(1024*1024*100); - // for(int i=0;iassertInLoopThread(); - EventLoop *ioLoop = NULL; - if (loopPoolPtr_ && loopPoolPtr_->size() > 0) + EventLoop *ioLoop = ioLoops_[nextLoopIdx_]; + if (++nextLoopIdx_ >= numIoLoops_) { - ioLoop = loopPoolPtr_->getNextLoop(); + nextLoopIdx_ = 0; } - if (ioLoop == NULL) - ioLoop = loop_; - std::shared_ptr newPtr; - if (sslCtxPtr_) + TcpConnectionPtr newPtr; + if (policyPtr_) { -#ifdef USE_OPENSSL + assert(sslContextPtr_); newPtr = std::make_shared( ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer, - sslCtxPtr_); -#else - LOG_FATAL << "OpenSSL is not found in your system!"; - throw std::runtime_error("OpenSSL is not found in your system!"); -#endif + policyPtr_, + sslContextPtr_); } else { @@ -102,7 +102,10 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) if (writeCompleteCallback_) writeCompleteCallback_(connectionPtr); }); - newPtr->setCloseCallback(std::bind(&TcpServer::connectionClosed, this, _1)); + + newPtr->setCloseCallback([this](const TcpConnectionPtr &closeConnPtr) { + connectionClosed(closeConnPtr); + }); connSet_.insert(newPtr); newPtr->connectEstablished(); } @@ -114,29 +117,15 @@ void TcpServer::start() started_ = true; if (idleTimeout_ > 0) { - timingWheelMap_[loop_] = - std::make_shared(loop_, - idleTimeout_, - 1.0F, - idleTimeout_ < 500 - ? idleTimeout_ + 1 - : 100); - if (loopPoolPtr_) + for (EventLoop *loop : ioLoops_) { - auto loopNum = loopPoolPtr_->size(); - while (loopNum > 0) - { - // LOG_TRACE << "new Wheel loopNum=" << loopNum; - auto poolLoop = loopPoolPtr_->getNextLoop(); - timingWheelMap_[poolLoop] = - std::make_shared(poolLoop, - idleTimeout_, - 1.0F, - idleTimeout_ < 500 - ? idleTimeout_ + 1 - : 100); - --loopNum; - } + timingWheelMap_[loop] = + std::make_shared(loop, + idleTimeout_, + 1.0F, + idleTimeout_ < 500 + ? idleTimeout_ + 1 + : 100); } } LOG_TRACE << "map size=" << timingWheelMap_.size(); @@ -156,7 +145,7 @@ void TcpServer::stop() { connPtrs.push_back(conn); } - for (auto connection : connPtrs) + for (auto &connection : connPtrs) { connection->forceClose(); } @@ -173,7 +162,7 @@ void TcpServer::stop() { connPtrs.push_back(conn); } - for (auto connection : connPtrs) + for (auto &connection : connPtrs) { connection->forceClose(); } @@ -204,10 +193,8 @@ void TcpServer::handleCloseInLoop(const TcpConnectionPtr &connectionPtr) // may be in loop_'s current active channels, waiting to be processed. // If `connectDestroyed()` is called here, we will be using an wild pointer // later. - connLoop->queueInLoop([connectionPtr]() { - static_cast(connectionPtr.get()) - ->connectDestroyed(); - }); + connLoop->queueInLoop( + [connectionPtr]() { connectionPtr->connectDestroyed(); }); } void TcpServer::connectionClosed(const TcpConnectionPtr &connectionPtr) { @@ -223,7 +210,7 @@ void TcpServer::connectionClosed(const TcpConnectionPtr &connectionPtr) } } -const std::string TcpServer::ipPort() const +std::string TcpServer::ipPort() const { return acceptorPtr_->addr().toIpPort(); } @@ -240,20 +227,30 @@ void TcpServer::enableSSL( const std::vector> &sslConfCmds, const std::string &caPath) { -#ifdef USE_OPENSSL - /* Create a new OpenSSL context */ - sslCtxPtr_ = - newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds, caPath); -#else - // When not using OpenSSL, using `void` here will - // work around the unused parameter warnings without overhead. - (void)certPath; - (void)keyPath; - (void)useOldTLS; - (void)sslConfCmds; - (void)caPath; - - LOG_FATAL << "OpenSSL is not found in your system!"; - throw std::runtime_error("OpenSSL is not found in your system!"); -#endif + policyPtr_ = TLSPolicy::defaultServerPolicy(certPath, keyPath); + policyPtr_->setUseOldTLS(useOldTLS) + .setConfCmds(sslConfCmds) + .setCaPath(caPath) + .setValidate(caPath.empty() ? false : true); + sslContextPtr_ = newSSLContext(*policyPtr_, true); } + +void TcpServer::reloadSSL() +{ + if (loop_->isInLoopThread()) + { + if (policyPtr_) + { + sslContextPtr_ = newSSLContext(*policyPtr_, true); + } + } + else + { + loop_->queueInLoop([this]() { + if (policyPtr_) + { + sslContextPtr_ = newSSLContext(*policyPtr_, true); + } + }); + } +} \ No newline at end of file diff --git a/trantor/net/TcpServer.h b/trantor/net/TcpServer.h index d5b77062..33b041b8 100644 --- a/trantor/net/TcpServer.h +++ b/trantor/net/TcpServer.h @@ -13,22 +13,22 @@ */ #pragma once -#include -#include -#include +#include #include #include #include +#include +#include +#include #include -#include -#include +#include #include #include -#include +#include + namespace trantor { class Acceptor; -class SSLContext; /** * @brief This class represents a TCP server. * @@ -48,7 +48,7 @@ class TRANTOR_EXPORT TcpServer : NonCopyable */ TcpServer(EventLoop *loop, const InetAddress &address, - const std::string &name, + std::string name, bool reUseAddr = true, bool reUsePort = true); ~TcpServer(); @@ -68,6 +68,7 @@ class TRANTOR_EXPORT TcpServer : NonCopyable /** * @brief Set the number of event loops in which the I/O of connections to * the server is handled. + * An EventLoopThreadPool is created and managed by TcpServer. * * @param num */ @@ -76,11 +77,14 @@ class TRANTOR_EXPORT TcpServer : NonCopyable assert(!started_); loopPoolPtr_ = std::make_shared(num); loopPoolPtr_->start(); + ioLoops_ = loopPoolPtr_->getLoops(); + numIoLoops_ = ioLoops_.size(); } /** * @brief Set the event loops pool in which the I/O of connections to * the server is handled. + * A shared_ptr of EventLoopThreadPool is copied. * * @param pool */ @@ -89,7 +93,26 @@ class TRANTOR_EXPORT TcpServer : NonCopyable assert(pool->size() > 0); assert(!started_); loopPoolPtr_ = pool; - loopPoolPtr_->start(); + loopPoolPtr_->start(); // TODO: should not start by TcpServer + ioLoops_ = loopPoolPtr_->getLoops(); + numIoLoops_ = ioLoops_.size(); + } + + /** + * @brief Set the event loops in which the I/O of connections to + * the server is handled. + * The loops are managed by caller. Caller should ensure that ioLoops + * lives longer than TcpServer. + * + * @param ioLoops + */ + void setIoLoops(const std::vector &ioLoops) + { + assert(!ioLoops.empty()); + assert(!started_); + ioLoops_ = ioLoops; + numIoLoops_ = ioLoops_.size(); + loopPoolPtr_.reset(); } /** @@ -137,6 +160,19 @@ class TRANTOR_EXPORT TcpServer : NonCopyable writeCompleteCallback_ = std::move(cb); } + /** + * @brief Set the before listen setsockopt callback. + * + * @param cb This callback will be called before the listen + */ + void setBeforeListenSockOptCallback(SockOptCallback cb); + /** + * @brief Set the after accept setsockopt callback. + * + * @param cb This callback will be called after accept + */ + void setAfterAcceptSockOptCallback(SockOptCallback cb); + /** * @brief Get the name of the server. * @@ -152,7 +188,7 @@ class TRANTOR_EXPORT TcpServer : NonCopyable * * @return const std::string */ - const std::string ipPort() const; + std::string ipPort() const; /** * @brief Get the address of the server. @@ -178,7 +214,7 @@ class TRANTOR_EXPORT TcpServer : NonCopyable */ std::vector getIoLoops() const { - return loopPoolPtr_->getLoops(); + return ioLoops_; } /** @@ -204,21 +240,41 @@ class TRANTOR_EXPORT TcpServer : NonCopyable * server. * @param sslConfCmds The commands used to call the SSL_CONF_cmd function in * OpenSSL. + * @param caPath The path of the certificate authority file. * @note It's well known that TLS 1.0 and 1.1 are not considered secure in * 2020. And it's a good practice to only use TLS 1.2 and above. */ - void enableSSL(const std::string &certPath, - const std::string &keyPath, - bool useOldTLS = false, - const std::vector> - &sslConfCmds = {}, - const std::string &caPath = ""); + [[deprecated("Use enableSSL(TLSPolicyPtr) instead")]] void enableSSL( + const std::string &certPath, + const std::string &keyPath, + bool useOldTLS = false, + const std::vector> &sslConfCmds = + {}, + const std::string &caPath = ""); + /** + * @brief Enable SSL encryption. + */ + void enableSSL(TLSPolicyPtr policy) + { + policyPtr_ = std::move(policy); + sslContextPtr_ = newSSLContext(*policyPtr_, true); + } + + /** + * @brief Reload the SSL context. + * @note Call this function when the certificate or private key is updated. + * The server will reload the SSL context and use the new certificate and + * private key. new connections will use the new SSL context. + */ + void reloadSSL(); private: - EventLoop *loop_; void handleCloseInLoop(const TcpConnectionPtr &connectionPtr); - std::unique_ptr acceptorPtr_; void newConnection(int fd, const InetAddress &peer); + void connectionClosed(const TcpConnectionPtr &connectionPtr); + + EventLoop *loop_; + std::unique_ptr acceptorPtr_; std::string serverName_; std::set connSet_; @@ -228,8 +284,18 @@ class TRANTOR_EXPORT TcpServer : NonCopyable size_t idleTimeout_{0}; std::map> timingWheelMap_; - void connectionClosed(const TcpConnectionPtr &connectionPtr); + + // `loopPoolPtr_` may and may not hold the internal thread pool. + // We should not access it directly in codes. + // Instead, we should use its delegation variable `ioLoops_`. std::shared_ptr loopPoolPtr_; + // If one of `setIoLoopNum()`, `setIoLoopThreadPool()` and `setIoLoops()` is + // called, `ioLoops_` will hold the loops passed in. + // Otherwise, it should contain only one element, which is `loop_`. + std::vector ioLoops_; + size_t nextLoopIdx_{0}; + size_t numIoLoops_{0}; + #ifndef _WIN32 class IgnoreSigPipe { @@ -244,9 +310,8 @@ class TRANTOR_EXPORT TcpServer : NonCopyable IgnoreSigPipe initObj; #endif bool started_{false}; - - // OpenSSL SSL context Object; - std::shared_ptr sslCtxPtr_; + TLSPolicyPtr policyPtr_{nullptr}; + SSLContextPtr sslContextPtr_{nullptr}; }; } // namespace trantor diff --git a/trantor/net/callbacks.h b/trantor/net/callbacks.h index f7378706..b81f96ee 100644 --- a/trantor/net/callbacks.h +++ b/trantor/net/callbacks.h @@ -21,7 +21,8 @@ namespace trantor enum class SSLError { kSSLHandshakeError, - kSSLInvalidCertificate + kSSLInvalidCertificate, + kSSLProtocolError }; using TimerCallback = std::function; @@ -39,5 +40,6 @@ using WriteCompleteCallback = std::function; using HighWaterMarkCallback = std::function; using SSLErrorCallback = std::function; +using SockOptCallback = std::function; } // namespace trantor diff --git a/trantor/net/inner/Acceptor.cc b/trantor/net/inner/Acceptor.cc index 6186b94b..0b8249cb 100644 --- a/trantor/net/inner/Acceptor.cc +++ b/trantor/net/inner/Acceptor.cc @@ -54,6 +54,8 @@ Acceptor::~Acceptor() void Acceptor::listen() { loop_->assertInLoopThread(); + if (beforeListenSetSockOptCallback_) + beforeListenSetSockOptCallback_(sock_.fd()); sock_.listen(); acceptChannel_.enableReading(); } @@ -64,6 +66,8 @@ void Acceptor::readCallback() int newsock = sock_.accept(&peer); if (newsock >= 0) { + if (afterAcceptSetSockOptCallback_) + afterAcceptSetSockOptCallback_(newsock); if (newConnectionCallback_) { newConnectionCallback_(newsock, peer); @@ -79,7 +83,7 @@ void Acceptor::readCallback() } else { - LOG_SYSERR << "Accpetor::readCallback"; + LOG_SYSERR << "Acceptor::readCallback"; // Read the section named "The special problem of // accept()ing when you can't" in libev's doc. // By Marc Lehmann, author of libev. diff --git a/trantor/net/inner/Acceptor.h b/trantor/net/inner/Acceptor.h index 62ec4218..22c7e6c0 100644 --- a/trantor/net/inner/Acceptor.h +++ b/trantor/net/inner/Acceptor.h @@ -24,6 +24,7 @@ namespace trantor { using NewConnectionCallback = std::function; +using AcceptorSockOptCallback = std::function; class Acceptor : NonCopyable { public: @@ -42,6 +43,16 @@ class Acceptor : NonCopyable }; void listen(); + void setBeforeListenSockOptCallback(AcceptorSockOptCallback cb) + { + beforeListenSetSockOptCallback_ = std::move(cb); + } + + void setAfterAcceptSockOptCallback(AcceptorSockOptCallback cb) + { + afterAcceptSetSockOptCallback_ = std::move(cb); + } + protected: #ifndef _WIN32 int idleFd_; @@ -52,5 +63,7 @@ class Acceptor : NonCopyable NewConnectionCallback newConnectionCallback_; Channel acceptChannel_; void readCallback(); + AcceptorSockOptCallback beforeListenSetSockOptCallback_; + AcceptorSockOptCallback afterAcceptSetSockOptCallback_; }; } // namespace trantor diff --git a/trantor/net/inner/AresResolver.cc b/trantor/net/inner/AresResolver.cc index ec844614..d950a575 100644 --- a/trantor/net/inner/AresResolver.cc +++ b/trantor/net/inner/AresResolver.cc @@ -52,10 +52,17 @@ bool Resolver::isCAresUsed() AresResolver::LibraryInitializer::LibraryInitializer() { ares_library_init(ARES_LIB_INIT_ALL); + + hints_ = new ares_addrinfo_hints; + hints_->ai_flags = 0; + hints_->ai_family = AF_INET; + hints_->ai_socktype = 0; + hints_->ai_protocol = 0; } AresResolver::LibraryInitializer::~LibraryInitializer() { ares_library_cleanup(); + delete hints_; } AresResolver::LibraryInitializer AresResolver::libraryInitializer_; @@ -110,24 +117,26 @@ AresResolver::~AresResolver() } void AresResolver::resolveInLoop(const std::string& hostname, - const Callback& cb) + const ResolverResultsCallback& cb) { loop_->assertInLoopThread(); #ifdef _WIN32 if (hostname == "localhost") { - const static trantor::InetAddress localhost_{"127.0.0.1", 0}; + const static std::vector localhost_{ + trantor::InetAddress{"127.0.0.1", 0}}; cb(localhost_); return; } #endif init(); QueryData* queryData = new QueryData(this, cb, hostname); - ares_gethostbyname(ctx_, - hostname.c_str(), - AF_INET, - &AresResolver::ares_hostcallback_, - queryData); + ares_getaddrinfo(ctx_, + hostname.c_str(), + NULL, + libraryInitializer_.hints_, + &AresResolver::ares_hostcallback_, + queryData); struct timeval tv; struct timeval* tvp = ares_timeout(ctx_, NULL, &tv); double timeout = getSeconds(tvp); @@ -165,27 +174,52 @@ void AresResolver::onTimer() } void AresResolver::onQueryResult(int status, - struct hostent* result, + struct ares_addrinfo* result, const std::string& hostname, - const Callback& callback) + const ResolverResultsCallback& callback) { LOG_TRACE << "onQueryResult " << status; - struct sockaddr_in addr; - memset(&addr, 0, sizeof addr); - addr.sin_family = AF_INET; - addr.sin_port = 0; + auto inets_ptr = std::make_shared>(); if (result) { - addr.sin_addr = *reinterpret_cast(result->h_addr); + auto pptr = (struct ares_addrinfo_node*)result->nodes; + for (; pptr != NULL; pptr = pptr->ai_next) + { + trantor::InetAddress inet; + if (pptr->ai_family == AF_INET) + { + struct sockaddr_in* addr4 = (struct sockaddr_in*)pptr->ai_addr; + inets_ptr->emplace_back(trantor::InetAddress{*addr4}); + } + else if (pptr->ai_family == AF_INET6) + { + struct sockaddr_in6* addr6 = + (struct sockaddr_in6*)pptr->ai_addr; + inets_ptr->emplace_back(trantor::InetAddress{*addr6}); + } + else + { + // TODO: Handle unknown family? + } + } + ares_freeaddrinfo(result); + } + if (inets_ptr->empty()) + { + struct sockaddr_in addr; + memset(&addr, 0, sizeof addr); + addr.sin_family = AF_INET; + addr.sin_port = 0; + InetAddress inet(addr); + inets_ptr->emplace_back(std::move(inet)); } - InetAddress inet(addr); { std::lock_guard lock(globalMutex()); auto& addrItem = globalCache()[hostname]; - addrItem.first = addr.sin_addr; + addrItem.first = inets_ptr; addrItem.second = trantor::Date::date(); } - callback(inet); + callback(*inets_ptr); } void AresResolver::onSockCreate(int sockfd, int type) @@ -202,9 +236,6 @@ void AresResolver::onSockCreate(int sockfd, int type) void AresResolver::onSockStateChange(int sockfd, bool read, bool write) { (void)write; - loop_->assertInLoopThread(); - ChannelList::iterator it = channels_.find(sockfd); - assert(it != channels_.end()); if (read) { // update @@ -212,6 +243,9 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) } else if (*loopValid_) { + loop_->assertInLoopThread(); + ChannelList::iterator it = channels_.find(sockfd); + assert(it != channels_.end()); // remove it->second->disableAll(); it->second->remove(); @@ -222,7 +256,7 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) void AresResolver::ares_hostcallback_(void* data, int status, int timeouts, - struct hostent* hostent) + struct ares_addrinfo* hostent) { (void)timeouts; QueryData* query = static_cast(data); diff --git a/trantor/net/inner/AresResolver.h b/trantor/net/inner/AresResolver.h index e47ef7bb..52515771 100644 --- a/trantor/net/inner/AresResolver.h +++ b/trantor/net/inner/AresResolver.h @@ -15,8 +15,9 @@ extern "C" { - struct hostent; + struct ares_addrinfo; struct ares_channeldata; + struct ares_addrinfo_hints; using ares_channel = struct ares_channeldata*; } namespace trantor @@ -43,12 +44,7 @@ class AresResolver : public Resolver, if (timeout_ == 0 || cachedAddr.second.after(timeout_) > trantor::Date::date()) { - struct sockaddr_in addr; - memset(&addr, 0, sizeof addr); - addr.sin_family = AF_INET; - addr.sin_port = 0; - addr.sin_addr = cachedAddr.first; - inet = InetAddress(addr); + inet = (*cachedAddr.first)[0]; cached = true; } } @@ -59,6 +55,47 @@ class AresResolver : public Resolver, return; } if (loop_->isInLoopThread()) + { + resolveInLoop(hostname, + [cb](const std::vector& inets) { + cb(inets[0]); + }); + } + else + { + loop_->queueInLoop([thisPtr = shared_from_this(), hostname, cb]() { + thisPtr->resolveInLoop( + hostname, + [cb](const std::vector& inets) { + cb(inets[0]); + }); + }); + } + } + + virtual void resolve(const std::string& hostname, + const ResolverResultsCallback& cb) override + { + std::shared_ptr> inets_ptr{nullptr}; + { + std::lock_guard lock(globalMutex()); + auto iter = globalCache().find(hostname); + if (iter != globalCache().end()) + { + auto& cachedAddr = iter->second; + if (timeout_ == 0 || + cachedAddr.second.after(timeout_) > trantor::Date::date()) + { + inets_ptr = cachedAddr.first; + } + } + } + if (inets_ptr) + { + cb(*inets_ptr); + return; + } + if (loop_->isInLoopThread()) { resolveInLoop(hostname, cb); } @@ -74,16 +111,17 @@ class AresResolver : public Resolver, struct QueryData { AresResolver* owner_; - Callback callback_; + ResolverResultsCallback callback_; std::string hostname_; QueryData(AresResolver* o, - const Callback& cb, + const ResolverResultsCallback& cb, const std::string& hostname) : owner_(o), callback_(cb), hostname_(hostname) { } }; - void resolveInLoop(const std::string& hostname, const Callback& cb); + void resolveInLoop(const std::string& hostname, + const ResolverResultsCallback& cb); void init(); trantor::EventLoop* loop_; std::shared_ptr loopValid_; @@ -91,12 +129,16 @@ class AresResolver : public Resolver, bool timerActive_{false}; using ChannelList = std::map>; ChannelList channels_; - static std::unordered_map>& + static std::unordered_map< + std::string, + std::pair>, + trantor::Date>>& globalCache() { - static std::unordered_map> + static std::unordered_map< + std::string, + std::pair>, + trantor::Date>> dnsCache; return dnsCache; } @@ -116,16 +158,16 @@ class AresResolver : public Resolver, void onRead(int sockfd); void onTimer(); void onQueryResult(int status, - struct hostent* result, + struct ares_addrinfo* result, const std::string& hostname, - const Callback& callback); + const ResolverResultsCallback& callback); void onSockCreate(int sockfd, int type); void onSockStateChange(int sockfd, bool read, bool write); static void ares_hostcallback_(void* data, int status, int timeouts, - struct hostent* hostent); + struct ares_addrinfo* hostent); #ifdef _WIN32 static int ares_sock_createcallback_(SOCKET sockfd, int type, void* data); #else @@ -143,6 +185,7 @@ class AresResolver : public Resolver, { LibraryInitializer(); ~LibraryInitializer(); + ares_addrinfo_hints* hints_; }; static LibraryInitializer libraryInitializer_; }; diff --git a/trantor/net/inner/AsyncStreamBufferNode.cc b/trantor/net/inner/AsyncStreamBufferNode.cc new file mode 100644 index 00000000..c30db1fa --- /dev/null +++ b/trantor/net/inner/AsyncStreamBufferNode.cc @@ -0,0 +1,65 @@ +#include + +namespace trantor +{ +class AsyncBufferNode : public BufferNode +{ + public: + AsyncBufferNode() = default; + ~AsyncBufferNode() override = default; + bool isAsync() const override + { + return true; + } + bool isStream() const override + { + return true; + } + long long remainingBytes() const override + { + if (msgBufferPtr_) + return static_cast(msgBufferPtr_->readableBytes()); + return 0; + } + bool available() const override + { + return !isDone_; + } + void getData(const char *&data, size_t &len) override + { + if (msgBufferPtr_) + { + data = msgBufferPtr_->peek(); + len = msgBufferPtr_->readableBytes(); + } + else + { + data = nullptr; + len = 0; + } + } + void retrieve(size_t len) override + { + assert(msgBufferPtr_); + if (msgBufferPtr_) + { + msgBufferPtr_->retrieve(len); + } + } + void append(const char *data, size_t len) override + { + if (!msgBufferPtr_) + { + msgBufferPtr_ = std::make_unique(len); + } + msgBufferPtr_->append(data, len); + } + + private: + std::unique_ptr msgBufferPtr_; +}; +BufferNodePtr BufferNode::newAsyncStreamBufferNode() +{ + return std::make_shared(); +} +} // namespace trantor diff --git a/trantor/net/inner/BufferNode.h b/trantor/net/inner/BufferNode.h new file mode 100644 index 00000000..bbce40e8 --- /dev/null +++ b/trantor/net/inner/BufferNode.h @@ -0,0 +1,86 @@ +/** + * + * @file BufferNode.h + * @author An Tao + * + * Public header file in trantor lib. + * + * Copyright 2018, An Tao. All rights reserved. + * Use of this source code is governed by a BSD-style license + * that can be found in the License file. + * + * + */ + +#pragma once +#ifdef _WIN32 +#include +#endif +#include +#include +#include +#include +#include +#include + +namespace trantor +{ +class BufferNode; +using BufferNodePtr = std::shared_ptr; +using StreamCallback = std::function; +class BufferNode : public NonCopyable +{ + public: + virtual bool isFile() const + { + return false; + } + virtual ~BufferNode() = default; + virtual bool isStream() const + { + return false; + } + virtual void getData(const char *&data, size_t &len) = 0; + virtual void append(const char *, size_t) + { + LOG_FATAL << "Not a memory buffer node"; + } + virtual void retrieve(size_t len) = 0; + virtual long long remainingBytes() const = 0; + virtual int getFd() const + { + LOG_FATAL << "Not a file buffer node"; + return -1; + } + virtual bool available() const + { + return true; + } + virtual bool isAsync() const + { + return false; + } + + void done() + { + isDone_ = true; + } + static BufferNodePtr newMemBufferNode(); + + static BufferNodePtr newStreamBufferNode(StreamCallback &&cb); +#ifdef _WIN32 + static BufferNodePtr newFileBufferNode(const wchar_t *fileName, + long long offset, + long long length); +#else + static BufferNodePtr newFileBufferNode(const char *fileName, + long long offset, + long long length); +#endif + static BufferNodePtr newAsyncStreamBufferNode(); + + protected: + bool isDone_{false}; +}; + +} // namespace trantor \ No newline at end of file diff --git a/trantor/net/inner/Connector.cc b/trantor/net/inner/Connector.cc index 1f139042..adbf1f09 100644 --- a/trantor/net/inner/Connector.cc +++ b/trantor/net/inner/Connector.cc @@ -79,6 +79,8 @@ void Connector::connect() { socketHanded_ = false; fd_ = Socket::createNonblockingSocketOrDie(serverAddr_.family()); + if (sockOptCallback_) + sockOptCallback_(fd_); errno = 0; int ret = Socket::connect(fd_, serverAddr_); int savedErrno = (ret == 0) ? 0 : errno; diff --git a/trantor/net/inner/Connector.h b/trantor/net/inner/Connector.h index 91bce8e8..3dca4847 100644 --- a/trantor/net/inner/Connector.h +++ b/trantor/net/inner/Connector.h @@ -28,6 +28,7 @@ class Connector : public NonCopyable, public: using NewConnectionCallback = std::function; using ConnectionErrorCallback = std::function; + using SockOptCallback = std::function; Connector(EventLoop *loop, const InetAddress &addr, bool retry = true); Connector(EventLoop *loop, InetAddress &&addr, bool retry = true); ~Connector(); @@ -47,6 +48,14 @@ class Connector : public NonCopyable, { errorCallback_ = std::move(cb); } + void setSockOptCallback(const SockOptCallback &cb) + { + sockOptCallback_ = cb; + } + void setSockOptCallback(SockOptCallback &&cb) + { + sockOptCallback_ = std::move(cb); + } const InetAddress &serverAddress() const { return serverAddr_; @@ -58,6 +67,7 @@ class Connector : public NonCopyable, private: NewConnectionCallback newConnectionCallback_; ConnectionErrorCallback errorCallback_; + SockOptCallback sockOptCallback_; enum class Status { Disconnected, diff --git a/trantor/net/inner/FileBufferNodeUnix.cc b/trantor/net/inner/FileBufferNodeUnix.cc new file mode 100644 index 00000000..9850f3f9 --- /dev/null +++ b/trantor/net/inner/FileBufferNodeUnix.cc @@ -0,0 +1,150 @@ +#include +#include +#include +#include +#include + +namespace trantor +{ +static const size_t kMaxSendFileBufferSize = 16 * 1024; +class FileBufferNode : public BufferNode +{ + public: + FileBufferNode(const char *fileName, long long offset, long long length) + { + assert(offset >= 0); + if (offset < 0) + { + LOG_ERROR << "offset must be greater than or equal to 0"; + isDone_ = true; + return; + } + sendFd_ = open(fileName, O_RDONLY); + + if (sendFd_ < 0) + { + LOG_SYSERR << fileName << " open error"; + isDone_ = true; + return; + } + struct stat filestat; + if (stat(fileName, &filestat) < 0) + { + LOG_SYSERR << fileName << " stat error"; + close(sendFd_); + sendFd_ = -1; + isDone_ = true; + return; + } + if (length == 0) + { + if (offset >= filestat.st_size) + { + LOG_ERROR << "The file size is " << filestat.st_size + << " bytes, but the offset is " << offset + << " bytes and the length is " << length << " bytes"; + close(sendFd_); + sendFd_ = -1; + isDone_ = true; + return; + } + fileBytesToSend_ = filestat.st_size - offset; + } + else + { + if (length > filestat.st_size - offset) + { + LOG_ERROR << "The file size is " << filestat.st_size + << " bytes, but the offset is " << offset + << " bytes and the length is " << length << " bytes"; + close(sendFd_); + sendFd_ = -1; + isDone_ = true; + return; + } + fileBytesToSend_ = length; + } + lseek(sendFd_, offset, SEEK_SET); + } + bool isFile() const override + { + return true; + } + int getFd() const override + { + return sendFd_; + } + void getData(const char *&data, size_t &len) override + { + if (msgBufferPtr_ == nullptr) + { + msgBufferPtr_ = std::make_unique( + (std::min)(kMaxSendFileBufferSize, + static_cast(fileBytesToSend_))); + } + if (msgBufferPtr_->readableBytes() == 0 && fileBytesToSend_ > 0 && + sendFd_ >= 0) + { + msgBufferPtr_->ensureWritableBytes( + (std::min)(kMaxSendFileBufferSize, + static_cast(fileBytesToSend_))); + auto n = read(sendFd_, + msgBufferPtr_->beginWrite(), + msgBufferPtr_->writableBytes()); + if (n > 0) + { + msgBufferPtr_->hasWritten(n); + } + else if (n == 0) + { + LOG_TRACE << "Read the end of file."; + } + else + { + LOG_SYSERR << "FileBufferNode::getData()"; + } + } + data = msgBufferPtr_->peek(); + len = msgBufferPtr_->readableBytes(); + } + void retrieve(size_t len) override + { + if (msgBufferPtr_) + { + msgBufferPtr_->retrieve(len); + } + fileBytesToSend_ -= static_cast(len); + if (fileBytesToSend_ < 0) + fileBytesToSend_ = 0; + } + long long remainingBytes() const override + { + if (isDone_) + return 0; + return fileBytesToSend_; + } + ~FileBufferNode() override + { + if (sendFd_ >= 0) + { + close(sendFd_); + } + } + bool available() const override + { + return sendFd_ >= 0; + } + + private: + int sendFd_{-1}; + long long fileBytesToSend_{0}; + std::unique_ptr msgBufferPtr_; +}; + +BufferNodePtr BufferNode::newFileBufferNode(const char *fileName, + long long offset, + long long length) +{ + return std::make_shared(fileName, offset, length); +} +} // namespace trantor \ No newline at end of file diff --git a/trantor/net/inner/FileBufferNodeWin.cc b/trantor/net/inner/FileBufferNodeWin.cc new file mode 100644 index 00000000..d99b707a --- /dev/null +++ b/trantor/net/inner/FileBufferNodeWin.cc @@ -0,0 +1,171 @@ +#include +#include +#include +#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) +#define UWP 1 +#else +#define UWP 0 +#endif + +namespace trantor +{ +static const size_t kMaxSendFileBufferSize = 16 * 1024; +class FileBufferNode : public BufferNode +{ + public: + FileBufferNode(const wchar_t *fileName, long long offset, long long length) + { +#if UWP + sendHandle_ = CreateFile2( + fileName, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, nullptr); + +#else + sendHandle_ = CreateFileW(fileName, + GENERIC_READ, + FILE_SHARE_READ, + nullptr, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + nullptr); +#endif + if (sendHandle_ == INVALID_HANDLE_VALUE) + { + LOG_SYSERR << fileName << " open error"; + isDone_ = true; + return; + } + LARGE_INTEGER fileSize; + if (!GetFileSizeEx(sendHandle_, &fileSize)) + { + LOG_SYSERR << fileName << " stat error"; + CloseHandle(sendHandle_); + sendHandle_ = INVALID_HANDLE_VALUE; + isDone_ = true; + return; + } + + if (length == 0) + { + if (offset >= fileSize.QuadPart) + { + LOG_ERROR << "The file size is " << fileSize.QuadPart + << " bytes, but the offset is " << offset + << " bytes and the length is " << length << " bytes"; + CloseHandle(sendHandle_); + sendHandle_ = INVALID_HANDLE_VALUE; + isDone_ = true; + return; + } + fileBytesToSend_ = fileSize.QuadPart - offset; + } + else + { + if (length + offset > fileSize.QuadPart) + { + LOG_ERROR << "The file size is " << fileSize.QuadPart + << " bytes, but the offset is " << offset + << " bytes and the length is " << length << " bytes"; + CloseHandle(sendHandle_); + sendHandle_ = INVALID_HANDLE_VALUE; + isDone_ = true; + return; + } + + fileBytesToSend_ = length; + } + LARGE_INTEGER li; + li.QuadPart = offset; + if (!SetFilePointerEx(sendHandle_, li, nullptr, FILE_BEGIN)) + { + LOG_SYSERR << fileName << " seek error"; + CloseHandle(sendHandle_); + sendHandle_ = INVALID_HANDLE_VALUE; + isDone_ = true; + return; + } + msgBufferPtr_ = std::make_unique( + kMaxSendFileBufferSize < fileBytesToSend_ ? kMaxSendFileBufferSize + : fileBytesToSend_); + } + + bool isFile() const override + { + return true; + } + + void getData(const char *&data, size_t &len) override + { + if (msgBufferPtr_->readableBytes() == 0 && fileBytesToSend_ > 0 && + sendHandle_ != INVALID_HANDLE_VALUE) + { + msgBufferPtr_->ensureWritableBytes(kMaxSendFileBufferSize < + fileBytesToSend_ + ? kMaxSendFileBufferSize + : fileBytesToSend_); + DWORD n = 0; + if (!ReadFile(sendHandle_, + msgBufferPtr_->beginWrite(), + (uint32_t)msgBufferPtr_->writableBytes(), + &n, + nullptr)) + { + LOG_SYSERR << "FileBufferNode::getData()"; + } + if (n > 0) + { + msgBufferPtr_->hasWritten(n); + } + else if (n == 0) + { + LOG_TRACE << "Read the end of file."; + } + else + { + LOG_SYSERR << "FileBufferNode::getData()"; + } + } + data = msgBufferPtr_->peek(); + len = msgBufferPtr_->readableBytes(); + } + void retrieve(size_t len) override + { + msgBufferPtr_->retrieve(len); + fileBytesToSend_ -= static_cast(len); + if (fileBytesToSend_ < 0) + fileBytesToSend_ = 0; + } + long long remainingBytes() const override + { + if (isDone_) + return 0; + return fileBytesToSend_; + } + ~FileBufferNode() override + { + if (sendHandle_ != INVALID_HANDLE_VALUE) + { + CloseHandle(sendHandle_); + } + } + int getFd() const override + { + LOG_ERROR << "getFd() is not supported on Windows"; + return 0; + } + bool available() const override + { + return sendHandle_ != INVALID_HANDLE_VALUE; + } + + private: + HANDLE sendHandle_{INVALID_HANDLE_VALUE}; + long long fileBytesToSend_{0}; + std::unique_ptr msgBufferPtr_; +}; +BufferNodePtr BufferNode::newFileBufferNode(const wchar_t *fileName, + long long offset, + long long length) +{ + return std::make_shared(fileName, offset, length); +} +} // namespace trantor \ No newline at end of file diff --git a/trantor/net/inner/MemBufferNode.cc b/trantor/net/inner/MemBufferNode.cc new file mode 100644 index 00000000..c6cd2479 --- /dev/null +++ b/trantor/net/inner/MemBufferNode.cc @@ -0,0 +1,36 @@ +#include +namespace trantor +{ +class MemBufferNode : public BufferNode +{ + public: + MemBufferNode() = default; + + void getData(const char *&data, size_t &len) override + { + data = buffer_.peek(); + len = buffer_.readableBytes(); + } + void retrieve(size_t len) override + { + buffer_.retrieve(len); + } + long long remainingBytes() const override + { + if (isDone_) + return 0; + return static_cast(buffer_.readableBytes()); + } + void append(const char *data, size_t len) override + { + buffer_.append(data, len); + } + + private: + trantor::MsgBuffer buffer_; +}; +BufferNodePtr BufferNode::newMemBufferNode() +{ + return std::make_shared(); +} +} // namespace trantor diff --git a/trantor/net/inner/NormalResolver.h b/trantor/net/inner/NormalResolver.h index bc19e0f9..1a5a26f5 100644 --- a/trantor/net/inner/NormalResolver.h +++ b/trantor/net/inner/NormalResolver.h @@ -23,6 +23,13 @@ class NormalResolver : public Resolver, public: virtual void resolve(const std::string& hostname, const Callback& callback) override; + virtual void resolve(const std::string& hostname, + const ResolverResultsCallback& callback) override + { + resolve(hostname, [callback](const trantor::InetAddress& inet) { + callback(std::vector{inet}); + }); + } explicit NormalResolver(size_t timeout) : timeout_(timeout), resolveBuffer_(kResolveBufferLength) { diff --git a/trantor/net/inner/StreamBufferNode.cc b/trantor/net/inner/StreamBufferNode.cc new file mode 100644 index 00000000..0376ac9c --- /dev/null +++ b/trantor/net/inner/StreamBufferNode.cc @@ -0,0 +1,67 @@ +#include +namespace trantor +{ +static const size_t kMaxSendFileBufferSize = 16 * 1024; +class StreamBufferNode : public BufferNode +{ + public: + StreamBufferNode(std::function &&callback) + : streamCallback_(std::move(callback)) + { + } + bool isStream() const override + { + return true; + } + void getData(const char *&data, size_t &len) override + { + if (msgBuffer_.readableBytes() == 0) + { + msgBuffer_.ensureWritableBytes(kMaxSendFileBufferSize); + auto n = streamCallback_(msgBuffer_.beginWrite(), + msgBuffer_.writableBytes()); + if (n > 0) + { + msgBuffer_.hasWritten(n); + } + else + { + isDone_ = true; + } + } + data = msgBuffer_.peek(); + len = msgBuffer_.readableBytes(); + } + void retrieve(size_t len) override + { + msgBuffer_.retrieve(len); +#ifndef NDEBUG + dataWritten_ += len; + LOG_TRACE << "send stream in loop: bytes written: " << dataWritten_ + << " / total bytes written: " << dataWritten_; +#endif + } + long long remainingBytes() const override + { + if (isDone_) + return 0; + return 1; + } + ~StreamBufferNode() override + { + if (streamCallback_) + streamCallback_(nullptr, 0); // cleanup callback internals + } + + private: + std::function streamCallback_; +#ifndef NDEBUG // defined by CMake for release build + std::size_t dataWritten_{0}; +#endif + MsgBuffer msgBuffer_; +}; +BufferNodePtr BufferNode::newStreamBufferNode(StreamCallback &&callback) +{ + return std::make_shared(std::move(callback)); +} +} // namespace trantor \ No newline at end of file diff --git a/trantor/net/inner/TLSProvider.h b/trantor/net/inner/TLSProvider.h new file mode 100644 index 00000000..7d6899ad --- /dev/null +++ b/trantor/net/inner/TLSProvider.h @@ -0,0 +1,173 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace trantor +{ +struct TLSProvider +{ + TLSProvider(TcpConnection* conn, TLSPolicyPtr policy, SSLContextPtr ctx) + : conn_(conn), + policyPtr_(std::move(policy)), + contextPtr_(std::move(ctx)), + loop_(conn_->getLoop()) + { + } + virtual ~TLSProvider() = default; + using WriteCallback = ssize_t (*)(TcpConnection*, + const void* data, + size_t len); + using ErrorCallback = void (*)(TcpConnection*, SSLError err); + using HandshakeCallback = void (*)(TcpConnection*); + using MessageCallback = void (*)(TcpConnection*, MsgBuffer* buffer); + using CloseCallback = void (*)(TcpConnection*); + + /** + * @brief Sends data to the TLSProvider to process handshake and decrypt + * data + */ + virtual void recvData(MsgBuffer* buffer) = 0; + + /** + * @brief Encrypt and send data via TLS + * @return the number of bytes sent, or -1 on error, or 0 if EAGAIN or + * EWOULDBLOCK. + */ + virtual ssize_t sendData(const char* ptr, size_t size) = 0; + + /** + * @brief Close the TLS connection + */ + virtual void close() = 0; + + virtual void startEncryption() = 0; + + bool sendBufferedData() + { + if (writeBuffer_.readableBytes() == 0) + return true; + + auto n = writeCallback_(conn_, + writeBuffer_.peek(), + writeBuffer_.readableBytes()); + if (n == -1) + { + LOG_ERROR << "WTF! Failed to send buffered data. Error: " + << strerror(errno); + return false; + } + else if ((size_t)n != writeBuffer_.readableBytes()) + { + writeBuffer_.retrieve(n); + return false; + } + + writeBuffer_.retrieveAll(); + return true; + } + + MsgBuffer& getBufferedData() + { + return writeBuffer_; + } + + void appendToWriteBuffer(const char* ptr, size_t size) + { + writeBuffer_.ensureWritableBytes(size); + writeBuffer_.append(ptr, size); + } + + /** + * @brief Set a function to be called when the TLSProvider wants to send + * data + * + * @note The caller MUST guarantee that it will not make the TLSProvider + * send data after caller is destroyed. std::function used due to + * performance reasons. + */ + void setWriteCallback(WriteCallback cb) + { + writeCallback_ = cb; + } + + void setErrorCallback(ErrorCallback cb) + { + errorCallback_ = cb; + } + + void setHandshakeCallback(HandshakeCallback cb) + { + handshakeCallback_ = cb; + } + + void setMessageCallback(MessageCallback cb) + { + messageCallback_ = cb; + } + void setCloseCallback(CloseCallback cb) + { + closeCallback_ = cb; + } + + MsgBuffer& getRecvBuffer() + { + return recvBuffer_; + } + + const CertificatePtr& peerCertificate() const + { + return peerCertificate_; + } + + const std::string& applicationProtocol() const + { + return applicationProtocol_; + } + + const std::string& sniName() const + { + return sniName_; + } + + protected: + void setPeerCertificate(CertificatePtr cert) + { + peerCertificate_ = std::move(cert); + } + + void setApplicationProtocol(std::string protocol) + { + applicationProtocol_ = std::move(protocol); + } + + void setSniName(std::string name) + { + sniName_ = std::move(name); + } + + WriteCallback writeCallback_ = nullptr; + ErrorCallback errorCallback_ = nullptr; + HandshakeCallback handshakeCallback_ = nullptr; + MessageCallback messageCallback_ = nullptr; + CloseCallback closeCallback_ = nullptr; + TcpConnection* conn_ = nullptr; + const TLSPolicyPtr policyPtr_; + const SSLContextPtr contextPtr_; + MsgBuffer recvBuffer_; + EventLoop* loop_ = nullptr; + CertificatePtr peerCertificate_; + std::string applicationProtocol_; + std::string sniName_; + MsgBuffer writeBuffer_; +}; + +std::shared_ptr newTLSProvider(TcpConnection* conn, + TLSPolicyPtr policy, + SSLContextPtr ctx); +} // namespace trantor diff --git a/trantor/net/inner/TcpConnectionImpl.cc b/trantor/net/inner/TcpConnectionImpl.cc index 22f90c93..b786c762 100644 --- a/trantor/net/inner/TcpConnectionImpl.cc +++ b/trantor/net/inner/TcpConnectionImpl.cc @@ -18,22 +18,13 @@ #include #ifdef __linux__ #include +#include #endif #include #ifndef _WIN32 #include -#else -#include -#include -#include -#endif -#include -#include -#ifdef USE_OPENSSL -#include -#include -#include #endif + using namespace trantor; #ifdef _WIN32 @@ -46,399 +37,38 @@ using namespace trantor; #undef ECONNRESET #define ECONNRESET WSAECONNRESET #endif - -#ifdef USE_OPENSSL -namespace trantor -{ -namespace internal -{ -#ifdef _WIN32 -// Code yanked from stackoverflow -// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store -inline bool loadWindowsSystemCert(X509_STORE *store) -{ - auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); - - if (!hStore) - { - return false; - } - - PCCERT_CONTEXT pContext = NULL; - while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != - nullptr) - { - auto encoded_cert = - static_cast(pContext->pbCertEncoded); - - auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); - if (x509) - { - X509_STORE_add_cert(store, x509); - X509_free(x509); - } - } - - CertFreeCertificateContext(pContext); - CertCloseStore(hStore, 0); - - return true; -} -#endif - -inline bool verifyCommonName(X509 *cert, const std::string &hostname) +static inline bool isEAGAIN() { - X509_NAME *subjectName = X509_get_subject_name(cert); - - if (subjectName != nullptr) + if (errno == EWOULDBLOCK || errno == EAGAIN || errno == 0) { - std::array name; - auto length = X509_NAME_get_text_by_NID(subjectName, - NID_commonName, - name.data(), - (int)name.size()); - if (length == -1) - return false; - - return utils::verifySslName(std::string(name.begin(), - name.begin() + length), - hostname); - } - - return false; -} - -inline bool verifyAltName(X509 *cert, const std::string &hostname) -{ - bool good = false; - auto altNames = static_cast( - X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); - - if (altNames) - { - int numNames = sk_GENERAL_NAME_num(altNames); - - for (int i = 0; i < numNames && !good; i++) - { - auto val = sk_GENERAL_NAME_value(altNames, i); - if (val->type != GEN_DNS) - { - LOG_WARN << "Name using IP addresses are not supported. Open " - "an issue if you need that feature"; - continue; - } -#if (OPENSSL_VERSION_NUMBER >= 0x10100000L) - auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); -#else - auto name = (const char *)ASN1_STRING_data(val->d.ia5); -#endif - auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); - good = utils::verifySslName(std::string(name, name + name_len), - hostname); - } + LOG_TRACE << "write buffer is full"; + return true; } - - GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)altNames); - return good; -} - -} // namespace internal - -void initOpenSSL() -{ -#if (OPENSSL_VERSION_NUMBER < 0x10100000L) || \ - (defined(LIBRESSL_VERSION_NUMBER) && \ - LIBRESSL_VERSION_NUMBER < 0x20700000L) - // Initialize OpenSSL once; - static std::once_flag once; - std::call_once(once, []() { - SSL_library_init(); - ERR_load_crypto_strings(); - SSL_load_error_strings(); - OpenSSL_add_all_algorithms(); - }); -#endif -} - -class SSLContext -{ - public: - explicit SSLContext( - bool useOldTLS, - bool enableValidtion, - const std::vector> &sslConfCmds) + else if (errno == EPIPE || errno == ECONNRESET) { -#ifdef LIBRESSL_VERSION_NUMBER - ctxPtr_ = SSL_CTX_new(TLS_method()); - if (sslConfCmds.size() != 0) - { - LOG_WARN << "LibreSSL does not support SSL configuration commands"; - } - if (!useOldTLS) - { - SSL_CTX_set_min_proto_version(ctxPtr_, TLS1_2_VERSION); - } -#elif (OPENSSL_VERSION_NUMBER >= 0x10100000L) - ctxPtr_ = SSL_CTX_new(TLS_method()); - SSL_CONF_CTX *cctx = SSL_CONF_CTX_new(); - SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_SERVER); - SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CLIENT); - SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CERTIFICATE); - SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_FILE); - SSL_CONF_CTX_set_ssl_ctx(cctx, ctxPtr_); - for (const auto &cmd : sslConfCmds) - { - SSL_CONF_cmd(cctx, cmd.first.data(), cmd.second.data()); - } - SSL_CONF_CTX_finish(cctx); - SSL_CONF_CTX_free(cctx); - if (!useOldTLS) - { - SSL_CTX_set_min_proto_version(ctxPtr_, TLS1_2_VERSION); - } - else - { - LOG_WARN << "TLS 1.0/1.1 are enabled. They are considered " - "obsolete, insecure standards and should only be " - "used for legacy purpose."; - } -#else - ctxPtr_ = SSL_CTX_new(SSLv23_method()); - SSL_CONF_CTX *cctx = SSL_CONF_CTX_new(); - SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_SERVER); - SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CLIENT); - SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CERTIFICATE); - SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_FILE); - SSL_CONF_CTX_set_ssl_ctx(cctx, ctxPtr_); - for (const auto &cmd : sslConfCmds) - { - SSL_CONF_cmd(cctx, cmd.first.data(), cmd.second.data()); - } - SSL_CONF_CTX_finish(cctx); - SSL_CONF_CTX_free(cctx); - if (!useOldTLS) - { - SSL_CTX_set_options(ctxPtr_, SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1); - } - else - { - LOG_WARN << "TLS 1.0/1.1 are enabled. They are considered " - "obsolete, insecure standards and should only be " - "used for legacy purpose."; - } -#endif #ifdef _WIN32 - if (enableValidtion) - internal::loadWindowsSystemCert(SSL_CTX_get_cert_store(ctxPtr_)); + LOG_TRACE << "WSAENOTCONN or WSAECONNRESET, errno=" << errno; #else - if (enableValidtion) - SSL_CTX_set_default_verify_paths(ctxPtr_); + LOG_TRACE << "EPIPE or ECONNRESET, errno=" << errno; #endif - } - ~SSLContext() - { - if (ctxPtr_) - { - SSL_CTX_free(ctxPtr_); - } - } - - SSL_CTX *get() - { - return ctxPtr_; - } - bool mtlsEnabled = false; - - private: - SSL_CTX *ctxPtr_; -}; -class SSLConn -{ - public: - explicit SSLConn(SSL_CTX *ctx, bool mtlsEnabled_) - { - SSL_ = SSL_new(ctx); - mtlsEnabled = mtlsEnabled_; - } - ~SSLConn() - { - if (SSL_) - { - SSL_free(SSL_); - } - } - SSL *get() - { - return SSL_; - } - bool mtlsEnabled = false; - - private: - SSL *SSL_; -}; - -std::shared_ptr newSSLContext( - bool useOldTLS, - bool validateCert, - const std::vector> &sslConfCmds) -{ // init OpenSSL - initOpenSSL(); - return std::make_shared(useOldTLS, validateCert, sslConfCmds); -} -std::shared_ptr newSSLServerContext( - const std::string &certPath, - const std::string &keyPath, - bool useOldTLS, - const std::vector> &sslConfCmds, - const std::string &caPath) -{ - auto ctx = newSSLContext(useOldTLS, false, sslConfCmds); - auto r = SSL_CTX_use_certificate_chain_file(ctx->get(), certPath.c_str()); - char errbuf[BUFSIZ]; - if (!r) - { - ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); - LOG_FATAL << "Reading certificate: " << certPath - << " failed. Error: " << errbuf; - throw std::runtime_error("SSL_CTX_use_certificate_chain_file error."); - } - r = SSL_CTX_use_PrivateKey_file(ctx->get(), - keyPath.c_str(), - SSL_FILETYPE_PEM); - if (!r) - { - ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); - LOG_FATAL << "Reading private key: " << keyPath - << " failed. Error: " << errbuf; - throw std::runtime_error("SSL_CTX_use_PrivateKey_file error"); - } - r = SSL_CTX_check_private_key(ctx->get()); - if (!r) - { - ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); - LOG_FATAL << "Checking private key matches certificate: " << certPath - << " and " << keyPath << " mismatches. Error: " << errbuf; - throw std::runtime_error("SSL_CTX_check_private_key error"); - } - - if (!SSL_CTX_set_ecdh_auto(ctx->get(), 1)) - { - LOG_TRACE << "Failed to set_ecdh_auto, set_ecdh_auto DISABLED"; + LOG_TRACE << "send node in loop: return on connection closed"; + return false; } else { - LOG_TRACE << "set_ecdh_auto ENABLED"; - } - - if (!caPath.empty()) - { - auto checkCA = - SSL_CTX_load_verify_locations(ctx->get(), caPath.c_str(), NULL); - if (checkCA) - { - STACK_OF(X509_NAME) *cert_names = - SSL_load_client_CA_file(caPath.c_str()); - if (cert_names != NULL) - { - SSL_CTX_set_client_CA_list(ctx->get(), cert_names); - } - ctx->mtlsEnabled = true; - LOG_TRACE << "mTLS session ENABLED"; - } - else - { - LOG_FATAL << "caPath location error "; - throw std::runtime_error("SSL_CTX_load_verify_locations error"); - } - } - - return ctx; -} -std::shared_ptr newSSLClientContext( - bool useOldTLS, - bool validateCert, - const std::string &certPath, - const std::string &keyPath, - const std::vector> &sslConfCmds, - const std::string &caPath) -{ - auto ctx = newSSLContext(useOldTLS, validateCert, sslConfCmds); - if (certPath.empty() || keyPath.empty()) - return ctx; - - auto r = SSL_CTX_use_certificate_chain_file(ctx->get(), certPath.c_str()); - char errbuf[BUFSIZ]; - if (!r) - { - ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); - LOG_FATAL << "Reading certificate: " << certPath - << " failed. Error: " << errbuf; - throw std::runtime_error("SSL_CTX_use_certificate_chain_file error."); - } - r = SSL_CTX_use_PrivateKey_file(ctx->get(), - keyPath.c_str(), - SSL_FILETYPE_PEM); - if (!r) - { - ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); - LOG_FATAL << "Reading private key: " << keyPath - << " failed. Error: " << errbuf; - throw std::runtime_error("SSL_CTX_use_PrivateKey_file error"); - } - r = SSL_CTX_check_private_key(ctx->get()); - if (!r) - { - ERR_error_string_n(ERR_get_error(), errbuf, sizeof(errbuf)); - LOG_FATAL << "Checking private key matches certificate: " << certPath - << " and " << keyPath << " mismatches. Error: " << errbuf; - throw std::runtime_error("SSL_CTX_check_private_key error"); - } - - if (!caPath.empty()) - { - auto checkCA = - SSL_CTX_load_verify_locations(ctx->get(), caPath.c_str(), NULL); - LOG_TRACE << "CA CHECK LOC: " << checkCA; - if (checkCA) - { - STACK_OF(X509_NAME) *cert_names = - SSL_load_client_CA_file(caPath.c_str()); - if (cert_names != NULL) - { - SSL_CTX_set_client_CA_list(ctx->get(), cert_names); - } - ctx->mtlsEnabled = true; - } - else - { - LOG_FATAL << "caPath location error "; - throw std::runtime_error("SSL_CTX_load_verify_locations error"); - } + // TODO: any others? + LOG_SYSERR << "send node in loop: return on unexpected error(" << errno + << ")"; + return false; } - - return ctx; } -} // namespace trantor -#else -namespace trantor -{ -std::shared_ptr newSSLServerContext( - const std::string &, - const std::string &, - bool, - const std::vector> &, - const std::string &) -{ - LOG_FATAL << "OpenSSL is not found in your system!"; - throw std::runtime_error("OpenSSL is not found in your system!"); -} -} // namespace trantor -#endif TcpConnectionImpl::TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, - const InetAddress &peerAddr) + const InetAddress &peerAddr, + TLSPolicyPtr policy, + SSLContextPtr ctx) : loop_(loop), ioChannelPtr_(new Channel(loop, socketfd)), socketPtr_(new Socket(socketfd)), @@ -447,226 +77,87 @@ TcpConnectionImpl::TcpConnectionImpl(EventLoop *loop, { LOG_TRACE << "new connection:" << peerAddr.toIpPort() << "->" << localAddr.toIpPort(); - ioChannelPtr_->setReadCallback( - std::bind(&TcpConnectionImpl::readCallback, this)); - ioChannelPtr_->setWriteCallback( - std::bind(&TcpConnectionImpl::writeCallback, this)); - ioChannelPtr_->setCloseCallback( - std::bind(&TcpConnectionImpl::handleClose, this)); - ioChannelPtr_->setErrorCallback( - std::bind(&TcpConnectionImpl::handleError, this)); + ioChannelPtr_->setReadCallback([this]() { readCallback(); }); + ioChannelPtr_->setWriteCallback([this]() { writeCallback(); }); + ioChannelPtr_->setCloseCallback([this]() { handleClose(); }); + ioChannelPtr_->setErrorCallback([this]() { handleError(); }); socketPtr_->setKeepAlive(true); name_ = localAddr.toIpPort() + "--" + peerAddr.toIpPort(); -} -TcpConnectionImpl::~TcpConnectionImpl() -{ -} -#ifdef USE_OPENSSL -void TcpConnectionImpl::startClientEncryptionInLoop( - std::function &&callback, - bool useOldTLS, - bool validateCert, - const std::string &hostname, - const std::vector> &sslConfCmds) -{ - validateCert_ = validateCert; - loop_->assertInLoopThread(); - if (isEncrypted_) - { - LOG_WARN << "This connection is already encrypted"; - return; - } - sslEncryptionPtr_ = std::make_unique(); - sslEncryptionPtr_->upgradeCallback_ = std::move(callback); - sslEncryptionPtr_->sslCtxPtr_ = - newSSLContext(useOldTLS, validateCert_, sslConfCmds); - sslEncryptionPtr_->sslPtr_ = - std::make_unique(sslEncryptionPtr_->sslCtxPtr_->get(), - sslEncryptionPtr_->sslCtxPtr_->mtlsEnabled); - if (validateCert || sslEncryptionPtr_->sslPtr_->mtlsEnabled) - { - LOG_TRACE << "MTLS: " << sslEncryptionPtr_->sslPtr_->mtlsEnabled; - SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(), - sslEncryptionPtr_->sslPtr_->mtlsEnabled - ? SSL_VERIFY_PEER - : SSL_VERIFY_NONE, - nullptr); - validateCert_ = validateCert; - } - if (!hostname.empty()) + + if (policy != nullptr) { - SSL_set_tlsext_host_name(sslEncryptionPtr_->sslPtr_->get(), - hostname.data()); - sslEncryptionPtr_->hostname_ = hostname; + tlsProviderPtr_ = + newTLSProvider(this, std::move(policy), std::move(ctx)); + tlsProviderPtr_->setWriteCallback(onSslWrite); + tlsProviderPtr_->setErrorCallback(onSslError); + tlsProviderPtr_->setHandshakeCallback(onHandshakeFinished); + tlsProviderPtr_->setMessageCallback(onSslMessage); + // This is triggered when peer sends a close alert + tlsProviderPtr_->setCloseCallback(onSslCloseAlert); } - isEncrypted_ = true; - sslEncryptionPtr_->isUpgrade_ = true; - auto r = SSL_set_fd(sslEncryptionPtr_->sslPtr_->get(), socketPtr_->fd()); - (void)r; - assert(r); - sslEncryptionPtr_->sendBufferPtr_ = - std::make_unique>(); - LOG_TRACE << "connectEstablished"; - ioChannelPtr_->enableWriting(); - SSL_set_connect_state(sslEncryptionPtr_->sslPtr_->get()); } -void TcpConnectionImpl::startServerEncryptionInLoop( - const std::shared_ptr &ctx, - std::function &&callback) +TcpConnectionImpl::~TcpConnectionImpl() { - loop_->assertInLoopThread(); - if (isEncrypted_) + std::size_t readableTlsBytes = 0; + if (tlsProviderPtr_) { - LOG_WARN << "This connection is already encrypted"; - return; + readableTlsBytes = tlsProviderPtr_->getBufferedData().readableBytes(); } - sslEncryptionPtr_ = std::make_unique(); - sslEncryptionPtr_->upgradeCallback_ = std::move(callback); - sslEncryptionPtr_->sslCtxPtr_ = ctx; - sslEncryptionPtr_->isServer_ = true; - sslEncryptionPtr_->sslPtr_ = - std::make_unique(sslEncryptionPtr_->sslCtxPtr_->get(), - sslEncryptionPtr_->sslCtxPtr_->mtlsEnabled); - isEncrypted_ = true; - sslEncryptionPtr_->isUpgrade_ = true; - if (sslEncryptionPtr_->isServer_ == false || - sslEncryptionPtr_->sslPtr_->mtlsEnabled) + if (!writeBufferList_.empty()) { - LOG_TRACE << "MTLS: " << sslEncryptionPtr_->sslPtr_->mtlsEnabled; - SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(), - sslEncryptionPtr_->sslPtr_->mtlsEnabled - ? SSL_VERIFY_PEER - : SSL_VERIFY_NONE, - nullptr); + LOG_DEBUG << "write node list size: " << writeBufferList_.size() + << " first node is file? " + << writeBufferList_.front()->isFile() + << " first node is stream? " + << writeBufferList_.front()->isStream() + << " first node is async? " + << writeBufferList_.front()->isAsync() << " first node size: " + << writeBufferList_.front()->remainingBytes() + << " buffered TLS data size: " << readableTlsBytes; } - - auto r = SSL_set_fd(sslEncryptionPtr_->sslPtr_->get(), socketPtr_->fd()); - (void)r; - assert(r); - sslEncryptionPtr_->sendBufferPtr_ = - std::make_unique>(); - LOG_TRACE << "upgrade to ssl"; - SSL_set_accept_state(sslEncryptionPtr_->sslPtr_->get()); -} -#endif -void TcpConnectionImpl::startServerEncryption( - const std::shared_ptr &ctx, - std::function callback) -{ -#ifndef USE_OPENSSL - // When not using OpenSSL, using `void` here will - // work around the unused parameter warnings without overhead. - (void)ctx; - (void)callback; - - LOG_FATAL << "OpenSSL is not found in your system!"; - throw std::runtime_error("OpenSSL is not found in your system!"); -#else - if (loop_->isInLoopThread()) - { - startServerEncryptionInLoop(ctx, std::move(callback)); - } - else + else if (readableTlsBytes != 0) { - loop_->queueInLoop([thisPtr = shared_from_this(), - ctx, - callback = std::move(callback)]() mutable { - thisPtr->startServerEncryptionInLoop(ctx, std::move(callback)); - }); + LOG_DEBUG << "write node list size: 0 buffered TLS data size: " + << readableTlsBytes; } - -#endif + // send a close alert to peer if we are still connected + if (tlsProviderPtr_ && status_ == ConnStatus::Connected) + tlsProviderPtr_->close(); } -void TcpConnectionImpl::startClientEncryption( - std::function callback, - bool useOldTLS, - bool validateCert, - std::string hostname, - const std::vector> &sslConfCmds) + +void TcpConnectionImpl::readCallback() { -#ifndef USE_OPENSSL - // When not using OpenSSL, using `void` here will - // work around the unused parameter warnings without overhead. - (void)callback; - (void)useOldTLS; - (void)validateCert; - (void)hostname; - (void)sslConfCmds; + // LOG_TRACE<<"read Callback"; + loop_->assertInLoopThread(); + int ret = 0; - LOG_FATAL << "OpenSSL is not found in your system!"; - throw std::runtime_error("OpenSSL is not found in your system!"); -#else - if (!hostname.empty()) + ssize_t n = readBuffer_.readFd(socketPtr_->fd(), &ret); + // LOG_TRACE<<"read "<hostname_ = hostname; + // socket closed by peer + handleClose(); } - if (loop_->isInLoopThread()) + else if (n < 0) { - startClientEncryptionInLoop(std::move(callback), - useOldTLS, - validateCert, - hostname, - sslConfCmds); - } - else - { - loop_->queueInLoop([thisPtr = shared_from_this(), - callback = std::move(callback), - useOldTLS, - hostname = std::move(hostname), - validateCert, - &sslConfCmds]() mutable { - thisPtr->startClientEncryptionInLoop(std::move(callback), - useOldTLS, - validateCert, - hostname, - sslConfCmds); - }); - } -#endif -} -void TcpConnectionImpl::readCallback() -{ -// LOG_TRACE<<"read Callback"; -#ifdef USE_OPENSSL - if (!isEncrypted_) - { -#endif - loop_->assertInLoopThread(); - int ret = 0; - - ssize_t n = readBuffer_.readFd(socketPtr_->fd(), &ret); - // LOG_TRACE<<"read "< 0) - { - bytesReceived_ += n; - if (recvMsgCallback_) - { - recvMsgCallback_(shared_from_this(), &readBuffer_); - } - } -#ifdef USE_OPENSSL + LOG_SYSERR << "read socket error"; + handleClose(); + return; } - else + extendLife(); + if (n > 0) { - LOG_TRACE << "read Callback"; - loop_->assertInLoopThread(); - if (sslEncryptionPtr_->statusOfSSL_ == SSLStatus::Handshaking) + bytesReceived_ += n; + if (tlsProviderPtr_) { - doHandshaking(); - return; + tlsProviderPtr_->recvData(&readBuffer_); } - else if (sslEncryptionPtr_->statusOfSSL_ == SSLStatus::Connected) + else if (recvMsgCallback_) { - int rd; - bool newDataFlag = false; - size_t readLength; - do - { - readBuffer_.ensureWritableBytes(1024); - readLength = readBuffer_.writableBytes(); - rd = SSL_read(sslEncryptionPtr_->sslPtr_->get(), - readBuffer_.beginWrite(), - static_cast(readLength)); - LOG_TRACE << "ssl read:" << rd << " bytes"; - if (rd <= 0) - { - int sslerr = - SSL_get_error(sslEncryptionPtr_->sslPtr_->get(), rd); - if (sslerr == SSL_ERROR_WANT_READ) - { - break; - } - else - { - LOG_TRACE << "ssl read err:" << sslerr; - sslEncryptionPtr_->statusOfSSL_ = - SSLStatus::DisConnected; - handleClose(); - return; - } - } - readBuffer_.hasWritten(rd); - newDataFlag = true; - } while ((size_t)rd == readLength); - if (newDataFlag) - { - // Run callback function - recvMsgCallback_(shared_from_this(), &readBuffer_); - } + recvMsgCallback_(shared_from_this(), &readBuffer_); } } -#endif } void TcpConnectionImpl::extendLife() { @@ -760,218 +203,83 @@ void TcpConnectionImpl::extendLife() } void TcpConnectionImpl::writeCallback() { -#ifdef USE_OPENSSL - if (!isEncrypted_ || - (sslEncryptionPtr_ && - sslEncryptionPtr_->statusOfSSL_ == SSLStatus::Connected)) + loop_->assertInLoopThread(); + if (ioChannelPtr_->isWriting()) { -#endif - loop_->assertInLoopThread(); - extendLife(); - if (ioChannelPtr_->isWriting()) + if (tlsProviderPtr_) + { + bool sentAll = tlsProviderPtr_->sendBufferedData(); + if (!sentAll) + { + return; + } + } + while (!writeBufferList_.empty()) { - assert(!writeBufferList_.empty()); - auto writeBuffer_ = writeBufferList_.front(); - if (!writeBuffer_->isFile()) + auto &nodePtr = writeBufferList_.front(); + if (nodePtr->remainingBytes() == 0) { - // not a file - if (writeBuffer_->msgBuffer_->readableBytes() <= 0) + if (!nodePtr->isAsync() || !nodePtr->available()) { // finished sending writeBufferList_.pop_front(); - if (writeBufferList_.empty()) - { - // stop writing - ioChannelPtr_->disableWriting(); - if (writeCompleteCallback_) - writeCompleteCallback_(shared_from_this()); - if (status_ == ConnStatus::Disconnecting) - { - socketPtr_->closeWrite(); - } - } - else - { - // send next - // what if the next is not a file??? - auto fileNode = writeBufferList_.front(); - assert(fileNode->isFile()); - sendFileInLoop(fileNode); - } } else { - // continue sending - auto n = - writeInLoop(writeBuffer_->msgBuffer_->peek(), - writeBuffer_->msgBuffer_->readableBytes()); - if (n >= 0) - { - writeBuffer_->msgBuffer_->retrieve(n); - } - else - { -#ifdef _WIN32 - if (errno != 0 && errno != EWOULDBLOCK) -#else - if (errno != EWOULDBLOCK) -#endif - { - // TODO: any others? - if (errno == EPIPE || errno == ECONNRESET) - { -#ifdef _WIN32 - LOG_TRACE - << "WSAENOTCONN or WSAECONNRESET, errno=" - << errno; -#else - LOG_TRACE << "EPIPE or ECONNRESET, errno=" << errno; -#endif - return; - } - LOG_SYSERR << "Unexpected error(" << errno << ")"; - return; - } - } + // the first node is an async node and is available + ioChannelPtr_->disableWriting(); + return; } } else { - // is a file - if (writeBuffer_->fileBytesToSend_ <= 0) - { - // finished sending - writeBufferList_.pop_front(); - if (writeBufferList_.empty()) - { - // stop writing - ioChannelPtr_->disableWriting(); - if (writeCompleteCallback_) - writeCompleteCallback_(shared_from_this()); - if (status_ == ConnStatus::Disconnecting) - { - socketPtr_->closeWrite(); - } - } - else - { - // next is not a file - if (!writeBufferList_.front()->isFile()) - { - // There is data to be sent in the buffer. - auto n = writeInLoop( - writeBufferList_.front()->msgBuffer_->peek(), - writeBufferList_.front() - ->msgBuffer_->readableBytes()); - if (n >= 0) - { - writeBufferList_.front()->msgBuffer_->retrieve( - n); - } - else - { -#ifdef _WIN32 - if (errno != 0 && errno != EWOULDBLOCK) -#else - if (errno != EWOULDBLOCK) -#endif - { - // TODO: any others? - if (errno == EPIPE || errno == ECONNRESET) - { -#ifdef _WIN32 - LOG_TRACE << "WSAENOTCONN or " - "WSAECONNRESET, errno=" - << errno; -#else - LOG_TRACE << "EPIPE or " - "ECONNRESET, erron=" - << errno; -#endif - return; - } - LOG_SYSERR << "Unexpected error(" << errno - << ")"; - return; - } - } - } - else - { - // next is a file - sendFileInLoop(writeBufferList_.front()); - } - } - } - else - { - sendFileInLoop(writeBuffer_); - } + // continue sending + auto n = sendNodeInLoop(nodePtr); + if (nodePtr->remainingBytes() > 0 || n < 0) + return; } } - else + assert(writeBufferList_.empty()); + if (tlsProviderPtr_ == nullptr || + tlsProviderPtr_->getBufferedData().readableBytes() == 0) { - LOG_SYSERR << "no writing but write callback called"; + ioChannelPtr_->disableWriting(); + if (closeOnEmpty_) + { + shutdown(); + } } -#ifdef USE_OPENSSL } else { - LOG_TRACE << "write Callback"; - loop_->assertInLoopThread(); - if (sslEncryptionPtr_->statusOfSSL_ == SSLStatus::Handshaking) - { - doHandshaking(); - return; - } + LOG_SYSERR << "no writing but write callback called"; } -#endif } void TcpConnectionImpl::connectEstablished() { -// loop_->assertInLoopThread(); -#ifdef USE_OPENSSL - if (!isEncrypted_) - { -#endif - auto thisPtr = shared_from_this(); - loop_->runInLoop([thisPtr]() { - LOG_TRACE << "connectEstablished"; - assert(thisPtr->status_ == ConnStatus::Connecting); - thisPtr->ioChannelPtr_->tie(thisPtr); - thisPtr->ioChannelPtr_->enableReading(); - thisPtr->status_ = ConnStatus::Connected; - if (thisPtr->connectionCallback_) - thisPtr->connectionCallback_(thisPtr); - }); -#ifdef USE_OPENSSL - } - else - { - loop_->runInLoop([thisPtr = shared_from_this()]() { - LOG_TRACE << "connectEstablished"; - assert(thisPtr->status_ == ConnStatus::Connecting); - thisPtr->ioChannelPtr_->tie(thisPtr); - thisPtr->ioChannelPtr_->enableReading(); - thisPtr->status_ = ConnStatus::Connected; - if (thisPtr->sslEncryptionPtr_->isServer_) - { - SSL_set_accept_state( - thisPtr->sslEncryptionPtr_->sslPtr_->get()); - } - else - { - thisPtr->ioChannelPtr_->enableWriting(); - SSL_set_connect_state( - thisPtr->sslEncryptionPtr_->sslPtr_->get()); - } - }); - } -#endif + auto thisPtr = shared_from_this(); + loop_->runInLoop([thisPtr]() { + LOG_TRACE << "connectEstablished"; + assert(thisPtr->status_ == ConnStatus::Connecting); + thisPtr->ioChannelPtr_->tie(thisPtr); + thisPtr->ioChannelPtr_->enableReading(); + thisPtr->status_ = ConnStatus::Connected; + + if (thisPtr->tlsProviderPtr_) + thisPtr->tlsProviderPtr_->startEncryption(); + else if (thisPtr->connectionCallback_) + thisPtr->connectionCallback_(thisPtr); + }); } void TcpConnectionImpl::handleClose() { LOG_TRACE << "connection closed, fd=" << socketPtr_->fd(); + LOG_TRACE << "write buffer size: " << writeBufferList_.size(); + if (!writeBufferList_.empty()) + { + LOG_TRACE << writeBufferList_.front()->isFile(); + LOG_TRACE << writeBufferList_.front()->isStream(); + } loop_->assertInLoopThread(); status_ = ConnStatus::Disconnected; ioChannelPtr_->disableAll(); @@ -1027,6 +335,25 @@ void TcpConnectionImpl::shutdown() loop_->runInLoop([thisPtr]() { if (thisPtr->status_ == ConnStatus::Connected) { + if (thisPtr->tlsProviderPtr_) + { + // there's still data to be sent, so we can't close the + // connection just yet + if (thisPtr->tlsProviderPtr_->getBufferedData() + .readableBytes() != 0 || + !thisPtr->writeBufferList_.empty()) + { + thisPtr->closeOnEmpty_ = true; + return; + } + thisPtr->tlsProviderPtr_->close(); + } + if (thisPtr->tlsProviderPtr_ == nullptr && + !thisPtr->writeBufferList_.empty()) + { + thisPtr->closeOnEmpty_ = true; + return; + } thisPtr->status_ = ConnStatus::Disconnecting; if (!thisPtr->ioChannelPtr_->isWriting()) { @@ -1045,6 +372,9 @@ void TcpConnectionImpl::forceClose() { thisPtr->status_ = ConnStatus::Disconnecting; thisPtr->handleClose(); + + if (thisPtr->tlsProviderPtr_) + thisPtr->tlsProviderPtr_->close(); } }); } @@ -1057,11 +387,9 @@ void TcpConnectionImpl::sendInLoop(const char *buffer, size_t length) loop_->assertInLoopThread(); if (status_ != ConnStatus::Connected) { - LOG_WARN << "Connection is not connected,give up sending"; + LOG_DEBUG << "Connection is not connected,give up sending"; return; } - extendLife(); - size_t remainLen = length; ssize_t sendLen = 0; if (!ioChannelPtr_->isWriting() && writeBufferList_.empty()) { @@ -1069,55 +397,35 @@ void TcpConnectionImpl::sendInLoop(const char *buffer, size_t length) sendLen = writeInLoop(buffer, length); if (sendLen < 0) { - // error -#ifdef _WIN32 - if (errno != 0 && errno != EWOULDBLOCK) -#else - if (errno != EWOULDBLOCK) -#endif - { - if (errno == EPIPE || errno == ECONNRESET) // TODO: any others? - { -#ifdef _WIN32 - LOG_TRACE << "WSAENOTCONN or WSAECONNRESET, errno=" - << errno; -#else - LOG_TRACE << "EPIPE or ECONNRESET, errno=" << errno; -#endif - return; - } - LOG_SYSERR << "Unexpected error(" << errno << ")"; - return; - } - sendLen = 0; + LOG_TRACE << "write error"; + return; } - remainLen -= sendLen; + length -= sendLen; } - if (remainLen > 0 && status_ == ConnStatus::Connected) + if (length > 0 && status_ == ConnStatus::Connected) { - if (writeBufferList_.empty()) + if (writeBufferList_.empty() || writeBufferList_.back()->isFile() || + writeBufferList_.back()->isStream()) { - BufferNodePtr node = std::make_shared(); - node->msgBuffer_ = std::make_shared(); - writeBufferList_.push_back(std::move(node)); + writeBufferList_.push_back(BufferNode::newMemBufferNode()); } - else if (writeBufferList_.back()->isFile()) + writeBufferList_.back()->append(static_cast(buffer) + + sendLen, + length); + if (highWaterMarkCallback_ && + writeBufferList_.back()->remainingBytes() > + static_cast(highWaterMarkLen_)) { - BufferNodePtr node = std::make_shared(); - node->msgBuffer_ = std::make_shared(); - writeBufferList_.push_back(std::move(node)); + highWaterMarkCallback_(shared_from_this(), + writeBufferList_.back()->remainingBytes()); } - writeBufferList_.back()->msgBuffer_->append( - static_cast(buffer) + sendLen, remainLen); - if (!ioChannelPtr_->isWriting()) - ioChannelPtr_->enableWriting(); - if (highWaterMarkCallback_ && - writeBufferList_.back()->msgBuffer_->readableBytes() > + if (highWaterMarkCallback_ && tlsProviderPtr_ && + tlsProviderPtr_->getBufferedData().readableBytes() > highWaterMarkLen_) { highWaterMarkCallback_( shared_from_this(), - writeBufferList_.back()->msgBuffer_->readableBytes()); + tlsProviderPtr_->getBufferedData().readableBytes()); } } } @@ -1126,31 +434,12 @@ void TcpConnectionImpl::send(const std::shared_ptr &msgPtr) { if (loop_->isInLoopThread()) { - std::lock_guard guard(sendNumMutex_); - if (sendNum_ == 0) - { - sendInLoop(msgPtr->data(), msgPtr->length()); - } - else - { - ++sendNum_; - auto thisPtr = shared_from_this(); - loop_->queueInLoop([thisPtr, msgPtr]() { - thisPtr->sendInLoop(msgPtr->data(), msgPtr->length()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); - } + sendInLoop(msgPtr->data(), msgPtr->length()); } else { - auto thisPtr = shared_from_this(); - std::lock_guard guard(sendNumMutex_); - ++sendNum_; - loop_->queueInLoop([thisPtr, msgPtr]() { + loop_->queueInLoop([thisPtr = shared_from_this(), msgPtr]() { thisPtr->sendInLoop(msgPtr->data(), msgPtr->length()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; }); } } @@ -1159,31 +448,12 @@ void TcpConnectionImpl::send(const std::shared_ptr &msgPtr) { if (loop_->isInLoopThread()) { - std::lock_guard guard(sendNumMutex_); - if (sendNum_ == 0) - { - sendInLoop(msgPtr->peek(), msgPtr->readableBytes()); - } - else - { - ++sendNum_; - auto thisPtr = shared_from_this(); - loop_->queueInLoop([thisPtr, msgPtr]() { - thisPtr->sendInLoop(msgPtr->peek(), msgPtr->readableBytes()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); - } + sendInLoop(msgPtr->peek(), msgPtr->readableBytes()); } else { - auto thisPtr = shared_from_this(); - std::lock_guard guard(sendNumMutex_); - ++sendNum_; - loop_->queueInLoop([thisPtr, msgPtr]() { + loop_->queueInLoop([thisPtr = shared_from_this(), msgPtr]() { thisPtr->sendInLoop(msgPtr->peek(), msgPtr->readableBytes()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; }); } } @@ -1191,106 +461,47 @@ void TcpConnectionImpl::send(const char *msg, size_t len) { if (loop_->isInLoopThread()) { - std::lock_guard guard(sendNumMutex_); - if (sendNum_ == 0) - { - sendInLoop(msg, len); - } - else - { - ++sendNum_; - auto buffer = std::make_shared(msg, len); - auto thisPtr = shared_from_this(); - loop_->queueInLoop([thisPtr, buffer]() { - thisPtr->sendInLoop(buffer->data(), buffer->length()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); - } + sendInLoop(msg, len); } else { auto buffer = std::make_shared(msg, len); - auto thisPtr = shared_from_this(); - std::lock_guard guard(sendNumMutex_); - ++sendNum_; - loop_->queueInLoop([thisPtr, buffer]() { - thisPtr->sendInLoop(buffer->data(), buffer->length()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); + loop_->queueInLoop( + [thisPtr = shared_from_this(), buffer = std::move(buffer)]() { + thisPtr->sendInLoop(buffer->data(), buffer->length()); + }); } } void TcpConnectionImpl::send(const void *msg, size_t len) { if (loop_->isInLoopThread()) { - std::lock_guard guard(sendNumMutex_); - if (sendNum_ == 0) - { #ifndef _WIN32 - sendInLoop(msg, len); + sendInLoop(msg, len); #else - sendInLoop(static_cast(msg), len); + sendInLoop(static_cast(msg), len); #endif - } - else - { - ++sendNum_; - auto buffer = - std::make_shared(static_cast(msg), - len); - auto thisPtr = shared_from_this(); - loop_->queueInLoop([thisPtr, buffer]() { - thisPtr->sendInLoop(buffer->data(), buffer->length()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); - } } else { auto buffer = std::make_shared(static_cast(msg), len); - auto thisPtr = shared_from_this(); - std::lock_guard guard(sendNumMutex_); - ++sendNum_; - loop_->queueInLoop([thisPtr, buffer]() { - thisPtr->sendInLoop(buffer->data(), buffer->length()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); + loop_->queueInLoop( + [thisPtr = shared_from_this(), buffer = std::move(buffer)]() { + thisPtr->sendInLoop(buffer->data(), buffer->length()); + }); } } void TcpConnectionImpl::send(const std::string &msg) { if (loop_->isInLoopThread()) { - std::lock_guard guard(sendNumMutex_); - if (sendNum_ == 0) - { - sendInLoop(msg.data(), msg.length()); - } - else - { - ++sendNum_; - auto thisPtr = shared_from_this(); - loop_->queueInLoop([thisPtr, msg]() { - thisPtr->sendInLoop(msg.data(), msg.length()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); - } + sendInLoop(msg.data(), msg.length()); } else { - auto thisPtr = shared_from_this(); - std::lock_guard guard(sendNumMutex_); - ++sendNum_; - loop_->queueInLoop([thisPtr, msg]() { + loop_->queueInLoop([thisPtr = shared_from_this(), msg]() { thisPtr->sendInLoop(msg.data(), msg.length()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; }); } } @@ -1298,32 +509,14 @@ void TcpConnectionImpl::send(std::string &&msg) { if (loop_->isInLoopThread()) { - std::lock_guard guard(sendNumMutex_); - if (sendNum_ == 0) - { - sendInLoop(msg.data(), msg.length()); - } - else - { - auto thisPtr = shared_from_this(); - ++sendNum_; - loop_->queueInLoop([thisPtr, msg = std::move(msg)]() { - thisPtr->sendInLoop(msg.data(), msg.length()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); - } + sendInLoop(msg.data(), msg.length()); } else { - auto thisPtr = shared_from_this(); - std::lock_guard guard(sendNumMutex_); - ++sendNum_; - loop_->queueInLoop([thisPtr, msg = std::move(msg)]() { - thisPtr->sendInLoop(msg.data(), msg.length()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); + loop_->queueInLoop( + [thisPtr = shared_from_this(), msg = std::move(msg)]() { + thisPtr->sendInLoop(msg.data(), msg.length()); + }); } } @@ -1331,31 +524,12 @@ void TcpConnectionImpl::send(const MsgBuffer &buffer) { if (loop_->isInLoopThread()) { - std::lock_guard guard(sendNumMutex_); - if (sendNum_ == 0) - { - sendInLoop(buffer.peek(), buffer.readableBytes()); - } - else - { - ++sendNum_; - auto thisPtr = shared_from_this(); - loop_->queueInLoop([thisPtr, buffer]() { - thisPtr->sendInLoop(buffer.peek(), buffer.readableBytes()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); - } + sendInLoop(buffer.peek(), buffer.readableBytes()); } else { - auto thisPtr = shared_from_this(); - std::lock_guard guard(sendNumMutex_); - ++sendNum_; - loop_->queueInLoop([thisPtr, buffer]() { + loop_->queueInLoop([thisPtr = shared_from_this(), buffer]() { thisPtr->sendInLoop(buffer.peek(), buffer.readableBytes()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; }); } } @@ -1364,168 +538,84 @@ void TcpConnectionImpl::send(MsgBuffer &&buffer) { if (loop_->isInLoopThread()) { - std::lock_guard guard(sendNumMutex_); - if (sendNum_ == 0) - { - sendInLoop(buffer.peek(), buffer.readableBytes()); - } - else - { - ++sendNum_; - auto thisPtr = shared_from_this(); - loop_->queueInLoop([thisPtr, buffer = std::move(buffer)]() { - thisPtr->sendInLoop(buffer.peek(), buffer.readableBytes()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); - } + sendInLoop(buffer.peek(), buffer.readableBytes()); } else { - auto thisPtr = shared_from_this(); - std::lock_guard guard(sendNumMutex_); - ++sendNum_; - loop_->queueInLoop([thisPtr, buffer = std::move(buffer)]() { - thisPtr->sendInLoop(buffer.peek(), buffer.readableBytes()); - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - }); + loop_->queueInLoop( + [thisPtr = shared_from_this(), buffer = std::move(buffer)]() { + thisPtr->sendInLoop(buffer.peek(), buffer.readableBytes()); + }); } } void TcpConnectionImpl::sendFile(const char *fileName, - size_t offset, - size_t length) + long long offset, + long long length) { assert(fileName); #ifdef _WIN32 sendFile(utils::toNativePath(fileName).c_str(), offset, length); #else // _WIN32 - int fd = open(fileName, O_RDONLY); + auto fileNode = BufferNode::newFileBufferNode(fileName, offset, length); - if (fd < 0) + if (!fileNode->available()) { LOG_SYSERR << fileName << " open error"; return; } - if (length == 0) - { - struct stat filestat; - if (stat(fileName, &filestat) < 0) - { - LOG_SYSERR << fileName << " stat error"; - close(fd); - return; - } - length = filestat.st_size; - } - - sendFile(fd, offset, length); + sendFile(std::move(fileNode)); #endif // _WIN32 } void TcpConnectionImpl::sendFile(const wchar_t *fileName, - size_t offset, - size_t length) + long long offset, + long long length) { assert(fileName); #ifndef _WIN32 sendFile(utils::toNativePath(fileName).c_str(), offset, length); -#else // _WIN32 - FILE *fp; -#ifndef _MSC_VER - fp = _wfopen(fileName, L"rb"); -#else // _MSC_VER - if (_wfopen_s(&fp, fileName, L"rb") != 0) - fp = nullptr; -#endif // _MSC_VER - if (fp == nullptr) +#else + auto fileNode = BufferNode::newFileBufferNode(fileName, offset, length); + if (!fileNode->available()) { LOG_SYSERR << fileName << " open error"; return; } - - if (length == 0) - { - struct _stati64 filestat; - if (_wstati64(fileName, &filestat) < 0) - { - LOG_SYSERR << fileName << " stat error"; - fclose(fp); - return; - } - length = filestat.st_size; - } - - sendFile(fp, offset, length); + sendFile(std::move(fileNode)); #endif // _WIN32 } -#ifndef _WIN32 -void TcpConnectionImpl::sendFile(int sfd, size_t offset, size_t length) -#else -void TcpConnectionImpl::sendFile(FILE *fp, size_t offset, size_t length) -#endif -{ - assert(length > 0); -#ifndef _WIN32 - assert(sfd >= 0); - BufferNodePtr node = std::make_shared(); - node->sendFd_ = sfd; -#else - assert(fp); - BufferNodePtr node = std::make_shared(); - node->sendFp_ = fp; -#endif - node->offset_ = static_cast(offset); - node->fileBytesToSend_ = length; +void TcpConnectionImpl::sendFile(BufferNodePtr &&fileNode) +{ + assert(fileNode->isFile() && fileNode->remainingBytes() > 0); if (loop_->isInLoopThread()) { - std::lock_guard guard(sendNumMutex_); - if (sendNum_ == 0) + if (writeBufferList_.empty()) { - writeBufferList_.push_back(node); - if (writeBufferList_.size() == 1) - { - sendFileInLoop(writeBufferList_.front()); - return; - } + auto n = sendNodeInLoop(fileNode); + if (fileNode->remainingBytes() > 0 && n >= 0) + writeBufferList_.push_back(std::move(fileNode)); + return; } else { - ++sendNum_; - auto thisPtr = shared_from_this(); - loop_->queueInLoop([thisPtr, node]() { - thisPtr->writeBufferList_.push_back(node); - { - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - } - - if (thisPtr->writeBufferList_.size() == 1) - { - thisPtr->sendFileInLoop(thisPtr->writeBufferList_.front()); - } - }); + writeBufferList_.push_back(std::move(fileNode)); } } else { - auto thisPtr = shared_from_this(); - std::lock_guard guard(sendNumMutex_); - ++sendNum_; - loop_->queueInLoop([thisPtr, node]() { - LOG_TRACE << "Push sendfile to list"; - thisPtr->writeBufferList_.push_back(node); - + loop_->queueInLoop([thisPtr = shared_from_this(), + node = std::move(fileNode)]() mutable { + if (thisPtr->writeBufferList_.empty()) { - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; + auto n = thisPtr->sendNodeInLoop(node); + if (node->remainingBytes() > 0 && n >= 0) + thisPtr->writeBufferList_.push_back(std::move(node)); } - - if (thisPtr->writeBufferList_.size() == 1) + else { - thisPtr->sendFileInLoop(thisPtr->writeBufferList_.front()); + thisPtr->writeBufferList_.push_back(std::move(node)); } }); } @@ -1534,562 +624,399 @@ void TcpConnectionImpl::sendFile(FILE *fp, size_t offset, size_t length) void TcpConnectionImpl::sendStream( std::function callback) { - BufferNodePtr node = std::make_shared(); - node->offset_ = - 0; // not used, the offset should be handled by the callback - node->fileBytesToSend_ = 1; // force to > 0 until stream sent - node->streamCallback_ = std::move(callback); + auto node = BufferNode::newStreamBufferNode(std::move(callback)); if (loop_->isInLoopThread()) { - std::lock_guard guard(sendNumMutex_); - if (sendNum_ == 0) + if (writeBufferList_.empty()) { - writeBufferList_.push_back(node); - if (writeBufferList_.size() == 1) - { - sendFileInLoop(writeBufferList_.front()); - return; - } + auto n = sendNodeInLoop(node); + if (node->remainingBytes() > 0 && n >= 0) + writeBufferList_.push_back(std::move(node)); + return; } else { - ++sendNum_; - auto thisPtr = shared_from_this(); - loop_->queueInLoop([thisPtr, node]() { - thisPtr->writeBufferList_.push_back(node); + writeBufferList_.push_back(std::move(node)); + } + } + else + { + loop_->queueInLoop( + [thisPtr = shared_from_this(), node = std::move(node)]() mutable { + LOG_TRACE << "Push send stream to list"; + if (thisPtr->writeBufferList_.empty()) { - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; + auto n = thisPtr->sendNodeInLoop(node); + if (node->remainingBytes() > 0 && n >= 0) + thisPtr->writeBufferList_.push_back(std::move(node)); } - - if (thisPtr->writeBufferList_.size() == 1) + else { - thisPtr->sendFileInLoop(thisPtr->writeBufferList_.front()); + thisPtr->writeBufferList_.push_back(std::move(node)); } }); - } - } - else - { - auto thisPtr = shared_from_this(); - std::lock_guard guard(sendNumMutex_); - ++sendNum_; - loop_->queueInLoop([thisPtr, node]() { - LOG_TRACE << "Push sendstream to list"; - thisPtr->writeBufferList_.push_back(node); - - { - std::lock_guard guard1(thisPtr->sendNumMutex_); - --thisPtr->sendNum_; - } - - if (thisPtr->writeBufferList_.size() == 1) - { - thisPtr->sendFileInLoop(thisPtr->writeBufferList_.front()); - } - }); } } -void TcpConnectionImpl::sendFileInLoop(const BufferNodePtr &filePtr) +ssize_t TcpConnectionImpl::sendNodeInLoop(const BufferNodePtr &nodePtr) { loop_->assertInLoopThread(); - assert(filePtr->isFile()); #ifdef __linux__ - if (!isEncrypted_ && !filePtr->streamCallback_) + if (nodePtr->isFile() && !tlsProviderPtr_) { + static const long long kMaxSendBytes = 0x7ffff000; LOG_TRACE << "send file in loop using linux kernel sendfile()"; - auto bytesSent = sendfile(socketPtr_->fd(), - filePtr->sendFd_, - &filePtr->offset_, - filePtr->fileBytesToSend_); - if (bytesSent < 0) + auto toSend = nodePtr->remainingBytes(); + if (toSend <= 0) { - if (errno != EAGAIN) - { - LOG_SYSERR << "TcpConnectionImpl::sendFileInLoop"; - if (ioChannelPtr_->isWriting()) - ioChannelPtr_->disableWriting(); - } - return; + LOG_ERROR << "0 or negative bytes to send"; + return -1; } - if (bytesSent < filePtr->fileBytesToSend_) + auto bytesSent = + sendfile(socketPtr_->fd(), + nodePtr->getFd(), + nullptr, + static_cast( + toSend < kMaxSendBytes ? toSend : kMaxSendBytes)); + if (bytesSent > 0) { - if (bytesSent == 0) - { - LOG_SYSERR << "TcpConnectionImpl::sendFileInLoop"; - return; - } + nodePtr->retrieve(bytesSent); + bytesSent_ += bytesSent; } - LOG_TRACE << "sendfile() " << bytesSent << " bytes sent"; - filePtr->fileBytesToSend_ -= bytesSent; - if (!ioChannelPtr_->isWriting()) + else if (!isEAGAIN()) + return -1; + extendLife(); + if (bytesSent < toSend) { - ioChannelPtr_->enableWriting(); + LOG_TRACE << "bytesSent = " << bytesSent << " toSend = " << toSend; + if (!ioChannelPtr_->isWriting()) + ioChannelPtr_->enableWriting(); } - return; + return bytesSent; } #endif // Send stream - if (filePtr->streamCallback_) + + LOG_TRACE << "send node in loop"; + const char *data; + size_t len; + ssize_t hasSent = 0; + while ((nodePtr->remainingBytes() > 0)) { - LOG_TRACE << "send stream in loop"; - if (!fileBufferPtr_) - { - fileBufferPtr_ = std::make_unique>(); - fileBufferPtr_->reserve(16 * 1024); - } - while ((filePtr->fileBytesToSend_ > 0) || !fileBufferPtr_->empty()) + // get next chunk + nodePtr->getData(data, len); + if (len == 0) { - // get next chunk - if (fileBufferPtr_->empty()) - { - // LOG_TRACE << "send stream in loop: fetch data - // on buffer empty"; - fileBufferPtr_->resize(16 * 1024); - std::size_t nData; - nData = filePtr->streamCallback_(fileBufferPtr_->data(), - fileBufferPtr_->size()); - fileBufferPtr_->resize(nData); - if (nData == 0) // no more data! - { - LOG_TRACE << "send stream in loop: no more data"; - filePtr->fileBytesToSend_ = 0; - } - } - if (fileBufferPtr_->empty()) - { - LOG_TRACE << "send stream in loop: break on buffer empty"; - break; - } - auto nToWrite = fileBufferPtr_->size(); - auto nWritten = writeInLoop(fileBufferPtr_->data(), nToWrite); - if (nWritten >= 0) - { -#ifndef NDEBUG // defined by CMake for release build - filePtr->nDataWritten_ += nWritten; - LOG_TRACE << "send stream in loop: bytes written: " << nWritten - << " / total bytes written: " - << filePtr->nDataWritten_; -#endif - if (static_cast(nWritten) < nToWrite) - { - // Partial write - return and wait for next call to continue - fileBufferPtr_->erase(fileBufferPtr_->begin(), - fileBufferPtr_->begin() + nWritten); - if (!ioChannelPtr_->isWriting()) - ioChannelPtr_->enableWriting(); - LOG_TRACE << "send stream in loop: return on partial write " - "(socket buffer full?)"; - return; - } - // LOG_TRACE << "send stream in loop: continue on - // data written"; - fileBufferPtr_->resize(0); - continue; - } - // nWritten < 0 -#ifdef _WIN32 - if (errno != 0 && errno != EWOULDBLOCK) -#else - if (errno != EWOULDBLOCK) -#endif - { - if (errno == EPIPE || errno == ECONNRESET) - { -#ifdef _WIN32 - LOG_TRACE << "WSAENOTCONN or WSAECONNRESET, errno=" - << errno; -#else - LOG_TRACE << "EPIPE or ECONNRESET, errno=" << errno; -#endif - // abort - LOG_TRACE - << "send stream in loop: return on connection closed"; - filePtr->fileBytesToSend_ = 0; - return; - } - // TODO: any others? - LOG_SYSERR << "send stream in loop: return on unexpected error(" - << errno << ")"; - filePtr->fileBytesToSend_ = 0; - return; - } - // Socket buffer full - return and wait for next call - LOG_TRACE << "send stream in loop: break on socket buffer full (?)"; + nodePtr->done(); break; } - if (!ioChannelPtr_->isWriting()) - ioChannelPtr_->enableWriting(); - LOG_TRACE << "send stream in loop: return on loop exit"; - return; - } - // Send file - LOG_TRACE << "send file in loop"; - if (!fileBufferPtr_) - { - fileBufferPtr_ = std::make_unique>(16 * 1024); - } -#ifndef _WIN32 - lseek(filePtr->sendFd_, filePtr->offset_, SEEK_SET); - while (filePtr->fileBytesToSend_ > 0) - { - auto n = read(filePtr->sendFd_, - &(*fileBufferPtr_)[0], - std::min(fileBufferPtr_->size(), - static_castsize())>( - filePtr->fileBytesToSend_))); -#else - _fseeki64(filePtr->sendFp_, filePtr->offset_, SEEK_SET); - while (filePtr->fileBytesToSend_ > 0) - { - // LOG_TRACE << "send file in loop: fetch more remaining data"; - auto bytes = static_castsize())>( - filePtr->fileBytesToSend_); - auto n = fread(&(*fileBufferPtr_)[0], - 1, - (fileBufferPtr_->size() < bytes ? fileBufferPtr_->size() - : bytes), - filePtr->sendFp_); -#endif - if (n > 0) + auto nWritten = writeInLoop(data, len); + if (nWritten >= 0) { - auto nSend = writeInLoop(&(*fileBufferPtr_)[0], n); - if (nSend >= 0) - { - filePtr->fileBytesToSend_ -= nSend; - filePtr->offset_ += static_cast(nSend); - if (static_cast(nSend) < static_cast(n)) - { - if (!ioChannelPtr_->isWriting()) - { - ioChannelPtr_->enableWriting(); - } - LOG_TRACE << "send file in loop: return on partial write " - "(socket buffer full?)"; - return; - } - else if (nSend == n) - { - // LOG_TRACE << "send file in loop: - // continue on data written"; - continue; - } - } - if (nSend < 0) + hasSent += nWritten; + nodePtr->retrieve(nWritten); + if (static_cast(nWritten) < len) { -#ifdef _WIN32 - if (errno != 0 && errno != EWOULDBLOCK) -#else - if (errno != EWOULDBLOCK) -#endif - { - // TODO: any others? - if (errno == EPIPE || errno == ECONNRESET) - { -#ifdef _WIN32 - LOG_TRACE << "WSAENOTCONN or WSAECONNRESET, errno=" - << errno; -#else - LOG_TRACE << "EPIPE or ECONNRESET, errno=" << errno; -#endif - LOG_TRACE - << "send file in loop: return on connection closed"; - return; - } - LOG_SYSERR - << "send file in loop: return on unexpected error(" - << errno << ")"; - return; - } - LOG_TRACE - << "send file in loop: break on socket buffer full (?)"; break; } + continue; } - if (n < 0) - { - LOG_SYSERR << "send file in loop: return on read error"; - if (ioChannelPtr_->isWriting()) - ioChannelPtr_->disableWriting(); - return; - } - if (n == 0) + else { - LOG_SYSERR - << "send file in loop: return on read 0 (file truncated)"; - return; + LOG_TRACE << "error(" << errno << ") on send Node in loop"; + return -1; } } - LOG_TRACE << "send file in loop: return on loop exit"; - if (!ioChannelPtr_->isWriting()) - { - ioChannelPtr_->enableWriting(); - } + return hasSent; } #ifndef _WIN32 -ssize_t TcpConnectionImpl::writeInLoop(const void *buffer, size_t length) +ssize_t TcpConnectionImpl::writeRaw(const void *buffer, size_t length) #else -ssize_t TcpConnectionImpl::writeInLoop(const char *buffer, size_t length) +ssize_t TcpConnectionImpl::writeRaw(const char *buffer, size_t length) #endif { -#ifdef USE_OPENSSL - if (!isEncrypted_) - { -// LOG_TRACE << "write in loop"; -#endif + // TODO: Abstract this away to support io_uring (and IOCP?) #ifndef _WIN32 - int nWritten = write(socketPtr_->fd(), buffer, length); + int nWritten = write(socketPtr_->fd(), buffer, length); #else int nWritten = ::send(socketPtr_->fd(), buffer, static_cast(length), 0); errno = (nWritten < 0) ? ::WSAGetLastError() : 0; #endif - if (nWritten > 0) - bytesSent_ += nWritten; + if (nWritten > 0) + bytesSent_ += nWritten; + else if (!isEAGAIN()) return nWritten; -#ifdef USE_OPENSSL + if (nWritten < 0) + { + nWritten = 0; } - else + if (nWritten < static_cast(length)) { - // LOG_TRACE << "write encrypted in loop"; - loop_->assertInLoopThread(); - if (status_ != ConnStatus::Connected && - status_ != ConnStatus::Disconnecting) - { - LOG_WARN << "Connection is not connected,give up sending"; - return -1; - } - if (sslEncryptionPtr_->statusOfSSL_ != SSLStatus::Connected) - { - LOG_WARN << "SSL is not connected,give up sending"; - return -1; - } - // send directly - size_t sendTotalLen = 0; - while (sendTotalLen < length) - { - auto len = length - sendTotalLen; - if (len > sslEncryptionPtr_->sendBufferPtr_->size()) - { - len = sslEncryptionPtr_->sendBufferPtr_->size(); - } - memcpy(sslEncryptionPtr_->sendBufferPtr_->data(), - static_cast(buffer) + sendTotalLen, - len); - ERR_clear_error(); - auto sendLen = SSL_write(sslEncryptionPtr_->sslPtr_->get(), - sslEncryptionPtr_->sendBufferPtr_->data(), - static_cast(len)); - if (sendLen <= 0) - { - int sslerr = - SSL_get_error(sslEncryptionPtr_->sslPtr_->get(), sendLen); - if (sslerr != SSL_ERROR_WANT_WRITE && - sslerr != SSL_ERROR_WANT_READ) - { - // LOG_ERROR << "ssl write error:" << sslerr; - forceClose(); - return -1; - } - return sendTotalLen; - } - sendTotalLen += sendLen; - } - return sendTotalLen; + LOG_TRACE << "nWritten = " << nWritten << " length = " << length; + if (!ioChannelPtr_->isWriting()) + ioChannelPtr_->enableWriting(); } + extendLife(); + return nWritten; +} + +#ifndef _WIN32 +ssize_t TcpConnectionImpl::writeInLoop(const void *buffer, size_t length) +#else +ssize_t TcpConnectionImpl::writeInLoop(const char *buffer, size_t length) #endif +{ + if (tlsProviderPtr_) + return tlsProviderPtr_->sendData((const char *)buffer, length); + else + return writeRaw(buffer, length); } -#ifdef USE_OPENSSL +#if !(defined(USE_OPENSSL) || defined(USE_BOTAN)) +SSLContextPtr trantor::newSSLContext(const TLSPolicy &policy, bool isServer) +{ + (void)policy; + (void)isServer; + throw std::runtime_error("SSL is not supported"); +} -TcpConnectionImpl::TcpConnectionImpl(EventLoop *loop, - int socketfd, - const InetAddress &localAddr, - const InetAddress &peerAddr, - const std::shared_ptr &ctxPtr, - bool isServer, - bool validateCert, - const std::string &hostname) - : isEncrypted_(true), - loop_(loop), - ioChannelPtr_(new Channel(loop, socketfd)), - socketPtr_(new Socket(socketfd)), - localAddr_(localAddr), - peerAddr_(peerAddr) +std::shared_ptr trantor::newTLSProvider(TcpConnection *conn, + TLSPolicyPtr policy, + SSLContextPtr sslContext) { - LOG_TRACE << "new connection:" << peerAddr.toIpPort() << "->" - << localAddr.toIpPort(); - ioChannelPtr_->setReadCallback( - std::bind(&TcpConnectionImpl::readCallback, this)); - ioChannelPtr_->setWriteCallback( - std::bind(&TcpConnectionImpl::writeCallback, this)); - ioChannelPtr_->setCloseCallback( - std::bind(&TcpConnectionImpl::handleClose, this)); - ioChannelPtr_->setErrorCallback( - std::bind(&TcpConnectionImpl::handleError, this)); - socketPtr_->setKeepAlive(true); - name_ = localAddr.toIpPort() + "--" + peerAddr.toIpPort(); - sslEncryptionPtr_ = std::make_unique(); - sslEncryptionPtr_->sslPtr_ = - std::make_unique(ctxPtr->get(), ctxPtr->mtlsEnabled); - sslEncryptionPtr_->isServer_ = isServer; - validateCert_ = validateCert; - if (isServer == false || sslEncryptionPtr_->sslPtr_->mtlsEnabled) + (void)conn; + (void)policy; + (void)sslContext; + throw std::runtime_error("SSL is not supported"); +} +#endif + +void TcpConnectionImpl::startEncryption( + TLSPolicyPtr policy, + bool isServer, + std::function upgradeCallback) +{ + if (tlsProviderPtr_ || upgradeCallback_) { - LOG_TRACE << "MTLS: " << sslEncryptionPtr_->sslPtr_->mtlsEnabled; - SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(), - sslEncryptionPtr_->sslPtr_->mtlsEnabled - ? SSL_VERIFY_PEER - : SSL_VERIFY_NONE, - nullptr); + LOG_ERROR << "TLS is already started"; + return; } + auto sslContextPtr = newSSLContext(*policy, isServer); + tlsProviderPtr_ = + newTLSProvider(this, std::move(policy), std::move(sslContextPtr)); + tlsProviderPtr_->setWriteCallback(onSslWrite); + tlsProviderPtr_->setErrorCallback(onSslError); + tlsProviderPtr_->setHandshakeCallback(onHandshakeFinished); + tlsProviderPtr_->setMessageCallback(onSslMessage); + // This is triggered when peer sends a close alert + tlsProviderPtr_->setCloseCallback(onSslCloseAlert); + tlsProviderPtr_->startEncryption(); + upgradeCallback_ = std::move(upgradeCallback); +} - if (!isServer && !hostname.empty()) +void TcpConnectionImpl::onSslError(TcpConnection *self, SSLError err) +{ + if (self->sslErrorCallback_) + self->sslErrorCallback_(err); + self->forceClose(); +} +void TcpConnectionImpl::onHandshakeFinished(TcpConnection *self) +{ + auto connPtr = ((TcpConnectionImpl *)self)->shared_from_this(); + if (connPtr->upgradeCallback_) { - SSL_set_tlsext_host_name(sslEncryptionPtr_->sslPtr_->get(), - hostname.data()); - sslEncryptionPtr_->hostname_ = hostname; + connPtr->upgradeCallback_(connPtr); + connPtr->upgradeCallback_ = nullptr; } - assert(sslEncryptionPtr_->sslPtr_); - auto r = SSL_set_fd(sslEncryptionPtr_->sslPtr_->get(), socketfd); - (void)r; - assert(r); - isEncrypted_ = true; - sslEncryptionPtr_->sendBufferPtr_ = - std::make_unique>(); + else if (self->connectionCallback_) + self->connectionCallback_(connPtr); } - -bool TcpConnectionImpl::validatePeerCertificate() +void TcpConnectionImpl::onSslMessage(TcpConnection *self, MsgBuffer *buffer) { - LOG_TRACE << "Validating peer cerificate"; - assert(sslEncryptionPtr_ != nullptr); - assert(sslEncryptionPtr_->sslPtr_ != nullptr); - SSL *ssl = sslEncryptionPtr_->sslPtr_->get(); - - auto result = SSL_get_verify_result(ssl); - -#ifdef ALLOW_SELF_SIGNED_CERTS - if (result != X509_V_OK && - result != X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT && - result != X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN && - result != X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY) + if (self->recvMsgCallback_) + self->recvMsgCallback_(((TcpConnectionImpl *)self)->shared_from_this(), + buffer); +} +ssize_t TcpConnectionImpl::onSslWrite(TcpConnection *self, + const void *data, + size_t len) +{ + auto connPtr = (TcpConnectionImpl *)self; + return connPtr->writeRaw((const char *)data, len); +} +void TcpConnectionImpl::onSslCloseAlert(TcpConnection *self) +{ + self->shutdown(); +} +class AsyncStreamImpl : public AsyncStream +{ + public: + explicit AsyncStreamImpl(std::function callback) + : callback_(std::move(callback)) { - LOG_TRACE << "cert error code: " << result; - LOG_ERROR << "Server certificate is not valid"; - return false; } -#else - if (result != X509_V_OK && result) + AsyncStreamImpl() = delete; + bool send(const char *data, size_t len) override { - LOG_TRACE << "cert error code: " << result; - LOG_ERROR << "Server certificate is not valid"; - return false; + return callback_(data, len); } -#endif - - X509 *cert = SSL_get_peer_certificate(ssl); - if (cert == nullptr) + void close() override { - LOG_ERROR << "Unable to obtain peer certificate"; - return false; + callback_(nullptr, 0); + callback_ = nullptr; + } + ~AsyncStreamImpl() override + { + if (callback_) + callback_(nullptr, 0); } - bool domainIsValid = - internal::verifyCommonName(cert, sslEncryptionPtr_->hostname_) || - internal::verifyAltName(cert, sslEncryptionPtr_->hostname_); - X509_free(cert); - - LOG_TRACE << "domainIsValid: " << domainIsValid; - - // if mtlsEnabled, ignore domain validation - if (sslEncryptionPtr_->sslPtr_->mtlsEnabled || domainIsValid) + private: + std::function callback_; +}; +AsyncStreamPtr TcpConnectionImpl::sendAsyncStream(bool disableKickoff) +{ + auto asyncStreamNode = BufferNode::newAsyncStreamBufferNode(); + std::weak_ptr weakPtr = shared_from_this(); + auto asyncStream = std::make_unique( + [asyncStreamNode, weakPtr = std::move(weakPtr)](const char *data, + size_t len) -> bool { + auto thisPtr = weakPtr.lock(); + if (!thisPtr) + { + LOG_DEBUG << "Connection is closed,give up sending"; + return false; + } + if (thisPtr->status_ != ConnStatus::Connected) + { + LOG_DEBUG << "Connection is not connected,give up sending"; + return false; + } + if (thisPtr->loop_->isInLoopThread()) + { + thisPtr->sendAsyncDataInLoop(asyncStreamNode, data, len); + } + else + { + if (data) + { + std::string buffer(data, len); + thisPtr->loop_->queueInLoop([thisPtr, + asyncStreamNode, + buffer = std::move(buffer)]() { + thisPtr->sendAsyncDataInLoop(asyncStreamNode, + buffer.data(), + buffer.length()); + }); + } + else + { + thisPtr->loop_->queueInLoop([thisPtr, asyncStreamNode]() { + thisPtr->sendAsyncDataInLoop(asyncStreamNode, + nullptr, + 0); + }); + } + } + return true; + }); + if (loop_->isInLoopThread()) { - return true; + if (disableKickoff) + { + auto entry = kickoffEntry_.lock(); + if (entry) + { + entry->reset(); + kickoffEntry_.reset(); + } + idleTimeoutBackup_ = idleTimeout_; + idleTimeout_ = 0; + } + + writeBufferList_.push_back(asyncStreamNode); } else { - LOG_ERROR << "Domain validation failed"; - return false; - } -} + loop_->queueInLoop([this, + thisPtr = shared_from_this(), + node = std::move(asyncStreamNode), + disableKickoff]() mutable { + if (disableKickoff) + { + auto entry = kickoffEntry_.lock(); + if (entry) + { + entry->reset(); + kickoffEntry_.reset(); + } + idleTimeoutBackup_ = idleTimeout_; + idleTimeout_ = 0; + } -std::string TcpConnectionImpl::getOpenSSLErrorStack() -{ - BIO *bio = BIO_new(BIO_s_mem()); - ERR_print_errors(bio); - char *buf; - size_t len = BIO_get_mem_data(bio, &buf); - std::string ret(buf, len); - BIO_free(bio); - return ret; + if (thisPtr->writeBufferList_.empty() && node->remainingBytes() > 0) + { + auto n = thisPtr->sendNodeInLoop(node); + if (n >= 0 && (node->remainingBytes() > 0 || node->available())) + thisPtr->writeBufferList_.push_back(std::move(node)); + } + else + { + thisPtr->writeBufferList_.push_back(std::move(node)); + } + }); + } + return asyncStream; } - -void TcpConnectionImpl::doHandshaking() +void TcpConnectionImpl::sendAsyncDataInLoop(const BufferNodePtr &node, + const char *data, + size_t len) { - assert(sslEncryptionPtr_->statusOfSSL_ == SSLStatus::Handshaking); - - int r = SSL_do_handshake(sslEncryptionPtr_->sslPtr_->get()); - LOG_TRACE << "hand shaking: " << r; - if (r == 1) + loop_->assertInLoopThread(); + if (data) { - // Clients don't commonly have certificates (except on mTLS). - // So if the SSL session is on server-side and without mTLS enabled, - // let's not validate the client certificate. - if (validateCert_ && (!sslEncryptionPtr_->isServer_ || - sslEncryptionPtr_->sslPtr_->mtlsEnabled)) + if (len > 0) { - if (validatePeerCertificate() == false) + if (!writeBufferList_.empty() && node == writeBufferList_.front() && + node->remainingBytes() == 0) { - LOG_ERROR << "SSL certificate validation failed."; - ioChannelPtr_->disableReading(); - sslEncryptionPtr_->statusOfSSL_ = SSLStatus::DisConnected; - if (sslErrorCallback_) + auto nWritten = writeInLoop(data, len); + if (nWritten < 0) { - sslErrorCallback_(SSLError::kSSLInvalidCertificate); + LOG_TRACE << "write error"; + return; + } + if (static_cast(nWritten) < len) + { + node->append(data + nWritten, len - nWritten); } - forceClose(); - return; + } + else + { + node->append(data, len); } } - sslEncryptionPtr_->statusOfSSL_ = SSLStatus::Connected; - if (sslEncryptionPtr_->isUpgrade_) - { - sslEncryptionPtr_->upgradeCallback_(); - } - else - { - connectionCallback_(shared_from_this()); - } - return; - } - int err = SSL_get_error(sslEncryptionPtr_->sslPtr_->get(), r); - LOG_TRACE << "hand shaking: " << err; - if (err == SSL_ERROR_WANT_WRITE) - { // SSL want writable; - if (!ioChannelPtr_->isWriting()) - ioChannelPtr_->enableWriting(); - // ioChannelPtr_->disableReading(); - } - else if (err == SSL_ERROR_WANT_READ) - { // SSL want readable; - if (!ioChannelPtr_->isReading()) - ioChannelPtr_->enableReading(); - if (ioChannelPtr_->isWriting()) - ioChannelPtr_->disableWriting(); } else { - LOG_TRACE << "SSL handshake err: " << err; - LOG_TRACE << "SSL error stack: " << getOpenSSLErrorStack(); - ioChannelPtr_->disableReading(); - sslEncryptionPtr_->statusOfSSL_ = SSLStatus::DisConnected; - if (sslErrorCallback_) + // stream is closed + node->done(); + if (!writeBufferList_.empty() && node == writeBufferList_.front() && + !ioChannelPtr_->isWriting()) + ioChannelPtr_->enableWriting(); + + if (idleTimeoutBackup_ > 0) { - sslErrorCallback_(SSLError::kSSLHandshakeError); + auto timingWheel = timingWheelWeakPtr_.lock(); + if (timingWheel) + { + auto entry = std::make_shared(shared_from_this()); + kickoffEntry_ = entry; + idleTimeout_ = idleTimeoutBackup_; + idleTimeoutBackup_ = 0; + timingWheel->insertEntry(idleTimeout_, std::move(entry)); + } } - forceClose(); } } - -#endif diff --git a/trantor/net/inner/TcpConnectionImpl.h b/trantor/net/inner/TcpConnectionImpl.h index 72818e52..f19729a0 100644 --- a/trantor/net/inner/TcpConnectionImpl.h +++ b/trantor/net/inner/TcpConnectionImpl.h @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include #ifndef _WIN32 @@ -26,52 +28,15 @@ namespace trantor { -#ifdef USE_OPENSSL -enum class SSLStatus -{ - Handshaking, - Connecting, - Connected, - DisConnecting, - DisConnected -}; -class SSLContext; -class SSLConn; - -std::shared_ptr newSSLContext( - bool useOldTLS, - bool validateCert, - const std::vector> &sslConfCmds); -std::shared_ptr newSSLServerContext( - const std::string &certPath, - const std::string &keyPath, - bool useOldTLS, - const std::vector> &sslConfCmds, - const std::string &caPath); -std::shared_ptr newSSLClientContext( - bool useOldTLS, - bool validateCert, - const std::string &certPath = "", - const std::string &keyPath = "", - const std::vector> &sslConfCmds = {}, - const std::string &caPath = ""); - -// void initServerSSLContext(const std::shared_ptr &ctx, -// const std::string &certPath, -// const std::string &keyPath); -#endif class Channel; class Socket; class TcpServer; -void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn); class TcpConnectionImpl : public TcpConnection, public NonCopyable, public std::enable_shared_from_this { friend class TcpServer; friend class TcpClient; - friend void trantor::removeConnection(EventLoop *loop, - const TcpConnectionPtr &conn); public: class KickoffEntry @@ -98,70 +63,62 @@ class TcpConnectionImpl : public TcpConnection, std::weak_ptr conn_; }; - TcpConnectionImpl(EventLoop *loop, - int socketfd, - const InetAddress &localAddr, - const InetAddress &peerAddr); -#ifdef USE_OPENSSL TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr, - const std::shared_ptr &ctxPtr, - bool isServer = true, - bool validateCert = true, - const std::string &hostname = ""); -#endif - virtual ~TcpConnectionImpl(); - virtual void send(const char *msg, size_t len) override; - virtual void send(const void *msg, size_t len) override; - virtual void send(const std::string &msg) override; - virtual void send(std::string &&msg) override; - virtual void send(const MsgBuffer &buffer) override; - virtual void send(MsgBuffer &&buffer) override; - virtual void send(const std::shared_ptr &msgPtr) override; - virtual void send(const std::shared_ptr &msgPtr) override; - virtual void sendFile(const char *fileName, - size_t offset = 0, - size_t length = 0) override; - virtual void sendFile(const wchar_t *fileName, - size_t offset = 0, - size_t length = 0) override; - virtual void sendStream( + TLSPolicyPtr policy = nullptr, + SSLContextPtr ctx = nullptr); + ~TcpConnectionImpl() override; + void send(const char *msg, size_t len) override; + void send(const void *msg, size_t len) override; + void send(const std::string &msg) override; + void send(std::string &&msg) override; + void send(const MsgBuffer &buffer) override; + void send(MsgBuffer &&buffer) override; + void send(const std::shared_ptr &msgPtr) override; + void send(const std::shared_ptr &msgPtr) override; + void sendFile(const char *fileName, + long long offset, + long long length) override; + void sendFile(const wchar_t *fileName, + long long offset, + long long length) override; + void sendStream( std::function callback) override; - virtual const InetAddress &localAddr() const override + const InetAddress &localAddr() const override { return localAddr_; } - virtual const InetAddress &peerAddr() const override + const InetAddress &peerAddr() const override { return peerAddr_; } - virtual bool connected() const override + bool connected() const override { return status_ == ConnStatus::Connected; } - virtual bool disconnected() const override + bool disconnected() const override { return status_ == ConnStatus::Disconnected; } // virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;} - virtual MsgBuffer *getRecvBuffer() override - { - return &readBuffer_; - } + // virtual MsgBuffer *getRecvBuffer() override + // { + // return &readBuffer_; + // } // set callbacks - virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, - size_t markLen) override + void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, + size_t markLen) override { highWaterMarkCallback_ = cb; highWaterMarkLen_ = markLen; } - virtual void keepAlive() override + void keepAlive() override { idleTimeout_ = 0; auto entry = kickoffEntry_.lock(); @@ -170,50 +127,71 @@ class TcpConnectionImpl : public TcpConnection, entry->reset(); } } - virtual bool isKeepAlive() override + bool isKeepAlive() override { return idleTimeout_ == 0; } - virtual void setTcpNoDelay(bool on) override; - virtual void shutdown() override; - virtual void forceClose() override; - virtual EventLoop *getLoop() override + void setTcpNoDelay(bool on) override; + void shutdown() override; + void forceClose() override; + EventLoop *getLoop() override { return loop_; } - virtual size_t bytesSent() const override + size_t bytesSent() const override { return bytesSent_; } - virtual size_t bytesReceived() const override + size_t bytesReceived() const override { return bytesReceived_; } - virtual void startClientEncryption( - std::function callback, - bool useOldTLS = false, - bool validateCert = true, - std::string hostname = "", - const std::vector> &sslConfCmds = - {}) override; - virtual void startServerEncryption(const std::shared_ptr &ctx, - std::function callback) override; - virtual bool isSSLConnection() const override + + bool isSSLConnection() const override { - return isEncrypted_; + return tlsProviderPtr_ != nullptr; } + void connectEstablished() override; + void connectDestroyed() override; - private: - /// Internal use only. + MsgBuffer *getRecvBuffer() override + { + if (tlsProviderPtr_) + return &tlsProviderPtr_->getRecvBuffer(); + return &readBuffer_; + } - std::weak_ptr kickoffEntry_; - std::weak_ptr timingWheelWeakPtr_; - size_t idleTimeout_{0}; - Date lastTimingWheelUpdateTime_; + std::string applicationProtocol() const override + { + if (tlsProviderPtr_) + return tlsProviderPtr_->applicationProtocol(); + return ""; + } + + CertificatePtr peerCertificate() const override + { + if (tlsProviderPtr_) + return tlsProviderPtr_->peerCertificate(); + return nullptr; + } - void enableKickingOff(size_t timeout, - const std::shared_ptr &timingWheel) + std::string sniName() const override + { + if (tlsProviderPtr_) + return tlsProviderPtr_->sniName(); + return ""; + } + + void startEncryption( + TLSPolicyPtr policy, + bool isServer, + std::function upgradeCallback) override; + AsyncStreamPtr sendAsyncStream(bool disableKickoff) override; + + void enableKickingOff( + size_t timeout, + const std::shared_ptr &timingWheel) override { assert(timingWheel); assert(timingWheel->getLoop() == loop_); @@ -224,101 +202,19 @@ class TcpConnectionImpl : public TcpConnection, idleTimeout_ = timeout; timingWheel->insertEntry(timeout, entry); } + + private: + /// Internal use only. + + std::weak_ptr kickoffEntry_; + std::weak_ptr timingWheelWeakPtr_; + size_t idleTimeout_{0}; + size_t idleTimeoutBackup_{0}; + Date lastTimingWheelUpdateTime_; void extendLife(); -#ifndef _WIN32 - void sendFile(int sfd, size_t offset = 0, size_t length = 0); -#else - void sendFile(FILE *fp, size_t offset = 0, size_t length = 0); -#endif - void setRecvMsgCallback(const RecvMessageCallback &cb) - { - recvMsgCallback_ = cb; - } - void setRecvMsgCallback(RecvMessageCallback &&cb) - { - recvMsgCallback_ = std::move(cb); - } - void setConnectionCallback(const ConnectionCallback &cb) - { - connectionCallback_ = cb; - } - void setConnectionCallback(ConnectionCallback &&cb) - { - connectionCallback_ = std::move(cb); - } - void setWriteCompleteCallback(const WriteCompleteCallback &cb) - { - writeCompleteCallback_ = cb; - } - void setWriteCompleteCallback(WriteCompleteCallback &&cb) - { - writeCompleteCallback_ = std::move(cb); - } - void setCloseCallback(const CloseCallback &cb) - { - closeCallback_ = cb; - } - void setCloseCallback(CloseCallback &&cb) - { - closeCallback_ = std::move(cb); - } - void setSSLErrorCallback(const SSLErrorCallback &cb) - { - sslErrorCallback_ = cb; - } - void setSSLErrorCallback(SSLErrorCallback &&cb) - { - sslErrorCallback_ = std::move(cb); - } - void connectDestroyed(); - virtual void connectEstablished(); + void sendFile(BufferNodePtr &&fileNode); protected: - struct BufferNode - { - // sendFile() specific -#ifndef _WIN32 - int sendFd_{-1}; - off_t offset_{0}; -#else - FILE *sendFp_{nullptr}; - long long offset_{0}; -#endif - ssize_t fileBytesToSend_{0}; - // sendStream() specific - std::function streamCallback_; -#ifndef NDEBUG // defined by CMake for release build - std::size_t nDataWritten_{0}; -#endif - // generic - std::shared_ptr msgBuffer_; - bool isFile() const - { - if (streamCallback_) - return true; -#ifndef _WIN32 - if (sendFd_ >= 0) - return true; -#else - if (sendFp_) - return true; -#endif - return false; - } - ~BufferNode() - { -#ifndef _WIN32 - if (sendFd_ >= 0) - close(sendFd_); -#else - if (sendFp_) - fclose(sendFp_); -#endif - if (streamCallback_) - streamCallback_(nullptr, 0); // cleanup callback internals - } - }; - using BufferNodePtr = std::shared_ptr; enum class ConnStatus { Disconnected, @@ -326,7 +222,6 @@ class TcpConnectionImpl : public TcpConnection, Connected, Disconnecting }; - bool isEncrypted_{false}; EventLoop *loop_; std::unique_ptr ioChannelPtr_; std::unique_ptr socketPtr_; @@ -336,63 +231,44 @@ class TcpConnectionImpl : public TcpConnection, void writeCallback(); InetAddress localAddr_, peerAddr_; ConnStatus status_{ConnStatus::Connecting}; - // callbacks - RecvMessageCallback recvMsgCallback_; - ConnectionCallback connectionCallback_; - CloseCallback closeCallback_; - WriteCompleteCallback writeCompleteCallback_; - HighWaterMarkCallback highWaterMarkCallback_; - SSLErrorCallback sslErrorCallback_; void handleClose(); void handleError(); // virtual void sendInLoop(const std::string &msg); - - void sendFileInLoop(const BufferNodePtr &file); + void sendAsyncDataInLoop(const BufferNodePtr &node, + const char *data, + size_t len); + // -1: error, 0: EAGAIN, >0: bytes sent + ssize_t sendNodeInLoop(const BufferNodePtr &node); #ifndef _WIN32 void sendInLoop(const void *buffer, size_t length); + ssize_t writeRaw(const void *buffer, size_t length); ssize_t writeInLoop(const void *buffer, size_t length); #else void sendInLoop(const char *buffer, size_t length); + // -1: error, 0: EAGAIN, >0: bytes sent + ssize_t writeRaw(const char *buffer, size_t length); + // -1: error, 0: EAGAIN, >0: bytes sent ssize_t writeInLoop(const char *buffer, size_t length); #endif - size_t highWaterMarkLen_; + size_t highWaterMarkLen_{0}; std::string name_; - uint64_t sendNum_{0}; - std::mutex sendNumMutex_; - size_t bytesSent_{0}; size_t bytesReceived_{0}; - std::unique_ptr> fileBufferPtr_; + // std::unique_ptr> fileBufferPtr_; + std::shared_ptr tlsProviderPtr_; + std::function upgradeCallback_; -#ifdef USE_OPENSSL - private: - void doHandshaking(); - bool validatePeerCertificate(); - std::string getOpenSSLErrorStack(); - struct SSLEncryption - { - SSLStatus statusOfSSL_ = SSLStatus::Handshaking; - // OpenSSL - std::shared_ptr sslCtxPtr_; - std::unique_ptr sslPtr_; - std::unique_ptr> sendBufferPtr_; - bool isServer_{false}; - bool isUpgrade_{false}; - std::function upgradeCallback_; - std::string hostname_; - }; - std::unique_ptr sslEncryptionPtr_; - void startClientEncryptionInLoop( - std::function &&callback, - bool useOldTLS, - bool validateCert, - const std::string &hostname, - const std::vector> &sslConfCmds); - void startServerEncryptionInLoop(const std::shared_ptr &ctx, - std::function &&callback); -#endif + bool closeOnEmpty_{false}; + + static void onSslError(TcpConnection *self, SSLError err); + static void onHandshakeFinished(TcpConnection *self); + static void onSslMessage(TcpConnection *self, MsgBuffer *buffer); + static ssize_t onSslWrite(TcpConnection *self, + const void *data, + size_t len); + static void onSslCloseAlert(TcpConnection *self); }; using TcpConnectionImplPtr = std::shared_ptr; diff --git a/trantor/net/inner/poller/EpollPoller.cc b/trantor/net/inner/poller/EpollPoller.cc index 8be13aa5..c9c4f7d4 100644 --- a/trantor/net/inner/poller/EpollPoller.cc +++ b/trantor/net/inner/poller/EpollPoller.cc @@ -53,7 +53,7 @@ const int kDeleted = 2; EpollPoller::EpollPoller(EventLoop *loop) : Poller(loop), #ifdef _WIN32 - // wepoll does not suppor flags + // wepoll does not support flags epollfd_(::epoll_create1(0)), #else epollfd_(::epoll_create1(EPOLL_CLOEXEC)), @@ -85,7 +85,7 @@ void EpollPoller::poll(int timeoutMs, ChannelList *activeChannels) // Timestamp now(Timestamp::now()); if (numEvents > 0) { - // LOG_TRACE << numEvents << " events happended"; + // LOG_TRACE << numEvents << " events happened"; fillActiveChannels(numEvents, activeChannels); if (static_cast(numEvents) == events_.size()) { @@ -94,7 +94,7 @@ void EpollPoller::poll(int timeoutMs, ChannelList *activeChannels) } else if (numEvents == 0) { - // std::cout << "nothing happended" << std::endl; + // std::cout << "nothing happened" << std::endl; } else { diff --git a/trantor/net/inner/poller/KQueue.cc b/trantor/net/inner/poller/KQueue.cc index acf4ae78..98636ef2 100644 --- a/trantor/net/inner/poller/KQueue.cc +++ b/trantor/net/inner/poller/KQueue.cc @@ -59,7 +59,7 @@ void KQueue::poll(int timeoutMs, ChannelList *activeChannels) // Timestamp now(Timestamp::now()); if (numEvents > 0) { - // LOG_TRACE << numEvents << " events happended"; + // LOG_TRACE << numEvents << " events happened"; fillActiveChannels(numEvents, activeChannels); if (static_cast(numEvents) == events_.size()) { @@ -68,7 +68,7 @@ void KQueue::poll(int timeoutMs, ChannelList *activeChannels) } else if (numEvents == 0) { - // std::cout << "nothing happended" << std::endl; + // std::cout << "nothing happened" << std::endl; } else { diff --git a/trantor/net/inner/poller/PollPoller.cc b/trantor/net/inner/poller/PollPoller.cc index 5223433a..c3cb5ab4 100644 --- a/trantor/net/inner/poller/PollPoller.cc +++ b/trantor/net/inner/poller/PollPoller.cc @@ -30,7 +30,7 @@ PollPoller::PollPoller(EventLoop* loop) : Poller(loop) { std::call_once(warning_flag, []() { LOG_WARN << "Creating a PollPoller. This poller is slow and should " - "only be used when no other pollers are avaliable"; + "only be used when no other pollers are available"; }); } diff --git a/trantor/net/inner/tlsprovider/BotanTLSProvider.cc b/trantor/net/inner/tlsprovider/BotanTLSProvider.cc new file mode 100644 index 00000000..93c1ee4d --- /dev/null +++ b/trantor/net/inner/tlsprovider/BotanTLSProvider.cc @@ -0,0 +1,492 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace trantor; +using namespace std::placeholders; + +static std::once_flag sessionManagerInitFlag; +static std::shared_ptr sessionManagerRng; +static std::shared_ptr sessionManager; +static thread_local std::shared_ptr rng; +static std::unique_ptr systemCertStore; +static std::once_flag systemCertStoreInitFlag; + +using namespace trantor; + +static std::string join(const std::vector &vec, + const std::string &delim) +{ + std::string ret; + for (auto const &str : vec) + { + if (ret.empty() == false) + ret += delim; + ret += str; + } + return ret; +} + +class Credentials : public Botan::Credentials_Manager +{ + public: + Credentials(std::shared_ptr key, + Botan::X509_Certificate *cert, + Botan::Certificate_Store *certStore) + : certStore_(certStore), cert_(cert), key_(key) + { + } + std::vector trusted_certificate_authorities( + const std::string &type, + const std::string &context) override + { + (void)type; + (void)context; + if (certStore_ == nullptr) + return {}; + return {certStore_}; + } + + std::vector find_cert_chain( + const std::vector &cert_key_types, + const std::vector &cert_signature_schemes, + const std::vector &acceptable_CAs, + const std::string &type, + const std::string &context) override + { + (void)type; + (void)context; + (void)cert_signature_schemes; + (void)acceptable_CAs; + if (cert_ == nullptr) + return {}; + + auto key_algo = + cert_->subject_public_key_algo().oid().to_formatted_string(); + auto it = + std::find(cert_key_types.begin(), cert_key_types.end(), key_algo); + if (it == cert_key_types.end()) + return {}; + return {*cert_}; + } + + std::shared_ptr private_key_for( + const Botan::X509_Certificate &cert, + const std::string &type, + const std::string &context) override + { + (void)cert; + (void)type; + (void)context; + return key_; + } + Botan::Certificate_Store *certStore_ = nullptr; + Botan::X509_Certificate *cert_ = nullptr; + std::shared_ptr key_ = nullptr; +}; + +struct BotanCertificate : public Certificate +{ + BotanCertificate(const Botan::X509_Certificate &cert) : cert_(cert) + { + } + + virtual std::string sha1Fingerprint() const override + { + return cert_.fingerprint("SHA-1"); + } + + virtual std::string sha256Fingerprint() const override + { + return cert_.fingerprint("SHA-256"); + } + + virtual std::string pem() const override + { + return cert_.PEM_encode(); + } + Botan::X509_Certificate cert_; +}; + +namespace trantor +{ +struct SSLContext +{ + std::shared_ptr key; + std::unique_ptr cert; + std::shared_ptr certStore; + bool isServer = false; + bool requireClientCert = false; +}; +} // namespace trantor + +class TrantorPolicy : public Botan::TLS::Policy +{ + virtual bool require_cert_revocation_info() const override + { + return false; + } + + virtual bool require_client_certificate_authentication() const override + { + return requireClientCert_; + } + + public: + bool requireClientCert_ = false; +}; + +struct BotanTLSProvider : public TLSProvider, + public NonCopyable, + public Botan::TLS::Callbacks, + public std::enable_shared_from_this +{ + public: + BotanTLSProvider(TcpConnection *conn, + TLSPolicyPtr policy, + SSLContextPtr ctx) + : TLSProvider(conn, std::move(policy), std::move(ctx)) + { + validationPolicy_ = std::make_shared(); + } + + virtual void recvData(MsgBuffer *buffer) override + { + LOG_TRACE << "Low level connection received " << buffer->readableBytes() + << " bytes."; + try + { + assert(channel_ != nullptr); + channel_->received_data((const uint8_t *)buffer->peek(), + buffer->readableBytes()); + } + catch (const Botan::TLS::TLS_Exception &e) + { + LOG_ERROR << "Unexpected TLS Exception: " << e.what(); + conn_->shutdown(); + + if (tlsConnected_ == false) + { + if (e.type() == Botan::TLS::Alert::BadCertificate) + handleSSLError(SSLError::kSSLInvalidCertificate); + else + handleSSLError(SSLError::kSSLHandshakeError); + } + else + handleSSLError(SSLError::kSSLProtocolError); + } + catch (const Botan::Exception &e) + { + LOG_ERROR << "Unexpected Botan Exception: " << e.what(); + conn_->shutdown(); + if (tlsConnected_ == false) + handleSSLError(SSLError::kSSLHandshakeError); + else + handleSSLError(SSLError::kSSLProtocolError); + } + catch (const std::exception &e) + { + LOG_ERROR << "Unexpected Generic Exception: " << e.what(); + conn_->shutdown(); + if (tlsConnected_ == false) + handleSSLError(SSLError::kSSLHandshakeError); + else + handleSSLError(SSLError::kSSLProtocolError); + } + buffer->retrieveAll(); + } + + virtual ssize_t sendData(const char *ptr, size_t size) override + { + if (getBufferedData().readableBytes() != 0) + { + errno = EAGAIN; + return 0; + } + + // Limit the size of the data we send in one go to avoid holding massive + // buffers in memory. + constexpr size_t maxSend = 64 * 1024; + size_t hasSent = 0; + while (hasSent < size && getBufferedData().readableBytes() == 0) + { + auto trunkLen = size - hasSent; + if (trunkLen > maxSend) + trunkLen = maxSend; + channel_->send((const uint8_t *)ptr, size); + // HACK: Botan doesn't provide a way to know how much raw data has + // been written to the underlying transport. So we have to assume + // that all data has been written. And cache the unwritten data in + // writeBuffer_. Then "fake" the consumed size in sendData() to make + // the caller think that all data has been written. Then return -1 + // if the underlying socket is not writable at all (i.e. write is + // all or nothing) + if (lastWriteSize_ == -1) + return -1; + hasSent += trunkLen; + } + return static_cast(hasSent); + } + + virtual void close() override + { + if (channel_ && channel_->is_active()) + { + channel_->close(); + } + } + + virtual void startEncryption() override + { + auto certStorePtr = contextPtr_->certStore.get(); + if (certStorePtr == nullptr) + { + std::call_once(systemCertStoreInitFlag, []() { + systemCertStore = + std::make_unique(); + }); + certStorePtr = systemCertStore.get(); + } + credsPtr_ = std::make_shared(contextPtr_->key, + contextPtr_->cert.get(), + certStorePtr); + if (policyPtr_->getConfCmds().empty() == false) + LOG_WARN << "BotanTLSConnectionImpl does not support sslConfCmds."; + + // initialize rng and session manager if we haven't already + std::call_once(sessionManagerInitFlag, []() { + sessionManagerRng = std::make_shared(); + sessionManager = + std::make_shared( + sessionManagerRng); + }); + if (rng == nullptr) + rng = std::make_shared(); + + auto fakeThis = std::shared_ptr(this, [](auto) {}); + if (contextPtr_->isServer) + { + // TODO: Need a more scalable way to manage session validation rules + validationPolicy_->requireClientCert_ = + contextPtr_->requireClientCert; + channel_ = std::make_unique(std::move(fakeThis), + sessionManager, + credsPtr_, + validationPolicy_, + rng); + } + else + { + validationPolicy_->requireClientCert_ = + contextPtr_->requireClientCert; + // technically Botan2 does support TLS 1.0 and 1.1, but Botan3 does + // not. So we just disable them to keep compatibility. + if (policyPtr_->getUseOldTLS()) + LOG_WARN << "Old TLS not supported by Botan (only >= TLS 1.2)"; + channel_ = std::make_unique( + std::move(fakeThis), + sessionManager, + credsPtr_, + validationPolicy_, + rng, + Botan::TLS::Server_Information(policyPtr_->getHostname(), + conn_->peerAddr().toPort()), + Botan::TLS::Protocol_Version::TLS_V12, + policyPtr_->getAlpnProtocols()); + setSniName(policyPtr_->getHostname()); + } + } + + void handleSSLError(SSLError err) + { + if (!errorCallback_) + return; + loop_->queueInLoop([this, err]() { errorCallback_(conn_, err); }); + } + + virtual ~BotanTLSProvider() override = default; + + void tls_emit_data(std::span data) override + { + auto n = writeCallback_(conn_, data.data(), data.size_bytes()); + lastWriteSize_ = n; + + // store the unsent data and send it later + if (n == ssize_t(data.size_bytes())) + return; + if (n == -1) + n = 0; + appendToWriteBuffer((const char *)data.data() + n, + data.size_bytes() - n); + } + + void tls_record_received(uint64_t seq_no, + std::span data) override + { + (void)seq_no; + recvBuffer_.append((const char *)data.data(), data.size_bytes()); + if (messageCallback_) + messageCallback_(conn_, &recvBuffer_); + } + + std::string tls_server_choose_app_protocol( + const std::vector &client_protos) override + { + assert(contextPtr_->isServer); + if (policyPtr_->getAlpnProtocols().empty() || client_protos.empty()) + return ""; + + for (auto const &proto : client_protos) + { + if (std::find(policyPtr_->getAlpnProtocols().begin(), + policyPtr_->getAlpnProtocols().end(), + proto) != policyPtr_->getAlpnProtocols().end()) + return proto; + } + + throw Botan::TLS::TLS_Exception( + Botan::TLS::Alert::NoApplicationProtocol, + "No supported application protocol found. Client offered: " + + join(client_protos, ", ") + " but we support: " + + join(policyPtr_->getAlpnProtocols(), ", ")); + } + + void tls_alert(Botan::TLS::Alert alert) override + { + if (alert.type() == Botan::TLS::Alert::CloseNotify) + { + LOG_TRACE << "TLS close notify received"; + if (closeCallback_) + closeCallback_(conn_); + } + else + { + if (errorCallback_) + errorCallback_(conn_, SSLError::kSSLProtocolError); + } + } + + void tls_session_activated() override + { + LOG_TRACE << "tls_session_activated"; + tlsConnected_ = true; + setApplicationProtocol(channel_->application_protocol()); + if (handshakeCallback_) + handshakeCallback_(conn_); + } + + void tls_verify_cert_chain( + const std::vector &certs, + const std::vector> &ocsp, + const std::vector &trusted_roots, + Botan::Usage_Type usage, + std::string_view hostname, + const Botan::TLS::Policy &policy) override + { + setSniName(std::string(hostname)); + if (policyPtr_->getValidate() && !policyPtr_->getAllowBrokenChain()) + Botan::TLS::Callbacks::tls_verify_cert_chain( + certs, ocsp, trusted_roots, usage, hostname, policy); + else if (policyPtr_->getValidate()) + { + if (certs.size() == 0) + throw Botan::TLS::TLS_Exception( + Botan::TLS::Alert::NoCertificate, + "Certificate validation failed: no certificate"); + // handle self-signed certificate + std::vector selfSigned = {certs[0]}; + + Botan::Path_Validation_Restrictions restrictions( + false, // require revocation + validationPolicy_->minimum_signature_strength()); + + auto now = std::chrono::system_clock::now(); + const auto status = Botan::PKIX::check_chain( + selfSigned, now, hostname, usage, restrictions); + + const auto result = Botan::PKIX::overall_status(status); + + if (result != Botan::Certificate_Status_Code::OK) + throw Botan::TLS::TLS_Exception( + Botan::TLS::Alert::BadCertificate, + std::string("Certificate validation failed: ") + + Botan::to_string(result)); + } + + if (certs.size() > 0) + setPeerCertificate(std::make_shared(certs[0])); + } + + std::shared_ptr validationPolicy_; + std::shared_ptr credsPtr_; + std::unique_ptr channel_; + bool tlsConnected_ = false; + ssize_t lastWriteSize_ = 0; +}; + +std::shared_ptr trantor::newTLSProvider(TcpConnection *conn, + TLSPolicyPtr policy, + SSLContextPtr ctx) +{ + return std::make_shared(conn, + std::move(policy), + std::move(ctx)); +} + +SSLContextPtr trantor::newSSLContext(const TLSPolicy &policy, bool server) +{ + auto ctx = std::make_shared(); + ctx->isServer = server; + if (!policy.getKeyPath().empty()) + { + Botan::DataSource_Stream in(policy.getKeyPath()); + ctx->key = Botan::PKCS8::load_key(in); + } + + if (!policy.getCertPath().empty()) + { + ctx->cert = + std::make_unique(policy.getCertPath()); + } + + if (policy.getValidate() && policy.getAllowBrokenChain()) + { + if (!policy.getCaPath().empty()) + { + ctx->certStore = + std::make_shared( + policy.getCaPath()); + if (server) + ctx->requireClientCert = true; + } + else if (policy.getUseSystemCertStore()) + { + static auto systemCertStore = + std::make_shared(); + ctx->certStore = systemCertStore; + } + } + + if (policy.getUseOldTLS()) + LOG_WARN << "SSLPloicy have set useOldTLS to true. BUt Botan does not " + "support TLS/SSL below TLS 1.2. Ignoring this option."; + return ctx; +} diff --git a/trantor/net/inner/tlsprovider/OpenSSLProvider.cc b/trantor/net/inner/tlsprovider/OpenSSLProvider.cc new file mode 100644 index 00000000..f91e9cde --- /dev/null +++ b/trantor/net/inner/tlsprovider/OpenSSLProvider.cc @@ -0,0 +1,904 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace trantor; + +// Force OpenSSL to initialize before main() is called +static bool sslInitFlag = []() { +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSL_library_init(); + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); + ERR_load_BIO_strings(); + ERR_load_crypto_strings(); +#elif defined(LIBRESSL_VERSION_NUMBER) + // LibreSSL needs explicit de-init + atexit(OPENSSL_cleanup); +#endif + return true; +}(); + +namespace internal +{ +#ifdef _WIN32 +// Code yanked from stackoverflow +// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store +inline bool loadWindowsSystemCert(X509_STORE *store) +{ + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); + + if (!hStore) + { + return false; + } + + PCCERT_CONTEXT pContext = NULL; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) + { + auto encoded_cert = + static_cast(pContext->pbCertEncoded); + + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) + { + X509_STORE_add_cert(store, x509); + X509_free(x509); + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + + return true; +} +#endif + +inline bool verifyCommonName(X509 *cert, const std::string &hostname) +{ + X509_NAME *subjectName = X509_get_subject_name(cert); + + if (subjectName != nullptr) + { + std::array name; + auto length = X509_NAME_get_text_by_NID(subjectName, + NID_commonName, + name.data(), + (int)name.size()); + if (length == -1) + return false; + + return utils::verifySslName(std::string(name.begin(), + name.begin() + length), + hostname); + } + + return false; +} + +inline bool verifyAltName(X509 *cert, const std::string &hostname) +{ + bool good = false; + auto altNames = static_cast( + X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); + + if (altNames) + { + int numNames = sk_GENERAL_NAME_num(altNames); + + for (int i = 0; i < numNames && !good; i++) + { + auto val = sk_GENERAL_NAME_value(altNames, i); + if (val->type != GEN_DNS) + { + LOG_WARN << "Name using IP addresses are not supported. Open " + "an issue if you need that feature"; + continue; + } +#if (OPENSSL_VERSION_NUMBER >= 0x10100000L) + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); +#else + auto name = (const char *)ASN1_STRING_data(val->d.ia5); +#endif + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + good = utils::verifySslName(std::string(name, name + name_len), + hostname); + } + } + + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)altNames); + return good; +} + +static bool validatePeerCertificate(SSL *ssl, + X509 *cert, + const std::string &hostname, + bool allowBrokenChain, + bool isServer) +{ + assert(ssl != nullptr); + assert(cert != nullptr); + LOG_TRACE << "Validating peer certificate"; + + if (isServer) + { + bool domainIsValid = + verifyCommonName(cert, hostname) || verifyAltName(cert, hostname); + if (!domainIsValid) + return false; + } + + auto result = SSL_get_verify_result(ssl); + if (result == X509_V_ERR_CERT_NOT_YET_VALID || + result == X509_V_ERR_CERT_HAS_EXPIRED) + { + // What happens if cert is self-signed and expired? + LOG_TRACE << "cert error code: " << result + << ", date validation failed"; + return false; + } + + if (result != X509_V_OK && !allowBrokenChain) + { + LOG_TRACE << "cert error code: " << result; + LOG_ERROR << "Peer certificate is not valid"; + return false; + } + + return true; +} + +static int serverSelectProtocol(SSL *ssl, + const unsigned char **out, + unsigned char *outlen, + const unsigned char *in, + unsigned int inlen, + void *arg) +{ + (void)ssl; + auto protocols = static_cast *>(arg); + if (protocols->empty()) + return SSL_TLSEXT_ERR_NOACK; + + for (auto &protocol : *protocols) + { + const unsigned char *cur = in; + const unsigned char *end = in + inlen; + while (cur < end) + { + unsigned int len = *cur++; + if (cur + len > end) + { + LOG_ERROR << "Client provided invalid protocol list in APLN"; + return SSL_TLSEXT_ERR_NOACK; + } + if (protocol.size() == len && + memcmp(cur, protocol.data(), len) == 0) + { + *out = cur; + *outlen = len; + LOG_TRACE << "Selected protocol: " << protocol; + return SSL_TLSEXT_ERR_OK; + } + } + } + + return SSL_TLSEXT_ERR_NOACK; +} + +} // namespace internal + +namespace trantor +{ +struct SSLContext +{ + SSLContext( + bool useOldTLS, + const std::vector> &sslConfCmds, + bool server) + : isServer(server) + { + // Ungodly amount of preprocessor macros to support older versions of + // OpenSSL and LibreSSL +#if OPENSSL_VERSION_NUMBER < 0x10100000L || defined(LIBRESSL_VERSION_NUMBER) +#define SSL_METHOD SSLv23_method +#else +#define SSL_METHOD TLS_method +#endif + +#ifdef LIBRESSL_VERSION_NUMBER + ctx_ = SSL_CTX_new(SSL_METHOD()); + if (ctx_ == nullptr) + throw std::runtime_error("Failed to create SSL context"); + if (sslConfCmds.size() != 0) + LOG_WARN << "LibreSSL does not support SSL configuration commands"; + + if (!useOldTLS) + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); +#else + ctx_ = SSL_CTX_new(SSL_METHOD()); + if (ctx_ == nullptr) + throw std::runtime_error("Failed to create SSL context"); + SSL_CONF_CTX *cctx = SSL_CONF_CTX_new(); + SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_SERVER); + SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CLIENT); + SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CERTIFICATE); + SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_FILE); + SSL_CONF_CTX_set_ssl_ctx(cctx, ctx_); + for (const auto &cmd : sslConfCmds) + SSL_CONF_cmd(cctx, cmd.first.data(), cmd.second.data()); + SSL_CONF_CTX_finish(cctx); + SSL_CONF_CTX_free(cctx); + if (useOldTLS == false) + { +#if OPENSSL_VERSION_NUMBER >= 0x10101000L + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); +#else + const auto opt = SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1 | + SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3; + SSL_CTX_set_options(ctx_, opt); +#endif + } + else + { + LOG_WARN << "TLS 1.1 and below enabled. They are considered " + "obsolete, insecure standards and should only be " + "used for legacy purpose."; + } +#endif + } + ~SSLContext() + { + if (ctx_) + SSL_CTX_free(ctx_); + } + SSL_CTX *ctx_ = nullptr; + + SSL_CTX *ctx() const + { + return ctx_; + } + + bool isServer{false}; +}; + +struct OpenSSLCertificate : public Certificate +{ + OpenSSLCertificate(X509 *cert) : cert_(cert) + { + assert(cert_); + } + ~OpenSSLCertificate() + { + X509_free(cert_); + } + virtual std::string sha1Fingerprint() const override + { + std::string sha1; + unsigned char md[EVP_MAX_MD_SIZE]; + unsigned int n = 0; + if (X509_digest(cert_, EVP_sha1(), md, &n)) + { + sha1.resize(n * 3); + for (unsigned int i = 0; i < n; i++) + { + snprintf(&sha1[i * 3], 4, "%02X:", md[i]); + } + sha1.resize(sha1.size() - 1); + } + else + { + // handle error + // LOG_ERROR << "X509_digest failed"; + } + return sha1; + } + + virtual std::string sha256Fingerprint() const override + { + std::string sha256; + unsigned char md[EVP_MAX_MD_SIZE]; + unsigned int n = 0; + if (X509_digest(cert_, EVP_sha256(), md, &n)) + { + sha256.resize(n * 3); + for (unsigned int i = 0; i < n; i++) + { + snprintf(&sha256[i * 3], 4, "%02X:", md[i]); + } + sha256.resize(sha256.size() - 1); + } + else + { + // handle error + // LOG_ERROR << "X509_digest failed"; + } + return sha256; + } + + virtual std::string pem() const override + { + std::string pem; + BIO *bio = BIO_new(BIO_s_mem()); + if (bio) + { + PEM_write_bio_X509(bio, cert_); + char *data = nullptr; + long len = BIO_get_mem_data(bio, &data); + if (len > 0) + { + pem.assign(data, len); + } + else + { + // handle error + // LOG_ERROR << "BIO_get_mem_data failed"; + } + BIO_free(bio); + } + else + { + // handle error + // LOG_ERROR << "BIO_new failed"; + } + return pem; + } + X509 *cert_ = nullptr; +}; + +class SessionManager +{ + struct SessionData + { + SSL_SESSION *session = nullptr; + std::string key; + TimerId timerId = 0; + EventLoop *loop = nullptr; + }; + + public: + ~SessionManager() + { + for (auto &session : sessions_) + { + SSL_SESSION_free(session.session); + } + } + + void store(const std::string &hostname, + InetAddress peerAddr, + SSL_SESSION *session, + EventLoop *loop) + { +#if OPENSSL_VERSION_NUMBER >= 0x10100000L + { + std::lock_guard lock(mutex_); + auto key = toKey(hostname, peerAddr); + auto it = sessionMap_.find(key); + if (it != sessionMap_.end()) + { + SSL_SESSION_free(it->second->session); + it->second->loop->invalidateTimer(it->second->timerId); + sessions_.erase(it->second); + sessionMap_.erase(it); + } + + SSL_SESSION_up_ref(session); + TimerId tid = loop->runAfter(sessionTimeout_, [this, key]() { + std::lock_guard lock(mutex_); + auto it = sessionMap_.find(key); + if (it != sessionMap_.end()) + { + SSL_SESSION_free(it->second->session); + sessions_.erase(it->second); + sessionMap_.erase(it); + } + }); + sessions_.push_front(SessionData{session, key, tid, loop}); + sessionMap_[key] = sessions_.begin(); + } + removeExcessSession(); +#else + (void)hostname; + (void)peerAddr; + (void)session; + (void)loop; + assert(false && "not support under ancient openssl"); +#endif + } + + SSL_SESSION *get(const std::string &hostname, InetAddress peerAddr) + { + std::lock_guard lock(mutex_); + auto key = toKey(hostname, peerAddr); + auto it = sessionMap_.find(key); + if (it != sessionMap_.end()) + { + return it->second->session; + } + return nullptr; + } + + void removeExcessSession() + { + std::lock_guard lock(mutex_); + assert(maxSessions_ > 0); + assert(mexExtendSize_ > 0); + if (sessions_.size() < size_t(maxSessions_ + mexExtendSize_)) + return; + if (sessions_.size() > size_t(maxSessions_)) + { + auto it = sessions_.end(); + it--; + SSL_SESSION_free(it->session); + it->loop->invalidateTimer(it->timerId); + sessionMap_.erase(it->key); + sessions_.erase(it); + } + } + + std::string toKey(const std::string &hostname, InetAddress peerAddr) + { + return hostname + peerAddr.toIpPort(); + } + + std::mutex mutex_; + int maxSessions_ = 150; + int mexExtendSize_ = 20; + int sessionTimeout_ = 3600; + std::list sessions_; + std::unordered_map::iterator> + sessionMap_; +}; + +} // namespace trantor + +static SessionManager sessionManager; + +struct OpenSSLProvider : public TLSProvider, public NonCopyable +{ + OpenSSLProvider(TcpConnection *conn, TLSPolicyPtr policy, SSLContextPtr ctx) + : TLSProvider(conn, std::move(policy), std::move(ctx)) + { + rbio_ = BIO_new(BIO_s_mem()); + wbio_ = BIO_new(BIO_s_mem()); + ssl_ = SSL_new(contextPtr_->ctx()); + assert(ssl_); + assert(rbio_); + assert(wbio_); + SSL_set_bio(ssl_, rbio_, wbio_); + if (!policyPtr_->getHostname().empty()) + SSL_set_tlsext_host_name(ssl_, policyPtr_->getHostname().c_str()); + } + + virtual ~OpenSSLProvider() + { + SSL_free(ssl_); + } + + virtual void startEncryption() override + { + if (contextPtr_->isServer) + { + assert(ssl_); + SSL_set_accept_state(ssl_); + } + else + { + assert(ssl_); + + const auto &protocols = policyPtr_->getAlpnProtocols(); + if (!protocols.empty()) + { + std::string alpnList; + alpnList.reserve(24); // some reasonable size + for (const auto &proto : policyPtr_->getAlpnProtocols()) + { + char ch = static_cast(proto.size()); + alpnList.push_back(ch); + alpnList.append(proto); + } + SSL_set_alpn_protos(ssl_, + (const unsigned char *)(alpnList.data()), + (unsigned int)alpnList.size()); + } + + SSL_SESSION *cachedSession = + sessionManager.get(policyPtr_->getHostname(), + conn_->peerAddr()); + if (cachedSession) + { + SSL_set_session(ssl_, cachedSession); + } + SSL_set_connect_state(ssl_); + } + + processHandshake(); + } + + virtual void recvData(MsgBuffer *buffer) override + { + LOG_TRACE << "Received " << buffer->readableBytes() + << " bytes from lower layer"; + if (buffer->readableBytes() == 0) + return; + while (buffer->readableBytes() > 0) + { + int n = + BIO_write(rbio_, buffer->peek(), (int)buffer->readableBytes()); + if (n <= 0) + { + // TODO: make the status code more specific + handleSSLError(SSLError::kSSLHandshakeError); + return; + } + + buffer->retrieve(n); + + if (!SSL_is_init_finished(ssl_)) + { + bool handshakeDone = processHandshake(); + if (handshakeDone) + processApplicationData(); + } + else + { + processApplicationData(); + } + } + } + + virtual void close() override + { + if (!SSL_is_init_finished(ssl_)) + return; + SSL_shutdown(ssl_); + sendTLSData(); + } + + virtual ssize_t sendData(const char *data, size_t len) override + { + if (getBufferedData().readableBytes() != 0) + { + errno = EAGAIN; + return 0; + } + // Limit the size of the data we send in one go to avoid holding massive + // buffers in memory. + constexpr size_t maxSend = 64 * 1024; + size_t hasSent = 0; + while (hasSent < len && getBufferedData().readableBytes() == 0) + { + auto trunkLen = len - hasSent; + if (trunkLen > maxSend) + trunkLen = maxSend; + int n = SSL_write(ssl_, data + hasSent, (int)trunkLen); + if (n <= 0 && len != 0) + { + handleSSLError(SSLError::kSSLProtocolError); + return -1; + } + auto num = sendTLSData(); + if (num == -1) + return -1; + hasSent += trunkLen; + } + return static_cast(hasSent); + } + + bool processHandshake() + { + int ret = SSL_do_handshake(ssl_); + if (ret == 1) + { + LOG_TRACE << "SSL handshake finished"; + if (contextPtr_->isServer) + { + const char *sniName = + SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name); + if (sniName) + setSniName(sniName); + + const unsigned char *alpn = nullptr; + unsigned int alpnlen = 0; + SSL_get0_alpn_selected(ssl_, &alpn, &alpnlen); + if (alpn) + setApplicationProtocol(std::string((char *)alpn, alpnlen)); + } + else + { + setSniName(policyPtr_->getHostname()); + if (policyPtr_->getAlpnProtocols().size() > 0) + { + const unsigned char *alpn = nullptr; + unsigned int alpnlen = 0; + SSL_get0_alpn_selected(ssl_, &alpn, &alpnlen); + if (alpn) + { + assert(alpnlen > 0); + setApplicationProtocol( + std::string((char *)alpn, alpnlen)); + } + } + +#if OPENSSL_VERSION_NUMBER >= 0x10101000L + SSL_SESSION *session = SSL_get0_session(ssl_); + assert(session); + if (SSL_SESSION_is_resumable(session)) + { + auto reused = SSL_session_reused(ssl_); + if (reused == 0) + sessionManager.store(sniName_, + conn_->peerAddr(), + session, + loop_); + } +#endif + } + + auto cert = SSL_get_peer_certificate(ssl_); + bool needCert = policyPtr_->getValidate(); + if (cert) + setPeerCertificate(std::make_shared(cert)); + + if (needCert) + { + if (cert) + { + bool valid = internal::validatePeerCertificate( + ssl_, + cert, + policyPtr_->getHostname(), + policyPtr_->getAllowBrokenChain(), + contextPtr_->isServer); + if (!valid) + { + LOG_TRACE + << "SSL handshake error: invalid peer certificate"; + SSL_shutdown(ssl_); + handleSSLError(SSLError::kSSLInvalidCertificate); + return false; + } + } + else + { + LOG_TRACE + << "SSL handshake error: no peer certificate. Cannot " + "perform validation"; + SSL_shutdown(ssl_); + handleSSLError(SSLError::kSSLInvalidCertificate); + return false; + } + } + + if (handshakeCallback_) + handshakeCallback_(conn_); + sendTLSData(); // Needed to send ChangeCipherSpec + return true; + } + else + { + int err = SSL_get_error(ssl_, ret); + if (err == SSL_ERROR_WANT_READ) + { + LOG_TRACE << "SSL handshake wants to read"; + sendTLSData(); + } + else if (err == SSL_ERROR_WANT_WRITE) + { + LOG_TRACE << "SSL handshake wants to write"; + sendTLSData(); + } + else + { + if (!processedHandshakeError_) + processedHandshakeError_ = true; + else + return false; + LOG_TRACE << "SSL handshake error: " + << ERR_error_string(ERR_get_error(), NULL); + conn_->shutdown(); + handleSSLError(SSLError::kSSLHandshakeError); + } + } + return false; + } + + void processApplicationData() + { + constexpr size_t maxSingleRead = 128 * 1024; + constexpr size_t maxWritibleBytes = (std::numeric_limits::max)(); + while (true) + { + auto pending = BIO_pending(rbio_); + // horrible syntax, because MSVC + pending = (std::max)(1024, pending); + recvBuffer_.ensureWritableBytes( + (std::min)(maxSingleRead, (size_t)pending)); + // clamp to int, because that's what SSL_read accepts + const size_t wrtibleSize = + (std::min)(maxWritibleBytes, recvBuffer_.writableBytes()); + int n = SSL_read(ssl_, recvBuffer_.beginWrite(), (int)wrtibleSize); + int shutdownState = SSL_get_shutdown(ssl_); + if (n == 0 && (shutdownState & SSL_RECEIVED_SHUTDOWN)) + { + LOG_TRACE << "SSL connection closed by peer"; + conn_->shutdown(); + return; + } + else if (n > 0) + { + recvBuffer_.hasWritten(n); + LOG_TRACE << "Received " << n << " bytes from SSL"; + if (messageCallback_) + messageCallback_(conn_, &recvBuffer_); + } + else if (n <= 0) + { + int err = SSL_get_error(ssl_, n); + if (err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) + { + handleSSLError(SSLError::kSSLProtocolError); + } + return; + } + } + } + + ssize_t sendTLSData() + { + void *data = nullptr; + int len = BIO_get_mem_data(wbio_, &data); + if (len < 0 || data == nullptr) + return -1; + if (len == 0) + return 0; + int n = writeCallback_(conn_, data, len); + + if (n >= 0) + { + appendToWriteBuffer((char *)data + n, len - n); + } + BIO_reset(wbio_); + if (n < 0) + return -1; + return len; + } + + void handleSSLError(SSLError error) + { + sendTLSData(); + + if (!processedSslError_) + processedSslError_ = true; + else + return; + if (errorCallback_) + errorCallback_(conn_, error); + } + + SSL *ssl_; + BIO *rbio_; + BIO *wbio_; + bool processedHandshakeError_{false}; + bool processedSslError_{false}; +}; + +std::shared_ptr trantor::newTLSProvider(TcpConnection *conn, + TLSPolicyPtr policy, + SSLContextPtr ctx) +{ + return std::make_shared(conn, + std::move(policy), + std::move(ctx)); +} + +SSLContextPtr trantor::newSSLContext(const TLSPolicy &policy, bool isServer) +{ + auto ctx = std::make_shared(policy.getUseOldTLS(), + policy.getConfCmds(), + isServer); + if (!policy.getCertPath().empty() && !policy.getKeyPath().empty()) + { + if (SSL_CTX_use_certificate_chain_file(ctx->ctx(), + policy.getCertPath().data()) <= + 0) + { + throw std::runtime_error("Failed to load certificate " + + policy.getCertPath()); + } + if (SSL_CTX_use_PrivateKey_file(ctx->ctx(), + policy.getKeyPath().data(), + SSL_FILETYPE_PEM) <= 0) + { + throw std::runtime_error("Failed to load private key"); + } + if (SSL_CTX_check_private_key(ctx->ctx()) == 0) + { + throw std::runtime_error( + "Private key does not match the " + "certificate public key"); + } + } + if (policy.getValidate() && policy.getUseSystemCertStore()) + { +#ifdef _WIN32 + internal::loadWindowsSystemCert(SSL_CTX_get_cert_store(ctx->ctx())); +#else + SSL_CTX_set_default_verify_paths(ctx->ctx()); +#endif + } + + if (!policy.getCaPath().empty()) + { + if (isServer) + { + if (SSL_CTX_load_verify_locations(ctx->ctx(), + policy.getCaPath().data(), + nullptr) <= 0) + { + throw std::runtime_error("Failed to load CA certificate"); + } + + STACK_OF(X509_NAME) *cert_names = + SSL_load_client_CA_file(policy.getCaPath().data()); + if (cert_names == nullptr) + { + throw std::runtime_error("Not CA names found in file"); + } + SSL_CTX_set_client_CA_list(ctx->ctx(), cert_names); + SSL_CTX_set_verify(ctx->ctx(), + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + nullptr); + LOG_TRACE << "Finished loading custom CA"; + } + else + { + auto *store = X509_STORE_new(); + if (!X509_STORE_load_locations(store, + policy.getCaPath().data(), + nullptr)) + { + throw std::runtime_error("Failed to load CA certificate"); + } + SSL_CTX_set_cert_store(ctx->ctx(), store); + } + } + + if (!policy.getAlpnProtocols().empty() && isServer) + { + SSL_CTX_set_alpn_select_cb(ctx->ctx(), + internal::serverSelectProtocol, + (void *)&policy.getAlpnProtocols()); + } + + if (!isServer) + { + // We have our own session cache, so disable OpenSSL's + SSL_CTX_set_session_cache_mode(ctx->ctx(), SSL_SESS_CACHE_OFF); + } + + // Disable weak ciphers. Weak hash and ciphers can die in a fire. + int status = + SSL_CTX_set_cipher_list(ctx->ctx(), "MEDIUM:HIGH:!aNULL!MD5:!RC4!3DES"); + if (status != 1) + throw std::runtime_error("Failed to select secure ciphers"); + + return ctx; +} diff --git a/trantor/tests/CMakeLists.txt b/trantor/tests/CMakeLists.txt index 146f62b2..90f75511 100644 --- a/trantor/tests/CMakeLists.txt +++ b/trantor/tests/CMakeLists.txt @@ -17,11 +17,12 @@ add_executable(sendstream_test SendstreamTest.cc) add_executable(timing_wheel_test TimingWheelTest.cc) add_executable(kickoff_test KickoffTest.cc) add_executable(dns_test DnsTest.cc) -add_executable(delayed_ssl_server_test DelayedSSLServerTest.cc) -add_executable(delayed_ssl_client_test DelayedSSLClientTest.cc) add_executable(run_on_quit_test RunOnQuitTest.cc) add_executable(path_conversion_test PathConversionTest.cc) add_executable(logger_macro_test LoggerMacroTest.cc) +add_executable(delayed_ssl_server_test DelayedSSLServerTest.cc) +add_executable(delayed_ssl_client_test DelayedSSLClientTest.cc) +add_executable(tcp_asyncstream_server_test TcpAsyncStreamServerTest.cc) set(targets_list ssl_server_test ssl_client_test @@ -42,11 +43,18 @@ set(targets_list timing_wheel_test kickoff_test dns_test - delayed_ssl_server_test - delayed_ssl_client_test run_on_quit_test path_conversion_test - logger_macro_test) + logger_macro_test + delayed_ssl_server_test + delayed_ssl_client_test + tcp_asyncstream_server_test +) + +if(HAVE_SPDLOG) + add_executable(spdlogger_test SpdLoggerTest.cc) + list(APPEND targets_list spdlogger_test) +endif(HAVE_SPDLOG) set_property(TARGET ${targets_list} PROPERTY CXX_STANDARD 14) set_property(TARGET ${targets_list} PROPERTY CXX_STANDARD_REQUIRED ON) diff --git a/trantor/tests/DelayedSSLClientTest.cc b/trantor/tests/DelayedSSLClientTest.cc index 72c55309..5abe20cc 100644 --- a/trantor/tests/DelayedSSLClientTest.cc +++ b/trantor/tests/DelayedSSLClientTest.cc @@ -18,8 +18,8 @@ int main() #endif std::shared_ptr client[10]; std::atomic_int connCount; - connCount = 10; - for (int i = 0; i < 10; ++i) + connCount = 1; + for (int i = 0; i < 1; ++i) { client[i] = std::make_shared(&loop, serverAddr, @@ -37,27 +37,27 @@ int main() loop.quit(); } }); - client[i]->setMessageCallback( - [](const TcpConnectionPtr &conn, MsgBuffer *buf) { - auto msg = std::string(buf->peek(), buf->readableBytes()); + client[i]->setMessageCallback([](const TcpConnectionPtr &conn, + MsgBuffer *buf) { + auto msg = std::string(buf->peek(), buf->readableBytes()); - LOG_INFO << msg; - if (msg == "hello") - { - buf->retrieveAll(); - conn->startClientEncryption( - [conn]() { - LOG_INFO << "SSL established"; - conn->send("Hello"); - }, - false, - false); - } - if (conn->isSSLConnection()) - { - buf->retrieveAll(); - } - }); + LOG_INFO << msg; + if (msg == "hello") + { + buf->retrieveAll(); + auto policy = TLSPolicy::defaultClientPolicy(); + policy->setValidate(false); + conn->startEncryption( + policy, false, [](const TcpConnectionPtr &encryptedConn) { + LOG_INFO << "SSL established"; + encryptedConn->send("Hello"); + }); + } + if (conn->isSSLConnection()) + { + buf->retrieveAll(); + } + }); client[i]->connect(); } loop.loop(); diff --git a/trantor/tests/DelayedSSLServerTest.cc b/trantor/tests/DelayedSSLServerTest.cc index 466b7ad1..56caade7 100644 --- a/trantor/tests/DelayedSSLServerTest.cc +++ b/trantor/tests/DelayedSSLServerTest.cc @@ -17,7 +17,7 @@ int main() InetAddress addr(8888); #endif TcpServer server(loopThread.getLoop(), addr, "test"); - auto ctx = newSSLServerContext("server.pem", "server.pem", {}); + // auto ctx = newSSLServerContext("server.pem", "server.pem", {}); LOG_INFO << "start"; server.setRecvMessageCallback( [](const TcpConnectionPtr &connectionPtr, MsgBuffer *buffer) { @@ -26,14 +26,14 @@ int main() buffer->retrieveAll(); connectionPtr->shutdown(); }); - server.setConnectionCallback([ctx](const TcpConnectionPtr &connPtr) { + server.setConnectionCallback([](const TcpConnectionPtr &connPtr) { if (connPtr->connected()) { LOG_DEBUG << "New connection"; connPtr->send("hello"); - connPtr->startServerEncryption(ctx, [] { - LOG_INFO << "SSL established"; - }); + auto policy = + TLSPolicy::defaultServerPolicy("server.crt", "server.key"); + connPtr->startEncryption(policy, true); } else if (connPtr->disconnected()) { diff --git a/trantor/tests/DnsTest.cc b/trantor/tests/DnsTest.cc index 1a71f53f..996a0a9b 100644 --- a/trantor/tests/DnsTest.cc +++ b/trantor/tests/DnsTest.cc @@ -9,6 +9,15 @@ void dns(const std::shared_ptr &resolver) std::cout << "baidu:" << addr.toIp() << " " << interval / 1000 << "ms" << std::endl; }); + resolver->resolve("www.baidu.com", + [now](const std::vector &addrs) { + auto interval = + trantor::Date::now().microSecondsSinceEpoch() - + now.microSecondsSinceEpoch(); + for (auto &addr : addrs) + std::cout << "baidu:" << addr.toIp() << " " + << interval / 1000 << "ms" << std::endl; + }); resolver->resolve("www.google.com", [now](const trantor::InetAddress &addr) { auto interval = diff --git a/trantor/tests/LoggerTest.cc b/trantor/tests/LoggerTest.cc index 408fcb77..fafd849c 100644 --- a/trantor/tests/LoggerTest.cc +++ b/trantor/tests/LoggerTest.cc @@ -4,6 +4,7 @@ int main() { int i; + LOG_COMPACT_DEBUG << "Hello, world!"; LOG_DEBUG << (float)3.14; LOG_DEBUG << (const char)'8'; LOG_DEBUG << &i; diff --git a/trantor/tests/MTLSClient.cc b/trantor/tests/MTLSClient.cc index 488813fe..f30af86f 100644 --- a/trantor/tests/MTLSClient.cc +++ b/trantor/tests/MTLSClient.cc @@ -79,13 +79,12 @@ int main() // That key is common for client and server // The CA file must be the client CA, for this sample the CA is common // for both - client[i]->enableSSL(false, - false, - "localhost", - sslcmd, - "./client-crt.pem", - "./server-key.pem", - "./ca-crt.pem"); + auto policy = TLSPolicy::defaultClientPolicy(); + policy->setCertPath("./client-crt.pem") + .setKeyPath("./server-key.pem") + .setCaPath("./ca-crt.pem") + .setHostname("localhost"); + client[i]->enableSSL(policy); client[i]->setConnectionCallback( [i, &loop, &connCount](const TcpConnectionPtr &conn) { if (conn->connected()) diff --git a/trantor/tests/MTLSServer.cc b/trantor/tests/MTLSServer.cc index 54a853e2..44e7f1f6 100644 --- a/trantor/tests/MTLSServer.cc +++ b/trantor/tests/MTLSServer.cc @@ -68,8 +68,13 @@ int main() // the CA file must be the client CA, for this sample the CA is common for // both - server.enableSSL( - "server-crt.pem", "server-key.pem", false, sslcmd, "ca-crt.pem"); + auto policy = + TLSPolicy::defaultServerPolicy("server-crt.pem", "server-key.pem"); + policy->setCaPath("ca-crt.pem") + .setValidateChain(true) + .setValidateDate(true) + .setValidateDomain(false); // client's don't have a domain name + server.enableSSL(policy); server.setRecvMessageCallback( [](const TcpConnectionPtr &connectionPtr, MsgBuffer *buffer) { // LOG_DEBUG<<"recv callback!"; diff --git a/trantor/tests/PathConversionTest.cc b/trantor/tests/PathConversionTest.cc old mode 100755 new mode 100644 diff --git a/trantor/tests/RunInLoopTest2.cc b/trantor/tests/RunInLoopTest2.cc index ce969a37..0b68e845 100644 --- a/trantor/tests/RunInLoopTest2.cc +++ b/trantor/tests/RunInLoopTest2.cc @@ -8,8 +8,8 @@ int main() { - std::atomic counter; - counter = 0; + // Local variable to be used within the loopThread + uint64_t counter = 0; std::promise pro; auto ft = pro.get_future(); trantor::EventLoopThread loopThread; @@ -20,7 +20,7 @@ int main() { loop->queueInLoop([&counter, &pro]() { ++counter; - if (counter.load() == 110000) + if (counter == 110000) pro.set_value(1); }); } @@ -32,7 +32,7 @@ int main() { loop->runInLoop([&counter, &pro]() { ++counter; - if (counter.load() == 110000) + if (counter == 110000) pro.set_value(1); }); } @@ -40,5 +40,5 @@ int main() } loopThread.run(); ft.get(); - std::cout << "counter=" << counter.load() << std::endl; + std::cout << "counter=" << counter << std::endl; } diff --git a/trantor/tests/SSLClientTest.cc b/trantor/tests/SSLClientTest.cc index f38a0f3d..397f92d1 100644 --- a/trantor/tests/SSLClientTest.cc +++ b/trantor/tests/SSLClientTest.cc @@ -24,15 +24,15 @@ int main() client[i] = std::make_shared(&loop, serverAddr, "tcpclienttest"); - client[i]->enableSSL(false, false); + auto policy = TLSPolicy::defaultClientPolicy(); + policy->setValidate(false); + client[i]->enableSSL(std::move(policy)); client[i]->setConnectionCallback( [i, &loop, &connCount](const TcpConnectionPtr &conn) { if (conn->connected()) { LOG_DEBUG << i << " connected!"; - char tmp[20]; - sprintf(tmp, "%d client!!", i); - conn->send(tmp); + conn->send(std::to_string(i) + " client!!"); } else { diff --git a/trantor/tests/SSLServerTest.cc b/trantor/tests/SSLServerTest.cc index 3cf52132..28a6c33c 100644 --- a/trantor/tests/SSLServerTest.cc +++ b/trantor/tests/SSLServerTest.cc @@ -17,14 +17,15 @@ int main() InetAddress addr(8888); #endif TcpServer server(loopThread.getLoop(), addr, "test"); - server.enableSSL("server.pem", "server.pem"); + auto policy = TLSPolicy::defaultServerPolicy("server.crt", "server.key"); + server.enableSSL(std::move(policy)); server.setRecvMessageCallback( [](const TcpConnectionPtr &connectionPtr, MsgBuffer *buffer) { // LOG_DEBUG<<"recv callback!"; std::cout << std::string(buffer->peek(), buffer->readableBytes()); connectionPtr->send(buffer->peek(), buffer->readableBytes()); buffer->retrieveAll(); - connectionPtr->forceClose(); + // connectionPtr->forceClose(); }); server.setConnectionCallback([](const TcpConnectionPtr &connPtr) { if (connPtr->connected()) diff --git a/trantor/tests/SendfileTest.cc b/trantor/tests/SendfileTest.cc index c0abb5b5..2b158088 100644 --- a/trantor/tests/SendfileTest.cc +++ b/trantor/tests/SendfileTest.cc @@ -63,10 +63,10 @@ int main(int argc, char *argv[]) for (int i = 0; i < 5; ++i) { connPtr->sendFile(argv[1]); - char str[64]; ++counter; - sprintf(str, "\n%d files sent!\n", counter); - connPtr->send(str, strlen(str)); + std::string str = + "\n" + std::to_string(counter) + " files sent!\n"; + connPtr->send(std::move(str)); } }); t.detach(); @@ -74,10 +74,10 @@ int main(int argc, char *argv[]) for (int i = 0; i < 3; ++i) { connPtr->sendFile(argv[1]); - char str[64]; ++counter; - sprintf(str, "\n%d files sent!\n", counter); - connPtr->send(str, strlen(str)); + std::string str = + "\n" + std::to_string(counter) + " files sent!\n"; + connPtr->send(std::move(str)); } } else if (connPtr->disconnected()) diff --git a/trantor/tests/SendstreamTest.cc b/trantor/tests/SendstreamTest.cc index 144adc71..49eacb1b 100644 --- a/trantor/tests/SendstreamTest.cc +++ b/trantor/tests/SendstreamTest.cc @@ -79,10 +79,10 @@ int main(int argc, char *argv[]) std::placeholders::_1, std::placeholders::_2); connPtr->sendStream(callback); - char str[64]; ++counter; - sprintf(str, "\n%d streams sent!\n", counter); - connPtr->send(str, strlen(str)); + std::string str = + "\n" + std::to_string(counter) + " streams sent!\n"; + connPtr->send(std::move(str)); } }); t.detach(); @@ -101,10 +101,10 @@ int main(int argc, char *argv[]) std::placeholders::_1, std::placeholders::_2); connPtr->sendStream(callback); - char str[64]; ++counter; - sprintf(str, "\n%d streams sent!\n", counter); - connPtr->send(str, strlen(str)); + std::string str = + "\n" + std::to_string(counter) + " streams sent!\n"; + connPtr->send(std::move(str)); } } else if (connPtr->disconnected()) diff --git a/trantor/tests/SpdLoggerTest.cc b/trantor/tests/SpdLoggerTest.cc new file mode 100644 index 00000000..c2053202 --- /dev/null +++ b/trantor/tests/SpdLoggerTest.cc @@ -0,0 +1,1145 @@ +#include +#include +#include +#include + +int main() +{ + trantor::Logger::enableSpdLog(); + trantor::Logger::enableSpdLog(5); + int i; + LOG_COMPACT_DEBUG << "Hello, world!"; + LOG_DEBUG << (float)3.14; + LOG_DEBUG << (const char)'8'; + LOG_DEBUG << &i; + LOG_DEBUG << (long double)3.1415; + LOG_DEBUG << trantor::Fmt("%.3g", 3.1415926); + LOG_DEBUG << "debug log!" << 1; + LOG_TRACE << "trace log!" << 2; + LOG_INFO << "info log!" << 3; + LOG_WARN << "warning log!" << 4; + if (1) + LOG_ERROR << "error log!" << 5; + std::thread thread_([]() { LOG_FATAL << "fatal log!" << 6; }); + + FILE *fp = fopen("/notexistfile", "rb"); + if (fp == NULL) + { + LOG_SYSERR << "syserr log!" << 7; + } + LOG_DEBUG << "long message test:"; + LOG_DEBUG + << "Applications\n" + "Developer\n" + "Library\n" + "Network\n" + "System\n" + "Users\n" + "Volumes\n" + "bin\n" + "cores\n" + "dev\n" + "etc\n" + "home\n" + "installer.failurerequests\n" + "net\n" + "opt\n" + "private\n" + "sbin\n" + "tmp\n" + "usr\n" + "var\n" + "vm\n" + "\n" + "/Applications:\n" + "Adobe\n" + "Adobe Creative Cloud\n" + "Adobe Photoshop CC\n" + "AirPlayer Pro.app\n" + "AliWangwang.app\n" + "Android Studio.app\n" + "App Store.app\n" + "Autodesk\n" + "Automator.app\n" + "Axure RP Pro 7.0.app\n" + "BaiduNetdisk_mac.app\n" + "CLion.app\n" + "Calculator.app\n" + "Calendar.app\n" + "Chess.app\n" + "CleanApp.app\n" + "Cocos\n" + "Contacts.app\n" + "DVD Player.app\n" + "Dashboard.app\n" + "Dictionary.app\n" + "Docs for Xcode.app\n" + "FaceTime.app\n" + "FinalShell\n" + "Firefox.app\n" + "Font Book.app\n" + "GitHub.app\n" + "Google Chrome.app\n" + "Image Capture.app\n" + "Lantern.app\n" + "Launchpad.app\n" + "License.rtf\n" + "MacPorts\n" + "Mail.app\n" + "Maps.app\n" + "Messages.app\n" + "Microsoft Excel.app\n" + "Microsoft Office 2011\n" + "Microsoft OneNote.app\n" + "Microsoft Outlook.app\n" + "Microsoft PowerPoint.app\n" + "Microsoft Word.app\n" + "Mindjet MindManager.app\n" + "Mission Control.app\n" + "Mockplus.app\n" + "MyEclipse 2015\n" + "Notes.app\n" + "Numbers.app\n" + "OmniGraffle.app\n" + "Pages.app\n" + "Photo Booth.app\n" + "Photos.app\n" + "Preview.app\n" + "QJVPN.app\n" + "QQ.app\n" + "QQMusic.app\n" + "QuickTime Player.app\n" + "RAR Extractor Lite.app\n" + "Reminders.app\n" + "Remote Desktop Connection.app\n" + "Renee Undeleter.app\n" + "Sabaki.app\n" + "Safari.app\n" + "ShadowsocksX.app\n" + "Siri.app\n" + "SogouCharacterViewer.app\n" + "SogouInputPad.app\n" + "Stickies.app\n" + "SupremePlayer Lite.app\n" + "System Preferences.app\n" + "TeX\n" + "Telegram.app\n" + "Telnet Lite.app\n" + "Termius.app\n" + "Tesumego - How to Make a Professional Go Player.app\n" + "TextEdit.app\n" + "Thunder.app\n" + "Time Machine.app\n" + "Tunnelblick.app\n" + "Utilities\n" + "VPN Shield.appdownload\n" + "WeChat.app\n" + "WinOnX2.app\n" + "Wireshark.app\n" + "Xcode.app\n" + "Yose.app\n" + "YoudaoNote.localized\n" + "finalshelldata\n" + "iBooks.app\n" + "iHex.app\n" + "iPhoto.app\n" + "iTools.app\n" + "iTunes.app\n" + "pgAdmin 4.app\n" + "vSSH Lite.app\n" + "wechatwebdevtools.app\n" + "\n" + "/Applications/Adobe:\n" + "Flash Player\n" + "\n" + "/Applications/Adobe/Flash Player:\n" + "AddIns\n" + "\n" + "/Applications/Adobe/Flash Player/AddIns:\n" + "airappinstaller\n" + "\n" + "/Applications/Adobe/Flash Player/AddIns/airappinstaller:\n" + "airappinstaller\n" + "digest.s\n" + "\n" + "/Applications/Adobe Creative Cloud:\n" + "Adobe Creative Cloud\n" + "Icon\n" + "Uninstall Adobe Creative Cloud\n" + "\n" + "/Applications/Adobe Photoshop CC:\n" + "Adobe Photoshop CC.app\n" + "Configuration\n" + "Icon\n" + "Legal\n" + "LegalNotices.pdf\n" + "Locales\n" + "Plug-ins\n" + "Presets\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop CC.app:\n" + "Contents\n" + "Linguistics\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop CC.app/Contents:\n" + "Application Data\n" + "Frameworks\n" + "Info.plist\n" + "MacOS\n" + "PkgInfo\n" + "Required\n" + "Resources\n" + "_CodeSignature\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data:\n" + "Custom File Info Panels\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info Panels:\n" + "4.0\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info Panels/4.0:\n" + "bin\n" + "custom\n" + "panels\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/bin:\n" + "FileInfoFoundation.swf\n" + "FileInfoUI.swf\n" + "framework.swf\n" + "loc\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/bin/loc:\n" + "FileInfo_ar_AE.dat\n" + "FileInfo_bg_BG.dat\n" + "FileInfo_cs_CZ.dat\n" + "FileInfo_da_DK.dat\n" + "FileInfo_de_DE.dat\n" + "FileInfo_el_GR.dat\n" + "FileInfo_en_US.dat\n" + "FileInfo_es_ES.dat\n" + "FileInfo_et_EE.dat\n" + "FileInfo_fi_FI.dat\n" + "FileInfo_fr_FR.dat\n" + "FileInfo_he_IL.dat\n" + "FileInfo_hr_HR.dat\n" + "FileInfo_hu_HU.dat\n" + "FileInfo_it_IT.dat\n" + "FileInfo_ja_JP.dat\n" + "FileInfo_ko_KR.dat\n" + "FileInfo_lt_LT.dat\n" + "FileInfo_lv_LV.dat\n" + "FileInfo_nb_NO.dat\n" + "FileInfo_nl_NL.dat\n" + "FileInfo_pl_PL.dat\n" + "FileInfo_pt_BR.dat\n" + "FileInfo_ro_RO.dat\n" + "FileInfo_ru_RU.dat\n" + "FileInfo_sk_SK.dat\n" + "FileInfo_sl_SI.dat\n" + "FileInfo_sv_SE.dat\n" + "FileInfo_tr_TR.dat\n" + "FileInfo_uk_UA.dat\n" + "FileInfo_zh_CN.dat\n" + "FileInfo_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/custom:\n" + "DICOM.xml\n" + "Mobile.xml\n" + "loc\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/custom/loc:\n" + "DICOM_ar_AE.dat\n" + "DICOM_bg_BG.dat\n" + "DICOM_cs_CZ.dat\n" + "DICOM_da_DK.dat\n" + "DICOM_de_DE.dat\n" + "DICOM_el_GR.dat\n" + "DICOM_en_US.dat\n" + "DICOM_es_ES.dat\n" + "DICOM_et_EE.dat\n" + "DICOM_fi_FI.dat\n" + "DICOM_fr_FR.dat\n" + "DICOM_he_IL.dat\n" + "DICOM_hr_HR.dat\n" + "DICOM_hu_HU.dat\n" + "DICOM_it_IT.dat\n" + "DICOM_ja_JP.dat\n" + "DICOM_ko_KR.dat\n" + "DICOM_lt_LT.dat\n" + "DICOM_lv_LV.dat\n" + "DICOM_nb_NO.dat\n" + "DICOM_nl_NL.dat\n" + "DICOM_pl_PL.dat\n" + "DICOM_pt_BR.dat\n" + "DICOM_ro_RO.dat\n" + "DICOM_ru_RU.dat\n" + "DICOM_sk_SK.dat\n" + "DICOM_sl_SI.dat\n" + "DICOM_sv_SE.dat\n" + "DICOM_tr_TR.dat\n" + "DICOM_uk_UA.dat\n" + "DICOM_zh_CN.dat\n" + "DICOM_zh_TW.dat\n" + "Mobile_ar_AE.dat\n" + "Mobile_bg_BG.dat\n" + "Mobile_cs_CZ.dat\n" + "Mobile_da_DK.dat\n" + "Mobile_de_DE.dat\n" + "Mobile_el_GR.dat\n" + "Mobile_en_US.dat\n" + "Mobile_es_ES.dat\n" + "Mobile_et_EE.dat\n" + "Mobile_fi_FI.dat\n" + "Mobile_fr_FR.dat\n" + "Mobile_he_IL.dat\n" + "Mobile_hr_HR.dat\n" + "Mobile_hu_HU.dat\n" + "Mobile_it_IT.dat\n" + "Mobile_ja_JP.dat\n" + "Mobile_ko_KR.dat\n" + "Mobile_lt_LT.dat\n" + "Mobile_lv_LV.dat\n" + "Mobile_nb_NO.dat\n" + "Mobile_nl_NL.dat\n" + "Mobile_pl_PL.dat\n" + "Mobile_pt_BR.dat\n" + "Mobile_ro_RO.dat\n" + "Mobile_ru_RU.dat\n" + "Mobile_sk_SK.dat\n" + "Mobile_sl_SI.dat\n" + "Mobile_sv_SE.dat\n" + "Mobile_tr_TR.dat\n" + "Mobile_uk_UA.dat\n" + "Mobile_zh_CN.dat\n" + "Mobile_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels:\n" + "IPTC\n" + "IPTCExt\n" + "advanced\n" + "audioData\n" + "camera\n" + "categories\n" + "description\n" + "dicom\n" + "gpsData\n" + "history\n" + "mobile\n" + "origin\n" + "rawpacket\n" + "videoData\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/IPTC:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/IPTC/bin:\n" + "iptc.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/IPTC/loc:\n" + "IPTC_ar_AE.dat\n" + "IPTC_bg_BG.dat\n" + "IPTC_cs_CZ.dat\n" + "IPTC_da_DK.dat\n" + "IPTC_de_DE.dat\n" + "IPTC_el_GR.dat\n" + "IPTC_en_US.dat\n" + "IPTC_es_ES.dat\n" + "IPTC_et_EE.dat\n" + "IPTC_fi_FI.dat\n" + "IPTC_fr_FR.dat\n" + "IPTC_he_IL.dat\n" + "IPTC_hr_HR.dat\n" + "IPTC_hu_HU.dat\n" + "IPTC_it_IT.dat\n" + "IPTC_ja_JP.dat\n" + "IPTC_ko_KR.dat\n" + "IPTC_lt_LT.dat\n" + "IPTC_lv_LV.dat\n" + "IPTC_nb_NO.dat\n" + "IPTC_nl_NL.dat\n" + "IPTC_pl_PL.dat\n" + "IPTC_pt_BR.dat\n" + "IPTC_ro_RO.dat\n" + "IPTC_ru_RU.dat\n" + "IPTC_sk_SK.dat\n" + "IPTC_sl_SI.dat\n" + "IPTC_sv_SE.dat\n" + "IPTC_tr_TR.dat\n" + "IPTC_uk_UA.dat\n" + "IPTC_zh_CN.dat\n" + "IPTC_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/IPTCExt:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/IPTCExt/bin:\n" + "iptcExt.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/IPTCExt/loc:\n" + "IPTCExt_bg_BG.dat\n" + "IPTCExt_cs_CZ.dat\n" + "IPTCExt_da_DK.dat\n" + "IPTCExt_de_DE.dat\n" + "IPTCExt_en_US.dat\n" + "IPTCExt_es_ES.dat\n" + "IPTCExt_et_EE.dat\n" + "IPTCExt_fi_FI.dat\n" + "IPTCExt_fr_FR.dat\n" + "IPTCExt_hr_HR.dat\n" + "IPTCExt_hu_HU.dat\n" + "IPTCExt_it_IT.dat\n" + "IPTCExt_ja_JP.dat\n" + "IPTCExt_ko_KR.dat\n" + "IPTCExt_lt_LT.dat\n" + "IPTCExt_lv_LV.dat\n" + "IPTCExt_nb_NO.dat\n" + "IPTCExt_nl_NL.dat\n" + "IPTCExt_pl_PL.dat\n" + "IPTCExt_pt_BR.dat\n" + "IPTCExt_ro_RO.dat\n" + "IPTCExt_ru_RU.dat\n" + "IPTCExt_sk_SK.dat\n" + "IPTCExt_sl_SI.dat\n" + "IPTCExt_sv_SE.dat\n" + "IPTCExt_tr_TR.dat\n" + "IPTCExt_uk_UA.dat\n" + "IPTCExt_zh_CN.dat\n" + "IPTCExt_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/advanced:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/advanced/bin:\n" + "advanced.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/advanced/loc:\n" + "Advanced_ar_AE.dat\n" + "Advanced_bg_BG.dat\n" + "Advanced_cs_CZ.dat\n" + "Advanced_da_DK.dat\n" + "Advanced_de_DE.dat\n" + "Advanced_el_GR.dat\n" + "Advanced_en_US.dat\n" + "Advanced_es_ES.dat\n" + "Advanced_et_EE.dat\n" + "Advanced_fi_FI.dat\n" + "Advanced_fr_FR.dat\n" + "Advanced_he_IL.dat\n" + "Advanced_hr_HR.dat\n" + "Advanced_hu_HU.dat\n" + "Advanced_it_IT.dat\n" + "Advanced_ja_JP.dat\n" + "Advanced_ko_KR.dat\n" + "Advanced_lt_LT.dat\n" + "Advanced_lv_LV.dat\n" + "Advanced_nb_NO.dat\n" + "Advanced_nl_NL.dat\n" + "Advanced_pl_PL.dat\n" + "Advanced_pt_BR.dat\n" + "Advanced_ro_RO.dat\n" + "Advanced_ru_RU.dat\n" + "Advanced_sk_SK.dat\n" + "Advanced_sl_SI.dat\n" + "Advanced_sv_SE.dat\n" + "Advanced_tr_TR.dat\n" + "Advanced_uk_UA.dat\n" + "Advanced_zh_CN.dat\n" + "Advanced_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/audioData:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/audioData/bin:\n" + "audioData.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/audioData/loc:\n" + "AudioData_ar_AE.dat\n" + "AudioData_bg_BG.dat\n" + "AudioData_cs_CZ.dat\n" + "AudioData_da_DK.dat\n" + "AudioData_de_DE.dat\n" + "AudioData_el_GR.dat\n" + "AudioData_en_US.dat\n" + "AudioData_es_ES.dat\n" + "AudioData_et_EE.dat\n" + "AudioData_fi_FI.dat\n" + "AudioData_fr_FR.dat\n" + "AudioData_he_IL.dat\n" + "AudioData_hr_HR.dat\n" + "AudioData_hu_HU.dat\n" + "AudioData_it_IT.dat\n" + "AudioData_ja_JP.dat\n" + "AudioData_ko_KR.dat\n" + "AudioData_lt_LT.dat\n" + "AudioData_lv_LV.dat\n" + "AudioData_nb_NO.dat\n" + "AudioData_nl_NL.dat\n" + "AudioData_pl_PL.dat\n" + "AudioData_pt_BR.dat\n" + "AudioData_ro_RO.dat\n" + "AudioData_ru_RU.dat\n" + "AudioData_sk_SK.dat\n" + "AudioData_sl_SI.dat\n" + "AudioData_sv_SE.dat\n" + "AudioData_tr_TR.dat\n" + "AudioData_uk_UA.dat\n" + "AudioData_zh_CN.dat\n" + "AudioData_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/camera:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/camera/bin:\n" + "camera.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/camera/loc:\n" + "Camera_ar_AE.dat\n" + "Camera_bg_BG.dat\n" + "Camera_cs_CZ.dat\n" + "Camera_da_DK.dat\n" + "Camera_de_DE.dat\n" + "Camera_el_GR.dat\n" + "Camera_en_US.dat\n" + "Camera_es_ES.dat\n" + "Camera_et_EE.dat\n" + "Camera_fi_FI.dat\n" + "Camera_fr_FR.dat\n" + "Camera_he_IL.dat\n" + "Camera_hr_HR.dat\n" + "Camera_hu_HU.dat\n" + "Camera_it_IT.dat\n" + "Camera_ja_JP.dat\n" + "Camera_ko_KR.dat\n" + "Camera_lt_LT.dat\n" + "Camera_lv_LV.dat\n" + "Camera_nb_NO.dat\n" + "Camera_nl_NL.dat\n" + "Camera_pl_PL.dat\n" + "Camera_pt_BR.dat\n" + "Camera_ro_RO.dat\n" + "Camera_ru_RU.dat\n" + "Camera_sk_SK.dat\n" + "Camera_sl_SI.dat\n" + "Camera_sv_SE.dat\n" + "Camera_tr_TR.dat\n" + "Camera_uk_UA.dat\n" + "Camera_zh_CN.dat\n" + "Camera_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/categories:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/categories/bin:\n" + "categories.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/categories/loc:\n" + "Categories_ar_AE.dat\n" + "Categories_bg_BG.dat\n" + "Categories_cs_CZ.dat\n" + "Categories_da_DK.dat\n" + "Categories_de_DE.dat\n" + "Categories_el_GR.dat\n" + "Categories_en_US.dat\n" + "Categories_es_ES.dat\n" + "Categories_et_EE.dat\n" + "Categories_fi_FI.dat\n" + "Categories_fr_FR.dat\n" + "Categories_he_IL.dat\n" + "Categories_hr_HR.dat\n" + "Categories_hu_HU.dat\n" + "Categories_it_IT.dat\n" + "Categories_ja_JP.dat\n" + "Categories_ko_KR.dat\n" + "Categories_lt_LT.dat\n" + "Categories_lv_LV.dat\n" + "Categories_nb_NO.dat\n" + "Categories_nl_NL.dat\n" + "Categories_pl_PL.dat\n" + "Categories_pt_BR.dat\n" + "Categories_ro_RO.dat\n" + "Categories_ru_RU.dat\n" + "Categories_sk_SK.dat\n" + "Categories_sl_SI.dat\n" + "Categories_sv_SE.dat\n" + "Categories_tr_TR.dat\n" + "Categories_uk_UA.dat\n" + "Categories_zh_CN.dat\n" + "Categories_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/description:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/description/bin:\n" + "description.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/description/loc:\n" + "description_ar_AE.dat\n" + "description_bg_BG.dat\n" + "description_cs_CZ.dat\n" + "description_da_DK.dat\n" + "description_de_DE.dat\n" + "description_el_GR.dat\n" + "description_en_US.dat\n" + "description_es_ES.dat\n" + "description_et_EE.dat\n" + "description_fi_FI.dat\n" + "description_fr_FR.dat\n" + "description_he_IL.dat\n" + "description_hr_HR.dat\n" + "description_hu_HU.dat\n" + "description_it_IT.dat\n" + "description_ja_JP.dat\n" + "description_ko_KR.dat\n" + "description_lt_LT.dat\n" + "description_lv_LV.dat\n" + "description_nb_NO.dat\n" + "description_nl_NL.dat\n" + "description_pl_PL.dat\n" + "description_pt_BR.dat\n" + "description_ro_RO.dat\n" + "description_ru_RU.dat\n" + "description_sk_SK.dat\n" + "description_sl_SI.dat\n" + "description_sv_SE.dat\n" + "description_tr_TR.dat\n" + "description_uk_UA.dat\n" + "description_zh_CN.dat\n" + "description_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/dicom:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/dicom/bin:\n" + "dicom.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/dicom/loc:\n" + "DICOM_ar_AE.dat\n" + "DICOM_bg_BG.dat\n" + "DICOM_cs_CZ.dat\n" + "DICOM_da_DK.dat\n" + "DICOM_de_DE.dat\n" + "DICOM_el_GR.dat\n" + "DICOM_en_US.dat\n" + "DICOM_es_ES.dat\n" + "DICOM_et_EE.dat\n" + "DICOM_fi_FI.dat\n" + "DICOM_fr_FR.dat\n" + "DICOM_he_IL.dat\n" + "DICOM_hr_HR.dat\n" + "DICOM_hu_HU.dat\n" + "DICOM_it_IT.dat\n" + "DICOM_ja_JP.dat\n" + "DICOM_ko_KR.dat\n" + "DICOM_lt_LT.dat\n" + "DICOM_lv_LV.dat\n" + "DICOM_nb_NO.dat\n" + "DICOM_nl_NL.dat\n" + "DICOM_pl_PL.dat\n" + "DICOM_pt_BR.dat\n" + "DICOM_ro_RO.dat\n" + "DICOM_ru_RU.dat\n" + "DICOM_sk_SK.dat\n" + "DICOM_sl_SI.dat\n" + "DICOM_sv_SE.dat\n" + "DICOM_tr_TR.dat\n" + "DICOM_uk_UA.dat\n" + "DICOM_zh_CN.dat\n" + "DICOM_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/gpsData:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/gpsData/bin:\n" + "gpsData.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/gpsData/loc:\n" + "GPSData_ar_AE.dat\n" + "GPSData_bg_BG.dat\n" + "GPSData_cs_CZ.dat\n" + "GPSData_da_DK.dat\n" + "GPSData_de_DE.dat\n" + "GPSData_el_GR.dat\n" + "GPSData_en_US.dat\n" + "GPSData_es_ES.dat\n" + "GPSData_et_EE.dat\n" + "GPSData_fi_FI.dat\n" + "GPSData_fr_FR.dat\n" + "GPSData_he_IL.dat\n" + "GPSData_hr_HR.dat\n" + "GPSData_hu_HU.dat\n" + "GPSData_it_IT.dat\n" + "GPSData_ja_JP.dat\n" + "GPSData_ko_KR.dat\n" + "GPSData_lt_LT.dat\n" + "GPSData_lv_LV.dat\n" + "GPSData_nb_NO.dat\n" + "GPSData_nl_NL.dat\n" + "GPSData_pl_PL.dat\n" + "GPSData_pt_BR.dat\n" + "GPSData_ro_RO.dat\n" + "GPSData_ru_RU.dat\n" + "GPSData_sk_SK.dat\n" + "GPSData_sl_SI.dat\n" + "GPSData_sv_SE.dat\n" + "GPSData_tr_TR.dat\n" + "GPSData_uk_UA.dat\n" + "GPSData_zh_CN.dat\n" + "GPSData_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/history:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/history/bin:\n" + "history.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/history/loc:\n" + "History_ar_AE.dat\n" + "History_bg_BG.dat\n" + "History_cs_CZ.dat\n" + "History_da_DK.dat\n" + "History_de_DE.dat\n" + "History_el_GR.dat\n" + "History_en_US.dat\n" + "History_es_ES.dat\n" + "History_et_EE.dat\n" + "History_fi_FI.dat\n" + "History_fr_FR.dat\n" + "History_he_IL.dat\n" + "History_hr_HR.dat\n" + "History_hu_HU.dat\n" + "History_it_IT.dat\n" + "History_ja_JP.dat\n" + "History_ko_KR.dat\n" + "History_lt_LT.dat\n" + "History_lv_LV.dat\n" + "History_nb_NO.dat\n" + "History_nl_NL.dat\n" + "History_pl_PL.dat\n" + "History_pt_BR.dat\n" + "History_ro_RO.dat\n" + "History_ru_RU.dat\n" + "History_sk_SK.dat\n" + "History_sl_SI.dat\n" + "History_sv_SE.dat\n" + "History_tr_TR.dat\n" + "History_uk_UA.dat\n" + "History_zh_CN.dat\n" + "History_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/mobile:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/mobile/bin:\n" + "mobile.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/mobile/loc:\n" + "Mobile_ar_AE.dat\n" + "Mobile_bg_BG.dat\n" + "Mobile_cs_CZ.dat\n" + "Mobile_da_DK.dat\n" + "Mobile_de_DE.dat\n" + "Mobile_el_GR.dat\n" + "Mobile_en_US.dat\n" + "Mobile_es_ES.dat\n" + "Mobile_et_EE.dat\n" + "Mobile_fi_FI.dat\n" + "Mobile_fr_FR.dat\n" + "Mobile_he_IL.dat\n" + "Mobile_hr_HR.dat\n" + "Mobile_hu_HU.dat\n" + "Mobile_it_IT.dat\n" + "Mobile_ja_JP.dat\n" + "Mobile_ko_KR.dat\n" + "Mobile_lt_LT.dat\n" + "Mobile_lv_LV.dat\n" + "Mobile_nb_NO.dat\n" + "Mobile_nl_NL.dat\n" + "Mobile_pl_PL.dat\n" + "Mobile_pt_BR.dat\n" + "Mobile_ro_RO.dat\n" + "Mobile_ru_RU.dat\n" + "Mobile_sk_SK.dat\n" + "Mobile_sl_SI.dat\n" + "Mobile_sv_SE.dat\n" + "Mobile_tr_TR.dat\n" + "Mobile_uk_UA.dat\n" + "Mobile_zh_CN.dat\n" + "Mobile_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/origin:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/origin/bin:\n" + "origin.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/origin/loc:\n" + "origin_ar_AE.dat\n" + "origin_bg_BG.dat\n" + "origin_cs_CZ.dat\n" + "origin_da_DK.dat\n" + "origin_de_DE.dat\n" + "origin_el_GR.dat\n" + "origin_en_US.dat\n" + "origin_es_ES.dat\n" + "origin_et_EE.dat\n" + "origin_fi_FI.dat\n" + "origin_fr_FR.dat\n" + "origin_he_IL.dat\n" + "origin_hr_HR.dat\n" + "origin_hu_HU.dat\n" + "origin_it_IT.dat\n" + "origin_ja_JP.dat\n" + "origin_ko_KR.dat\n" + "origin_lt_LT.dat\n" + "origin_lv_LV.dat\n" + "origin_nb_NO.dat\n" + "origin_nl_NL.dat\n" + "origin_pl_PL.dat\n" + "origin_pt_BR.dat\n" + "origin_ro_RO.dat\n" + "origin_ru_RU.dat\n" + "origin_sk_SK.dat\n" + "origin_sl_SI.dat\n" + "origin_sv_SE.dat\n" + "origin_tr_TR.dat\n" + "origin_uk_UA.dat\n" + "origin_zh_CN.dat\n" + "origin_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/rawpacket:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/rawpacket/bin:\n" + "rawpacket.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/rawpacket/loc:\n" + "Rawpacket_ar_AE.dat\n" + "Rawpacket_bg_BG.dat\n" + "Rawpacket_cs_CZ.dat\n" + "Rawpacket_da_DK.dat\n" + "Rawpacket_de_DE.dat\n" + "Rawpacket_el_GR.dat\n" + "Rawpacket_en_US.dat\n" + "Rawpacket_es_ES.dat\n" + "Rawpacket_et_EE.dat\n" + "Rawpacket_fi_FI.dat\n" + "Rawpacket_fr_FR.dat\n" + "Rawpacket_he_IL.dat\n" + "Rawpacket_hr_HR.dat\n" + "Rawpacket_hu_HU.dat\n" + "Rawpacket_it_IT.dat\n" + "Rawpacket_ja_JP.dat\n" + "Rawpacket_ko_KR.dat\n" + "Rawpacket_lt_LT.dat\n" + "Rawpacket_lv_LV.dat\n" + "Rawpacket_nb_NO.dat\n" + "Rawpacket_nl_NL.dat\n" + "Rawpacket_pl_PL.dat\n" + "Rawpacket_pt_BR.dat\n" + "Rawpacket_ro_RO.dat\n" + "Rawpacket_ru_RU.dat\n" + "Rawpacket_sk_SK.dat\n" + "Rawpacket_sl_SI.dat\n" + "Rawpacket_sv_SE.dat\n" + "Rawpacket_tr_TR.dat\n" + "Rawpacket_uk_UA.dat\n" + "Rawpacket_zh_CN.dat\n" + "Rawpacket_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/videoData:\n" + "bin\n" + "loc\n" + "manifest.xml\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/videoData/bin:\n" + "videoData.swf\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Application Data/Custom File Info " + "Panels/4.0/panels/videoData/loc:\n" + "VideoData_ar_AE.dat\n" + "VideoData_bg_BG.dat\n" + "VideoData_cs_CZ.dat\n" + "VideoData_da_DK.dat\n" + "VideoData_de_DE.dat\n" + "VideoData_el_GR.dat\n" + "VideoData_en_US.dat\n" + "VideoData_es_ES.dat\n" + "VideoData_et_EE.dat\n" + "VideoData_fi_FI.dat\n" + "VideoData_fr_FR.dat\n" + "VideoData_he_IL.dat\n" + "VideoData_hr_HR.dat\n" + "VideoData_hu_HU.dat\n" + "VideoData_it_IT.dat\n" + "VideoData_ja_JP.dat\n" + "VideoData_ko_KR.dat\n" + "VideoData_lt_LT.dat\n" + "VideoData_lv_LV.dat\n" + "VideoData_nb_NO.dat\n" + "VideoData_nl_NL.dat\n" + "VideoData_pl_PL.dat\n" + "VideoData_pt_BR.dat\n" + "VideoData_ro_RO.dat\n" + "VideoData_ru_RU.dat\n" + "VideoData_sk_SK.dat\n" + "VideoData_sl_SI.dat\n" + "VideoData_sv_SE.dat\n" + "VideoData_tr_TR.dat\n" + "VideoData_uk_UA.dat\n" + "VideoData_zh_CN.dat\n" + "VideoData_zh_TW.dat\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks:\n" + "AdbeScriptUIFlex.framework\n" + "AdobeACE.framework\n" + "AdobeAGM.framework\n" + "AdobeAXE8SharedExpat.framework\n" + "AdobeAXEDOMCore.framework\n" + "AdobeBIB.framework\n" + "AdobeBIBUtils.framework\n" + "AdobeCoolType.framework\n" + "AdobeCrashReporter.framework\n" + "AdobeExtendScript.framework\n" + "AdobeLinguistic.framework\n" + "AdobeMPS.framework\n" + "AdobeOwl.framework\n" + "AdobePDFSettings.framework\n" + "AdobePIP.framework\n" + "AdobeScCore.framework\n" + "AdobeUpdater.framework\n" + "AdobeXMP.framework\n" + "AdobeXMPFiles.framework\n" + "AdobeXMPScript.framework\n" + "AlignmentLib.framework\n" + "CIT\n" + "CIT.framework\n" + "CITThreading.framework\n" + "Cg.framework\n" + "FileInfo.framework\n" + "ICUConverter.framework\n" + "ICUData.framework\n" + "IMSLib.dylib\n" + "LogSession.framework\n" + "PlugPlugOwl.framework\n" + "WRServices.framework\n" + "adbeape.framework\n" + "adobe_caps.framework\n" + "adobejp2k.framework\n" + "adobepdfl.framework\n" + "ahclient.framework\n" + "aif_core.framework\n" + "aif_ocl.framework\n" + "aif_ogl.framework\n" + "amtlib.framework\n" + "boost_date_time.framework\n" + "boost_signals.framework\n" + "boost_system.framework\n" + "boost_threads.framework\n" + "dvaaudiodevice.framework\n" + "dvacore.framework\n" + "dvamarshal.framework\n" + "dvamediatypes.framework\n" + "dvaplayer.framework\n" + "dvatransport.framework\n" + "dvaunittesting.framework\n" + "dynamiclink.framework\n" + "filter_graph.framework\n" + "libtbb.dylib\n" + "libtbbmalloc.dylib\n" + "mediacoreif.framework\n" + "patchmatch.framework\n" + "updaternotifications.framework\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdbeScriptUIFlex.framework:\n" + "AdbeScriptUIFlex\n" + "Resources\n" + "Versions\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdbeScriptUIFlex.framework/Versions:\n" + "A\n" + "Current\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdbeScriptUIFlex.framework/Versions/A:\n" + "AdbeScriptUIFlex\n" + "CodeResources\n" + "_CodeSignature\n" + "resources\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdbeScriptUIFlex.framework/Versions/A/" + "_CodeSignature:\n" + "CodeResources\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdbeScriptUIFlex.framework/Versions/A/" + "resources:\n" + "Info.plist\n" + "english.lproj\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdbeScriptUIFlex.framework/Versions/A/" + "resources/english.lproj:\n" + "InfoPlist.strings\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdobeACE.framework:\n" + "AdobeACE\n" + "Versions\n" + "resources\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdobeACE.framework/Versions:\n" + "A\n" + "current\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdobeACE.framework/Versions/A:\n" + "AdobeACE\n" + "CodeResources\n" + "_CodeSignature\n" + "resources\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdobeACE.framework/Versions/A/" + "_CodeSignature:\n" + "CodeResources\n" + "\n" + "/Applications/Adobe Photoshop CC/Adobe Photoshop " + "CC.app/Contents/Frameworks/AdobeACE.framework/Versions/A/" + "resources:\n" + "Info.plist\n" + "english.lproj\n" + "" + << 123 << 123.123 << "haha" << '\n' + << std::string("12356"); + LOG_RAW << "Testing finished\n"; + LOG_RAW_TO(5) << "Testing finished\n"; + thread_.join(); + spdlog::shutdown(); +} diff --git a/trantor/tests/TcpAsyncStreamServerTest.cc b/trantor/tests/TcpAsyncStreamServerTest.cc new file mode 100644 index 00000000..d9761763 --- /dev/null +++ b/trantor/tests/TcpAsyncStreamServerTest.cc @@ -0,0 +1,53 @@ +#include +#include +#include +#include +#include +using namespace trantor; +#define USE_IPV6 0 +int main() +{ + LOG_DEBUG << "test start"; + Logger::setLogLevel(Logger::kTrace); + EventLoopThread loopThread; + loopThread.run(); +#if USE_IPV6 + InetAddress addr(8888, true, true); +#else + InetAddress addr(8888); +#endif + TcpServer server(loopThread.getLoop(), addr, "test"); + + server.setRecvMessageCallback( + [](const TcpConnectionPtr &connectionPtr, MsgBuffer *buffer) { + // LOG_DEBUG<<"recv callback!"; + std::cout << std::string(buffer->peek(), buffer->readableBytes()); + connectionPtr->send(buffer->peek(), buffer->readableBytes()); + buffer->retrieveAll(); + // connectionPtr->forceClose(); + }); + server.setConnectionCallback([](const TcpConnectionPtr &connPtr) { + if (connPtr->connected()) + { + LOG_DEBUG << "New connection"; + auto stream = connPtr->sendAsyncStream(); + stream->send("hello world 1..."); + std::thread([stream = std::move(stream)] { + for (int i = 2; i < 10; i++) + { + std::this_thread::sleep_for(std::chrono::seconds(1)); + stream->send("hello world " + std::to_string(i) + "..."); + } + stream->close(); + }).detach(); + connPtr->send("hello world"); + } + else if (connPtr->disconnected()) + { + LOG_DEBUG << "connection disconnected"; + } + }); + server.setIoLoopNum(3); + server.start(); + loopThread.wait(); +} diff --git a/trantor/tests/TcpClientTest.cc b/trantor/tests/TcpClientTest.cc index 68d6f899..bbbb33f1 100644 --- a/trantor/tests/TcpClientTest.cc +++ b/trantor/tests/TcpClientTest.cc @@ -4,6 +4,13 @@ #include #include #include +#ifdef _WIN32 +#include +#else +#include +#include +#endif + using namespace trantor; #define USE_IPV6 0 int main() @@ -24,13 +31,35 @@ int main() client[i] = std::make_shared(&loop, serverAddr, "tcpclienttest"); + client[i]->setSockOptCallback([](int fd) { + LOG_DEBUG << "setSockOptCallback!"; +#ifdef _WIN32 +#elif __linux__ + int optval = 10; + ::setsockopt(fd, + SOL_TCP, + TCP_KEEPCNT, + &optval, + static_cast(sizeof optval)); + ::setsockopt(fd, + SOL_TCP, + TCP_KEEPIDLE, + &optval, + static_cast(sizeof optval)); + ::setsockopt(fd, + SOL_TCP, + TCP_KEEPINTVL, + &optval, + static_cast(sizeof optval)); +#else +#endif + }); client[i]->setConnectionCallback( [i, &loop, &connCount](const TcpConnectionPtr &conn) { if (conn->connected()) { LOG_DEBUG << i << " connected!"; - char tmp[20]; - sprintf(tmp, "%d client!!", i); + std::string tmp = std::to_string(i) + " client!!"; conn->send(tmp); } else diff --git a/trantor/tests/TcpServerTest.cc b/trantor/tests/TcpServerTest.cc index dbd8b7c3..e89b7f1b 100644 --- a/trantor/tests/TcpServerTest.cc +++ b/trantor/tests/TcpServerTest.cc @@ -17,6 +17,12 @@ int main() InetAddress addr(8888); #endif TcpServer server(loopThread.getLoop(), addr, "test"); + server.setBeforeListenSockOptCallback([](int fd) { + std::cout << "setBeforeListenSockOptCallback:" << fd << std::endl; + }); + server.setAfterAcceptSockOptCallback([](int fd) { + std::cout << "afterAcceptSockOptCallback:" << fd << std::endl; + }); server.setRecvMessageCallback( [](const TcpConnectionPtr &connectionPtr, MsgBuffer *buffer) { // LOG_DEBUG<<"recv callback!"; diff --git a/trantor/tests/server.crt b/trantor/tests/server.crt new file mode 100644 index 00000000..e7ea6b77 --- /dev/null +++ b/trantor/tests/server.crt @@ -0,0 +1,33 @@ +-----BEGIN CERTIFICATE----- +MIIFrTCCA5WgAwIBAgIUDWd3k9cUXe271g10ep5vAZzvoL8wDQYJKoZIhvcNAQEL +BQAwYjELMAkGA1UEBhMCVFcxDzANBgNVBAgMBlRhaXBlaTEPMA0GA1UEBwwGVGFp +cGVpMQ8wDQYDVQQKDAZNYXJ0aW4xDzANBgNVBAsMBk1hcnRpbjEPMA0GA1UEAwwG +TWFydGluMCAXDTIzMDIwOTEyMzczNFoYDzIxMjMwMTE2MTIzNzM0WjBiMQswCQYD +VQQGEwJUVzEPMA0GA1UECAwGVGFpcGVpMQ8wDQYDVQQHDAZUYWlwZWkxDzANBgNV +BAoMBk1hcnRpbjEPMA0GA1UECwwGTWFydGluMQ8wDQYDVQQDDAZNYXJ0aW4wggIi +MA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQC/jq9cDERsVm6MrflAuoSYfdRf +dxLQlSUMLcX6w67RPvZHk71CBB8ycjStftcDrVMoXDPVZU8yboHhusmqeWGek8kD +1m0kKyOkMJ8IsWAobReoD/nNnNAMmnN785QqhhtUwfWEAIF9YJHoCkM9/Q/ysvTg +INtqtZauPi3TVo5019/UJTEKbixp33d8US1JzsPY8ZQ3WPSp3+8joPlXOgqIGZNi +9n1UWe5lq+/HODVHKG9o9aT6CSQzsKU02Uu/DlkNmK702dQUSnwbevy7qCr/Gms9 +MZ/dkY6IfTIPy8lb2iL31XCeMkh8/A9FA+uuInKurH7/DqsxbEZ3NcBXKnijd6+l +/Y2rwttJdEmMKjxVp3TE9b0cvT6MWVe21qI6Q+SqGYpK1sRNHm0OX5xZC2q3NTPi +i4T2IoN8628h80Xt8Pls5fig71t+Mw7/9m97W46WbruLmM945t+Q4rDijqU8gvVW +JYvSjRKO8RGKVj4drvqdzbJASwhyoadJWyD3QhkJhZ/2B1/7Ix153KNgnQ+q8yUB +4tIxeq3fRlvX5UCjHf/j6A1OCbHlvDUGiKBMt6vwKBhFoFj24Gf7o2rCx/jomduj +XcsX9maZbKnn7gvY+VcNVSydxSGsoBI3a8OBD5K0PrG59Qt/JNQtkAPe4gZP6KsW +vcRKC9nYxT9+2JIGaQIDAQABo1kwVzAUBgNVHREEDTALgglsb2NhbGhvc3QwCwYD +VR0PBAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMBMB0GA1UdDgQWBBRWDkJqcfm9 +wvz9w5tDpYTe0SIW0zANBgkqhkiG9w0BAQsFAAOCAgEArsjorf9iq2byicXC/tz+ +ml3ZZc20XjZfmM79yBwM2WpGJjEe+Rm1tZq7fWHO+gohXtQ7qX/5/RoQoBQaaAti +BVqYS+r9ab5Nqyo74wgFyRlCrvqhNgdqLm/itZuDjvAJSEMtQWLety5GhJPEFbOR +WbA4gHNSoh96OYs4s9p6PtJrnZMGflFEtvpXfdKi3Q9gK6/ib1f2zSNQ5vAYuBR4 +AzGDUvpEpRRu83OSlpYnPc3pITMyjL8B4EssOZBDYq4Yx2ZZO2xpjXM3Ns77g9Z7 +z/YqVDVAO8Fpw+byCBrtJtg9an+LnXuxji/MOYXkNnZqQJMonHuUZM2b3mPtQIgG +pRi0F4tSr3OLYNrffww8I3g5o9O19Bcw4mpNuI1tuLrlcrq67R3ddpi8wgxxzl92 +ghNknQJ/T2dpyBptdoh3U7qo/6pluBTH1tv3+FLDvcd7wTT2/lnvmmtG0te2fdvd +cucG+S8/I+OuH6OW0UuGHBdpB+vvMYbepLCri8FVs4x9/g8Wxpxujvd2fLmyKx+w +EGOIVjG5gIKwl0ohlCKOJZ02Oo7EgHoCB1EiHyNQ6zTq0N7P99uL23OSoeoSIPeU +WEPSP2YlVKKSB6+JMBFPFo2oj55mB8cu7y2rffp6G29CORA4FGPyA0So0vhpa32v +QTEL1RJLiGdqE980bxcP/+4= +-----END CERTIFICATE----- diff --git a/trantor/tests/server.key b/trantor/tests/server.key new file mode 100644 index 00000000..fe0bdabc --- /dev/null +++ b/trantor/tests/server.key @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJRAIBADANBgkqhkiG9w0BAQEFAASCCS4wggkqAgEAAoICAQC/jq9cDERsVm6M +rflAuoSYfdRfdxLQlSUMLcX6w67RPvZHk71CBB8ycjStftcDrVMoXDPVZU8yboHh +usmqeWGek8kD1m0kKyOkMJ8IsWAobReoD/nNnNAMmnN785QqhhtUwfWEAIF9YJHo +CkM9/Q/ysvTgINtqtZauPi3TVo5019/UJTEKbixp33d8US1JzsPY8ZQ3WPSp3+8j +oPlXOgqIGZNi9n1UWe5lq+/HODVHKG9o9aT6CSQzsKU02Uu/DlkNmK702dQUSnwb +evy7qCr/Gms9MZ/dkY6IfTIPy8lb2iL31XCeMkh8/A9FA+uuInKurH7/DqsxbEZ3 +NcBXKnijd6+l/Y2rwttJdEmMKjxVp3TE9b0cvT6MWVe21qI6Q+SqGYpK1sRNHm0O +X5xZC2q3NTPii4T2IoN8628h80Xt8Pls5fig71t+Mw7/9m97W46WbruLmM945t+Q +4rDijqU8gvVWJYvSjRKO8RGKVj4drvqdzbJASwhyoadJWyD3QhkJhZ/2B1/7Ix15 +3KNgnQ+q8yUB4tIxeq3fRlvX5UCjHf/j6A1OCbHlvDUGiKBMt6vwKBhFoFj24Gf7 +o2rCx/jomdujXcsX9maZbKnn7gvY+VcNVSydxSGsoBI3a8OBD5K0PrG59Qt/JNQt +kAPe4gZP6KsWvcRKC9nYxT9+2JIGaQIDAQABAoICAEni4OvZxXiePATiQ/5+Ew/4 +lPZ/qM+wf3o7m542ZVNLfFYue7UffuMH3x6+inPeInGyYsHgUlRrAIkPcaLiL8+p +RENJLY7iXtyBbo49UJA3SAUoqFtxLWR3HK1GTjO6x4cBS1Bvm4K/QXgloTsjRcgA +0+gxdECsKyMpU6atP8R80dZzw/84cMQjkGRwsU3DRZKD1/4jPzfY6tYszJAjEJXf +e5ST69Kh34zy7UlD+nToeVScT1asOP0BGTAR0qAuihXu+yjxblanRkiZPyuo2XDN +gWi4n+eoMbiexbUHDzNxJ8S9XLOARKqE1OTzdrATlHWgjlmWEF0/XMy1fGuCs8X6 +I1TiKB2zbVT9vQ1aoTcGjHCJUDjWu5VU3xU/ARPz/IxE6okWdsYfAno+gWvLr3w9 +VqHbD5JsIhaKr03PIMblCvB2jYzrVOo5tWd+W5d/354a3zh8BA7n8dTHLeugJw0U +HP7kL+FgcOlHTx++Gm8Fl6my4YVRkKiIEEErurfNsvPJK5LlUoPTD5G/3LXUspyP +K/8v/RMSEREonmCPqIDJRbqB64egA9f8F0uCcIAUsbiG+TeSC91tGvfgj2CxJHNz +ieyPx9kJC6xWECEexQrD3YSUe26p3Bgf4RJ8FG5LEEhOQQm5jAI6QWd2LJFPWwW+ +XE1+CM4v8Y/n+txSOU07AoIBAQDyrIyXxfBNvCoRK88ZHmFvUXvDnWqQYdfp1nJl +kmy3exXscwiac+QehFPnsPJoj8+8D7lb/rA7oDWPNamlKQWwmyndGZp+Z8G5cXXN +T5dHUZY9tuAoyn8JAjDDTwsV1h6p2K5JuErHPhglRFp25h0Oc9GnogTGIB6zCBDZ +trA6yDUqNTrUxB+f9Ul4SE3k5pd3tw1AWtqCs3ra699NX01cKRFyL+bOWOmoScT4 +q3pyO5BufWB7Lv/8EG1wqr6gNxiSMVrKlKupC/B7UBVnXnX4W5nfPfSoUK/nCciW +Mgj/v82QqPP+i4j/VfI5Kt6zv0WFxDUyH1GtNqhOHJihtgzbAoIBAQDKE4z0NC+K +N1ZoTlAJKAYdzW1C3dJ7PgQq1qS0Z3q85CfYlNxv6ds7PPqao4+J7JeSTwCAtwFW +u3japhQ4B5wKiKYsvyE/ws+syWdOJ1x2E2ibn7FnuENvgFmYc15mlA1yCWYYkkpX +1KOYtfrjyGHtw8PlA6hXBsyxA/F9IzWpStl9htVgpJC7gXZKX/e4oxNfeQ0lyvbe +0h6GaSihl26z5NSKQCCs36ahX2TEYS1Nck5UYaKz/5Uaq5x7ex0HB5aEjL8d4UU/ +bVnRE20snAsoEf6jeBHeyOVPGw2bT/4zk7vuv6FC2KhcqNrbVhOe5LrdIvJkydiZ +InWK3xrYKzsLAoIBAQC9anVq2fNRmbd0I6/IuW/wBbgG3c4Z2GVBfkNYiMwXAxn/ +r2JdvGuobj1XsUPk3auV7OgPqGJCiDCGEarS4YwxZ0tr6touJCqP5sG+eYto/YO5 +tA6PiE9T5sPNDttmNfVFOX4AyLqFfjA2ln3OJJs1dq2EnPAA/X043OjaJsCzgSYO +RfIftN3CayDno/g43MwJg3Xyb3fzYMhaLJXlvKeTcfLOIBmVoszusHXwa1ht5ZQ8 +ydwPCoaAZwolUQDt6VNieOeXDChZEJqqhb3PK2oFaupV1/QplKFYQsiwg2mGxl1b +tqSMYLmUI6+nc5DU2E0ZtiaXct67xtfj8GoqfwDVAoIBAQCTLsY9oDz4GPIwqsmU +wbgiwNtSFqsV5Me4Q/pXA//b0PpMv7AHO3fYn8OQGo2T0eVcRXqCRckN2SJfbxPO +84vuCDWw5c1b2ZLVsSQzQmwP/Hb20suuVgGYFw4rAezCHhfk9X+NahAIBPLbacDB +Y9QgD7SA+7cDHAq+67ZahOiy07exvCFycKqSR+tWpKuTqgOUSGERI9HH3ZcqIzHa +8KdLE+LSh37FK2j8pLSKbJVIkXcH8s1E+WUqtdAWCEfONPKmvLT/GHMNjaIbrGCa +W1Ws695iRjQN5plOks/ITe1Ct9nsPVtBivil9L7jfsBvvP11z9xpGLNQZk7ixTmS +NXqdAoIBAQCw87D8u+2j0Gcy3xTRAgL6I50hAqrqv0LcclR+d6M4hDe/VssUJ721 +Xs+lVX8JDO6CjzISMZ5uRWFXWFFquPmNav6eZf9SXb9mEXy5qFaoSrU7JgHKXyKS +1F0fMkSv3E3s2Iw8qQkHCI6D6Fza5fWiXvDaJBRRbxcxSakMB3PB//my1/h6rZY/ +i8PatPOMBY180PC2RtkWWjN4XeH0Ra8j7NXdtvZaFLYVJVrLYjUhzX2e6TstyGbH +k1Ert/B0+RxXCBMj5n36fGSw9lBFh1UiiE3Pfra7JOTs0PJG40T90UfJnQ93NUlT +AyBDBVOh3ubATMFsBx5bwp26rFaTcdg5 +-----END PRIVATE KEY----- diff --git a/trantor/tests/server.pem b/trantor/tests/server.pem deleted file mode 100644 index 507f700b..00000000 --- a/trantor/tests/server.pem +++ /dev/null @@ -1,49 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC/vLeZ4B80b22y -JSF9WoAEGo5D0GLlzSxL7rRZSd81h3YMKpsgFT4WkZxs9sqoRW/OTaqMGuEzGt6a -54RMwycQwJL+GGt/stZ+7AYHoLE6cQzP8iFcwy4mLFsX+gQjoE0UHr0RykKRl3Vw -qZST74AhAMlFvEDCfoMxTy9dy5ZjbLTFrH1kvo5CT9MEbyBLpxYrYYFjBAMIZdvV -82Mz3JftDLbx7t+96PaGLpDMLk9NgrV1kbUpzck/aCbZ5DgV7enTU33tVIcfMU5e -fZii044OkOvFNXibi6tdPORP8TBnrmEZjmDnFAgEijzpIcJ4VVPTrlCuxLUKcSDY -wYLro7gbAgMBAAECggEAesztcnoewkCTq0MovdZWo0o2z6wJi1DrC/7oNz+e2/PU -YVpwXA3+5AmCfC9cAIXoY+NOVclpbofJBsE89MUQoiQUgPU29GSgCE42VnBO0jVR -lWVohLblOcGy3hpcyEyE0VwWj+xQ0lqE9xFFfbIpB/ou7qDxgR/x+oTSu2oG+cmr -xv+Dj+k46QW0X2A3XHnlZkk0fQGb+Tf3fIJ4BBwF/UFQTDyZPZrsZfbpoEk1BnIf -VPcfM1/ca6h1t/oEnpD66H2TjxYzZC58mqZSZIZ+Pk+wocSsF5UcQLedcM137LDH -OR+fdscinr1xgMhV30RmAUP5tug9Km3oq8wf4Xl1IQKBgQDmDtYbwN7V1eA6Dbeq -HCk07t/llK+j4mvyuU+pVKfNHmnUUd+NPpDwREFrfoR6ilrZMJ4Obl7z53rja+9s -JqECBThhRzHxnUQgsu3DQeY7e2KxBBGGvhQBjRGtVRID3aFTTyMKvaOkTc7kL+x1 -5DbtoVAA1kbYnjgZJyZosx+bmQKBgQDVW6lXxoLzC7npXdXOTMJifExcCIzXskc9 -5b+rJTXgkqaPQSQ7nNmRReUY9r9UBsZKLduuYJZ3TLMCaU91/wBwc1ygkQvRz+8A -2QGMovlVOqwPAOEc1lMCXzQki9PcqmVX/e0oFmTtUgyF4RKEwL/q8A9fWy587rkQ -57MDNJ3h0wKBgQDNJgPFweq0EsGd4yeZuP0B59Wee0VYxgru6lLgM85iujEzFUNd -R6KlrqgLvElUoNW8gX8gbUmdBBlwfYqGDbhb/d212W/u/geHhSdCjBxLhI6QPYmH -dy6N54cQ4yBqdBNtH8+mv08SsBPDJf0db8GPi960sF+CwSxTObclfD2+WQKBgBgi -HxyLmsJNIEFSWN3V9uLW9ngui2fWhZJty2lbcyWs0ORBVQzdKArzof9Z4bhqb8Fy -QHgP+tURuunZ6aAKMQ2HLwIGhhS8dWdeJHu474UBdvbXfZ8aaxdIl4hOvK8oIwB5 -+3peVho1/q6iD8suVkcH0mVR1gdRpWNRIgGJ0RX7AoGBAMzp0u3X3SEMWTlqFR7D -BJqmpiMvJJJSqex1i4PJoy+uJpqfl3b1jUMnFIS4GiscVCIEJkPpZaz3eSTt7aVQ -+7v0b/Cusv1qYX2c5v9j9x7nmUlfdwT5rY2g6RXI/xlhPD9elb1dxlzmhiO4VQtr -AdiFmu0poyPpEGwjpcLEZq6K ------END PRIVATE KEY----- ------BEGIN CERTIFICATE----- -MIIDajCCAlICCQCXmD7IkkDArzANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJD -TjEQMA4GA1UECAwHQmVpamluZzEQMA4GA1UEBwwHQmVpamluZzEPMA0GA1UECgwG -YW4tdGFvMQ8wDQYDVQQLDAZhbi10YW8xIjAgBgkqhkiG9w0BCQEWE2FudGFvMjAw -MkBnbWFpbC5jb20wHhcNMTgwNzIxMDEyMjEyWhcNMTkwNzIxMDEyMjEyWjB3MQsw -CQYDVQQGEwJDTjEQMA4GA1UECAwHQmVpamluZzEQMA4GA1UEBwwHQmVpamluZzEP -MA0GA1UECgwGYW4tdGFvMQ8wDQYDVQQLDAZhbi10YW8xIjAgBgkqhkiG9w0BCQEW -E2FudGFvMjAwMkBnbWFpbC5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK -AoIBAQC/vLeZ4B80b22yJSF9WoAEGo5D0GLlzSxL7rRZSd81h3YMKpsgFT4WkZxs -9sqoRW/OTaqMGuEzGt6a54RMwycQwJL+GGt/stZ+7AYHoLE6cQzP8iFcwy4mLFsX -+gQjoE0UHr0RykKRl3VwqZST74AhAMlFvEDCfoMxTy9dy5ZjbLTFrH1kvo5CT9ME -byBLpxYrYYFjBAMIZdvV82Mz3JftDLbx7t+96PaGLpDMLk9NgrV1kbUpzck/aCbZ -5DgV7enTU33tVIcfMU5efZii044OkOvFNXibi6tdPORP8TBnrmEZjmDnFAgEijzp -IcJ4VVPTrlCuxLUKcSDYwYLro7gbAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAClA -Z6JHessY1mo0ObCp8Sh83gIiGu9M77xz6Gy2vUpUGE9N1hJwYH3uZPLbpnIuxZU0 -uDXd9w5vCJmtqo/4X6sedqKA6Upi8DROawtCUJaGrHrb2vJONyL7A3OWpHyYgOAa -z+sl4jbrGGv4azSg9ef1mQZTDcEcdsYiKj9U4zoyDXteyt6NQK3x5BtziB0mJfeM -gjrltVcc59s8Z2jf3PI3qmWtMuJm7KGHlan2DUz12AZxY3sJrFyew3Vi1V2uvIZB -YwgcXbVa3r6PQU3RwFwkazxAe80ZgiQlVsEiBl8mre4zznt5XHF6mzAqvpHyXPCI -LfEO3waLkgB94mFha5A= ------END CERTIFICATE----- diff --git a/trantor/unittests/CMakeLists.txt b/trantor/unittests/CMakeLists.txt index 4814b566..cdbc9ff9 100644 --- a/trantor/unittests/CMakeLists.txt +++ b/trantor/unittests/CMakeLists.txt @@ -5,13 +5,16 @@ add_executable(date_unittest DateUnittest.cc) add_executable(split_string_unittest splitStringUnittest.cc) add_executable(string_encoding_unittest stringEncodingUnittest.cc) add_executable(ssl_name_verify_unittest sslNameVerifyUnittest.cc) +add_executable(hash_unittest HashUnittest.cc) set(UNITTEST_TARGETS msgbuffer_unittest inetaddress_unittest date_unittest split_string_unittest string_encoding_unittest - ssl_name_verify_unittest) + ssl_name_verify_unittest + hash_unittest +) set_property(TARGET ${UNITTEST_TARGETS} PROPERTY CXX_STANDARD 14) set_property(TARGET ${UNITTEST_TARGETS} PROPERTY CXX_STANDARD_REQUIRED ON) set_property(TARGET ${UNITTEST_TARGETS} PROPERTY CXX_EXTENSIONS OFF) diff --git a/trantor/unittests/DateUnittest.cc b/trantor/unittests/DateUnittest.cc index 255ce285..9d516a75 100644 --- a/trantor/unittests/DateUnittest.cc +++ b/trantor/unittests/DateUnittest.cc @@ -7,37 +7,113 @@ TEST(Date, constructorTest) { EXPECT_STREQ("1985-01-01 00:00:00", trantor::Date(1985, 1, 1) - .toCustomedFormattedStringLocal("%Y-%m-%d %H:%M:%S") + .toCustomFormattedStringLocal("%Y-%m-%d %H:%M:%S") .c_str()); EXPECT_STREQ("2004-02-29 00:00:00.000000", trantor::Date(2004, 2, 29) - .toCustomedFormattedStringLocal("%Y-%m-%d %H:%M:%S", true) + .toCustomFormattedStringLocal("%Y-%m-%d %H:%M:%S", true) .c_str()); EXPECT_STRNE("2001-02-29 00:00:00.000000", trantor::Date(2001, 2, 29) - .toCustomedFormattedStringLocal("%Y-%m-%d %H:%M:%S", true) + .toCustomFormattedStringLocal("%Y-%m-%d %H:%M:%S", true) .c_str()); EXPECT_STREQ("2018-01-01 00:00:00.000000", trantor::Date(2018, 1, 1, 12, 12, 12, 2321) .roundDay() - .toCustomedFormattedStringLocal("%Y-%m-%d %H:%M:%S", true) + .toCustomFormattedStringLocal("%Y-%m-%d %H:%M:%S", true) .c_str()); } TEST(Date, DatabaseStringTest) { auto now = trantor::Date::now(); EXPECT_EQ(now, trantor::Date::fromDbStringLocal(now.toDbStringLocal())); + EXPECT_EQ(now, trantor::Date::fromDbString(now.toDbString())); std::string dbString = "2018-01-01 00:00:00.123"; auto dbDate = trantor::Date::fromDbStringLocal(dbString); auto ms = (dbDate.microSecondsSinceEpoch() % 1000000) / 1000; EXPECT_EQ(ms, 123); + EXPECT_EQ(dbDate, + trantor::Date::fromDbStringLocal(dbDate.toDbStringLocal())); + EXPECT_EQ(dbDate, trantor::Date::fromDbString(dbDate.toDbString())); + dbString = "2018-01-01 00:00:00.023"; + dbDate = trantor::Date::fromDbStringLocal(dbString); + ms = (dbDate.microSecondsSinceEpoch() % 1000000) / 1000; + EXPECT_EQ(ms, 23); + EXPECT_EQ(dbDate, + trantor::Date::fromDbStringLocal(dbDate.toDbStringLocal())); + EXPECT_EQ(dbDate, trantor::Date::fromDbString(dbDate.toDbString())); + dbString = "2018-01-01 00:00:00.003"; + dbDate = trantor::Date::fromDbStringLocal(dbString); + ms = (dbDate.microSecondsSinceEpoch() % 1000000) / 1000; + EXPECT_EQ(ms, 3); + EXPECT_EQ(dbDate, + trantor::Date::fromDbStringLocal(dbDate.toDbStringLocal())); + EXPECT_EQ(dbDate, trantor::Date::fromDbString(dbDate.toDbString())); + dbString = "2018-01-01 00:00:00.000123"; + dbDate = trantor::Date::fromDbStringLocal(dbString); + auto us = (dbDate.microSecondsSinceEpoch() % 1000000); + EXPECT_EQ(us, 123); + EXPECT_EQ(dbDate, + trantor::Date::fromDbStringLocal(dbDate.toDbStringLocal())); + EXPECT_EQ(dbDate, trantor::Date::fromDbString(dbDate.toDbString())); + dbString = "2018-01-01 00:00:00.000023"; + dbDate = trantor::Date::fromDbStringLocal(dbString); + us = (dbDate.microSecondsSinceEpoch() % 1000000); + EXPECT_EQ(us, 23); + EXPECT_EQ(dbDate, + trantor::Date::fromDbStringLocal(dbDate.toDbStringLocal())); + EXPECT_EQ(dbDate, trantor::Date::fromDbString(dbDate.toDbString())); + dbString = "2018-01-01 00:00:00.000003"; + dbDate = trantor::Date::fromDbStringLocal(dbString); + us = (dbDate.microSecondsSinceEpoch() % 1000000); + EXPECT_EQ(us, 3); + dbString = "2018-01-01 00:00:00"; dbDate = trantor::Date::fromDbStringLocal(dbString); ms = (dbDate.microSecondsSinceEpoch() % 1000000) / 1000; EXPECT_EQ(ms, 0); + + dbString = "2018-01-01 00:00:00"; + dbDate = trantor::Date::fromDbStringLocal(dbString); + auto dbDateGMT = trantor::Date::fromDbString(dbString); + auto secLocal = (dbDate.microSecondsSinceEpoch() / 1000000); + auto secGMT = (dbDateGMT.microSecondsSinceEpoch() / 1000000); + // timeZone at least 1 minute (can be >=1 hour, 30 min, 15 min. Error if + // difference less then minute) + auto timeZoneOffsetMinutePart = (secLocal - secGMT) % 60; + EXPECT_EQ(timeZoneOffsetMinutePart, 0); + dbString = "2018-01-01 00:00:00.123"; + dbDate = trantor::Date::fromDbString(dbString); + ms = (dbDate.microSecondsSinceEpoch() % 1000000) / 1000; + EXPECT_EQ(ms, 123); + dbString = "2018-01-01 00:00:00.023"; + dbDate = trantor::Date::fromDbString(dbString); + ms = (dbDate.microSecondsSinceEpoch() % 1000000) / 1000; + EXPECT_EQ(ms, 23); + dbString = "2018-01-01 00:00:00.003"; + dbDate = trantor::Date::fromDbString(dbString); + ms = (dbDate.microSecondsSinceEpoch() % 1000000) / 1000; + EXPECT_EQ(ms, 3); + dbString = "2018-01-01 00:00:00.000123"; + dbDate = trantor::Date::fromDbString(dbString); + us = (dbDate.microSecondsSinceEpoch() % 1000000); + EXPECT_EQ(us, 123); + dbString = "2018-01-01 00:00:00.000023"; + dbDate = trantor::Date::fromDbString(dbString); + us = (dbDate.microSecondsSinceEpoch() % 1000000); + EXPECT_EQ(us, 23); + dbString = "2018-01-01 00:00:00.000003"; + dbDate = trantor::Date::fromDbString(dbString); + us = (dbDate.microSecondsSinceEpoch() % 1000000); + EXPECT_EQ(us, 3); + + dbString = "1970-01-01"; + dbDateGMT = trantor::Date::fromDbString(dbString); + auto epoch = dbDateGMT.microSecondsSinceEpoch(); + EXPECT_EQ(epoch, 0); } int main(int argc, char **argv) { testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} \ No newline at end of file +} diff --git a/trantor/unittests/HashUnittest.cc b/trantor/unittests/HashUnittest.cc new file mode 100644 index 00000000..47313df9 --- /dev/null +++ b/trantor/unittests/HashUnittest.cc @@ -0,0 +1,58 @@ +#include + +#include + +#include +#include +using namespace trantor; +using namespace trantor::utils; + +TEST(Hash, MD5) +{ + EXPECT_EQ(toHexString(md5("hello")), "5D41402ABC4B2A76B9719D911017C592"); + EXPECT_EQ(toHexString(md5("trantor")), "95FC641C9E629D2854B0B60F5A51E1FD"); +} + +TEST(Hash, SHA1) +{ + EXPECT_EQ(toHexString(sha1("hello")), + "AAF4C61DDCC5E8A2DABEDE0F3B482CD9AEA9434D"); + EXPECT_EQ(toHexString(sha1("trantor")), + "A9E084054D439FCD87D2438FB5FE4DDD7D8CC204"); +} + +TEST(Hash, SHA256) +{ + EXPECT_EQ( + toHexString(sha256("hello")), + "2CF24DBA5FB0A30E26E83B2AC5B9E29E1B161E5C1FA7425E73043362938B9824"); + EXPECT_EQ( + toHexString(sha256("trantor")), + "C72002E712A3BA6D60125D4B3D0B816758FBDCA98F2A892077BD4182E71CF6F5"); +} + +TEST(Hash, SHA3) +{ + EXPECT_EQ( + toHexString(sha3("hello")), + "3338BE694F50C5F338814986CDF0686453A888B84F424D792AF4B9202398F392"); + EXPECT_EQ( + toHexString(sha3("trantor")), + "135E1D2372F0A48525E09D47C6FFCA14077D8C5A0905410FA81C30ED9AFF696A"); +} + +TEST(Hash, BLAKE2b) +{ + EXPECT_EQ( + toHexString(blake2b("hello")), + "324DCF027DD4A30A932C441F365A25E86B173DEFA4B8E58948253471B81B72CF"); + EXPECT_EQ( + toHexString(blake2b("trantor")), + "2D03B3D7E76C52DD7A32689ADE4406798B50BC5B09428E3F90F56182898873C8"); +} + +int main(int argc, char **argv) +{ + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/trantor/unittests/InetAddressUnittest.cc b/trantor/unittests/InetAddressUnittest.cc index 487ee37d..7bf57efe 100644 --- a/trantor/unittests/InetAddressUnittest.cc +++ b/trantor/unittests/InetAddressUnittest.cc @@ -15,8 +15,33 @@ TEST(InetAddress, innerIpTest) EXPECT_EQ(false, InetAddress("127.0.0.2", 0).isUnspecified()); EXPECT_EQ(false, InetAddress("0.0.0.0", 0).isUnspecified()); } +TEST(InetAddress, toIpPortNetEndianTest) +{ + EXPECT_EQ(std::string({char(192), char(168), 0, 1, 0, 80}), + InetAddress("192.168.0.1", 80).toIpPortNetEndian()); + EXPECT_EQ(std::string({0x20, + 0x01, + 0x0d, + char(0xb8), + 0x33, + 0x33, + 0x44, + 0x44, + 0x55, + 0x55, + 0x66, + 0x66, + 0x77, + 0x77, + char(0x88), + char(0x88), + 1, + char(187)}), + InetAddress("2001:0db8:3333:4444:5555:6666:7777:8888", 443, true) + .toIpPortNetEndian()); +} int main(int argc, char **argv) { testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} \ No newline at end of file +} diff --git a/trantor/unittests/stringEncodingUnittest.cc b/trantor/unittests/stringEncodingUnittest.cc old mode 100755 new mode 100644 diff --git a/trantor/utils/AsyncFileLogger.cc b/trantor/utils/AsyncFileLogger.cc index 3b97dae0..8aa034d6 100755 --- a/trantor/utils/AsyncFileLogger.cc +++ b/trantor/utils/AsyncFileLogger.cc @@ -14,15 +14,18 @@ #include #include -#ifndef _WIN32 +#if !defined(_WIN32) || defined(__MINGW32__) #include +#include +#include #ifdef __linux__ #include #endif #else -#include +#include #endif #include +#include #include #include #include @@ -61,7 +64,7 @@ AsyncFileLogger::~AsyncFileLogger() } while (!writeBuffers_.empty()) { - StringPtr tmpPtr = (StringPtr &&) writeBuffers_.front(); + StringPtr tmpPtr = (StringPtr &&)writeBuffers_.front(); writeBuffers_.pop(); writeLogToFile(tmpPtr); } @@ -119,13 +122,17 @@ void AsyncFileLogger::writeLogToFile(const StringPtr buf) { if (!loggerFilePtr_) { - loggerFilePtr_ = std::unique_ptr( - new LoggerFile(filePath_, fileBaseName_, fileExtName_)); + loggerFilePtr_ = + std::unique_ptr(new LoggerFile(filePath_, + fileBaseName_, + fileExtName_, + switchOnLimitOnly_, + maxFiles_)); } loggerFilePtr_->writeLog(buf); if (loggerFilePtr_->getLength() > sizeLimit_) { - loggerFilePtr_.reset(); + loggerFilePtr_->switchLog(true); } } @@ -155,7 +162,7 @@ void AsyncFileLogger::logThreadFunc() while (!tmpBuffers_.empty()) { - StringPtr tmpPtr = (StringPtr &&) tmpBuffers_.front(); + StringPtr tmpPtr = (StringPtr &&)tmpBuffers_.front(); tmpBuffers_.pop(); writeLogToFile(tmpPtr); tmpPtr->clear(); @@ -177,13 +184,31 @@ void AsyncFileLogger::startLogging() AsyncFileLogger::LoggerFile::LoggerFile(const std::string &filePath, const std::string &fileBaseName, - const std::string &fileExtName) + const std::string &fileExtName, + bool switchOnLimitOnly, + size_t maxFiles) : creationDate_(Date::date()), filePath_(filePath), fileBaseName_(fileBaseName), - fileExtName_(fileExtName) + fileExtName_(fileExtName), + switchOnLimitOnly_(switchOnLimitOnly), + maxFiles_(maxFiles) +{ + open(); + + if (maxFiles_ > 0) + { + initFilenameQueue(); + } +} + +/** + * Open file for append logs + * Always write to file with base name. + */ +void AsyncFileLogger::LoggerFile::open() { - fileFullName_ = filePath + fileBaseName + fileExtName; + fileFullName_ = filePath_ + fileBaseName_ + fileExtName_; #ifndef _MSC_VER fp_ = fopen(fileFullName_.c_str(), "a"); #else @@ -222,20 +247,29 @@ uint64_t AsyncFileLogger::LoggerFile::getLength() return 0; } -AsyncFileLogger::LoggerFile::~LoggerFile() +/** + * Force store the current file (with the base name) + * with the newly generated name by adding time point. + * + * @param openNewOne - true for keeping log file opened and continuing logging + */ +void AsyncFileLogger::LoggerFile::switchLog(bool openNewOne) { if (fp_) { fclose(fp_); + fp_ = nullptr; + char seq[12]; snprintf(seq, sizeof(seq), ".%06llu", static_cast(fileSeq_ % 1000000)); ++fileSeq_; + // NOTE: Remember to update initFilenameQueue() if name format changes std::string newName = filePath_ + fileBaseName_ + "." + - creationDate_.toCustomedFormattedString("%y%m%d-%H%M%S") + + creationDate_.toCustomFormattedString("%y%m%d-%H%M%S") + std::string(seq) + fileExtName_; #if !defined(_WIN32) || defined(__MINGW32__) rename(fileFullName_.c_str(), newName.c_str()); @@ -245,6 +279,119 @@ AsyncFileLogger::LoggerFile::~LoggerFile() auto wNewName{utils::toNativePath(newName)}; _wrename(wFullName.c_str(), wNewName.c_str()); #endif + if (maxFiles_ > 0) + { + filenameQueue_.push_back(newName); + if (filenameQueue_.size() > maxFiles_) + { + deleteOldFiles(); + } + } + if (openNewOne) + open(); // continue logging with base name until next renaming will + // be required + } +} + +AsyncFileLogger::LoggerFile::~LoggerFile() +{ + if (!switchOnLimitOnly_) // rename on each destroy + switchLog(false); + if (fp_) + fclose(fp_); +} + +void AsyncFileLogger::LoggerFile::initFilenameQueue() +{ + if (maxFiles_ <= 0) + { + return; + } + + // walk through the directory and file all files +#if !defined(_WIN32) || defined(__MINGW32__) + DIR *dp; + struct dirent *dirp; + struct stat st; + + if ((dp = opendir(filePath_.c_str())) == nullptr) + { + fprintf(stderr, + "Can't open dir %s: %s\n", + filePath_.c_str(), + strerror_tl(errno)); + return; + } + + while ((dirp = readdir(dp)) != nullptr) + { + std::string name = dirp->d_name; + // .yymmdd-hhmmss.000000 + // NOTE: magic number 21: the length of middle part of generated name + if (name.size() != fileBaseName_.size() + 21 + fileExtName_.size() || + name.compare(0, fileBaseName_.size(), fileBaseName_) != 0 || + name.compare(name.size() - fileExtName_.size(), + fileExtName_.size(), + fileExtName_) != 0) + { + continue; + } + std::string fullname = filePath_ + name; + if (stat(fullname.c_str(), &st) == -1) + { + fprintf(stderr, + "Can't stat file %s: %s\n", + fullname.c_str(), + strerror_tl(errno)); + continue; + } + if (!S_ISREG(st.st_mode)) + { + continue; + } + filenameQueue_.push_back(fullname); + std::push_heap(filenameQueue_.begin(), + filenameQueue_.end(), + std::greater<>()); + if (filenameQueue_.size() > maxFiles_) + { + std::pop_heap(filenameQueue_.begin(), + filenameQueue_.end(), + std::greater<>()); + auto fileToRemove = std::move(filenameQueue_.back()); + filenameQueue_.pop_back(); + remove(fileToRemove.c_str()); + } + } + closedir(dp); +#else + // TODO: windows implementation +#endif + + std::sort(filenameQueue_.begin(), filenameQueue_.end(), std::less<>()); +} + +void AsyncFileLogger::LoggerFile::deleteOldFiles() +{ + while (filenameQueue_.size() > maxFiles_) + { + std::string filename = std::move(filenameQueue_.front()); + filenameQueue_.pop_front(); + +#if !defined(_WIN32) || defined(__MINGW32__) + int r = remove(filename.c_str()); +#else + // Convert UTF-8 file to UCS-2 + auto wName{utils::toNativePath(filename)}; + int r = _wremove(wName.c_str()); +#endif + if (r != 0) + { + fprintf(stderr, + "Failed to remove file %s: %s\n", + filename.c_str(), + strerror_tl(errno)); + } } } diff --git a/trantor/utils/AsyncFileLogger.h b/trantor/utils/AsyncFileLogger.h index 0f54c679..af0952f4 100644 --- a/trantor/utils/AsyncFileLogger.h +++ b/trantor/utils/AsyncFileLogger.h @@ -69,6 +69,29 @@ class TRANTOR_EXPORT AsyncFileLogger : NonCopyable sizeLimit_ = limit; } + /** + * @brief Set the max number of log files. When the number exceeds the + * limit, the oldest log file will be deleted. + * + * @param maxFiles + */ + void setMaxFiles(size_t maxFiles) + { + maxFiles_ = maxFiles; + } + + /** + * @brief Set whether to switch the log file when the AsyncFileLogger object + * is destroyed. If this flag is set to true, the log file is not switched + * when the AsyncFileLogger object is destroyed. + * + * @param flag + */ + void setSwitchOnLimitOnly(bool flag = true) + { + switchOnLimitOnly_ = flag; + } + /** * @brief Set the log file name. * @@ -107,14 +130,22 @@ class TRANTOR_EXPORT AsyncFileLogger : NonCopyable std::string fileBaseName_{"trantor"}; std::string fileExtName_{".log"}; uint64_t sizeLimit_{20 * 1024 * 1024}; + bool switchOnLimitOnly_{false}; // by default false, will generate new + // file name on each destroy. + size_t maxFiles_{0}; + class LoggerFile : NonCopyable { public: LoggerFile(const std::string &filePath, const std::string &fileBaseName, - const std::string &fileExtName); + const std::string &fileExtName, + bool switchOnLimitOnly = false, + size_t maxFiles = 0); ~LoggerFile(); void writeLog(const StringPtr buf); + void open(); + void switchLog(bool openNewOne); uint64_t getLength(); explicit operator bool() const { @@ -123,6 +154,9 @@ class TRANTOR_EXPORT AsyncFileLogger : NonCopyable void flush(); protected: + void initFilenameQueue(); + void deleteOldFiles(); + FILE *fp_{nullptr}; Date creationDate_; std::string fileFullName_; @@ -130,6 +164,12 @@ class TRANTOR_EXPORT AsyncFileLogger : NonCopyable std::string fileBaseName_; std::string fileExtName_; static uint64_t fileSeq_; + bool switchOnLimitOnly_{false}; // by default false, will generate new + // file name on each destroy + + size_t maxFiles_{0}; + // store generated filenames + std::deque filenameQueue_; }; std::unique_ptr loggerFilePtr_; diff --git a/trantor/utils/Date.cc b/trantor/utils/Date.cc index dd9c284f..26acf794 100644 --- a/trantor/utils/Date.cc +++ b/trantor/utils/Date.cc @@ -21,7 +21,7 @@ #include #include #ifdef _WIN32 -#include +#include #include #endif @@ -55,29 +55,29 @@ const Date Date::date() struct timeval tv; gettimeofday(&tv, NULL); int64_t seconds = tv.tv_sec; - return Date(seconds * MICRO_SECONDS_PRE_SEC + tv.tv_usec); + return Date(seconds * MICRO_SECONDS_PER_SEC + tv.tv_usec); #else timeval tv; gettimeofday(&tv, NULL); int64_t seconds = tv.tv_sec; - return Date(seconds * MICRO_SECONDS_PRE_SEC + tv.tv_usec); + return Date(seconds * MICRO_SECONDS_PER_SEC + tv.tv_usec); #endif } const Date Date::after(double second) const { return Date(static_cast(microSecondsSinceEpoch_ + - second * MICRO_SECONDS_PRE_SEC)); + second * MICRO_SECONDS_PER_SEC)); } const Date Date::roundSecond() const { return Date(microSecondsSinceEpoch_ - - (microSecondsSinceEpoch_ % MICRO_SECONDS_PRE_SEC)); + (microSecondsSinceEpoch_ % MICRO_SECONDS_PER_SEC)); } const Date Date::roundDay() const { struct tm t; time_t seconds = - static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC); #ifndef _WIN32 localtime_r(&seconds, &t); #else @@ -86,12 +86,12 @@ const Date Date::roundDay() const t.tm_hour = 0; t.tm_min = 0; t.tm_sec = 0; - return Date(mktime(&t) * MICRO_SECONDS_PRE_SEC); + return Date(mktime(&t) * MICRO_SECONDS_PER_SEC); } struct tm Date::tmStruct() const { time_t seconds = - static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC); struct tm tm_time; #ifndef _WIN32 gmtime_r(&seconds, &tm_time); @@ -105,7 +105,7 @@ std::string Date::toFormattedString(bool showMicroseconds) const // std::cout<<"toFormattedString"<(microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC); struct tm tm_time; #ifndef _WIN32 gmtime_r(&seconds, &tm_time); @@ -116,7 +116,7 @@ std::string Date::toFormattedString(bool showMicroseconds) const if (showMicroseconds) { int microseconds = - static_cast(microSecondsSinceEpoch_ % MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ % MICRO_SECONDS_PER_SEC); snprintf(buf, sizeof(buf), "%4d%02d%02d %02d:%02d:%02d.%06d", @@ -142,12 +142,12 @@ std::string Date::toFormattedString(bool showMicroseconds) const } return buf; } -std::string Date::toCustomedFormattedString(const std::string &fmtStr, - bool showMicroseconds) const +std::string Date::toCustomFormattedString(const std::string &fmtStr, + bool showMicroseconds) const { char buf[256] = {0}; time_t seconds = - static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC); struct tm tm_time; #ifndef _WIN32 gmtime_r(&seconds, &tm_time); @@ -159,17 +159,17 @@ std::string Date::toCustomedFormattedString(const std::string &fmtStr, return std::string(buf); char decimals[12] = {0}; int microseconds = - static_cast(microSecondsSinceEpoch_ % MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ % MICRO_SECONDS_PER_SEC); snprintf(decimals, sizeof(decimals), ".%06d", microseconds); return std::string(buf) + decimals; } -void Date::toCustomedFormattedString(const std::string &fmtStr, - char *str, - size_t len) const +void Date::toCustomFormattedString(const std::string &fmtStr, + char *str, + size_t len) const { // not safe time_t seconds = - static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC); struct tm tm_time; #ifndef _WIN32 gmtime_r(&seconds, &tm_time); @@ -183,7 +183,7 @@ std::string Date::toFormattedStringLocal(bool showMicroseconds) const // std::cout<<"toFormattedString"<(microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC); struct tm tm_time; #ifndef _WIN32 localtime_r(&seconds, &tm_time); @@ -194,7 +194,7 @@ std::string Date::toFormattedStringLocal(bool showMicroseconds) const if (showMicroseconds) { int microseconds = - static_cast(microSecondsSinceEpoch_ % MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ % MICRO_SECONDS_PER_SEC); snprintf(buf, sizeof(buf), "%4d%02d%02d %02d:%02d:%02d.%06d", @@ -224,7 +224,7 @@ std::string Date::toDbStringLocal() const { char buf[128] = {0}; time_t seconds = - static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC); struct tm tm_time; #ifndef _WIN32 localtime_r(&seconds, &tm_time); @@ -232,11 +232,11 @@ std::string Date::toDbStringLocal() const localtime_s(&tm_time, &seconds); #endif bool showMicroseconds = - (microSecondsSinceEpoch_ % MICRO_SECONDS_PRE_SEC != 0); + (microSecondsSinceEpoch_ % MICRO_SECONDS_PER_SEC != 0); if (showMicroseconds) { int microseconds = - static_cast(microSecondsSinceEpoch_ % MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ % MICRO_SECONDS_PER_SEC); snprintf(buf, sizeof(buf), "%4d-%02d-%02d %02d:%02d:%02d.%06d", @@ -284,11 +284,36 @@ Date Date::fromDbStringLocal(const std::string &datetime) unsigned int year = {0}, month = {0}, day = {0}, hour = {0}, minute = {0}, second = {0}, microSecond = {0}; std::vector &&v = splitString(datetime, " "); - if (2 == v.size()) + + if (v.size() == 0) + { + throw std::invalid_argument("Invalid date string: " + datetime); + } + const std::vector date = splitString(v[0], "-"); + if (date.size() != 3) + { + throw std::invalid_argument("Invalid date string: " + datetime); + } + if (v.size() == 1) { - // date - std::vector date = splitString(v[0], "-"); - if (3 == date.size()) + // Fromat YYYY-MM-DD is given + try + { + year = std::stol(date[0]); + month = std::stol(date[1]); + day = std::stol(date[2]); + } + catch (...) + { + throw std::invalid_argument("Invalid date string: " + datetime); + } + return Date(year, month, day, hour, minute, second, microSecond); + } + + if (v.size() == 2) + { + // Format YYYY-MM-DD HH:MM:SS[.UUUUUU] is given + try { year = std::stol(date[0]); month = std::stol(date[1]); @@ -314,21 +339,28 @@ Date Date::fromDbStringLocal(const std::string &datetime) } } } + catch (...) + { + throw std::invalid_argument("Invalid date string: " + datetime); + } + return Date(year, month, day, hour, minute, second, microSecond); } - return trantor::Date(year, month, day, hour, minute, second, microSecond); + + throw std::invalid_argument("Invalid date string: " + datetime); } + Date Date::fromDbString(const std::string &datetime) { return fromDbStringLocal(datetime).after( static_cast(timezoneOffset())); } -std::string Date::toCustomedFormattedStringLocal(const std::string &fmtStr, - bool showMicroseconds) const +std::string Date::toCustomFormattedStringLocal(const std::string &fmtStr, + bool showMicroseconds) const { char buf[256] = {0}; time_t seconds = - static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC); struct tm tm_time; #ifndef _WIN32 localtime_r(&seconds, &tm_time); @@ -340,7 +372,7 @@ std::string Date::toCustomedFormattedStringLocal(const std::string &fmtStr, return std::string(buf); char decimals[12] = {0}; int microseconds = - static_cast(microSecondsSinceEpoch_ % MICRO_SECONDS_PRE_SEC); + static_cast(microSecondsSinceEpoch_ % MICRO_SECONDS_PER_SEC); snprintf(decimals, sizeof(decimals), ".%06d", microseconds); return std::string(buf) + decimals; } @@ -364,7 +396,7 @@ Date::Date(unsigned int year, tm.tm_sec = second; epoch = mktime(&tm); microSecondsSinceEpoch_ = - static_cast(epoch) * MICRO_SECONDS_PRE_SEC + microSecond; + static_cast(epoch) * MICRO_SECONDS_PER_SEC + microSecond; } } // namespace trantor diff --git a/trantor/utils/Date.h b/trantor/utils/Date.h index 73386c4c..4684227d 100644 --- a/trantor/utils/Date.h +++ b/trantor/utils/Date.h @@ -18,8 +18,6 @@ #include #include -#define MICRO_SECONDS_PRE_SEC 1000000LL - namespace trantor { /** @@ -76,8 +74,9 @@ class TRANTOR_EXPORT Date static int64_t timezoneOffset() { - static int64_t offset = - -Date::fromDbStringLocal("1970-01-01 00:00:00").secondsSinceEpoch(); + static int64_t offset = -( + Date::fromDbStringLocal("1970-01-03 00:00:00").secondsSinceEpoch() - + 2LL * 3600LL * 24LL); return offset; } @@ -182,7 +181,7 @@ class TRANTOR_EXPORT Date */ int64_t secondsSinceEpoch() const { - return microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC; + return microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC; } /** @@ -194,50 +193,86 @@ class TRANTOR_EXPORT Date /** * @brief Generate a UTC time string - * @example: - * 20180101 10:10:25 //If the @param showMicroseconds is false - * 20180101 10:10:25:102414 //If the @param showMicroseconds is true + * @param showMicroseconds whether the microseconds are returned. + * @note Examples: + * - "20180101 10:10:25" if the @p showMicroseconds is false + * - "20180101 10:10:25:102414" if the @p showMicroseconds is true */ std::string toFormattedString(bool showMicroseconds) const; + /* clang-format off */ /** - * @brief Generate a UTC time string formated by the @param fmtStr - * The @param fmtStr is the format string for the function strftime() - * @example: - * 2018-01-01 10:10:25 //If the @param fmtStr is "%Y-%m-%d - * %H:%M:%S" and the @param showMicroseconds is false 2018-01-01 - * 10:10:25:102414 //If the @param fmtStr is "%Y-%m-%d %H:%M:%S" and the - * @param showMicroseconds is true + * @brief Generate a UTC time string formatted by the @p fmtStr + * @param fmtStr is the format string for the function strftime() + * @param showMicroseconds whether the microseconds are returned. + * @note Examples: + * - "2018-01-01 10:10:25" if the @p fmtStr is "%Y-%m-%d %H:%M:%S" and the + * @p showMicroseconds is false + * - "2018-01-01 10:10:25:102414" if the @p fmtStr is "%Y-%m-%d %H:%M:%S" + * and the @p showMicroseconds is true + * @deprecated Replaced by toCustomFormattedString */ + [[deprecated("Replaced by toCustomFormattedString")]] std::string toCustomedFormattedString(const std::string &fmtStr, - bool showMicroseconds = false) const; - + bool showMicroseconds = false) const + { + return toCustomFormattedString(fmtStr, showMicroseconds); + }; + /* clang-format on */ + /** + * @brief Generate a UTC time string formatted by the @p fmtStr + * @param fmtStr is the format string for the function strftime() + * @param showMicroseconds whether the microseconds are returned. + * @note Examples: + * - "2018-01-01 10:10:25" if the @p fmtStr is "%Y-%m-%d %H:%M:%S" and the + * @p showMicroseconds is false + * - "2018-01-01 10:10:25:102414" if the @p fmtStr is "%Y-%m-%d %H:%M:%S" + * and the @p showMicroseconds is true + */ + std::string toCustomFormattedString(const std::string &fmtStr, + bool showMicroseconds = false) const; /** * @brief Generate a local time zone string, the format of the string is - * same as the mothed toFormattedString + * same as the method toFormattedString * * @param showMicroseconds * @return std::string */ std::string toFormattedStringLocal(bool showMicroseconds) const; + /* clang-format off */ + /** + * @brief Generate a local time zone string formatted by the @p fmtStr + * + * @param fmtStr + * @param showMicroseconds + * @return std::string + * @deprecated Replaced by toCustomFormattedString + */ + [[deprecated("Replaced by toCustomFormattedStringLocal")]] + std::string toCustomedFormattedStringLocal(const std::string &fmtStr, + bool showMicroseconds = false) const + { + return toCustomFormattedStringLocal(fmtStr, showMicroseconds); + } + /* clang-format on */ /** - * @brief Generate a local time zone string formated by the @param fmtStr + * @brief Generate a local time zone string formatted by the @p fmtStr * * @param fmtStr * @param showMicroseconds * @return std::string */ - std::string toCustomedFormattedStringLocal( + std::string toCustomFormattedStringLocal( const std::string &fmtStr, bool showMicroseconds = false) const; /** * @brief Generate a local time zone string for database. - * @example: - * 2018-01-01 //If hours, minutes, seconds and - * microseconds are zero 2018-01-01 10:10:25 //If the microsecond - * is zero 2018-01-01 10:10:25:102414 //If the microsecond is not zero + * @note Examples: + * - "2018-01-01" if hours, minutes, seconds and microseconds are zero + * - "2018-01-01 10:10:25" if the microsecond is zero + * - "2018-01-01 10:10:25:102414" if the microsecond is not zero */ std::string toDbStringLocal() const; /** @@ -258,16 +293,34 @@ class TRANTOR_EXPORT Date */ static Date fromDbString(const std::string &datetime); + /* clang-format off */ /** * @brief Generate a UTC time string. * * @param fmtStr The format string. * @param str The string buffer for the generated time string. * @param len The length of the string buffer. + * @deprecated Replaced by toCustomFormattedString */ + [[deprecated("Replaced by toCustomFormattedString")]] void toCustomedFormattedString(const std::string &fmtStr, char *str, - size_t len) const; // UTC + size_t len) const + { + toCustomFormattedString(fmtStr, str, len); + } + /* clang-format on */ + + /** + * @brief Generate a UTC time string. + * + * @param fmtStr The format string. + * @param str The string buffer for the generated time string. + * @param len The length of the string buffer. + */ + void toCustomFormattedString(const std::string &fmtStr, + char *str, + size_t len) const; // UTC /** * @brief Return true if the time point is in a same second as another. @@ -278,8 +331,8 @@ class TRANTOR_EXPORT Date */ bool isSameSecond(const Date &date) const { - return microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC == - date.microSecondsSinceEpoch_ / MICRO_SECONDS_PRE_SEC; + return microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC == + date.microSecondsSinceEpoch_ / MICRO_SECONDS_PER_SEC; } /** @@ -292,6 +345,8 @@ class TRANTOR_EXPORT Date std::swap(microSecondsSinceEpoch_, that.microSecondsSinceEpoch_); } + static constexpr long MICRO_SECONDS_PER_SEC = 1000000LL; + private: int64_t microSecondsSinceEpoch_{0}; }; diff --git a/trantor/utils/Funcs.h b/trantor/utils/Funcs.h index 7a13273b..79f7a473 100644 --- a/trantor/utils/Funcs.h +++ b/trantor/utils/Funcs.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include namespace trantor { inline uint64_t hton64(uint64_t n) diff --git a/trantor/utils/Logger.cc b/trantor/utils/Logger.cc index b86d8ccc..56f0349a 100644 --- a/trantor/utils/Logger.cc +++ b/trantor/utils/Logger.cc @@ -25,7 +25,28 @@ #elif defined _WIN32 #include #endif - +#ifdef TRANTOR_SPDLOG_SUPPORT +#include +#include +#include +#endif +#if (__cplusplus >= 201703L) || \ + (defined(_MSVC_LANG) && \ + (_MSVC_LANG >= \ + 201703L)) // c++17 - the _MSVC_LANG extra check can be + // removed if the support for VS2015 & 2017 is dropped +#include +#else +namespace std +{ +inline trantor::Logger::LogLevel clamp(trantor::Logger::LogLevel v, + trantor::Logger::LogLevel min, + trantor::Logger::LogLevel max) +{ + return (v < min) ? min : (v > max) ? max : v; +} +} // namespace std +#endif #if defined __FreeBSD__ #include #endif @@ -67,6 +88,7 @@ inline LogStream &operator<<(LogStream &s, const Logger::SourceFile &v) s.append(v.data_, v.size_); return s; } + } // namespace trantor using namespace trantor; @@ -88,24 +110,51 @@ void Logger::formatTime() if (now != lastSecond_) { lastSecond_ = now; + if (displayLocalTime_()) + { #ifndef _MSC_VER - strncpy(lastTimeString_, - date_.toFormattedString(false).c_str(), + strncpy(lastTimeString_, + date_.toFormattedStringLocal(false).c_str(), + sizeof(lastTimeString_) - 1); +#else + strncpy_s( + lastTimeString_, + date_.toFormattedStringLocal(false).c_str(), sizeof(lastTimeString_) - 1); +#endif + } + else + { +#ifndef _MSC_VER + strncpy(lastTimeString_, + date_.toFormattedString(false).c_str(), + sizeof(lastTimeString_) - 1); #else - strncpy_s( - lastTimeString_, - date_.toFormattedString(false).c_str(), - sizeof(lastTimeString_) - 1); + strncpy_s( + lastTimeString_, + date_.toFormattedString(false).c_str(), + sizeof(lastTimeString_) - 1); #endif + } } logStream_ << T(lastTimeString_, 17); char tmp[32]; - snprintf(tmp, - sizeof(tmp), - ".%06llu UTC ", - static_cast(microSec)); - logStream_ << T(tmp, 12); + if (displayLocalTime_()) + { + snprintf(tmp, + sizeof(tmp), + ".%06llu ", + static_cast(microSec)); + logStream_ << T(tmp, 8); + } + else + { + snprintf(tmp, + sizeof(tmp), + ".%06llu UTC ", + static_cast(microSec)); + logStream_ << T(tmp, 12); + } #ifdef __linux__ if (threadId_ == 0) threadId_ = static_cast(::syscall(SYS_gettid)); @@ -147,36 +196,245 @@ static const char *logLevelStr[Logger::LogLevel::kNumberOfLogLevels] = { " ERROR ", " FATAL ", }; + Logger::Logger(SourceFile file, int line) : sourceFile_(file), fileLine_(line), level_(kInfo) { formatTime(); logStream_ << T(logLevelStr[level_], 7); +#ifdef TRANTOR_SPDLOG_SUPPORT + spdLogMessageOffset_ = logStream_.bufferLength(); +#endif } Logger::Logger(SourceFile file, int line, LogLevel level) - : sourceFile_(file), fileLine_(line), level_(level) + : sourceFile_(file), + fileLine_(line), + level_(std::clamp(level, kTrace, kFatal)) { formatTime(); logStream_ << T(logLevelStr[level_], 7); +#ifdef TRANTOR_SPDLOG_SUPPORT + spdLogMessageOffset_ = logStream_.bufferLength(); +#endif } Logger::Logger(SourceFile file, int line, LogLevel level, const char *func) - : sourceFile_(file), fileLine_(line), level_(level) + : sourceFile_(file), + fileLine_(line), + level_(std::clamp(level, kTrace, kFatal)) +#ifdef TRANTOR_SPDLOG_SUPPORT + , + func_(func) +#endif { formatTime(); logStream_ << T(logLevelStr[level_], 7) << "[" << func << "] "; +#ifdef TRANTOR_SPDLOG_SUPPORT + spdLogMessageOffset_ = logStream_.bufferLength(); +#endif } Logger::Logger(SourceFile file, int line, bool) : sourceFile_(file), fileLine_(line), level_(kFatal) { formatTime(); logStream_ << T(logLevelStr[level_], 7); +#ifdef TRANTOR_SPDLOG_SUPPORT + spdLogMessageOffset_ = logStream_.bufferLength(); +#endif if (errno != 0) { logStream_ << strerror_tl(errno) << " (errno=" << errno << ") "; } } + +// LOG_COMPACT +Logger::Logger() : level_(kInfo) +{ + formatTime(); + logStream_ << T(logLevelStr[level_], 7); +#ifdef TRANTOR_SPDLOG_SUPPORT + spdLogMessageOffset_ = logStream_.bufferLength(); +#endif +} +Logger::Logger(LogLevel level) : level_(std::clamp(level, kTrace, kFatal)) +{ + formatTime(); + logStream_ << T(logLevelStr[level_], 7); +#ifdef TRANTOR_SPDLOG_SUPPORT + spdLogMessageOffset_ = logStream_.bufferLength(); +#endif +} +Logger::Logger(bool) : level_(kFatal) +{ + formatTime(); + logStream_ << T(logLevelStr[level_], 7); +#ifdef TRANTOR_SPDLOG_SUPPORT + spdLogMessageOffset_ = logStream_.bufferLength(); +#endif + if (errno != 0) + { + logStream_ << strerror_tl(errno) << " (errno=" << errno << ") "; + } +} + +bool Logger::hasSpdLogSupport() +{ +#ifdef TRANTOR_SPDLOG_SUPPORT + return true; +#else + return false; +#endif +} + +#ifdef TRANTOR_SPDLOG_SUPPORT +// Helper for uniform naming +static std::string defaultSpdLoggerName(int index) +{ + using namespace std::literals::string_literals; + std::string loggerName = "trantor"s; + if (index >= 0) + loggerName.append(std::to_string(index)); + return loggerName; +} +// a map with int keys is more efficient than spdlog internal registry based on +// strings (logger name) +static std::map> spdLoggers; +// same sinks, but the format pattern is only "%v", for LOG_RAW[_TO] +static std::map> rawSpdLoggers; +static std::mutex spdLoggersMtx; +#endif // TRANTOR_SPDLOG_SUPPORT + +std::shared_ptr Logger::getDefaultSpdLogger(int index) +{ +#ifdef TRANTOR_SPDLOG_SUPPORT + auto loggerName = defaultSpdLoggerName(index); + auto logger = spdlog::get(loggerName); + if (logger) + return logger; + // Create a new spdlog logger with the same sinks as the current default + // Logger or spdlog logger + auto &sinks = + ((spdLoggers.begin() != spdLoggers.end() ? spdLoggers.begin()->second + : spdlog::default_logger())) + ->sinks(); + logger = std::make_shared(loggerName, + sinks.begin(), + sinks.end()); + // keep a log format similar to the existing one, but with coloured + // level on console since it's nice :) + // see reference: https://github.com/gabime/spdlog/wiki/3.-Custom-formatting + logger->set_pattern("%Y%m%d %T.%f %6t %^%=8l%$ [%!] %v - %s:%#"); + // the filtering is done at Logger level, so no need to filter here + logger->set_level(spdlog::level::trace); + logger->flush_on(spdlog::level::err); + spdlog::register_logger(logger); + + return logger; +#else + (void)index; + return {}; +#endif // TRANTOR_SPDLOG_SUPPORT +} + +std::shared_ptr Logger::getSpdLogger(int index) +{ +#ifdef TRANTOR_SPDLOG_SUPPORT + std::lock_guard lck(spdLoggersMtx); + auto it = spdLoggers.find((index < 0) ? -1 : index); + return (it == spdLoggers.end()) ? std::shared_ptr() + : it->second; +#else + (void)index; + return {}; +#endif // TRANTOR_SPDLOG_SUPPORT +} + +#ifdef TRANTOR_SPDLOG_SUPPORT +static std::shared_ptr getRawSpdLogger(int index) +{ + // Create/delete RAW logger on-the fly + // drawback: changes to the main logger's level or sinks won't be + // reflected in the raw logger once it's created + if (index < -1) + index = -1; + std::lock_guard lck(spdLoggersMtx); + auto itMain = spdLoggers.find(index); + auto itRaw = rawSpdLoggers.find(index); + if (itMain == spdLoggers.end()) + { + if (itRaw != rawSpdLoggers.end()) + { + spdlog::drop(itRaw->second->name()); + rawSpdLoggers.erase(itRaw); + } + return {}; + } + auto mainLogger = itMain->second; + if (itRaw != rawSpdLoggers.end()) + return itRaw->second; + auto rawLogger = + std::make_shared(mainLogger->name() + "_raw", + mainLogger->sinks().begin(), + mainLogger->sinks().end()); + rawLogger->set_pattern("%v"); + rawLogger->set_level(mainLogger->level()); + rawLogger->flush_on(mainLogger->flush_level()); + rawSpdLoggers[index] = rawLogger; + spdlog::register_logger(rawLogger); + return rawLogger; +} +#endif // TRANTOR_SPDLOG_SUPPORT + +void Logger::enableSpdLog(int index, std::shared_ptr logger) +{ +#ifdef TRANTOR_SPDLOG_SUPPORT + if (index < -1) + index = -1; + std::lock_guard lck(spdLoggersMtx); + spdLoggers[index] = logger ? logger : getDefaultSpdLogger(index); +#else + (void)index; + (void)logger; +#endif // TRANTOR_SPDLOG_SUPPORT +} + +void Logger::disableSpdLog(int index) +{ +#ifdef TRANTOR_SPDLOG_SUPPORT + std::lock_guard lck(spdLoggersMtx); + if (index < -1) + index = -1; + auto it = spdLoggers.find(index); + if (it == spdLoggers.end()) + return; + // auto-unregister + if (it->second->name() == defaultSpdLoggerName(index)) + spdlog::drop(it->second->name()); + spdLoggers.erase(it); +#else + (void)index; +#endif // TRANTOR_SPDLOG_SUPPORT +} + RawLogger::~RawLogger() { +#ifdef TRANTOR_SPDLOG_SUPPORT + auto logger = getRawSpdLogger(index_); + if (logger) + { + // The only way to be fully compatible with the existing non-spdlog RAW + // mode (dumping raw without adding a '\n') would be to store the + // concatenated messages along the logger, and pass the complete message + // to spdlog only when it ends with '\n'. + // But it's overkill... + // For now, just remove the trailing '\n', if any, since spdlog + // automatically adds one. + auto msglen = logStream_.bufferLength(); + if ((msglen > 0) && (logStream_.bufferData()[msglen - 1] == '\n')) + msglen--; + logger->info(spdlog::string_view_t(logStream_.bufferData(), msglen)); + return; + } +#endif if (index_ < 0) { auto &oFunc = Logger::outputFunc_(); @@ -192,9 +450,36 @@ RawLogger::~RawLogger() oFunc(logStream_.bufferData(), logStream_.bufferLength()); } } + Logger::~Logger() { - logStream_ << T(" - ", 3) << sourceFile_ << ':' << fileLine_ << '\n'; +#ifdef TRANTOR_SPDLOG_SUPPORT + auto spdLogger = getSpdLogger(index_); + if (spdLogger) + { + spdlog::source_loc spdLocation; + if (sourceFile_.data_) + spdLocation = {sourceFile_.data_, fileLine_, func_ ? func_ : ""}; + spdlog::string_view_t message(logStream_.bufferData(), + logStream_.bufferLength()); + message.remove_prefix(spdLogMessageOffset_); +#if defined(SPDLOG_VERSION) && (SPDLOG_VERSION >= 10600) + spdLogger->log(std::chrono::system_clock::time_point( + std::chrono::duration( + date_.microSecondsSinceEpoch())), + spdLocation, + spdlog::level::level_enum(level_), + message); +#else // very old version, cannot specify time + spdLogger->log(spdLocation, spdlog::level::level_enum(level_), message); +#endif + return; + } +#endif // TRANTOR_SPDLOG_SUPPORT + if (sourceFile_.data_) + logStream_ << T(" - ", 3) << sourceFile_ << ':' << fileLine_ << '\n'; + else + logStream_ << '\n'; if (index_ < 0) { auto &oFunc = Logger::outputFunc_(); @@ -213,8 +498,6 @@ Logger::~Logger() if (level_ >= kError) Logger::flushFunc_(index_)(); } - - // logStream_.resetBuffer(); } LogStream &Logger::stream() { diff --git a/trantor/utils/Logger.h b/trantor/utils/Logger.h index 862186cd..16d040aa 100644 --- a/trantor/utils/Logger.h +++ b/trantor/utils/Logger.h @@ -22,6 +22,11 @@ #include #include #include +namespace spdlog +{ +class logger; +} +#include #define TRANTOR_IF_(cond) for (int _r = 0; _r == 0 && (cond); _r = 1) @@ -56,7 +61,11 @@ class TRANTOR_EXPORT Logger : public NonCopyable inline SourceFile(const char (&arr)[N]) : data_(arr), size_(N - 1) { // std::cout< + Logger(); + Logger(LogLevel level); + Logger(bool isSysErr); + ~Logger(); Logger &setIndex(int index) { @@ -94,6 +118,7 @@ class TRANTOR_EXPORT Logger : public NonCopyable * * @param outputFunc The function to output a log message. * @param flushFunc The function to flush. + * @param index The channel index. * @note Logs are output to the standard output by default. */ static void setOutputFunction( @@ -133,6 +158,111 @@ class TRANTOR_EXPORT Logger : public NonCopyable return logLevel_(); } + /** + * @brief Check whether it shows local time or UTC time. + */ + static bool displayLocalTime() + { + return displayLocalTime_(); + } + + /** + * @brief Set whether it shows local time or UTC time. the default is UTC. + */ + static void setDisplayLocalTime(bool showLocalTime) + { + displayLocalTime_() = showLocalTime; + } + + /** + * @brief Check whether trantor was build with spdlog support + * @retval true if yes + * @retval false if not - in this case, all the spdlog functions are noop + * functions + */ + static bool hasSpdLogSupport(); + /** + * @brief Enable logging with spdlog for the specified channel. + * @param index channel index (-1 = default channel). + * @param logger spdlog::logger object to use. + * If none given, defaults to getDefaultSpdLogger(@p index). + * @remarks If provided, it is not registered with the spdlog logger + * registry, it's up to you to register/drop it. + */ + static void enableSpdLog(int index, + std::shared_ptr logger = {}); + /** + * @brief Enable logging with spdlog for the default channel. + * @param logger spdlog::logger object to use. + * If none given, defaults to getDefaultSpdLogger(). + * @remarks If provided, it is not registered with the spdlog logger + * registry, it's up to you to register/drop it. + */ + inline static void enableSpdLog(std::shared_ptr logger = {}) + { + enableSpdLog(-1, logger); + } + /** + * @brief Disable logging with spdlog for the specified channel. + * @param[in] channel index (-1 = default channel). + * @remarks The spdlog::logger object is unregistered and + * destroyed only if it was created by + * getDefaultSpdLogger(@p index). + * Custom loggers are only unset. + */ + static void disableSpdLog(int index); + /** + * @brief Disable logging with spdlog for the default channel + * @remarks The spdlog::logger object is unregistered and + * destroyed only if it was created by getDefaultSpdLogger(). + * Custom loggers are only unset. + */ + static void disableSpdLog() + { + disableSpdLog(-1); + } + /** + * @brief Get the spdlog::logger set on the specified channel. + * @param[in] channel index (-1 = default channel). + * @return the logger, if set, else a null pointer. + */ + static std::shared_ptr getSpdLogger(int index = -1); + /** + * @brief Get a default spdlog::logger for the specified channel. + * @details This helper function provides a default spdlog::logger with a + * similar output format as the existing non-spdlog trantor::Logger + * format. + * + * If a default logger was already created for the channel, it is + * returned as-is. + * + * Otherwise, a new spdlog::logger object named "trantor" (for + * index < 0) or "trantor" is created, registered with + * spdlog, and configured as follows: + * - it has the same sinks as the lowest (index) enabled channel, + * or those of the spdlog::default_logger(), which by defaults + * outputs to stdout (spdlog::sinks::stdout_color_mt), + * - its format pattern is set to resemble to the existing + * non-spdlog trantor::Logger format + * ("%Y%m%d %T.%f %6t %^%=8l%$ [%!] %v - %s:%#"), + * - the logging level is set to unfiltered (spdlog::level::trace) + * since the internal trantor/drogon level filtering is still + * managed by trantor:::Logger, + * - the flush level is set to spdlog::level::error. + * @note To add custom sinks to all the channels, you can do that this way: + * -# (optional) add your sinks to spdlog::default_logger(), + * -# create the default logger for the default channel using + * getDefaultSpdLogger(-1), + * -# if not done at step 1., add your sinks to this logger, + * -# enable the logger with enableSpdLog(), + * -# for the other channels, invoke enableSpdLog(index). + * @remarks The created spdlog::logger is automatically registered + * with the spdlog logger registry. + * @param[in] channel index (-1 = default channel). + * @return the default spdlog logger for the channel. + */ + static std::shared_ptr getDefaultSpdLogger(int index); + protected: static void defaultOutputFunction(const char *msg, const uint64_t len) { @@ -143,6 +273,12 @@ class TRANTOR_EXPORT Logger : public NonCopyable fflush(stdout); } void formatTime(); + static bool &displayLocalTime_() + { + static bool showLocalTime = false; + return showLocalTime; + } + static LogLevel &logLevel_() { #ifdef RELEASE @@ -152,8 +288,8 @@ class TRANTOR_EXPORT Logger : public NonCopyable #endif return logLevel; } - static std::function - &outputFunc_() + static std::function & + outputFunc_() { static std::function outputFunc = Logger::defaultOutputFunction; @@ -164,8 +300,8 @@ class TRANTOR_EXPORT Logger : public NonCopyable static std::function flushFunc = Logger::defaultFlushFunction; return flushFunc; } - static std::function - &outputFunc_(size_t index) + static std::function & + outputFunc_(size_t index) { static std::vector< std::function> @@ -200,6 +336,8 @@ class TRANTOR_EXPORT Logger : public NonCopyable int fileLine_; LogLevel level_; int index_{-1}; + const char *func_{nullptr}; + std::size_t spdLogMessageOffset_{0}; }; class TRANTOR_EXPORT RawLogger : public NonCopyable { @@ -274,6 +412,33 @@ class TRANTOR_EXPORT RawLogger : public NonCopyable #define LOG_SYSERR_TO(index) \ trantor::Logger(__FILE__, __LINE__, true).setIndex(index).stream() +// LOG_COMPACT_... begin block +#define LOG_COMPACT_DEBUG \ + TRANTOR_IF_(trantor::Logger::logLevel() <= trantor::Logger::kDebug) \ + trantor::Logger(trantor::Logger::kDebug).stream() +#define LOG_COMPACT_DEBUG_TO(index) \ + TRANTOR_IF_(trantor::Logger::logLevel() <= trantor::Logger::kDebug) \ + trantor::Logger(trantor::Logger::kDebug).setIndex(index).stream() +#define LOG_COMPACT_INFO \ + TRANTOR_IF_(trantor::Logger::logLevel() <= trantor::Logger::kInfo) \ + trantor::Logger().stream() +#define LOG_COMPACT_INFO_TO(index) \ + TRANTOR_IF_(trantor::Logger::logLevel() <= trantor::Logger::kInfo) \ + trantor::Logger().setIndex(index).stream() +#define LOG_COMPACT_WARN trantor::Logger(trantor::Logger::kWarn).stream() +#define LOG_COMPACT_WARN_TO(index) \ + trantor::Logger(trantor::Logger::kWarn).setIndex(index).stream() +#define LOG_COMPACT_ERROR trantor::Logger(trantor::Logger::kError).stream() +#define LOG_COMPACT_ERROR_TO(index) \ + trantor::Logger(trantor::Logger::kError).setIndex(index).stream() +#define LOG_COMPACT_FATAL trantor::Logger(trantor::Logger::kFatal).stream() +#define LOG_COMPACT_FATAL_TO(index) \ + trantor::Logger(trantor::Logger::kFatal).setIndex(index).stream() +#define LOG_COMPACT_SYSERR trantor::Logger(true).stream() +#define LOG_COMPACT_SYSERR_TO(index) \ + trantor::Logger(true).setIndex(index).stream() +// LOG_COMPACT_... end block + #define LOG_RAW trantor::RawLogger().stream() #define LOG_RAW_TO(index) trantor::RawLogger().setIndex(index).stream() diff --git a/trantor/utils/MsgBuffer.h b/trantor/utils/MsgBuffer.h index 19d42e27..4e8600cf 100644 --- a/trantor/utils/MsgBuffer.h +++ b/trantor/utils/MsgBuffer.h @@ -21,7 +21,8 @@ #include #include #include -#ifdef _WIN32 +#include +#if defined(_WIN32) && !defined(_SSIZE_T_DEFINED) using ssize_t = std::intptr_t; #endif @@ -206,7 +207,7 @@ class TRANTOR_EXPORT MsgBuffer void appendInt32(const uint32_t i); /** - * @brief Appaend a unsigned int64 value to the end of the buffer. + * @brief Append a unsigned int64 value to the end of the buffer. * * @param l */ diff --git a/trantor/utils/SerialTaskQueue.cc b/trantor/utils/SerialTaskQueue.cc index 4b989ccb..71a7103a 100644 --- a/trantor/utils/SerialTaskQueue.cc +++ b/trantor/utils/SerialTaskQueue.cc @@ -20,7 +20,7 @@ namespace trantor { SerialTaskQueue::SerialTaskQueue(const std::string &name) - : queueName_(name.empty() ? "SerailTaskQueue" : name), + : queueName_(name.empty() ? "SerialTaskQueue" : name), loopThread_(queueName_) { loopThread_.run(); diff --git a/trantor/utils/SerialTaskQueue.h b/trantor/utils/SerialTaskQueue.h index 04742bb2..f0165728 100644 --- a/trantor/utils/SerialTaskQueue.h +++ b/trantor/utils/SerialTaskQueue.h @@ -58,7 +58,7 @@ class TRANTOR_EXPORT SerialTaskQueue : public TaskQueue SerialTaskQueue() = delete; /** - * @brief Construct a new serail task queue instance. + * @brief Construct a new serial task queue instance. * * @param name */ @@ -66,13 +66,28 @@ class TRANTOR_EXPORT SerialTaskQueue : public TaskQueue virtual ~SerialTaskQueue(); + /* clang-format off */ /** * @brief Check whether a task is running in the queue. * * @return true * @return false + * @deprecated Use isRunningTask instead */ + [[deprecated("Use isRunningTask instead")]] bool isRuningTask() + { + return isRunningTask(); + } + /* clang-format on */ + + /** + * @brief Check whether a task is running in the queue. + * + * @return true + * @return false + */ + bool isRunningTask() { return loopThread_.getLoop() ? loopThread_.getLoop()->isCallingFunctions() diff --git a/trantor/utils/TaskQueue.h b/trantor/utils/TaskQueue.h index fb656d5c..b5309718 100644 --- a/trantor/utils/TaskQueue.h +++ b/trantor/utils/TaskQueue.h @@ -36,7 +36,7 @@ class TaskQueue : public NonCopyable }; /** - * @brief Run a task in the queue sychronously. This means that the task is + * @brief Run a task in the queue synchronously. This means that the task is * executed before the method returns. * * @param task diff --git a/trantor/utils/TimingWheel.h b/trantor/utils/TimingWheel.h index 55f7fe03..075e59c4 100644 --- a/trantor/utils/TimingWheel.h +++ b/trantor/utils/TimingWheel.h @@ -70,7 +70,8 @@ class TRANTOR_EXPORT TimingWheel * @param bucketsNumPerWheel The number of buckets per wheel. * @note The max delay of the timing wheel is about * ticksInterval*(bucketsNumPerWheel^wheelsNum) seconds. - * @example Four wheels with 200 buckets per wheel means the timing wheel + * @note + * Example: Four wheels with 200 buckets per wheel means the timing wheel * can work with a timeout up to 200^4 seconds, about 50 years; */ TimingWheel(trantor::EventLoop *loop, diff --git a/trantor/utils/Utilities.cc b/trantor/utils/Utilities.cc index 978dac8c..665253d2 100644 --- a/trantor/utils/Utilities.cc +++ b/trantor/utils/Utilities.cc @@ -14,9 +14,12 @@ #include "Utilities.h" #ifdef _WIN32 -#include +#include +#include #include #else // _WIN32 +#include +#include #if __cplusplus < 201103L || __cplusplus >= 201703L #include #include @@ -26,6 +29,127 @@ #endif // __cplusplus #endif // _WIN32 +#if defined(USE_OPENSSL) +#include +#include +#elif defined(USE_BOTAN) +#include +#else +#include "crypto/md5.h" +#include "crypto/sha1.h" +#include "crypto/sha256.h" +#include "crypto/sha3.h" +#include "crypto/blake2.h" +#include +#include +#include +#endif + +#ifdef _MSC_VER +#include +#else +#if defined(__x86_64__) || defined(__i386__) +#include +#endif +#endif + +#include +#include +#include +#include + +#if __cplusplus < 201103L || __cplusplus >= 201703L +static std::wstring utf8Toutf16(const std::string &utf8Str) +{ + std::wstring utf16Str; + utf16Str.reserve(utf8Str.length()); // Reserve space to avoid reallocations + + for (size_t i = 0; i < utf8Str.length();) + { + wchar_t unicode_char; + + // Check the first byte + if ((utf8Str[i] & 0b10000000) == 0) + { + // Single-byte character (ASCII) + unicode_char = utf8Str[i++]; + } + else if ((utf8Str[i] & 0b11100000) == 0b11000000) + { + if (i + 1 >= utf8Str.length()) + { + // Invalid UTF-8 sequence + // Handle the error as needed + return L""; + } + // Two-byte character + unicode_char = ((utf8Str[i] & 0b00011111) << 6) | + (utf8Str[i + 1] & 0b00111111); + i += 2; + } + else if ((utf8Str[i] & 0b11110000) == 0b11100000) + { + if (i + 2 >= utf8Str.length()) + { + // Invalid UTF-8 sequence + // Handle the error as needed + return L""; + } + // Three-byte character + unicode_char = ((utf8Str[i] & 0b00001111) << 12) | + ((utf8Str[i + 1] & 0b00111111) << 6) | + (utf8Str[i + 2] & 0b00111111); + i += 3; + } + else + { + // Invalid UTF-8 sequence + // Handle the error as needed + return L""; + } + + utf16Str.push_back(unicode_char); + } + + return utf16Str; +} + +static std::string utf16Toutf8(const std::wstring &utf16Str) +{ + std::string utf8Str; + utf8Str.reserve(utf16Str.length() * 3); + + for (size_t i = 0; i < utf16Str.length(); ++i) + { + wchar_t unicode_char = utf16Str[i]; + + if (unicode_char <= 0x7F) + { + // Single-byte character (ASCII) + utf8Str.push_back(static_cast(unicode_char)); + } + else if (unicode_char <= 0x7FF) + { + // Two-byte character + utf8Str.push_back( + static_cast(0xC0 | ((unicode_char >> 6) & 0x1F))); + utf8Str.push_back(static_cast(0x80 | (unicode_char & 0x3F))); + } + else + { + // Three-byte character + utf8Str.push_back( + static_cast(0xE0 | ((unicode_char >> 12) & 0x0F))); + utf8Str.push_back( + static_cast(0x80 | ((unicode_char >> 6) & 0x3F))); + utf8Str.push_back(static_cast(0x80 | (unicode_char & 0x3F))); + } + } + + return utf8Str; +} +#endif // __cplusplus + namespace trantor { namespace utils @@ -48,26 +172,12 @@ std::string toUtf8(const std::wstring &wstr) nSizeNeeded, NULL, NULL); -#else // _WIN32 -#if __cplusplus < 201103L || __cplusplus >= 201703L - // Note: Introduced in c++11 and deprecated with c++17. - // Revert to C99 code since there no replacement yet - strTo.resize(3 * wstr.length(), 0); - locale_t utf8 = newlocale(LC_ALL_MASK, "C.UTF-8", NULL); - if (!utf8) - utf8 = newlocale(LC_ALL_MASK, "C.utf-8", NULL); - if (!utf8) - utf8 = newlocale(LC_ALL_MASK, "C.UTF8", NULL); - if (!utf8) - utf8 = newlocale(LC_ALL_MASK, "C.utf8", NULL); - auto nLen = wcstombs_l(&strTo[0], wstr.c_str(), strTo.length(), utf8); - strTo.resize(nLen); - freelocale(utf8); -#else // c++11 to c++14 +#elif __cplusplus < 201103L || __cplusplus >= 201703L + strTo = utf16Toutf8(wstr); +#else // c++11 to c++14 std::wstring_convert, wchar_t> utf8conv; strTo = utf8conv.to_bytes(wstr); -#endif // __cplusplus -#endif // _WIN32 +#endif return strTo; } std::wstring fromUtf8(const std::string &str) @@ -81,22 +191,9 @@ std::wstring fromUtf8(const std::string &str) wstrTo.resize(nSizeNeeded, 0); ::MultiByteToWideChar( CP_UTF8, 0, &str[0], (int)str.size(), &wstrTo[0], nSizeNeeded); -#else // _WIN32 -#if __cplusplus < 201103L || __cplusplus >= 201703L - // Note: Introduced in c++11 and deprecated with c++17. - // Revert to C99 code since there no replacement yet - wstrTo.resize(str.length(), 0); - locale_t utf8 = newlocale(LC_ALL_MASK, "en_US.UTF-8", NULL); - if (!utf8) - utf8 = newlocale(LC_ALL_MASK, "C.utf-8", NULL); - if (!utf8) - utf8 = newlocale(LC_ALL_MASK, "C.UTF8", NULL); - if (!utf8) - utf8 = newlocale(LC_ALL_MASK, "C.utf8", NULL); - auto nLen = mbstowcs_l(&wstrTo[0], str.c_str(), wstrTo.length(), utf8); - wstrTo.resize(nLen); - freelocale(utf8); -#else // c++11 to c++14 +#elif __cplusplus < 201103L || __cplusplus >= 201703L + wstrTo = utf8Toutf16(str); +#else // c++11 to c++14 std::wstring_convert, wchar_t> utf8conv; try { @@ -105,8 +202,7 @@ std::wstring fromUtf8(const std::string &str) catch (...) // Should never fail if str valid UTF-8 { } -#endif // __cplusplus -#endif // _WIN32 +#endif return wstrTo; } @@ -241,9 +337,241 @@ bool verifySslName(const std::string &certName, const std::string &hostname) return true; } + assert(false && "This line should not be reached in verifySslName"); // should not reach return certName == hostname; } +#define STRINGIFY(x) #x +#define TOSTRING(x) STRINGIFY(x) + +std::string tlsBackend() +{ + return TOSTRING(TRANTOR_TLS_PROVIDER); +} +#undef TOSTRING +#undef STRINGIFY + +#if !defined(USE_BOTAN) && !defined(USE_OPENSSL) +Hash128 md5(const void *data, size_t len) +{ + MD5_CTX ctx; + trantor_md5_init(&ctx); + trantor_md5_update(&ctx, (const unsigned char *)data, len); + Hash128 hash; + trantor_md5_final(&ctx, (unsigned char *)&hash); + return hash; +} + +Hash160 sha1(const void *data, size_t len) +{ + SHA1_CTX ctx; + trantor_sha1_init(&ctx); + trantor_sha1_update(&ctx, (const unsigned char *)data, len); + Hash160 hash; + trantor_sha1_final((unsigned char *)&hash, &ctx); + return hash; +} + +Hash256 sha256(const void *data, size_t len) +{ + SHA256_CTX ctx; + trantor_sha256_init(&ctx); + trantor_sha256_update(&ctx, (const unsigned char *)data, len); + Hash256 hash; + trantor_sha256_final(&ctx, (unsigned char *)&hash); + return hash; +} + +Hash256 sha3(const void *data, size_t len) +{ + Hash256 hash; + trantor_sha3((const unsigned char *)data, len, &hash, sizeof(hash)); + return hash; +} + +Hash256 blake2b(const void *data, size_t len) +{ + Hash256 hash; + trantor_blake2b(&hash, sizeof(hash), data, len, NULL, 0); + return hash; +} +#endif + +std::string toHexString(const void *data, size_t len) +{ + std::string str; + str.resize(len * 2); + for (size_t i = 0; i < len; i++) + { + unsigned char c = ((const unsigned char *)data)[i]; + str[i * 2] = "0123456789ABCDEF"[c >> 4]; + str[i * 2 + 1] = "0123456789ABCDEF"[c & 0xf]; + } + return str; +} + +#if !defined(USE_BOTAN) && !defined(USE_OPENSSL) +/** + * @brief Generates `size` random bytes from the systems random source and + * stores them into `ptr`. + * @note We only use this we no TLS backend is available. Thus we can't piggy + * back on the TLS backend's random source. + */ +static bool systemRandomBytes(void *ptr, size_t size) +{ +#if defined(__BSD__) || defined(__APPLE__) + arc4random_buf(ptr, size); + return true; +#elif defined(__linux__) && \ + ((defined(__GLIBC__) && \ + (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 25)))) + return getentropy(ptr, size) != -1; +#elif defined(_WIN32) // Windows + return RtlGenRandom(ptr, (ULONG)size); +#elif defined(__unix__) || defined(__HAIKU__) + // fallback to /dev/urandom for other/old UNIX + thread_local std::unique_ptr > fptr( + fopen("/dev/urandom", "rb"), [](FILE *ptr) { + if (ptr != nullptr) + fclose(ptr); + }); + if (fptr == nullptr) + { + LOG_FATAL << "Failed to open /dev/urandom for randomness"; + abort(); + } + if (fread(ptr, 1, size, fptr.get()) != 0) + return true; +#endif + return false; +} +#endif + +struct RngState +{ + Hash256 secret; + Hash256 prev; + int64_t time; + uint64_t counter = 0; +}; + +bool secureRandomBytes(void *data, size_t len) +{ +#if defined(USE_OPENSSL) + // OpenSSL's RAND_bytes() uses int as the length parameter + for (size_t i = 0; i < len; i += (std::numeric_limits::max)()) + { + int fillSize = + (int)(std::min)(len - i, (size_t)(std::numeric_limits::max)()); + if (!RAND_bytes((unsigned char *)data + i, fillSize)) + return false; + } + return true; +#elif defined(USE_BOTAN) + thread_local Botan::AutoSeeded_RNG rng; + rng.randomize((unsigned char *)data, len); + return true; +#else + // If no TLS backend is used, we use a CSPRNG of our own. This makes us use + // up LESS system entropy. CSPRNG proposed by Dan Kaminsky in his DEFCON 22 + // talk. With some modifications to make it suitable for trantor's + // codebase. (RIP Dan Kaminsky. That talk was epic.) + // https://youtu.be/xneBjc8z0DE?t=2250 + namespace chrono = std::chrono; + static_assert(sizeof(RngState) < 128, + "RngState must be less then BLAKE2b's chunk size"); + + thread_local int useCount = 0; + thread_local RngState state; + static const int64_t shiftAmount = []() { + int64_t shift = 0; + if (!systemRandomBytes(&shift, sizeof(shift))) + { + // fallback to a random device. Not guaranteed to be secure + // but it's better than nothing. + shift = std::random_device{}(); + } + return shift; + }(); + // Update secret every 1024 calls to this function + if (useCount == 0) + { + if (!systemRandomBytes(&state.secret, sizeof(state.secret))) + return false; + } + useCount = (useCount + 1) % 1024; + + // use the cycle counter register to get a bit more entropy. + // Quote from the talk: "You can at least get a timestamp. And it turns out + // you just needs bits that are different. .... If you integrate time. It + // tuns out impericaly It's a pain in the butt to get two things to happen + // at the exactly the same CPU nanosecond. It's not that it can't. IT'S THAT + // IT WON'T. AND THAT'S A GOOD THING." +#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || \ + defined(_M_IX86) + state.time = __rdtsc(); +#elif defined(__aarch64__) || defined(_M_ARM64) + // IMPORTANT! ARMv8 cntvct_el0 is not a cycle counter. It's a free running + // counter that increments at 1~50MHz. 20~40x slower than the CPU. But + // hashing takes more then that. So it's still good enough. +#ifdef _MSC_VER + state.time = _ReadStatusReg(ARM64_CNTVCT_EL0); +#else + asm volatile("mrs %0, cntvct_el0" : "=r"(state.time)); +#endif +#elif defined(__riscv) && __riscv_xlen == 64 + asm volatile("rdtime %0" : "=r"(state.time)); +#elif defined(__riscv) && __riscv_xlen == 32 + uint32_t timeLo, timeHi; + asm volatile("rdtimeh %0" : "=r"(timeHi)); + asm volatile("rdtime %0" : "=r"(timeLo)); + state.time = (uint64_t)timeHi << 32 | timeLo; +#elif defined(__s390__) // both s390 and s390x + asm volatile("stck %0" : "=Q"(state.time)); +#else + auto now = chrono::steady_clock::now(); + // the proposed algorithm uses the time in nanoseconds, but we don't have a + // way to read it (yet) not C++ provided a standard way to do it. Falling + // back to milliseconds. This along with additional entropy is hopefully + // good enough. + state.time = chrono::time_point_cast(now) + .time_since_epoch() + .count(); + // `now` lives on the stack, so address in each call _may_ be different. + // This code works on both 32-bit and 64-bit systems. As well as big-endian + // and little-endian systems. + void *stack_ptr = &now; + uint32_t *stack_ptr32 = (uint32_t *)&stack_ptr; + uint32_t garbage = *stack_ptr32; + static_assert(sizeof(void *) >= sizeof(uint32_t), "pointer size too small"); + for (size_t i = 1; i < sizeof(void *) / sizeof(uint32_t); i++) + garbage ^= stack_ptr32[i]; + state.time ^= garbage; +#endif + state.time += shiftAmount; + + // generate the random data as described in the talk. We use BLAKE2b since + // it's fast and has a good security margin. + for (size_t i = 0; i < len / sizeof(Hash256); i++) + { + auto hash = blake2b(&state, sizeof(state)); + memcpy((char *)data + i * sizeof(hash), &hash, sizeof(hash)); + state.counter++; + state.prev = hash; + } + if (len % sizeof(Hash256) != 0) + { + auto hash = blake2b(&state, sizeof(state)); + memcpy((char *)data + len - len % sizeof(hash), + &hash, + len % sizeof(hash)); + state.counter++; + state.prev = hash; + } + return true; +#endif +} + } // namespace utils } // namespace trantor diff --git a/trantor/utils/Utilities.h b/trantor/utils/Utilities.h index 3e477c06..c21c2aeb 100644 --- a/trantor/utils/Utilities.h +++ b/trantor/utils/Utilities.h @@ -29,7 +29,7 @@ namespace utils * @brief Convert a wide string to a UTF-8. * @details UCS2 on Windows, UTF-32 on Linux & Mac * - * @param str String to convert + * @param wstr String to convert * * @return converted string. */ @@ -172,13 +172,122 @@ inline std::string fromNativePath(const std::wstring &strPath) } /** - * @brief Check if the name supplied by the SSL Cert matchs a FQDN + * @brief Check if the name supplied by the SSL Cert matches a FQDN * @param certName The name supplied by the SSL Cert * @param hostName The FQDN to match * * @return true if matches. false otherwise */ -bool verifySslName(const std::string &certName, const std::string &hostname); +bool verifySslName(const std::string &certName, const std::string &hostName); + +/** + * @brief Returns the TLS backend used by trantor. Could be "None", "OpenSSL" or + * "Botan" + */ +TRANTOR_EXPORT std::string tlsBackend(); + +struct Hash128 +{ + unsigned char bytes[16]; +}; + +struct Hash160 +{ + unsigned char bytes[20]; +}; + +struct Hash256 +{ + unsigned char bytes[32]; +}; + +// provide sane hash functions so users don't have to provide their own + +/** + * @brief Compute the MD5 hash of the given data + * @note don't use MD5 for new applications. It's here only for compatibility + */ +TRANTOR_EXPORT Hash128 md5(const void *data, size_t len); +inline Hash128 md5(const std::string &str) +{ + return md5(str.data(), str.size()); +} + +/** + * @brief Compute the SHA1 hash of the given data + */ +TRANTOR_EXPORT Hash160 sha1(const void *data, size_t len); +inline Hash160 sha1(const std::string &str) +{ + return sha1(str.data(), str.size()); +} + +/** + * @brief Compute the SHA256 hash of the given data + */ +TRANTOR_EXPORT Hash256 sha256(const void *data, size_t len); +inline Hash256 sha256(const std::string &str) +{ + return sha256(str.data(), str.size()); +} + +/** + * @brief Compute the SHA3 hash of the given data + */ +TRANTOR_EXPORT Hash256 sha3(const void *data, size_t len); +inline Hash256 sha3(const std::string &str) +{ + return sha3(str.data(), str.size()); +} + +/** + * @brief Compute the BLAKE2b hash of the given data + * @note When in doubt, use SHA3 or BLAKE2b. Both are safe and SHA3 is faster if + * you are using OpenSSL and it has SHA3 in hardware mode. Otherwise BLAKE2b is + * faster in software. + */ +TRANTOR_EXPORT Hash256 blake2b(const void *data, size_t len); +inline Hash256 blake2b(const std::string &str) +{ + return blake2b(str.data(), str.size()); +} + +/** + * @brief hex encode the given data + * @note When in doubt, use SHA3 or BLAKE2b. Both are safe and SHA3 is faster if + * you are using OpenSSL and it has SHA3 in hardware mode. Otherwise BLAKE2b is + * faster in software. + */ +TRANTOR_EXPORT std::string toHexString(const void *data, size_t len); +inline std::string toHexString(const Hash128 &hash) +{ + return toHexString(hash.bytes, sizeof(hash.bytes)); +} + +inline std::string toHexString(const Hash160 &hash) +{ + return toHexString(hash.bytes, sizeof(hash.bytes)); +} + +inline std::string toHexString(const Hash256 &hash) +{ + return toHexString(hash.bytes, sizeof(hash.bytes)); +} + +/** + * @brief Generates cryptographically secure random bytes + * @param ptr Pointer to the buffer to fill + * @param size Size of the buffer + * @return true if successful, false otherwise + * + * @note This function really shouldn't fail, but it's possible that + * + * - OpenSSL can't access /dev/urandom + * - Compiled with glibc that supports getentropy() but the kernel doesn't + * + * When using Botan or on *BSD/macOS, this function will always succeed. + */ +TRANTOR_EXPORT bool secureRandomBytes(void *ptr, size_t size); } // namespace utils diff --git a/trantor/utils/crypto/blake2.cc b/trantor/utils/crypto/blake2.cc new file mode 100644 index 00000000..ad759ed3 --- /dev/null +++ b/trantor/utils/crypto/blake2.cc @@ -0,0 +1,343 @@ +// Taken from https://github.com/Sachin-A/Blake2 +#include +#include +#include +#include + +/** + * The BLAKE2b initialization vectors + */ +static const uint64_t blake2b_IV[8] = {0x6a09e667f3bcc908ULL, + 0xbb67ae8584caa73bULL, + 0x3c6ef372fe94f82bULL, + 0xa54ff53a5f1d36f1ULL, + 0x510e527fade682d1ULL, + 0x9b05688c2b3e6c1fULL, + 0x1f83d9abfb41bd6bULL, + 0x5be0cd19137e2179ULL}; + +/** + * Table of permutations + */ +static const uint8_t blake2b_sigma[12][16] = { + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + {14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3}, + {11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4}, + {7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8}, + {9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13}, + {2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9}, + {12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11}, + {13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10}, + {6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5}, + {10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + {14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3}}; + +enum blake2b_constant +{ + BLAKE2B_BLOCKBYTES = 128, + BLAKE2B_OUTBYTES = 64, + BLAKE2B_KEYBYTES = 64, + BLAKE2B_SALTBYTES = 16, + BLAKE2B_PERSONALBYTES = 16 +}; + +typedef struct blake2b_param +{ + uint8_t digest_length; /* 1 */ + uint8_t key_length; /* 2 */ + uint8_t fanout; /* 3 */ + uint8_t depth; /* 4 */ + uint32_t leaf_length; /* 8 */ + uint64_t node_offset; /* 16 */ + uint8_t node_depth; /* 17 */ + uint8_t inner_length; /* 18 */ + uint8_t reserved[14]; /* 32 */ + uint8_t salt[BLAKE2B_SALTBYTES]; /* 48 */ + uint8_t personal[BLAKE2B_PERSONALBYTES]; /* 64 */ +} blake2b_param; + +typedef struct blake2b_state +{ + uint64_t h[8]; /* chained state */ + uint64_t t[2]; /* total number of bytes */ + uint64_t f[2]; /* last block flag */ + uint8_t buf[BLAKE2B_BLOCKBYTES]; /* input buffer */ + size_t buflen; /* size of buffer */ + size_t outlen; /* digest size */ +} blake2b_state; + +/** + * Helper macro to perform rotation in a 64 bit int + * + * @param[in] w original word + * @param[in] c offset to rotate by + */ +#define ROTR64(w, c) ((w) >> (c)) | ((w) << (64 - (c))) + +/** + * Helper macro to load into src 64 bytes at a time + * + * @param[in] dest the destination + * @param[in] src the source + */ +#if defined(NATIVE_LITTLE_ENDIAN) +#define LOAD64(dest, src) memcpy(&(dest), (src), sizeof(dest)) +#else +#define LOAD64(dest, src) \ + do \ + { \ + const uint8_t* load = (const uint8_t*)(src); \ + dest = ((uint64_t)(load[0]) << 0) | ((uint64_t)(load[1]) << 8) | \ + ((uint64_t)(load[2]) << 16) | ((uint64_t)(load[3]) << 24) | \ + ((uint64_t)(load[4]) << 32) | ((uint64_t)(load[5]) << 40) | \ + ((uint64_t)(load[6]) << 48) | ((uint64_t)(load[7]) << 56); \ + } while (0) +#endif + +/** + * Stores w into dst + * + * @param dst the destination + * @param[in] w word to be stored + */ +static void store64(uint8_t* dst, uint64_t w) +{ +#if defined(NATIVE_LITTLE_ENDIAN) + memcpy(dst, &w, sizeof w); +#else + uint8_t* p = dst; + + p[0] = (uint8_t)(w >> 0); + p[1] = (uint8_t)(w >> 8); + p[2] = (uint8_t)(w >> 16); + p[3] = (uint8_t)(w >> 24); + p[4] = (uint8_t)(w >> 32); + p[5] = (uint8_t)(w >> 40); + p[6] = (uint8_t)(w >> 48); + p[7] = (uint8_t)(w >> 56); +#endif +} + +/** + * Increments the blake2b state counter + * + * @param S blake2b_state instance + * @param[in] inc the increment value + */ +static void trantor_blake2b_increment_counter(blake2b_state* state, + const uint64_t inc) +{ + state->t[0] += inc; + state->t[1] += (state->t[0] < inc); +} + +/** + * The blake2b mixing function like macro mixes two 8-byte words from the + * message into the hash state + * + * @params a, b, c, d indices to 8-byte word entries from the work vector V + * @params x, y two 8-byte word entries from padded message v + */ +#define G(a, b, c, d, x, y) \ + do \ + { \ + a = a + b + x; \ + d = ROTR64(d ^ a, 32); \ + c = c + d; \ + b = ROTR64(b ^ c, 24); \ + a = a + b + y; \ + d = ROTR64(d ^ a, 16); \ + c = c + d; \ + b = ROTR64(b ^ c, 63); \ + } while (0) + +/** + * The blake2b compress function which takes a full 128-byte chunk of the + * input message and mixes it into the ongoing state array + * + * @param state blake2b_state instance + * @param block the input block + */ +static void F(blake2b_state* state, const uint8_t block[BLAKE2B_BLOCKBYTES]) +{ + size_t i, j; + uint64_t v[16], m[16], s[16]; + + for (i = 0; i < 16; ++i) + { + LOAD64(m[i], block + i * sizeof(m[i])); + } + + for (i = 0; i < 8; ++i) + { + v[i] = state->h[i]; + v[i + 8] = blake2b_IV[i]; + } + + v[12] ^= state->t[0]; + v[13] ^= state->t[1]; + v[14] ^= state->f[0]; + v[15] ^= state->f[1]; + + for (i = 0; i < 12; i++) + { + for (j = 0; j < 16; j++) + { + s[j] = blake2b_sigma[i][j]; + } + G(v[0], v[4], v[8], v[12], m[s[0]], m[s[1]]); + G(v[1], v[5], v[9], v[13], m[s[2]], m[s[3]]); + G(v[2], v[6], v[10], v[14], m[s[4]], m[s[5]]); + G(v[3], v[7], v[11], v[15], m[s[6]], m[s[7]]); + G(v[0], v[5], v[10], v[15], m[s[8]], m[s[9]]); + G(v[1], v[6], v[11], v[12], m[s[10]], m[s[11]]); + G(v[2], v[7], v[8], v[13], m[s[12]], m[s[13]]); + G(v[3], v[4], v[9], v[14], m[s[14]], m[s[15]]); + } + + for (i = 0; i < 8; i++) + { + state->h[i] = state->h[i] ^ v[i] ^ v[i + 8]; + } +} + +/** + * Updates blake2b state + * + * @param state blake2b state instance + * @param[in] input_buffer the input buffer + * @param[in] inlen the input length + */ +void trantor_blake2b_update(blake2b_state* state, + const unsigned char* input_buffer, + size_t inlen) +{ + const unsigned char* in = input_buffer; + size_t left = state->buflen; + size_t fill = BLAKE2B_BLOCKBYTES - left; + if (inlen > fill) + { + state->buflen = 0; + memcpy(state->buf + left, in, fill); + trantor_blake2b_increment_counter(state, BLAKE2B_BLOCKBYTES); + F(state, state->buf); + in += fill; + inlen -= fill; + + while (inlen > BLAKE2B_BLOCKBYTES) + { + trantor_blake2b_increment_counter(state, BLAKE2B_BLOCKBYTES); + F(state, in); + in += BLAKE2B_BLOCKBYTES; + inlen -= BLAKE2B_BLOCKBYTES; + } + } + memcpy(state->buf + state->buflen, in, inlen); + state->buflen += inlen; +} + +/** + * Initializes blake2b state + * + * @param state blake2b_state instance passed by reference + * @param[in] outlen the hash output length + */ +void trantor_blake2b_init(blake2b_state* state, + size_t outlen, + const void* key, + size_t keylen) +{ + blake2b_param P; + memset(&P, 0, sizeof(P)); + const uint8_t* p; + size_t i; + uint64_t dest; + + P.digest_length = (uint8_t)outlen; + if (keylen > 0) + { + P.key_length = (uint8_t)keylen; + } + P.fanout = 1; + P.depth = 1; + + dest = 0; + p = (const uint8_t*)(&P); + for (i = 0; i < 8; ++i) + { + state->h[i] = blake2b_IV[i]; + } + for (i = 0; i < 8; ++i) + { + LOAD64(dest, p + sizeof(state->h[i]) * i); + state->h[i] ^= dest; + } + state->outlen = P.digest_length; + + if (keylen > 0) + { + uint8_t block[BLAKE2B_BLOCKBYTES] = {0}; + memcpy(block, key, keylen); + trantor_blake2b_update(state, block, BLAKE2B_BLOCKBYTES); + memset(block, 0, BLAKE2B_BLOCKBYTES); + } +} + +/** + * Finalizes state, pads final block and stores hash + * + * @param state blake2b state instance + * @param[in] out the output buffer + * @param[in] inlen the digest size + */ + +void trantor_blake2b_final(blake2b_state* state, void* out, size_t outlen) +{ + (void)(outlen); + uint8_t buffer[BLAKE2B_OUTBYTES] = {0}; + size_t i; + + trantor_blake2b_increment_counter(state, state->buflen); + + /* set last chunk = true */ + state->f[0] = UINT64_MAX; + + /* padding */ + memset(state->buf + state->buflen, 0, BLAKE2B_BLOCKBYTES - state->buflen); + F(state, state->buf); + + /* Store back in little endian */ + for (i = 0; i < 8; ++i) + { + store64(buffer + sizeof(state->h[i]) * i, state->h[i]); + } + + /* Copy first outlen bytes into output buffer */ + memcpy(out, buffer, state->outlen); +} + +/** + * The main blake2b function + * + * @param output the hash output + * @param[in] outlen the hash length + * @param[in] input the message input + * @param[in] inlen the message length + * @param[in] key the key + * @param[in] keylen the key length + */ +void trantor_blake2b(void* output, + size_t outlen, + const void* input, + size_t inlen, + const void* key, + size_t keylen) +{ + blake2b_state state; + memset(&state, 0, sizeof(state)); + + trantor_blake2b_init(&state, outlen, key, keylen); + trantor_blake2b_update(&state, (const uint8_t*)input, inlen); + trantor_blake2b_final(&state, output, outlen); +} diff --git a/trantor/utils/crypto/blake2.h b/trantor/utils/crypto/blake2.h new file mode 100644 index 00000000..844a7e7a --- /dev/null +++ b/trantor/utils/crypto/blake2.h @@ -0,0 +1,6 @@ +void trantor_blake2b(void* output, + size_t outlen, + const void* input, + size_t inlen, + const void* key, + size_t keylen); diff --git a/trantor/utils/crypto/botan.cc b/trantor/utils/crypto/botan.cc new file mode 100644 index 00000000..32a0370d --- /dev/null +++ b/trantor/utils/crypto/botan.cc @@ -0,0 +1,58 @@ +#include +#include + +#include + +namespace trantor +{ +namespace utils +{ +Hash128 md5(const void* data, size_t len) +{ + Hash128 hash; + auto md5 = Botan::HashFunction::create("MD5"); + md5->update((const unsigned char*)data, len); + md5->final((unsigned char*)&hash); + return hash; +} + +Hash160 sha1(const void* data, size_t len) +{ + Hash160 hash; + auto sha1 = Botan::HashFunction::create("SHA-1"); + sha1->update((const unsigned char*)data, len); + sha1->final((unsigned char*)&hash); + return hash; +} + +Hash256 sha256(const void* data, size_t len) +{ + Hash256 hash; + auto sha256 = Botan::HashFunction::create("SHA-256"); + sha256->update((const unsigned char*)data, len); + sha256->final((unsigned char*)&hash); + return hash; +} + +Hash256 sha3(const void* data, size_t len) +{ + Hash256 hash; + auto sha3 = Botan::HashFunction::create("SHA-3(256)"); + assert(sha3 != nullptr); + sha3->update((const unsigned char*)data, len); + sha3->final((unsigned char*)&hash); + return hash; +} + +Hash256 blake2b(const void* data, size_t len) +{ + Hash256 hash; + auto blake2b = Botan::HashFunction::create("BLAKE2b(256)"); + assert(blake2b != nullptr); + blake2b->update((const unsigned char*)data, len); + blake2b->final((unsigned char*)&hash); + return hash; +} + +} // namespace utils +} // namespace trantor \ No newline at end of file diff --git a/trantor/utils/crypto/md5.cc b/trantor/utils/crypto/md5.cc new file mode 100644 index 00000000..d9a72bdc --- /dev/null +++ b/trantor/utils/crypto/md5.cc @@ -0,0 +1,209 @@ +/********************************************************************* +* Filename: md5.c +* Author: Brad Conte (brad AT bradconte.com) +* Copyright: +* Disclaimer: This code is presented "as is" without any guarantees. +* Details: Implementation of the MD5 hashing algorithm. + Algorithm specification can be found here: + * http://tools.ietf.org/html/rfc1321 + This implementation uses little endian byte +order. +*********************************************************************/ + +/*************************** HEADER FILES ***************************/ +#include +#include +#include "md5.h" + +/****************************** MACROS ******************************/ +#define ROTLEFT(a, b) ((a << b) | (a >> (32 - b))) + +#define F(x, y, z) ((x & y) | (~x & z)) +#define G(x, y, z) ((x & z) | (y & ~z)) +#define H(x, y, z) (x ^ y ^ z) +#define I(x, y, z) (y ^ (x | ~z)) + +#define FF(a, b, c, d, m, s, t) \ + { \ + a += F(b, c, d) + m + t; \ + a = b + ROTLEFT(a, s); \ + } +#define GG(a, b, c, d, m, s, t) \ + { \ + a += G(b, c, d) + m + t; \ + a = b + ROTLEFT(a, s); \ + } +#define HH(a, b, c, d, m, s, t) \ + { \ + a += H(b, c, d) + m + t; \ + a = b + ROTLEFT(a, s); \ + } +#define II(a, b, c, d, m, s, t) \ + { \ + a += I(b, c, d) + m + t; \ + a = b + ROTLEFT(a, s); \ + } + +/*********************** FUNCTION DEFINITIONS ***********************/ +void trantor_md5_transform(MD5_CTX *ctx, const uint8_t data[]) +{ + uint32_t a, b, c, d, m[16], i, j; + + // MD5 specifies big endian byte order, but this implementation assumes a + // little endian byte order CPU. Reverse all the bytes upon input, and + // re-reverse them on output (in md5_final()). + for (i = 0, j = 0; i < 16; ++i, j += 4) + m[i] = (data[j]) + (data[j + 1] << 8) + (data[j + 2] << 16) + + (data[j + 3] << 24); + + a = ctx->state[0]; + b = ctx->state[1]; + c = ctx->state[2]; + d = ctx->state[3]; + + FF(a, b, c, d, m[0], 7, 0xd76aa478); + FF(d, a, b, c, m[1], 12, 0xe8c7b756); + FF(c, d, a, b, m[2], 17, 0x242070db); + FF(b, c, d, a, m[3], 22, 0xc1bdceee); + FF(a, b, c, d, m[4], 7, 0xf57c0faf); + FF(d, a, b, c, m[5], 12, 0x4787c62a); + FF(c, d, a, b, m[6], 17, 0xa8304613); + FF(b, c, d, a, m[7], 22, 0xfd469501); + FF(a, b, c, d, m[8], 7, 0x698098d8); + FF(d, a, b, c, m[9], 12, 0x8b44f7af); + FF(c, d, a, b, m[10], 17, 0xffff5bb1); + FF(b, c, d, a, m[11], 22, 0x895cd7be); + FF(a, b, c, d, m[12], 7, 0x6b901122); + FF(d, a, b, c, m[13], 12, 0xfd987193); + FF(c, d, a, b, m[14], 17, 0xa679438e); + FF(b, c, d, a, m[15], 22, 0x49b40821); + + GG(a, b, c, d, m[1], 5, 0xf61e2562); + GG(d, a, b, c, m[6], 9, 0xc040b340); + GG(c, d, a, b, m[11], 14, 0x265e5a51); + GG(b, c, d, a, m[0], 20, 0xe9b6c7aa); + GG(a, b, c, d, m[5], 5, 0xd62f105d); + GG(d, a, b, c, m[10], 9, 0x02441453); + GG(c, d, a, b, m[15], 14, 0xd8a1e681); + GG(b, c, d, a, m[4], 20, 0xe7d3fbc8); + GG(a, b, c, d, m[9], 5, 0x21e1cde6); + GG(d, a, b, c, m[14], 9, 0xc33707d6); + GG(c, d, a, b, m[3], 14, 0xf4d50d87); + GG(b, c, d, a, m[8], 20, 0x455a14ed); + GG(a, b, c, d, m[13], 5, 0xa9e3e905); + GG(d, a, b, c, m[2], 9, 0xfcefa3f8); + GG(c, d, a, b, m[7], 14, 0x676f02d9); + GG(b, c, d, a, m[12], 20, 0x8d2a4c8a); + + HH(a, b, c, d, m[5], 4, 0xfffa3942); + HH(d, a, b, c, m[8], 11, 0x8771f681); + HH(c, d, a, b, m[11], 16, 0x6d9d6122); + HH(b, c, d, a, m[14], 23, 0xfde5380c); + HH(a, b, c, d, m[1], 4, 0xa4beea44); + HH(d, a, b, c, m[4], 11, 0x4bdecfa9); + HH(c, d, a, b, m[7], 16, 0xf6bb4b60); + HH(b, c, d, a, m[10], 23, 0xbebfbc70); + HH(a, b, c, d, m[13], 4, 0x289b7ec6); + HH(d, a, b, c, m[0], 11, 0xeaa127fa); + HH(c, d, a, b, m[3], 16, 0xd4ef3085); + HH(b, c, d, a, m[6], 23, 0x04881d05); + HH(a, b, c, d, m[9], 4, 0xd9d4d039); + HH(d, a, b, c, m[12], 11, 0xe6db99e5); + HH(c, d, a, b, m[15], 16, 0x1fa27cf8); + HH(b, c, d, a, m[2], 23, 0xc4ac5665); + + II(a, b, c, d, m[0], 6, 0xf4292244); + II(d, a, b, c, m[7], 10, 0x432aff97); + II(c, d, a, b, m[14], 15, 0xab9423a7); + II(b, c, d, a, m[5], 21, 0xfc93a039); + II(a, b, c, d, m[12], 6, 0x655b59c3); + II(d, a, b, c, m[3], 10, 0x8f0ccc92); + II(c, d, a, b, m[10], 15, 0xffeff47d); + II(b, c, d, a, m[1], 21, 0x85845dd1); + II(a, b, c, d, m[8], 6, 0x6fa87e4f); + II(d, a, b, c, m[15], 10, 0xfe2ce6e0); + II(c, d, a, b, m[6], 15, 0xa3014314); + II(b, c, d, a, m[13], 21, 0x4e0811a1); + II(a, b, c, d, m[4], 6, 0xf7537e82); + II(d, a, b, c, m[11], 10, 0xbd3af235); + II(c, d, a, b, m[2], 15, 0x2ad7d2bb); + II(b, c, d, a, m[9], 21, 0xeb86d391); + + ctx->state[0] += a; + ctx->state[1] += b; + ctx->state[2] += c; + ctx->state[3] += d; +} + +void trantor_md5_init(MD5_CTX *ctx) +{ + ctx->datalen = 0; + ctx->bitlen = 0; + ctx->state[0] = 0x67452301; + ctx->state[1] = 0xEFCDAB89; + ctx->state[2] = 0x98BADCFE; + ctx->state[3] = 0x10325476; +} + +void trantor_md5_update(MD5_CTX *ctx, const uint8_t data[], size_t len) +{ + size_t i; + + for (i = 0; i < len; ++i) + { + ctx->data[ctx->datalen] = data[i]; + ctx->datalen++; + if (ctx->datalen == 64) + { + trantor_md5_transform(ctx, ctx->data); + ctx->bitlen += 512; + ctx->datalen = 0; + } + } +} + +void trantor_md5_final(MD5_CTX *ctx, uint8_t hash[]) +{ + size_t i; + + i = ctx->datalen; + + // Pad whatever data is left in the buffer. + if (ctx->datalen < 56) + { + ctx->data[i++] = 0x80; + while (i < 56) + ctx->data[i++] = 0x00; + } + else if (ctx->datalen >= 56) + { + ctx->data[i++] = 0x80; + while (i < 64) + ctx->data[i++] = 0x00; + trantor_md5_transform(ctx, ctx->data); + memset(ctx->data, 0, 56); + } + + // Append to the padding the total message's length in bits and transform. + ctx->bitlen += ctx->datalen * 8; + ctx->data[56] = (uint8_t)ctx->bitlen; + ctx->data[57] = (uint8_t)(ctx->bitlen >> 8); + ctx->data[58] = (uint8_t)(ctx->bitlen >> 16); + ctx->data[59] = (uint8_t)(ctx->bitlen >> 24); + ctx->data[60] = (uint8_t)(ctx->bitlen >> 32); + ctx->data[61] = (uint8_t)(ctx->bitlen >> 40); + ctx->data[62] = (uint8_t)(ctx->bitlen >> 48); + ctx->data[63] = (uint8_t)(ctx->bitlen >> 56); + trantor_md5_transform(ctx, ctx->data); + + // Since this implementation uses little endian byte ordering and MD uses + // big endian, reverse all the bytes when copying the final state to the + // output hash. + for (i = 0; i < 4; ++i) + { + hash[i] = (ctx->state[0] >> (i * 8)) & 0x000000ff; + hash[i + 4] = (ctx->state[1] >> (i * 8)) & 0x000000ff; + hash[i + 8] = (ctx->state[2] >> (i * 8)) & 0x000000ff; + hash[i + 12] = (ctx->state[3] >> (i * 8)) & 0x000000ff; + } +} \ No newline at end of file diff --git a/trantor/utils/crypto/md5.h b/trantor/utils/crypto/md5.h new file mode 100644 index 00000000..0979a75c --- /dev/null +++ b/trantor/utils/crypto/md5.h @@ -0,0 +1,32 @@ +/********************************************************************* + * Filename: md5.h + * Author: Brad Conte (brad AT bradconte.com) + * Copyright: + * Disclaimer: This code is presented "as is" without any guarantees. + * Details: Defines the API for the corresponding MD5 implementation. + *********************************************************************/ + +#ifndef MD5_H +#define MD5_H + +/*************************** HEADER FILES ***************************/ +#include +#include + +/****************************** MACROS ******************************/ +#define MD5_BLOCK_SIZE 32 // MD5 outputs a 32 byte digest + +typedef struct +{ + uint8_t data[64]; + uint32_t datalen; + uint64_t bitlen; + uint32_t state[4]; +} MD5_CTX; + +/*********************** FUNCTION DECLARATIONS **********************/ +void trantor_md5_init(MD5_CTX *ctx); +void trantor_md5_update(MD5_CTX *ctx, const uint8_t data[], size_t len); +void trantor_md5_final(MD5_CTX *ctx, uint8_t hash[]); + +#endif // MD5_H \ No newline at end of file diff --git a/trantor/utils/crypto/openssl.cc b/trantor/utils/crypto/openssl.cc new file mode 100644 index 00000000..8ac44a6f --- /dev/null +++ b/trantor/utils/crypto/openssl.cc @@ -0,0 +1,140 @@ +#include + +#include + +#if OPENSSL_VERSION_MAJOR < 3 +#include +#include +#endif + +#include "sha3.h" +#include "sha3.cc" + +// Some OpenSSL installations does not come with BLAKE2b-256 +// We use our own implementation in such case +#include "blake2.h" +#include "blake2.cc" + +namespace trantor +{ +namespace utils +{ +Hash128 md5(const void* data, size_t len) +{ +#if OPENSSL_VERSION_MAJOR >= 3 + Hash128 hash; + auto md5 = EVP_MD_fetch(nullptr, "MD5", nullptr); + auto ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(ctx, md5, nullptr); + EVP_DigestUpdate(ctx, data, len); + EVP_DigestFinal_ex(ctx, (unsigned char*)&hash, nullptr); + EVP_MD_CTX_free(ctx); + EVP_MD_free(md5); + return hash; +#else + Hash128 hash; + MD5_CTX ctx; + MD5_Init(&ctx); + MD5_Update(&ctx, data, len); + MD5_Final((unsigned char*)&hash, &ctx); + return hash; +#endif +} + +Hash160 sha1(const void* data, size_t len) +{ +#if OPENSSL_VERSION_MAJOR >= 3 + Hash160 hash; + auto sha1 = EVP_MD_fetch(nullptr, "SHA1", nullptr); + auto ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(ctx, sha1, nullptr); + EVP_DigestUpdate(ctx, data, len); + EVP_DigestFinal_ex(ctx, (unsigned char*)&hash, nullptr); + EVP_MD_CTX_free(ctx); + EVP_MD_free(sha1); + return hash; +#else + Hash160 hash; + SHA_CTX ctx; + SHA1_Init(&ctx); + SHA1_Update(&ctx, data, len); + SHA1_Final((unsigned char*)&hash, &ctx); + return hash; +#endif +} + +Hash256 sha256(const void* data, size_t len) +{ +#if OPENSSL_VERSION_MAJOR >= 3 + Hash256 hash; + auto sha256 = EVP_MD_fetch(nullptr, "SHA256", nullptr); + auto ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(ctx, sha256, nullptr); + EVP_DigestUpdate(ctx, data, len); + EVP_DigestFinal_ex(ctx, (unsigned char*)&hash, nullptr); + EVP_MD_CTX_free(ctx); + EVP_MD_free(sha256); + return hash; +#else + Hash256 hash; + SHA256_CTX ctx; + SHA256_Init(&ctx); + SHA256_Update(&ctx, data, len); + SHA256_Final((unsigned char*)&hash, &ctx); + return hash; +#endif +} + +Hash256 sha3(const void* data, size_t len) +{ + Hash256 hash; +#if OPENSSL_VERSION_MAJOR >= 3 + auto sha3 = EVP_MD_fetch(nullptr, "SHA3-256", nullptr); + if (sha3 != nullptr) + { + auto ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(ctx, sha3, nullptr); + EVP_DigestUpdate(ctx, data, len); + EVP_DigestFinal_ex(ctx, (unsigned char*)&hash, nullptr); + EVP_MD_CTX_free(ctx); + EVP_MD_free(sha3); + return hash; + } +#elif !defined(LIBRESSL_VERSION_NUMBER) + auto sha3 = EVP_sha3_256(); + if (sha3 != nullptr) + { + EVP_MD_CTX* ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(ctx, sha3, nullptr); + EVP_DigestUpdate(ctx, data, len); + EVP_DigestFinal_ex(ctx, (unsigned char*)&hash, nullptr); + EVP_MD_CTX_free(ctx); + return hash; + } +#endif + trantor_sha3((const unsigned char*)data, len, &hash, sizeof(hash)); + return hash; +} + +Hash256 blake2b(const void* data, size_t len) +{ + Hash256 hash; +#if OPENSSL_VERSION_MAJOR >= 3 + auto blake2b = EVP_MD_fetch(nullptr, "BLAKE2b-256", nullptr); + if (blake2b != nullptr) + { + auto ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(ctx, blake2b, nullptr); + EVP_DigestUpdate(ctx, data, len); + EVP_DigestFinal_ex(ctx, (unsigned char*)&hash, nullptr); + EVP_MD_CTX_free(ctx); + EVP_MD_free(blake2b); + return hash; + } +#endif + trantor_blake2b(&hash, sizeof(hash), data, len, nullptr, 0); + return hash; +} + +} // namespace utils +} // namespace trantor diff --git a/trantor/utils/crypto/sha1.cc b/trantor/utils/crypto/sha1.cc new file mode 100644 index 00000000..7ddb886f --- /dev/null +++ b/trantor/utils/crypto/sha1.cc @@ -0,0 +1,330 @@ +/* ================ sha1.c ================ */ +/* +SHA-1 in C +By Steve Reid +100% Public Domain +Last modified by Martin Chang for the Trantor project + +Test Vectors (from FIPS PUB 180-1) +"abc" + A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D +"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq" + 84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1 +A million repetitions of "a" + 34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F +*/ + +/* #define LITTLE_ENDIAN * This should be #define'd already, if true. */ +/* #define SHA1HANDSOFF * Copies data before messing with it. */ + +#define SHA1HANDSOFF + +#include +#include +#if defined(__sun) +#include "solarisfixes.h" +#endif +#include "sha1.h" + +#ifndef BYTE_ORDER +#if (BSD >= 199103) +#include +#else +#if defined(linux) || defined(__linux__) +#include +#else +#define LITTLE_ENDIAN 1234 /* least-significant byte first (vax, pc) */ +#define BIG_ENDIAN 4321 /* most-significant byte first (IBM, net) */ +#define PDP_ENDIAN 3412 /* LSB first in word, MSW first in long (pdp)*/ + +#if defined(vax) || defined(ns32000) || defined(sun386) || \ + defined(__i386__) || defined(MIPSEL) || defined(_MIPSEL) || \ + defined(BIT_ZERO_ON_RIGHT) || defined(__alpha__) || defined(__alpha) || \ + defined(__CYGWIN32__) || defined(_WIN64) || defined(_WIN32) || \ + defined(__arm64e__) || defined(__arm64__) || defined(__aarch64__) || \ + defined(__riscv) || defined(_M_ARM64) +#define BYTE_ORDER LITTLE_ENDIAN +#endif + +#if defined(sel) || defined(pyr) || defined(mc68000) || defined(sparc) || \ + defined(is68k) || defined(tahoe) || defined(ibm032) || defined(ibm370) || \ + defined(MIPSEB) || defined(_MIPSEB) || defined(_IBMR2) || defined(DGUX) || \ + defined(apollo) || defined(__convex__) || defined(_CRAY) || \ + defined(__hppa) || defined(__hp9000) || defined(__hp9000s300) || \ + defined(__hp9000s700) || defined(BIT_ZERO_ON_LEFT) || defined(m68k) || \ + defined(__sparc) || defined(__s390__) || defined(__ppc__) +#define BYTE_ORDER BIG_ENDIAN +#endif +#endif /* linux */ +#endif /* BSD */ +#endif /* BYTE_ORDER */ + +#if defined(__BYTE_ORDER) && !defined(BYTE_ORDER) +#if (__BYTE_ORDER == __LITTLE_ENDIAN) +#define BYTE_ORDER LITTLE_ENDIAN +#else +#define BYTE_ORDER BIG_ENDIAN +#endif +#endif + +#if !defined(BYTE_ORDER) || \ + (BYTE_ORDER != BIG_ENDIAN && BYTE_ORDER != LITTLE_ENDIAN && \ + BYTE_ORDER != PDP_ENDIAN) +/* you must determine what the correct bit order is for + * your compiler - the next line is an intentional error + * which will force your compiles to bomb until you fix + * the above macros. + */ +#error "Undefined or invalid BYTE_ORDER" +#endif + +#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits)))) + +/* blk0() and blk() perform the initial expand. */ +/* I got the idea of expanding during the round function from SSLeay */ +#if BYTE_ORDER == LITTLE_ENDIAN +#define blk0(i) \ + (block->l[i] = (rol(block->l[i], 24) & 0xFF00FF00) | \ + (rol(block->l[i], 8) & 0x00FF00FF)) +#elif BYTE_ORDER == BIG_ENDIAN +#define blk0(i) block->l[i] +#else +#error "Endianness not defined!" +#endif +#define blk(i) \ + (block->l[i & 15] = rol(block->l[(i + 13) & 15] ^ block->l[(i + 8) & 15] ^ \ + block->l[(i + 2) & 15] ^ block->l[i & 15], \ + 1)) + +/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */ +#define R0(v, w, x, y, z, i) \ + z += ((w & (x ^ y)) ^ y) + blk0(i) + 0x5A827999 + rol(v, 5); \ + w = rol(w, 30); +#define R1(v, w, x, y, z, i) \ + z += ((w & (x ^ y)) ^ y) + blk(i) + 0x5A827999 + rol(v, 5); \ + w = rol(w, 30); +#define R2(v, w, x, y, z, i) \ + z += (w ^ x ^ y) + blk(i) + 0x6ED9EBA1 + rol(v, 5); \ + w = rol(w, 30); +#define R3(v, w, x, y, z, i) \ + z += (((w | x) & y) | (w & x)) + blk(i) + 0x8F1BBCDC + rol(v, 5); \ + w = rol(w, 30); +#define R4(v, w, x, y, z, i) \ + z += (w ^ x ^ y) + blk(i) + 0xCA62C1D6 + rol(v, 5); \ + w = rol(w, 30); + +/* Hash a single 512-bit block. This is the core of the algorithm. */ + +void trantor_sha1_transform(uint32_t state[5], const unsigned char buffer[64]) +{ + uint32_t a, b, c, d, e; + typedef union + { + unsigned char c[64]; + uint32_t l[16]; + } CHAR64LONG16; +#ifdef SHA1HANDSOFF + CHAR64LONG16 block[1]; /* use array to appear as a pointer */ + memcpy(block, buffer, 64); +#else + /* The following had better never be used because it causes the + * pointer-to-const buffer to be cast into a pointer to non-const. + * And the result is written through. I threw a "const" in, hoping + * this will cause a diagnostic. + */ + CHAR64LONG16* block = (const CHAR64LONG16*)buffer; +#endif + /* Copy context->state[] to working vars */ + a = state[0]; + b = state[1]; + c = state[2]; + d = state[3]; + e = state[4]; + /* 4 rounds of 20 operations each. Loop unrolled. */ + R0(a, b, c, d, e, 0); + R0(e, a, b, c, d, 1); + R0(d, e, a, b, c, 2); + R0(c, d, e, a, b, 3); + R0(b, c, d, e, a, 4); + R0(a, b, c, d, e, 5); + R0(e, a, b, c, d, 6); + R0(d, e, a, b, c, 7); + R0(c, d, e, a, b, 8); + R0(b, c, d, e, a, 9); + R0(a, b, c, d, e, 10); + R0(e, a, b, c, d, 11); + R0(d, e, a, b, c, 12); + R0(c, d, e, a, b, 13); + R0(b, c, d, e, a, 14); + R0(a, b, c, d, e, 15); + R1(e, a, b, c, d, 16); + R1(d, e, a, b, c, 17); + R1(c, d, e, a, b, 18); + R1(b, c, d, e, a, 19); + R2(a, b, c, d, e, 20); + R2(e, a, b, c, d, 21); + R2(d, e, a, b, c, 22); + R2(c, d, e, a, b, 23); + R2(b, c, d, e, a, 24); + R2(a, b, c, d, e, 25); + R2(e, a, b, c, d, 26); + R2(d, e, a, b, c, 27); + R2(c, d, e, a, b, 28); + R2(b, c, d, e, a, 29); + R2(a, b, c, d, e, 30); + R2(e, a, b, c, d, 31); + R2(d, e, a, b, c, 32); + R2(c, d, e, a, b, 33); + R2(b, c, d, e, a, 34); + R2(a, b, c, d, e, 35); + R2(e, a, b, c, d, 36); + R2(d, e, a, b, c, 37); + R2(c, d, e, a, b, 38); + R2(b, c, d, e, a, 39); + R3(a, b, c, d, e, 40); + R3(e, a, b, c, d, 41); + R3(d, e, a, b, c, 42); + R3(c, d, e, a, b, 43); + R3(b, c, d, e, a, 44); + R3(a, b, c, d, e, 45); + R3(e, a, b, c, d, 46); + R3(d, e, a, b, c, 47); + R3(c, d, e, a, b, 48); + R3(b, c, d, e, a, 49); + R3(a, b, c, d, e, 50); + R3(e, a, b, c, d, 51); + R3(d, e, a, b, c, 52); + R3(c, d, e, a, b, 53); + R3(b, c, d, e, a, 54); + R3(a, b, c, d, e, 55); + R3(e, a, b, c, d, 56); + R3(d, e, a, b, c, 57); + R3(c, d, e, a, b, 58); + R3(b, c, d, e, a, 59); + R4(a, b, c, d, e, 60); + R4(e, a, b, c, d, 61); + R4(d, e, a, b, c, 62); + R4(c, d, e, a, b, 63); + R4(b, c, d, e, a, 64); + R4(a, b, c, d, e, 65); + R4(e, a, b, c, d, 66); + R4(d, e, a, b, c, 67); + R4(c, d, e, a, b, 68); + R4(b, c, d, e, a, 69); + R4(a, b, c, d, e, 70); + R4(e, a, b, c, d, 71); + R4(d, e, a, b, c, 72); + R4(c, d, e, a, b, 73); + R4(b, c, d, e, a, 74); + R4(a, b, c, d, e, 75); + R4(e, a, b, c, d, 76); + R4(d, e, a, b, c, 77); + R4(c, d, e, a, b, 78); + R4(b, c, d, e, a, 79); + /* Add the working vars back into context.state[] */ + state[0] += a; + state[1] += b; + state[2] += c; + state[3] += d; + state[4] += e; + /* Wipe variables */ + a = b = c = d = e = 0; +#ifdef SHA1HANDSOFF + memset(block, '\0', sizeof(block)); +#endif +} + +/* trantor_sha1_init - Initialize new context */ + +void trantor_sha1_init(SHA1_CTX* context) +{ + /* SHA1 initialization constants */ + context->state[0] = 0x67452301; + context->state[1] = 0xEFCDAB89; + context->state[2] = 0x98BADCFE; + context->state[3] = 0x10325476; + context->state[4] = 0xC3D2E1F0; + context->count[0] = context->count[1] = 0; +} + +/* Run your data through this. */ + +void trantor_sha1_update(SHA1_CTX* context, + const unsigned char* data, + size_t len) +{ + size_t i; + size_t j; + + j = context->count[0]; + if ((context->count[0] += len << 3) < j) + context->count[1]++; + context->count[1] += (len >> 29); + j = (j >> 3) & 63; + if ((j + len) > 63) + { + memcpy(&context->buffer[j], data, (i = 64 - j)); + trantor_sha1_transform(context->state, context->buffer); + for (; i + 63 < len; i += 64) + { + trantor_sha1_transform(context->state, &data[i]); + } + j = 0; + } + else + i = 0; + memcpy(&context->buffer[j], &data[i], len - i); +} + +/* Add padding and return the message digest. */ + +void trantor_sha1_final(unsigned char digest[20], SHA1_CTX* context) +{ + unsigned i; + unsigned char finalcount[8]; + unsigned char c; + +#if 0 /* untested "improvement" by DHR */ + /* Convert context->count to a sequence of bytes + * in finalcount. Second element first, but + * big-endian order within element. + * But we do it all backwards. + */ + unsigned char *fcp = &finalcount[8]; + + for (i = 0; i < 2; i++) + { + uint32_t t = context->count[i]; + int j; + + for (j = 0; j < 4; t >>= 8, j++) + *--fcp = (unsigned char) t + } +#else + for (i = 0; i < 8; i++) + { + finalcount[i] = (unsigned char)((context->count[(i >= 4 ? 0 : 1)] >> + ((3 - (i & 3)) * 8)) & + 255); /* Endian independent */ + } +#endif + c = 0200; + trantor_sha1_update(context, &c, 1); + while ((context->count[0] & 504) != 448) + { + c = 0000; + trantor_sha1_update(context, &c, 1); + } + trantor_sha1_update(context, + finalcount, + 8); /* Should cause a TrantorSHA1Transform() */ + for (i = 0; i < 20; i++) + { + digest[i] = + (unsigned char)((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & + 255); + } + /* Wipe variables */ + memset(context, '\0', sizeof(*context)); + memset(&finalcount, '\0', sizeof(finalcount)); +} diff --git a/trantor/utils/crypto/sha1.h b/trantor/utils/crypto/sha1.h new file mode 100644 index 00000000..f7344cb5 --- /dev/null +++ b/trantor/utils/crypto/sha1.h @@ -0,0 +1,26 @@ +/* ================ sha1.h ================ */ +/* +SHA-1 in C +By Steve Reid +100% Public Domain + +Last modified by Martin Chang for the Trantor project +*/ + +#pragma once + +#include + +typedef struct +{ + uint32_t state[5]; + size_t count[2]; + unsigned char buffer[64]; +} SHA1_CTX; + +void trantor_sha1_transform(uint32_t state[5], const unsigned char buffer[64]); +void trantor_sha1_init(SHA1_CTX* context); +void trantor_sha1_update(SHA1_CTX* context, + const unsigned char* data, + size_t len); +void trantor_sha1_final(unsigned char digest[20], SHA1_CTX* context); diff --git a/trantor/utils/crypto/sha256.cc b/trantor/utils/crypto/sha256.cc new file mode 100644 index 00000000..b5f36515 --- /dev/null +++ b/trantor/utils/crypto/sha256.cc @@ -0,0 +1,168 @@ +/********************************************************************* +* Filename: sha256.c +* Author: Brad Conte (brad AT bradconte.com) +* Copyright: +* Disclaimer: This code is presented "as is" without any guarantees. +* Details: Implementation of the SHA-256 hashing algorithm. + SHA-256 is one of the three algorithms in the SHA2 + specification. The others, SHA-384 and SHA-512, are not + offered in this implementation. + Algorithm specification can be found here: + * http://csrc.nist.gov/publications/fips/fips180-2/fips180-2withchangenotice.pdf + This implementation uses little endian byte order. +*********************************************************************/ + +/*************************** HEADER FILES ***************************/ +#include +#include +#include "sha256.h" + +/****************************** MACROS ******************************/ +#define ROTLEFT(a, b) (((a) << (b)) | ((a) >> (32 - (b)))) +#define ROTRIGHT(a, b) (((a) >> (b)) | ((a) << (32 - (b)))) + +#define CH(x, y, z) (((x) & (y)) ^ (~(x) & (z))) +#define MAJ(x, y, z) (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z))) +#define EP0(x) (ROTRIGHT(x, 2) ^ ROTRIGHT(x, 13) ^ ROTRIGHT(x, 22)) +#define EP1(x) (ROTRIGHT(x, 6) ^ ROTRIGHT(x, 11) ^ ROTRIGHT(x, 25)) +#define SIG0(x) (ROTRIGHT(x, 7) ^ ROTRIGHT(x, 18) ^ ((x) >> 3)) +#define SIG1(x) (ROTRIGHT(x, 17) ^ ROTRIGHT(x, 19) ^ ((x) >> 10)) + +/**************************** VARIABLES *****************************/ +static const uint32_t k[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, + 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, + 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, + 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, + 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, + 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, + 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, + 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2}; + +/*********************** FUNCTION DEFINITIONS ***********************/ +void trantor_sha256_transform(SHA256_CTX *ctx, const uint8_t data[]) +{ + uint32_t a, b, c, d, e, f, g, h, i, j, t1, t2, m[64]; + + for (i = 0, j = 0; i < 16; ++i, j += 4) + m[i] = (data[j] << 24) | (data[j + 1] << 16) | (data[j + 2] << 8) | + (data[j + 3]); + for (; i < 64; ++i) + m[i] = SIG1(m[i - 2]) + m[i - 7] + SIG0(m[i - 15]) + m[i - 16]; + + a = ctx->state[0]; + b = ctx->state[1]; + c = ctx->state[2]; + d = ctx->state[3]; + e = ctx->state[4]; + f = ctx->state[5]; + g = ctx->state[6]; + h = ctx->state[7]; + + for (i = 0; i < 64; ++i) + { + t1 = h + EP1(e) + CH(e, f, g) + k[i] + m[i]; + t2 = EP0(a) + MAJ(a, b, c); + h = g; + g = f; + f = e; + e = d + t1; + d = c; + c = b; + b = a; + a = t1 + t2; + } + + ctx->state[0] += a; + ctx->state[1] += b; + ctx->state[2] += c; + ctx->state[3] += d; + ctx->state[4] += e; + ctx->state[5] += f; + ctx->state[6] += g; + ctx->state[7] += h; +} + +void trantor_sha256_init(SHA256_CTX *ctx) +{ + ctx->datalen = 0; + ctx->bitlen = 0; + ctx->state[0] = 0x6a09e667; + ctx->state[1] = 0xbb67ae85; + ctx->state[2] = 0x3c6ef372; + ctx->state[3] = 0xa54ff53a; + ctx->state[4] = 0x510e527f; + ctx->state[5] = 0x9b05688c; + ctx->state[6] = 0x1f83d9ab; + ctx->state[7] = 0x5be0cd19; +} + +void trantor_sha256_update(SHA256_CTX *ctx, const uint8_t data[], size_t len) +{ + uint32_t i; + + for (i = 0; i < len; ++i) + { + ctx->data[ctx->datalen] = data[i]; + ctx->datalen++; + if (ctx->datalen == 64) + { + trantor_sha256_transform(ctx, ctx->data); + ctx->bitlen += 512; + ctx->datalen = 0; + } + } +} + +void trantor_sha256_final(SHA256_CTX *ctx, uint8_t hash[]) +{ + uint32_t i; + + i = ctx->datalen; + + // Pad whatever data is left in the buffer. + if (ctx->datalen < 56) + { + ctx->data[i++] = 0x80; + while (i < 56) + ctx->data[i++] = 0x00; + } + else + { + ctx->data[i++] = 0x80; + while (i < 64) + ctx->data[i++] = 0x00; + trantor_sha256_transform(ctx, ctx->data); + memset(ctx->data, 0, 56); + } + + // Append to the padding the total message's length in bits and transform. + ctx->bitlen += ctx->datalen * 8; + ctx->data[63] = (uint8_t)ctx->bitlen; + ctx->data[62] = (uint8_t)(ctx->bitlen >> 8); + ctx->data[61] = (uint8_t)(ctx->bitlen >> 16); + ctx->data[60] = (uint8_t)(ctx->bitlen >> 24); + ctx->data[59] = (uint8_t)(ctx->bitlen >> 32); + ctx->data[58] = (uint8_t)(ctx->bitlen >> 40); + ctx->data[57] = (uint8_t)(ctx->bitlen >> 48); + ctx->data[56] = (uint8_t)(ctx->bitlen >> 56); + trantor_sha256_transform(ctx, ctx->data); + + // Since this implementation uses little endian byte ordering and SHA uses + // big endian, reverse all the bytes when copying the final state to the + // output hash. + for (i = 0; i < 4; ++i) + { + hash[i] = (ctx->state[0] >> (24 - i * 8)) & 0x000000ff; + hash[i + 4] = (ctx->state[1] >> (24 - i * 8)) & 0x000000ff; + hash[i + 8] = (ctx->state[2] >> (24 - i * 8)) & 0x000000ff; + hash[i + 12] = (ctx->state[3] >> (24 - i * 8)) & 0x000000ff; + hash[i + 16] = (ctx->state[4] >> (24 - i * 8)) & 0x000000ff; + hash[i + 20] = (ctx->state[5] >> (24 - i * 8)) & 0x000000ff; + hash[i + 24] = (ctx->state[6] >> (24 - i * 8)) & 0x000000ff; + hash[i + 28] = (ctx->state[7] >> (24 - i * 8)) & 0x000000ff; + } +} \ No newline at end of file diff --git a/trantor/utils/crypto/sha256.h b/trantor/utils/crypto/sha256.h new file mode 100644 index 00000000..6a470faa --- /dev/null +++ b/trantor/utils/crypto/sha256.h @@ -0,0 +1,31 @@ +/********************************************************************* + * Filename: sha256.h + * Author: Brad Conte (brad AT bradconte.com) + * Copyright: + * Disclaimer: This code is presented "as is" without any guarantees. + * Details: Defines the API for the corresponding SHA1 implementation. + *********************************************************************/ + +#ifndef SHA256_H +#define SHA256_H + +/*************************** HEADER FILES ***************************/ +#include + +/****************************** MACROS ******************************/ +#define SHA256_BLOCK_SIZE 32 // SHA256 outputs a 32 byte digest + +typedef struct +{ + uint8_t data[64]; + uint32_t datalen; + uint64_t bitlen; + uint32_t state[8]; +} SHA256_CTX; + +/*********************** FUNCTION DECLARATIONS **********************/ +void trantor_sha256_init(SHA256_CTX *ctx); +void trantor_sha256_update(SHA256_CTX *ctx, const uint8_t data[], size_t len); +void trantor_sha256_final(SHA256_CTX *ctx, uint8_t hash[]); + +#endif // SHA256_H \ No newline at end of file diff --git a/trantor/utils/crypto/sha3.cc b/trantor/utils/crypto/sha3.cc new file mode 100644 index 00000000..dcb27b95 --- /dev/null +++ b/trantor/utils/crypto/sha3.cc @@ -0,0 +1,196 @@ +// sha3.c +// 19-Nov-11 Markku-Juhani O. Saarinen + +// Revised 07-Aug-15 to match with official release of FIPS PUB 202 "SHA3" +// Revised 03-Sep-15 for portability + OpenSSL - style API + +#include "sha3.h" + +// update the state with given number of rounds + +void trantor_sha3_keccakf(uint64_t st[25]) +{ + // constants + const uint64_t keccakf_rndc[24] = { + 0x0000000000000001, 0x0000000000008082, 0x800000000000808a, + 0x8000000080008000, 0x000000000000808b, 0x0000000080000001, + 0x8000000080008081, 0x8000000000008009, 0x000000000000008a, + 0x0000000000000088, 0x0000000080008009, 0x000000008000000a, + 0x000000008000808b, 0x800000000000008b, 0x8000000000008089, + 0x8000000000008003, 0x8000000000008002, 0x8000000000000080, + 0x000000000000800a, 0x800000008000000a, 0x8000000080008081, + 0x8000000000008080, 0x0000000080000001, 0x8000000080008008}; + const int keccakf_rotc[24] = {1, 3, 6, 10, 15, 21, 28, 36, + 45, 55, 2, 14, 27, 41, 56, 8, + 25, 43, 62, 18, 39, 61, 20, 44}; + const int keccakf_piln[24] = {10, 7, 11, 17, 18, 3, 5, 16, 8, 21, 24, 4, + 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1}; + + // variables + int i, j, r; + uint64_t t, bc[5]; + +#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ + uint8_t *v; + + // endianness conversion. this is redundant on little-endian targets + for (i = 0; i < 25; i++) + { + v = (uint8_t *)&st[i]; + st[i] = ((uint64_t)v[0]) | (((uint64_t)v[1]) << 8) | + (((uint64_t)v[2]) << 16) | (((uint64_t)v[3]) << 24) | + (((uint64_t)v[4]) << 32) | (((uint64_t)v[5]) << 40) | + (((uint64_t)v[6]) << 48) | (((uint64_t)v[7]) << 56); + } +#endif + + // actual iteration + for (r = 0; r < KECCAKF_ROUNDS; r++) + { + // Theta + for (i = 0; i < 5; i++) + bc[i] = st[i] ^ st[i + 5] ^ st[i + 10] ^ st[i + 15] ^ st[i + 20]; + + for (i = 0; i < 5; i++) + { + t = bc[(i + 4) % 5] ^ ROTL64(bc[(i + 1) % 5], 1); + for (j = 0; j < 25; j += 5) + st[j + i] ^= t; + } + + // Rho Pi + t = st[1]; + for (i = 0; i < 24; i++) + { + j = keccakf_piln[i]; + bc[0] = st[j]; + st[j] = ROTL64(t, keccakf_rotc[i]); + t = bc[0]; + } + + // Chi + for (j = 0; j < 25; j += 5) + { + for (i = 0; i < 5; i++) + bc[i] = st[j + i]; + for (i = 0; i < 5; i++) + st[j + i] ^= (~bc[(i + 1) % 5]) & bc[(i + 2) % 5]; + } + + // Iota + st[0] ^= keccakf_rndc[r]; + } + +#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ + // endianness conversion. this is redundant on little-endian targets + for (i = 0; i < 25; i++) + { + v = (uint8_t *)&st[i]; + t = st[i]; + v[0] = t & 0xFF; + v[1] = (t >> 8) & 0xFF; + v[2] = (t >> 16) & 0xFF; + v[3] = (t >> 24) & 0xFF; + v[4] = (t >> 32) & 0xFF; + v[5] = (t >> 40) & 0xFF; + v[6] = (t >> 48) & 0xFF; + v[7] = (t >> 56) & 0xFF; + } +#endif +} + +// Initialize the context for SHA3 + +int trantor_sha3_init(sha3_ctx_t *c, int mdlen) +{ + int i; + + for (i = 0; i < 25; i++) + c->st.q[i] = 0; + c->mdlen = mdlen; + c->rsiz = 200 - 2 * mdlen; + c->pt = 0; + + return 1; +} + +// update state with more data + +int trantor_sha3_update(sha3_ctx_t *c, const void *data, size_t len) +{ + size_t i; + int j; + + j = c->pt; + for (i = 0; i < len; i++) + { + c->st.b[j++] ^= ((const uint8_t *)data)[i]; + if (j >= c->rsiz) + { + trantor_sha3_keccakf(c->st.q); + j = 0; + } + } + c->pt = j; + + return 1; +} + +// finalize and output a hash + +int trantor_sha3_final(void *md, sha3_ctx_t *c) +{ + int i; + + c->st.b[c->pt] ^= 0x06; + c->st.b[c->rsiz - 1] ^= 0x80; + trantor_sha3_keccakf(c->st.q); + + for (i = 0; i < c->mdlen; i++) + { + ((uint8_t *)md)[i] = c->st.b[i]; + } + + return 1; +} + +// compute a SHA-3 hash (md) of given byte length from "in" + +void *trantor_sha3(const void *in, size_t inlen, void *md, int mdlen) +{ + sha3_ctx_t sha3; + + trantor_sha3_init(&sha3, mdlen); + trantor_sha3_update(&sha3, in, inlen); + trantor_sha3_final(md, &sha3); + + return md; +} + +// SHAKE128 and SHAKE256 extensible-output functionality + +void trantor_shake_xof(sha3_ctx_t *c) +{ + c->st.b[c->pt] ^= 0x1F; + c->st.b[c->rsiz - 1] ^= 0x80; + trantor_sha3_keccakf(c->st.q); + c->pt = 0; +} + +void trantor_shake_out(sha3_ctx_t *c, void *out, size_t len) +{ + size_t i; + int j; + + j = c->pt; + for (i = 0; i < len; i++) + { + if (j >= c->rsiz) + { + trantor_sha3_keccakf(c->st.q); + j = 0; + } + ((uint8_t *)out)[i] = c->st.b[j++]; + } + c->pt = j; +} diff --git a/trantor/utils/crypto/sha3.h b/trantor/utils/crypto/sha3.h new file mode 100644 index 00000000..a96f1122 --- /dev/null +++ b/trantor/utils/crypto/sha3.h @@ -0,0 +1,49 @@ +// sha3.h +// 19-Nov-11 Markku-Juhani O. Saarinen + +#ifndef SHA3_H +#define SHA3_H + +#include +#include + +#ifndef KECCAKF_ROUNDS +#define KECCAKF_ROUNDS 24 +#endif + +#ifndef ROTL64 +#define ROTL64(x, y) (((x) << (y)) | ((x) >> (64 - (y)))) +#endif + +// state context +typedef struct +{ + union + { // state: + uint8_t b[200]; // 8-bit bytes + uint64_t q[25]; // 64-bit words + } st; + int pt, rsiz, mdlen; // these don't overflow +} sha3_ctx_t; + +// Compression function. +void trantor_sha3_keccakf(uint64_t st[25]); + +// OpenSSL - like interface +int trantor_sha3_init(sha3_ctx_t *c, + int mdlen); // mdlen = hash output in bytes +int trantor_sha3_update(sha3_ctx_t *c, const void *data, size_t len); +int trantor_sha3_final(void *md, sha3_ctx_t *c); // digest goes to md + +// compute a sha3 hash (md) of given byte length from "in" +void *trantor_sha3(const void *in, size_t inlen, void *md, int mdlen); + +// SHAKE128 and SHAKE256 extensible-output functions +#define trantor_shake128_init(c) trantor_sha3_init(c, 16) +#define trantor_shake256_init(c) trantor_sha3_init(c, 32) +#define trantor_shake_update trantor_sha3_update + +void trantor_shake_xof(sha3_ctx_t *c); +void trantor_shake_out(sha3_ctx_t *c, void *out, size_t len); + +#endif