Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/sdksCommon.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ list(APPEND SDK_TEST_PROJECT_LIST "s3control:tests/aws-cpp-sdk-s3control-integra
list(APPEND SDK_TEST_PROJECT_LIST "sns:tests/aws-cpp-sdk-sns-integration-tests")
list(APPEND SDK_TEST_PROJECT_LIST "sqs:tests/aws-cpp-sdk-sqs-integration-tests")
list(APPEND SDK_TEST_PROJECT_LIST "sqs:tests/aws-cpp-sdk-sqs-unit-tests")
list(APPEND SDK_TEST_PROJECT_LIST "transfer:tests/aws-cpp-sdk-transfer-tests")
list(APPEND SDK_TEST_PROJECT_LIST "transfer:tests/aws-cpp-sdk-transfer-tests,tests/aws-cpp-sdk-transfer-unit-tests")
list(APPEND SDK_TEST_PROJECT_LIST "text-to-speech:tests/aws-cpp-sdk-text-to-speech-tests,tests/aws-cpp-sdk-polly-sample")
list(APPEND SDK_TEST_PROJECT_LIST "timestream-query:tests/aws-cpp-sdk-timestream-query-unit-tests")
list(APPEND SDK_TEST_PROJECT_LIST "transcribestreaming:tests/aws-cpp-sdk-transcribestreaming-integ-tests")
Expand Down
55 changes: 54 additions & 1 deletion src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,33 @@ namespace Aws
return rangeStream.str();
}

static bool VerifyContentRange(const Aws::String& requestedRange, const Aws::String& responseContentRange)
{
if (requestedRange.empty() || responseContentRange.empty())
{
return false;
}

if (requestedRange.find("bytes=") != 0)
{
return false;
}
Aws::String requestRange = requestedRange.substr(6);

if (responseContentRange.find("bytes ") != 0)
{
return false;
}
Aws::String responseRange = responseContentRange.substr(6);
size_t slashPos = responseRange.find('/');
if (slashPos != Aws::String::npos)
{
responseRange = responseRange.substr(0, slashPos);
}

return requestRange == responseRange;
}

void TransferManager::DoSinglePartDownload(const std::shared_ptr<TransferHandle>& handle)
{
auto queuedParts = handle->GetQueuedParts();
Expand Down Expand Up @@ -1091,7 +1118,6 @@ namespace Aws
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context)
{
AWS_UNREFERENCED_PARAM(client);
AWS_UNREFERENCED_PARAM(request);

std::shared_ptr<TransferHandleAsyncContext> transferContext =
std::const_pointer_cast<TransferHandleAsyncContext>(std::static_pointer_cast<const TransferHandleAsyncContext>(context));
Expand All @@ -1110,6 +1136,33 @@ namespace Aws
}
else
{
if (request.RangeHasBeenSet())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: else if(request.RangeHasBeenSet()) { ... } is preferred to else { if { ... }}

{
const auto& requestedRange = request.GetRange();
const auto& responseContentRange = outcome.GetResult().GetContentRange();

if (responseContentRange.empty() or !VerifyContentRange(requestedRange, responseContentRange)) {
Aws::Client::AWSError<Aws::S3::S3Errors> error(Aws::S3::S3Errors::INTERNAL_FAILURE,
"ContentRangeMismatch",
"ContentRange in response does not match requested range",
false);
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId()
<< "] ContentRange mismatch. Requested: [" << requestedRange
<< "] Received: [" << responseContentRange << "]");
handle->ChangePartToFailed(partState);
handle->SetError(error);
TriggerErrorCallback(handle, error);
handle->Cancel();

if(partState->GetDownloadBuffer())
{
m_bufferManager.Release(partState->GetDownloadBuffer());
partState->SetDownloadBuffer(nullptr);
}
return;
}
}

if(handle->ShouldContinue())
{
Aws::IOStream* bufferStream = partState->GetDownloadPartStream();
Expand Down
34 changes: 34 additions & 0 deletions tests/aws-cpp-sdk-transfer-tests/TransferTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2328,6 +2328,40 @@ TEST_P(TransferTests, TransferManager_TestRelativePrefix)
}
}

TEST_P(TransferTests, TransferManager_ContentRangeVerificationTest)
{
const Aws::String RandomFileName = Aws::Utils::UUID::RandomUUID();
Aws::String testFileName = MakeFilePath(RandomFileName.c_str());
ScopedTestFile testFile(testFileName, MEDIUM_TEST_SIZE, testString);

TransferManagerConfiguration transferManagerConfig(m_executor.get());
transferManagerConfig.s3Client = m_s3Clients[GetParam()];
auto transferManager = TransferManager::Create(transferManagerConfig);

std::shared_ptr<TransferHandle> uploadPtr = transferManager->UploadFile(testFileName, GetTestBucketName(), RandomFileName, "text/plain", Aws::Map<Aws::String, Aws::String>());
uploadPtr->WaitUntilFinished();
ASSERT_EQ(TransferStatus::COMPLETED, uploadPtr->GetStatus());
ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str()));

