Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ public ProxyTransport(
var handler = new HttpClientHandler
{
ServerCertificateCustomValidationCallback = (_, certificate, _, _) => certificate?.Issuer == certIssuer,
AllowAutoRedirect = false
AllowAutoRedirect = false,
UseCookies = false
};
_innerTransport = new HttpClientPipelineTransport(new HttpClient(handler));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
Expand Down Expand Up @@ -69,7 +70,8 @@ public void Intercept(IInvocation invocation)
}

Type returnType = methodInfo.ReturnType;
bool returnsSyncCollection = returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(CollectionResult<>);
bool returnsSyncCollection = (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(CollectionResult<>)) ||
returnType == typeof(CollectionResult);

try
{
Expand All @@ -93,10 +95,25 @@ public void Intercept(IInvocation invocation)
// Map IEnumerable to IAsyncEnumerable
if (returnsSyncCollection)
{
Type[] modelType = returnType.GenericTypeArguments;
Type wrapperType = typeof(SyncPageableWrapper<>).MakeGenericType(modelType);

invocation.ReturnValue = Activator.CreateInstance(wrapperType, new[] { result });
if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(CollectionResult<>))
{
// Handle generic CollectionResult<T>
Type[] modelType = returnType.GenericTypeArguments;
Type wrapperType = typeof(SyncPageableWrapper<>).MakeGenericType(modelType);
invocation.ReturnValue = Activator.CreateInstance(wrapperType, new[] { result });
}
else if (returnType == typeof(CollectionResult))
{
var collectionResult = result as CollectionResult;

if (collectionResult == null)
{
throw new InvalidOperationException("Expected CollectionResult from sync protocol method");
}

// Handle non-generic CollectionResult
invocation.ReturnValue = new SyncPageableWrapper(collectionResult);
}
}
else
{
Expand Down Expand Up @@ -134,6 +151,19 @@ private void SetAsyncResult(IInvocation invocation, Type returnType, object? res
}
}

// Handle non-generic AsyncCollectionResult case
if (methodReturnType == typeof(Task<AsyncCollectionResult>) && result is AsyncCollectionResult)
{
invocation.ReturnValue = _taskFromResultMethod?.MakeGenericMethod(typeof(AsyncCollectionResult)).Invoke(null, new[] { result });
return;
}

if (methodReturnType == typeof(ValueTask<AsyncCollectionResult>) && result is AsyncCollectionResult)
{
invocation.ReturnValue = new ValueTask<AsyncCollectionResult>((AsyncCollectionResult)result);
return;
}

throw new NotSupportedException();
}

Expand All @@ -156,7 +186,6 @@ private void SetAsyncException(IInvocation invocation, Type returnType, Exceptio
return;
}
}

throw new NotSupportedException();
}

Expand Down Expand Up @@ -307,4 +336,50 @@ public override async IAsyncEnumerable<ClientResult> GetRawPagesAsync()
return _enumerable.GetContinuationToken(page);
}
}

/// <summary>
/// Wraps a synchronous CollectionResult to provide an asynchronous
/// AsyncCollectionResult interface for testing scenarios where
/// sync methods need to be called from async method signatures.
/// </summary>
public class SyncPageableWrapper : AsyncCollectionResult
{
private readonly CollectionResult _enumerable;

/// <summary>
/// Initializes a new instance of <see cref="SyncPageableWrapper"/> for mocking scenarios.
/// </summary>
protected SyncPageableWrapper()
{
_enumerable = default!;
}

/// <summary>
/// Initializes a new instance of <see cref="SyncPageableWrapper"/> that wraps
/// the specified synchronous collection result.
/// </summary>
/// <param name="enumerable">The synchronous collection result to wrap.</param>
public SyncPageableWrapper(CollectionResult enumerable)
{
_enumerable = enumerable ?? throw new ArgumentNullException(nameof(enumerable));
}

/// <inheritdoc/>
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public override async IAsyncEnumerable<ClientResult> GetRawPagesAsync()
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
{
foreach (ClientResult page in _enumerable.GetRawPages())
{
yield return page;
}
}

/// <inheritdoc/>
public override ContinuationToken? GetContinuationToken(ClientResult page)
{
// Delegate directly to the wrapped sync collection
return _enumerable.GetContinuationToken(page);
}
}
}
Loading