Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Support Web Identity in STSProfileCredentialsProvider.
  • Loading branch information
teo-tsirpanis committed Jan 26, 2024
commit f4ef98abaa7928709f9f0a29b466fc2fcfaba560
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ namespace Aws
* Returns the assumed role credentials or empty credentials on error.
*/
AWSCredentials GetCredentialsFromSTS(const AWSCredentials& credentials, const Aws::String& roleARN);
AWSCredentials GetCredentialsFromWebIdentity(const Config::Profile& profile);
private:
AWSCredentials GetCredentialsFromSTSInternal(const Aws::String& roleArn, Aws::STS::STSClient* client);
AWSCredentials GetCredentialsFromWebIdentityInternal(const Config::Profile& profile, Aws::STS::STSClient* client);

Aws::String m_profileName;
AWSCredentials m_credentials;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

#include <aws/identity-management/auth/STSProfileCredentialsProvider.h>
#include <aws/sts/model/AssumeRoleRequest.h>
#include <aws/sts/model/AssumeRoleWithWebIdentityRequest.h>
#include <aws/sts/STSClient.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/Outcome.h>
#include <aws/core/utils/UUID.h>

#include <fstream>
#include <utility>

using namespace Aws;
Expand Down Expand Up @@ -88,25 +90,27 @@ enum class ProfileState
Process,
SourceProfile,
SelfReferencing, // special case of SourceProfile.
RoleARNWebIdentity
};

/*
* A valid profile can be in one of the following states. Any other state is considered invalid.
+---------+-----------+-----------+--------------+
| | | | |
| Role | Source | Process | Static |
| ARN | Profile | | Credentials |
+------------------------------------------------+
| | | | |
| false | false | false | TRUE |
| | | | |
| false | false | TRUE | false |
| | | | |
| TRUE | TRUE | false | false |
| | | | |
| TRUE | TRUE | false | TRUE |
| | | | |
+---------+-----------+-----------+--------------+
+---------+-----------+-----------+--------------+------------+
| | | | | |
| Role | Source | Process | Static | Web |
| ARN | Profile | | Credentials | Identity |
+------------------------------------------------+------------+
| | | | | |
| false | false | false | TRUE | false |
| | | | | |
| false | false | TRUE | false | false |
| | | | | |
| TRUE | TRUE | false | false | false |
| | | | | |
| TRUE | TRUE | false | TRUE | false |
| | | | | |
| TRUE | false | false | false | TRUE |
+---------+-----------+-----------+--------------+------------+

*/
static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLevelProfile)
Expand All @@ -115,6 +119,7 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe
constexpr int PROCESS_CREDENTIALS = 2;
constexpr int SOURCE_PROFILE = 4;
constexpr int ROLE_ARN = 8;
constexpr int WEB_IDENTITY_TOKEN_FILE = 16;

int state = 0;

Expand All @@ -138,6 +143,11 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe
state += ROLE_ARN;
}

if (!profile.GetValue("web_identity_token_file").empty())
{
state += WEB_IDENTITY_TOKEN_FILE;
}

if (topLevelProfile)
{
switch(state)
Expand All @@ -155,6 +165,8 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe
}
// source-profile over-rule static credentials in top-level profiles (except when self-referencing)
return ProfileState::SourceProfile;
case 24: // role arn && web identity
return ProfileState::RoleARNWebIdentity;
default:
// All other cases are considered malformed configuration.
return ProfileState::Invalid;
Expand All @@ -176,6 +188,8 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe
return ProfileState::SelfReferencing;
}
return ProfileState::Static; // static credentials over-rule source-profile (except when self-referencing)
case 24: // role arn && web identity
return ProfileState::RoleARNWebIdentity;
default:
// All other cases are considered malformed configuration.
return ProfileState::Invalid;
Expand Down Expand Up @@ -302,10 +316,14 @@ void STSProfileCredentialsProvider::Reload()

while (sourceProfiles.size() > 1)
{
const auto profile = sourceProfiles.back()->second;
const auto& profile = sourceProfiles.back()->second;
sourceProfiles.pop_back();
AWSCredentials stsCreds;
if (profile.GetCredentialProcess().empty())
if (CheckProfile(profile, false /*topLevelProfile*/) == ProfileState::RoleARNWebIdentity)
{
stsCreds = GetCredentialsFromWebIdentity(profile);
}
else if (profile.GetCredentialProcess().empty())
{
assert(!profile.GetCredentials().IsEmpty());
stsCreds = profile.GetCredentials();
Expand All @@ -316,7 +334,7 @@ void STSProfileCredentialsProvider::Reload()
}

// get the role arn from the profile at the top of the stack (which hasn't been popped out yet)
const auto arn = sourceProfiles.back()->second.GetRoleArn();
const auto& arn = sourceProfiles.back()->second.GetRoleArn();
const auto& assumedCreds = GetCredentialsFromSTS(stsCreds, arn);
sourceProfiles.back()->second.SetCredentials(assumedCreds);
}
Expand Down Expand Up @@ -366,3 +384,61 @@ AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromSTS(const AWSCre
Aws::STS::STSClient stsClient {credentials};
return GetCredentialsFromSTSInternal(roleArn, &stsClient);
}

AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromWebIdentityInternal(const Config::Profile& profile, Aws::STS::STSClient* client)
{
Aws::String roleSessionName = profile.GetValue("role_session_name");
if (roleSessionName.empty())
{
roleSessionName = Aws::Utils::UUID::PseudoRandomUUID();
}

Aws::String token;
{
auto& tokenPath = profile.GetValue("web_identity_token_file");
Aws::IFStream tokenFile(tokenPath);
if (tokenFile) {
token = Aws::String(
(std::istreambuf_iterator<char>(tokenFile)),
std::istreambuf_iterator<char>());
}
else {
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Can't open token file: " << tokenPath);
return {};
}
}

using namespace Aws::STS::Model;
AssumeRoleWithWebIdentityRequest assumeRoleRequest;
assumeRoleRequest
.WithRoleArn(profile.GetRoleArn())
.WithRoleSessionName(roleSessionName)
.WithWebIdentityToken(token)
.WithDurationSeconds(static_cast<int>(std::chrono::seconds(m_duration).count()));
auto outcome = client->AssumeRoleWithWebIdentity(assumeRoleRequest);
if (outcome.IsSuccess())
{
const auto& modelCredentials = outcome.GetResult().GetCredentials();
return {modelCredentials.GetAccessKeyId(),
modelCredentials.GetSecretAccessKey(),
modelCredentials.GetSessionToken(),
modelCredentials.GetExpiration()};
}
else
{
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Failed to assume role " << profile.GetRoleArn());
}
return {};
}

AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromWebIdentity(const Config::Profile& profile)
{
using namespace Aws::STS::Model;
if (m_stsClientFactory) {
auto client = m_stsClientFactory({});
return GetCredentialsFromWebIdentityInternal(profile, client.get());
}

Aws::STS::STSClient stsClient{AWSCredentials{}};
return GetCredentialsFromWebIdentityInternal(profile, &stsClient);
}