auto downloadFileName = MakeDownloadFileName(RandomFileName);
auto createStreamFn = [=](){
#ifdef _MSC_VER
return Aws::New<Aws::FStream>(ALLOCATION_TAG, Aws::Utils::StringUtils::ToWString(downloadFileName.c_str()).c_str(), std::ios_base::out | std::ios_base::in | std::ios_base::binary | std::ios_base::trunc);
#else
return Aws::New<Aws::FStream>(ALLOCATION_TAG, downloadFileName.c_str(), std::ios_base::out | std::ios_base::in | std::ios_base::binary | std::ios_base::trunc);
#endif
};

uint64_t offset = 1024;
uint64_t partSize = 2048;
std::shared_ptr<TransferHandle> downloadPtr = transferManager->DownloadFile(GetTestBucketName(), RandomFileName, offset, partSize, createStreamFn);

downloadPtr->WaitUntilFinished();
ASSERT_EQ(TransferStatus::COMPLETED, downloadPtr->GetStatus());
ASSERT_EQ(partSize, downloadPtr->GetBytesTotalSize());
ASSERT_EQ(partSize, downloadPtr->GetBytesTransferred());
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there also be a test with a false content range to test for failures?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into creating a mock S3 response with an incorrect ContentRange to test the failure path

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 i think even as going as far to have a "transfer manager unit test" might be good. you can largely copy the S3 unit tests to create the skeleton of it.

INSTANTIATE_TEST_SUITE_P(Https, TransferTests, testing::Values(TestType::Https));
INSTANTIATE_TEST_SUITE_P(Http, TransferTests, testing::Values(TestType::Http));

Expand Down
30 changes: 30 additions & 0 deletions tests/aws-cpp-sdk-transfer-unit-tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
add_project(aws-cpp-sdk-transfer-unit-tests
"Unit Tests for the Transfer Manager"
aws-cpp-sdk-transfer
testing-resources
aws_test_main
aws-cpp-sdk-core)

add_definitions(-DRESOURCES_DIR="${CMAKE_CURRENT_SOURCE_DIR}/resources")

if(MSVC AND BUILD_SHARED_LIBS)
add_definitions(-DGTEST_LINKED_AS_SHARED_LIBRARY=1)
endif()

enable_testing()

if(PLATFORM_ANDROID AND BUILD_SHARED_LIBS)
add_library(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/TransferUnitTests.cpp)
else()
add_executable(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/TransferUnitTests.cpp)
endif()

set_compiler_flags(${PROJECT_NAME})
set_compiler_warnings(${PROJECT_NAME})

target_link_libraries(${PROJECT_NAME} ${PROJECT_LIBS})

if(MSVC AND BUILD_SHARED_LIBS)
set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS "/DELAYLOAD:aws-cpp-sdk-transfer.dll /DELAYLOAD:aws-cpp-sdk-core.dll")
target_link_libraries(${PROJECT_NAME} delayimp.lib)
endif()
113 changes: 113 additions & 0 deletions tests/aws-cpp-sdk-transfer-unit-tests/TransferUnitTests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#include <gtest/gtest.h>
#include <aws/core/Aws.h>
#include <aws/testing/AwsTestHelpers.h>
#include <aws/testing/MemoryTesting.h>

using namespace Aws;

const char* ALLOCATION_TAG = "TransferUnitTest";

// Copy the VerifyContentRange function for testing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is what we meant by saying write unit tests for this. we dont want to copy and paste code to test it, we want to mock the interfaces that make actual calls such that we can test the code we added. I think the idea here is that you would create a transfer manager with a mocked out s3 client, so that when you run the transfer manager it uses this mocked out s3 client instead of a real one, so you can verify behavior without using a real network connection. something like this

#include <aws/core/Aws.h>
#include <aws/core/utils/threading/PooledThreadExecutor.h>
#include <aws/s3/S3Client.h>
#include <aws/transfer/TransferManager.h>

using namespace Aws;
using namespace Aws::Utils;
using namespace Aws::Utils::Logging;
using namespace Aws::Utils::Threading;
using namespace Aws::Transfer;
using namespace Aws::S3;

namespace {
const char* LOGTAG = "TestApplication";

class MockS3Client : public S3Client {
 public:
  Model::CompleteMultipartUploadOutcome
  CompleteMultipartUpload(const Model::CompleteMultipartUploadRequest& request) const override {
   //TODO: implement mock
  }

  Model::CreateMultipartUploadOutcome
  CreateMultipartUpload(const Model::CreateMultipartUploadRequest& request) const override {
    //TODO: implement mock
  }

  Model::GetObjectOutcome
  GetObject(const Model::GetObjectRequest& request) const override {
    //TODO: implement mock
  }

  Model::HeadObjectOutcome
  HeadObject(const Model::HeadObjectRequest& request) const override {
    //TODO: implement mock
  }

  Model::PutObjectOutcome
  PutObject(const Model::PutObjectRequest& request) const override {
    //TODO: implement mock
  }

  Model::UploadPartOutcome
  UploadPart(const Model::UploadPartRequest& request) const override {
    //TODO: implement mock
  }
};

}

