Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,26 @@ public async Task InitializeAsync(
this.container = containerResponse.Container;
}

/// <inheritdoc/>
public override async Task<DataEncryptionKey> FetchDataEncryptionKeyWithoutRawKeyAsync(
string id,
string encryptionAlgorithm,
CancellationToken cancellationToken)
{
return await this.FetchDekAsync(id, encryptionAlgorithm, cancellationToken);
}

/// <inheritdoc/>
public override async Task<DataEncryptionKey> FetchDataEncryptionKeyAsync(
string id,
string encryptionAlgorithm,
CancellationToken cancellationToken)
{
return await this.FetchDekAsync(id, encryptionAlgorithm, cancellationToken, true);

}

private async Task<DataEncryptionKey> FetchDekAsync(string id, string encryptionAlgorithm, CancellationToken cancellationToken, bool withRawKey = false)
{
DataEncryptionKeyProperties dataEncryptionKeyProperties = await this.dataEncryptionKeyContainerCore.FetchDataEncryptionKeyPropertiesAsync(
id,
Expand Down Expand Up @@ -200,7 +215,8 @@ public override async Task<DataEncryptionKey> FetchDataEncryptionKeyAsync(
InMemoryRawDek inMemoryRawDek = await this.dataEncryptionKeyContainerCore.FetchUnwrappedAsync(
dataEncryptionKeyProperties,
diagnosticsContext: CosmosDiagnosticsContext.Create(null),
cancellationToken: cancellationToken);
cancellationToken: cancellationToken,
withRawKey);

return inMemoryRawDek.DataEncryptionKey;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ public override async Task<byte[]> DecryptAsync(
string encryptionAlgorithm,
CancellationToken cancellationToken = default)
{
DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyAsync(
DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyWithoutRawKeyAsync(
dataEncryptionKeyId,
encryptionAlgorithm,
cancellationToken);

if (dek == null)
{
throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned from {nameof(this.DataEncryptionKeyProvider.FetchDataEncryptionKeyAsync)}.");
throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned from {nameof(this.DataEncryptionKeyProvider.FetchDataEncryptionKeyWithoutRawKeyAsync)}.");
}

return dek.DecryptData(cipherText);
Expand All @@ -55,14 +55,14 @@ public override async Task<byte[]> EncryptAsync(
string encryptionAlgorithm,
CancellationToken cancellationToken = default)
{
DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyAsync(
DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyWithoutRawKeyAsync(
dataEncryptionKeyId,
encryptionAlgorithm,
cancellationToken);

if (dek == null)
{
throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned from {nameof(this.DataEncryptionKeyProvider.FetchDataEncryptionKeyAsync)}.");
throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned from {nameof(this.DataEncryptionKeyProvider.FetchDataEncryptionKeyWithoutRawKeyAsync)}.");
}

return dek.EncryptData(plainText);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ internal async Task<DataEncryptionKey> FetchUnWrappedMdeSupportedLegacyDekAsync(
unwrapResult.DataEncryptionKey);

return new MdeEncryptionAlgorithm(
unwrapResult.DataEncryptionKey,
plaintextDataEncryptionKey,
Data.Encryption.Cryptography.EncryptionType.Randomized);
}
Expand Down Expand Up @@ -378,13 +379,14 @@ internal async Task<DataEncryptionKey> FetchUnWrappedLegacySupportedMdeDekAsync(
internal async Task<InMemoryRawDek> FetchUnwrappedAsync(
DataEncryptionKeyProperties dekProperties,
CosmosDiagnosticsContext diagnosticsContext,
CancellationToken cancellationToken)
CancellationToken cancellationToken,
bool withRawKey = false)
{
try
{
if (string.Equals(dekProperties.EncryptionAlgorithm, CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized))
{
DataEncryptionKey dek = this.InitMdeEncryptionAlgorithm(dekProperties);
DataEncryptionKey dek = this.InitMdeEncryptionAlgorithm(dekProperties, withRawKey);

// TTL is not used since DEK is not cached.
return new InMemoryRawDek(dek, TimeSpan.FromMilliseconds(0));
Expand Down Expand Up @@ -564,7 +566,7 @@ private async Task<EncryptionKeyUnwrapResult> UnWrapDekMdeEncAlgoAsync(
return unwrapResult;
}

internal DataEncryptionKey InitMdeEncryptionAlgorithm(DataEncryptionKeyProperties dekProperties)
internal DataEncryptionKey InitMdeEncryptionAlgorithm(DataEncryptionKeyProperties dekProperties, bool withRawKey = false)
{
if (this.DekProvider.MdeKeyWrapProvider == null)
{
Expand All @@ -576,7 +578,8 @@ internal DataEncryptionKey InitMdeEncryptionAlgorithm(DataEncryptionKeyPropertie
dekProperties,
Data.Encryption.Cryptography.EncryptionType.Randomized,
this.DekProvider.MdeKeyWrapProvider.EncryptionKeyStoreProvider,
this.DekProvider.PdekCacheTimeToLive);
this.DekProvider.PdekCacheTimeToLive,
withRawKey);
}

private async Task<DataEncryptionKeyProperties> ReadResourceAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,19 @@ namespace Microsoft.Azure.Cosmos.Encryption.Custom
public abstract class DataEncryptionKeyProvider
{
/// <summary>
/// Retrieves the data encryption key for the given id.
/// Retrieves the data encryption key for the given id without rawkey. RawKey will be set null.
/// </summary>
/// <param name="id">Identifier of the data encryption key.</param>
/// <param name="encryptionAlgorithm">Encryption algorithm that the retrieved key will be used with.</param>
/// <param name="cancellationToken">Token for request cancellation.</param>
/// <returns>Data encryption key bytes.</returns>
public abstract Task<DataEncryptionKey> FetchDataEncryptionKeyWithoutRawKeyAsync(
string id,
string encryptionAlgorithm,
CancellationToken cancellationToken);

/// <summary>
/// Retrieves the data encryption key for the given id with RawKey value.
/// </summary>
/// <param name="id">Identifier of the data encryption key.</param>
/// <param name="encryptionAlgorithm">Encryption algorithm that the retrieved key will be used with.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ internal sealed class MdeEncryptionAlgorithm : DataEncryptionKey
{
private readonly AeadAes256CbcHmac256EncryptionAlgorithm mdeAeadAes256CbcHmac256EncryptionAlgorithm;

private readonly byte[] unwrapKey;

// unused for MDE Algorithm.
public override byte[] RawKey => null;
public override byte[] RawKey { get; }

public override string EncryptionAlgorithm => CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized;

Expand All @@ -32,7 +34,8 @@ public MdeEncryptionAlgorithm(
DataEncryptionKeyProperties dekProperties,
Data.Encryption.Cryptography.EncryptionType encryptionType,
EncryptionKeyStoreProvider encryptionKeyStoreProvider,
TimeSpan? cacheTimeToLive)
TimeSpan? cacheTimeToLive,
bool withRawKey=false)
{
if (dekProperties == null)
{
Expand All @@ -49,36 +52,39 @@ public MdeEncryptionAlgorithm(
dekProperties.EncryptionKeyWrapMetadata.Value,
encryptionKeyStoreProvider);

ProtectedDataEncryptionKey protectedDataEncryptionKey;
if (cacheTimeToLive.HasValue)
if (!withRawKey)
{
// no caching
if (cacheTimeToLive.Value == TimeSpan.Zero)
{
protectedDataEncryptionKey = new ProtectedDataEncryptionKey(
ProtectedDataEncryptionKey protectedDataEncryptionKey = cacheTimeToLive.HasValue && cacheTimeToLive.Value == TimeSpan.Zero
? new ProtectedDataEncryptionKey(
dekProperties.Id,
keyEncryptionKey,
dekProperties.WrappedDataEncryptionKey)
: ProtectedDataEncryptionKey.GetOrCreate(
dekProperties.Id,
keyEncryptionKey,
dekProperties.WrappedDataEncryptionKey);
}
else
{
protectedDataEncryptionKey = ProtectedDataEncryptionKey.GetOrCreate(
dekProperties.Id,
keyEncryptionKey,
dekProperties.WrappedDataEncryptionKey);
}
this.mdeAeadAes256CbcHmac256EncryptionAlgorithm = AeadAes256CbcHmac256EncryptionAlgorithm.GetOrCreate(
protectedDataEncryptionKey,
encryptionType);
}
else
{
protectedDataEncryptionKey = ProtectedDataEncryptionKey.GetOrCreate(
dekProperties.Id,
keyEncryptionKey,
dekProperties.WrappedDataEncryptionKey);
byte[] rawKey = keyEncryptionKey.DecryptEncryptionKey(dekProperties.WrappedDataEncryptionKey);
PlaintextDataEncryptionKey plaintextDataEncryptionKey = cacheTimeToLive.HasValue && (cacheTimeToLive.Value == TimeSpan.Zero)
? new PlaintextDataEncryptionKey(
dekProperties.Id,
rawKey)
: PlaintextDataEncryptionKey.GetOrCreate(
dekProperties.Id,
rawKey);
this.RawKey = rawKey;
this.mdeAeadAes256CbcHmac256EncryptionAlgorithm = AeadAes256CbcHmac256EncryptionAlgorithm.GetOrCreate(
plaintextDataEncryptionKey,
encryptionType);

}

this.mdeAeadAes256CbcHmac256EncryptionAlgorithm = AeadAes256CbcHmac256EncryptionAlgorithm.GetOrCreate(
protectedDataEncryptionKey,
encryptionType);

}

/// <summary>
Expand All @@ -90,9 +96,11 @@ public MdeEncryptionAlgorithm(
/// <param name="dataEncryptionKey"> Data Encryption Key </param>
/// <param name="encryptionType"> Encryption type </param>
public MdeEncryptionAlgorithm(
byte[] rawkey,
Data.Encryption.Cryptography.DataEncryptionKey dataEncryptionKey,
Data.Encryption.Cryptography.EncryptionType encryptionType)
{
this.RawKey = rawkey;
this.mdeAeadAes256CbcHmac256EncryptionAlgorithm = AeadAes256CbcHmac256EncryptionAlgorithm.GetOrCreate(
dataEncryptionKey,
encryptionType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1747,14 +1747,14 @@ public override async Task<byte[]> DecryptAsync(
throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned.");
}

DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyAsync(
DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyWithoutRawKeyAsync(
dataEncryptionKeyId,
encryptionAlgorithm,
cancellationToken);

if (dek == null)
{
throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned from {nameof(this.DataEncryptionKeyProvider.FetchDataEncryptionKeyAsync)}.");
throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned from {nameof(this.DataEncryptionKeyProvider.FetchDataEncryptionKeyWithoutRawKeyAsync)}.");
}

return dek.DecryptData(cipherText);
Expand All @@ -1766,7 +1766,7 @@ public override async Task<byte[]> EncryptAsync(
string encryptionAlgorithm,
CancellationToken cancellationToken = default)
{
DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyAsync(
DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyWithoutRawKeyAsync(
dataEncryptionKeyId,
encryptionAlgorithm,
cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace Microsoft.Azure.Cosmos.Encryption.EmulatorTests
using EncryptionKeyWrapMetadata = Custom.EncryptionKeyWrapMetadata;
using DataEncryptionKey = Custom.DataEncryptionKey;
using Newtonsoft.Json.Linq;
using System.Buffers.Text;

[TestClass]
public class MdeCustomEncryptionTests
Expand Down Expand Up @@ -103,6 +104,47 @@ public async Task EncryptionCreateDek()
Assert.AreEqual(dekProperties, readProperties);
}

[TestMethod]
public async Task FetchDataEncryptionKeyWithRawKey()
{
CosmosDataEncryptionKeyProvider dekProvider = new CosmosDataEncryptionKeyProvider(new TestEncryptionKeyStoreProvider());
await dekProvider.InitializeAsync(MdeCustomEncryptionTests.database, MdeCustomEncryptionTests.keyContainer.Id);
DataEncryptionKey k = await dekProvider.FetchDataEncryptionKeyAsync(dekProperties.Id, dekProperties.EncryptionAlgorithm, CancellationToken.None);
Assert.IsNotNull(k.RawKey);
}

[TestMethod]
public async Task FetchDataEncryptionKeyWithoutRawKey()
{
CosmosDataEncryptionKeyProvider dekProvider = new CosmosDataEncryptionKeyProvider(new TestEncryptionKeyStoreProvider());
await dekProvider.InitializeAsync(MdeCustomEncryptionTests.database, MdeCustomEncryptionTests.keyContainer.Id);
DataEncryptionKey k = await dekProvider.FetchDataEncryptionKeyWithoutRawKeyAsync(dekProperties.Id, dekProperties.EncryptionAlgorithm, CancellationToken.None);
Assert.IsNull(k.RawKey);
}

[TestMethod]
[Obsolete]
public async Task FetchDataEncryptionKeyMdeDEKAndLegacyBasedAlgorithm()
{
CosmosDataEncryptionKeyProvider dekProvider = new CosmosDataEncryptionKeyProvider(new TestEncryptionKeyStoreProvider());
await dekProvider.InitializeAsync(MdeCustomEncryptionTests.database, MdeCustomEncryptionTests.keyContainer.Id);
DataEncryptionKey k = await dekProvider.FetchDataEncryptionKeyAsync(dekProperties.Id, CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized, CancellationToken.None);
Assert.IsNotNull(k.RawKey);
}

[TestMethod]
[Obsolete]
public async Task FetchDataEncryptionKeyLegacyDEKAndMdeBasedAlgorithm()
{
string dekId = "legacyDEK";
DataEncryptionKeyProperties dekProperties = await MdeCustomEncryptionTests.CreateDekAsync(MdeCustomEncryptionTests.dekProvider, dekId, CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized);
// Use different DEK provider to avoid (unintentional) cache impact
CosmosDataEncryptionKeyProvider dekProvider = new CosmosDataEncryptionKeyProvider(new TestKeyWrapProvider(), new TestEncryptionKeyStoreProvider());
await dekProvider.InitializeAsync(MdeCustomEncryptionTests.database, MdeCustomEncryptionTests.keyContainer.Id);
DataEncryptionKey k = await dekProvider.FetchDataEncryptionKeyAsync(dekProperties.Id, CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized, CancellationToken.None);
Assert.IsNotNull(k.RawKey);
}

[TestMethod]
public async Task EncryptionRewrapDek()
{
Expand Down Expand Up @@ -2190,14 +2232,14 @@ public override async Task<byte[]> DecryptAsync(
throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned.");
}

DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyAsync(
DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyWithoutRawKeyAsync(
dataEncryptionKeyId,
encryptionAlgorithm,
cancellationToken);

if (dek == null)
{
throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned from {nameof(this.DataEncryptionKeyProvider.FetchDataEncryptionKeyAsync)}.");
throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned from {nameof(this.DataEncryptionKeyProvider.FetchDataEncryptionKeyWithoutRawKeyAsync)}.");
}

return dek.DecryptData(cipherText);
Expand All @@ -2209,7 +2251,7 @@ public override async Task<byte[]> EncryptAsync(
string encryptionAlgorithm,
CancellationToken cancellationToken = default)
{
DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyAsync(
DataEncryptionKey dek = await this.DataEncryptionKeyProvider.FetchDataEncryptionKeyWithoutRawKeyAsync(
dataEncryptionKeyId,
encryptionAlgorithm,
cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static void ClassInitialize(TestContext testContext)

CosmosEncryptorTests.mockDataEncryptionKeyProvider = new Mock<DataEncryptionKeyProvider>();
CosmosEncryptorTests.mockDataEncryptionKeyProvider
.Setup(m => m.FetchDataEncryptionKeyAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<CancellationToken>()))
.Setup(m => m.FetchDataEncryptionKeyWithoutRawKeyAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<CancellationToken>()))
.ReturnsAsync((string dekId, string algo, CancellationToken cancellationToken) =>
dekId == CosmosEncryptorTests.dekId ? CosmosEncryptorTests.mockDataEncryptionKey.Object : null);

Expand Down Expand Up @@ -82,7 +82,7 @@ public async Task ValidateEncryptDecrypt()
Times.Once);

CosmosEncryptorTests.mockDataEncryptionKeyProvider.Verify(
m => m.FetchDataEncryptionKeyAsync(
m => m.FetchDataEncryptionKeyWithoutRawKeyAsync(
CosmosEncryptorTests.dekId,
CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized,
It.IsAny<CancellationToken>()), Times.Exactly(2));
Expand Down