class SdkContext {
 public:
  explicit SdkContext(SDKOptions&& options) : options_(std::move(options)) { InitAPI(options_); }
  ~SdkContext() { ShutdownAPI(options_); }

 private:
  SDKOptions options_;
};

auto main() -> int {
  SDKOptions options;
  options.loggingOptions.logLevel = LogLevel::Trace;
  SdkContext context(std::move(options));

  const auto executor = Aws::MakeUnique<PooledThreadExecutor>(LOGTAG, std::thread::hardware_concurrency());
  TransferManagerConfiguration configuration{executor.get()};
  configuration.s3Client = Aws::MakeShared<MockS3Client>("MockS3Client");
  // do testing operations
  return 0;
}

but wrapped by gtest

// In production, this would be exposed in a header or made testable
static bool VerifyContentRange(const Aws::String& requestedRange, const Aws::String& responseContentRange)
{
if (requestedRange.empty() || responseContentRange.empty())
{
return false;
}

if (requestedRange.find("bytes=") != 0)
{
return false;
}
Aws::String requestRange = requestedRange.substr(6);

if (responseContentRange.find("bytes ") != 0)
{
return false;
}
Aws::String responseRange = responseContentRange.substr(6);
size_t slashPos = responseRange.find('/');
if (slashPos != Aws::String::npos)
{
responseRange = responseRange.substr(0, slashPos);
}

return requestRange == responseRange;
}

class TransferUnitTest : public testing::Test {
protected:
static void SetUpTestSuite() {
#ifdef USE_AWS_MEMORY_MANAGEMENT
_testMemorySystem.reset(new ExactTestMemorySystem(1024, 128));
_options.memoryManagementOptions.memoryManager = _testMemorySystem.get();
#endif
InitAPI(_options);
}

static void TearDownTestSuite() {
ShutdownAPI(_options);
#ifdef USE_AWS_MEMORY_MANAGEMENT
EXPECT_EQ(_testMemorySystem->GetCurrentOutstandingAllocations(), 0ULL);
EXPECT_EQ(_testMemorySystem->GetCurrentBytesAllocated(), 0ULL);
EXPECT_TRUE(_testMemorySystem->IsClean());
if (_testMemorySystem->GetCurrentOutstandingAllocations() != 0ULL)
FAIL();
if (_testMemorySystem->GetCurrentBytesAllocated() != 0ULL)
FAIL();
if (!_testMemorySystem->IsClean())
FAIL();
_testMemorySystem.reset();
#endif
}

static SDKOptions _options;
#ifdef USE_AWS_MEMORY_MANAGEMENT
static std::unique_ptr<ExactTestMemorySystem> _testMemorySystem;
#endif
};

SDKOptions TransferUnitTest::_options;
#ifdef USE_AWS_MEMORY_MANAGEMENT
std::unique_ptr<ExactTestMemorySystem> TransferUnitTest::_testMemorySystem = nullptr;
#endif

TEST_F(TransferUnitTest, VerifyContentRange_ValidRanges) {
// Test matching ranges
EXPECT_TRUE(VerifyContentRange("bytes=0-1023", "bytes 0-1023/2048"));
EXPECT_TRUE(VerifyContentRange("bytes=1024-2047", "bytes 1024-2047/2048"));
EXPECT_TRUE(VerifyContentRange("bytes=0-499", "bytes 0-499/500"));

// Test without total size in response
EXPECT_TRUE(VerifyContentRange("bytes=0-1023", "bytes 0-1023"));
}

TEST_F(TransferUnitTest, VerifyContentRange_InvalidRanges) {
// Test mismatched ranges - this is what @kai-ion wanted to test!
EXPECT_FALSE(VerifyContentRange("bytes=0-1023", "bytes 0-1022/2048"));
EXPECT_FALSE(VerifyContentRange("bytes=0-1023", "bytes 1024-2047/2048"));
EXPECT_FALSE(VerifyContentRange("bytes=1024-2047", "bytes 0-1023/2048"));

// Test empty inputs
EXPECT_FALSE(VerifyContentRange("", "bytes 0-1023/2048"));
EXPECT_FALSE(VerifyContentRange("bytes=0-1023", ""));
EXPECT_FALSE(VerifyContentRange("", ""));

// Test invalid format
EXPECT_FALSE(VerifyContentRange("0-1023", "bytes 0-1023/2048"));
EXPECT_FALSE(VerifyContentRange("bytes=0-1023", "0-1023/2048"));
EXPECT_FALSE(VerifyContentRange("range=0-1023", "bytes 0-1023/2048"));
}

TEST_F(TransferUnitTest, VerifyContentRange_EdgeCases) {
// Test single byte range
EXPECT_TRUE(VerifyContentRange("bytes=0-0", "bytes 0-0/1"));

// Test large ranges
EXPECT_TRUE(VerifyContentRange("bytes=0-1073741823", "bytes 0-1073741823/1073741824"));

// Test ranges with different total sizes (should still match the range part)
EXPECT_TRUE(VerifyContentRange("bytes=0-1023", "bytes 0-1023/5000"));
EXPECT_TRUE(VerifyContentRange("bytes=0-1023", "bytes 0-1023/1024"));
}
Loading