Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import uk.gov.hmcts.ccd.domain.model.lau.CaseActionPostResponse;
import uk.gov.hmcts.ccd.domain.model.lau.CaseSearchPostResponse;
import uk.gov.hmcts.ccd.feign.FeignClientConfig;
Expand All @@ -15,17 +14,14 @@
configuration = FeignClientConfig.class)
public interface LogAndAuditFeignClient {


@PostMapping("/audit/caseAction")
ResponseEntity<CaseActionPostResponse> postCaseAction(
@RequestHeader("ServiceAuthorization") String serviceAuthorization,
@RequestBody CaseActionPostRequest caseActionPostRequest
@RequestBody CaseActionPostRequest caseActionPostRequest
);

@PostMapping("/audit/caseSearch")
ResponseEntity<CaseSearchPostResponse> postCaseSearch(
@RequestHeader("ServiceAuthorization") String serviceAuthorization,
@RequestBody CaseSearchPostRequest caseSearchPostRequest
@RequestBody CaseSearchPostRequest caseSearchPostRequest
);

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Lazy;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import uk.gov.hmcts.ccd.AuditCaseRemoteConfiguration;
import uk.gov.hmcts.ccd.auditlog.AuditEntry;
import uk.gov.hmcts.ccd.auditlog.AuditOperationType;
import uk.gov.hmcts.ccd.auditlog.LogAndAuditFeignClient;
import uk.gov.hmcts.ccd.data.SecurityUtils;
import uk.gov.hmcts.ccd.domain.model.lau.ActionLog;
import uk.gov.hmcts.ccd.domain.model.lau.CaseActionPostRequest;
import uk.gov.hmcts.ccd.domain.model.lau.CaseActionPostResponse;
Expand Down Expand Up @@ -40,15 +38,11 @@ public class AuditCaseRemoteOperation implements AuditRemoteOperation {

private final LogAndAuditFeignClient logAndAuditFeignClient;

private final SecurityUtils securityUtils;

private final AuditCaseRemoteConfiguration auditCaseRemoteConfiguration;

@Autowired
public AuditCaseRemoteOperation(@Lazy final SecurityUtils securityUtils,
LogAndAuditFeignClient logAndAuditFeignClient,
public AuditCaseRemoteOperation(LogAndAuditFeignClient logAndAuditFeignClient,
final AuditCaseRemoteConfiguration auditCaseRemoteConfiguration) {
this.securityUtils = securityUtils;
this.logAndAuditFeignClient = logAndAuditFeignClient;
this.auditCaseRemoteConfiguration = auditCaseRemoteConfiguration;
}
Expand Down Expand Up @@ -120,12 +114,12 @@ public void postAsyncAuditRequestAndHandleResponse(
if (LAU_CASE_ACTION_CREATE.equals(activity) || LAU_CASE_ACTION_UPDATE.equals(activity)
|| LAU_CASE_ACTION_VIEW.equals(activity)) {
CompletableFuture<ResponseEntity<CaseActionPostResponse>> future = CompletableFuture.supplyAsync(() ->
logAndAuditFeignClient.postCaseAction(securityUtils.getServiceAuthorization(), capr));
logAndAuditFeignClient.postCaseAction(capr));
future.whenComplete((response, error) ->
handleAuditResponse(response, error, entry.getRequestId(), activity, url, auditLogId));
} else if ("SEARCH".equals(activity)) {
CompletableFuture<ResponseEntity<CaseSearchPostResponse>> future = CompletableFuture.supplyAsync(() ->
logAndAuditFeignClient.postCaseSearch(securityUtils.getServiceAuthorization(), cspr));
logAndAuditFeignClient.postCaseSearch(cspr));
future.whenComplete((response, error) ->
handleAuditResponse(response, error, entry.getRequestId(), activity, url, auditLogId));
}
Expand Down
14 changes: 14 additions & 0 deletions src/main/java/uk/gov/hmcts/ccd/feign/FeignClientConfig.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package uk.gov.hmcts.ccd.feign;

import feign.RequestInterceptor;
import feign.Retryer;
import feign.codec.ErrorDecoder;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -32,4 +33,17 @@ public Retryer feignRetryer() {
public ErrorDecoder errorDecoder(@Lazy SecurityUtils securityUtils) {
return new FeignErrorDecoder(securityUtils);
}

/**
* Injects the ServiceAuthorization header dynamically.
* This is executed again for every retry, ensuring fresh tokens are used.
*/
@Bean
public RequestInterceptor serviceAuthRequestInterceptor(SecurityUtils securityUtils) {
return template -> {
template.removeHeader("ServiceAuthorization");
String token = securityUtils.getServiceAuthorization();
template.header("ServiceAuthorization", token);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will run on every single request going through feign and we need to check if performance is not going to be impacted. I believe this will run on every request, not as comment suggests - on retries only.

};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import org.mockito.MockitoAnnotations;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.test.context.bean.override.mockito.MockitoSpyBean;

import uk.gov.hmcts.ccd.AuditCaseRemoteConfiguration;
Expand All @@ -31,6 +34,8 @@
import uk.gov.hmcts.ccd.domain.model.lau.SearchLog;

import jakarta.inject.Inject;
import uk.gov.hmcts.reform.authorisation.generators.AuthTokenGenerator;

import java.io.IOException;
import java.time.Clock;
import java.time.Instant;
Expand All @@ -50,8 +55,10 @@
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.Mockito.doReturn;

@Import(AuditCaseRemoteOperationIT.MockConfig.class)
public class AuditCaseRemoteOperationIT extends WireMockBaseTest {

private static int ASYNC_DELAY_TIMEOUT_MILLISECONDS = 2000;
Expand All @@ -70,8 +77,12 @@ public class AuditCaseRemoteOperationIT extends WireMockBaseTest {
private static final String CASE_ID = "1504259907353529";
private static final String IDAM_ID = "1234";


@Autowired
SecurityUtils securityUtils;

@Autowired
private SecurityUtils securityUtils;
private AuthTokenGenerator authTokenGenerator;

@Autowired
private AuditCaseRemoteConfiguration auditCaseRemoteConfiguration;
Expand Down Expand Up @@ -119,6 +130,15 @@ public class AuditCaseRemoteOperationIT extends WireMockBaseTest {
private static final ZonedDateTime LOG_TIMESTAMP =
ZonedDateTime.of(LocalDateTime.now(fixedClock), ZoneOffset.UTC);

@TestConfiguration
static class MockConfig {

@Bean
public AuthTokenGenerator authTokenGenerator() {
return Mockito.mock(AuthTokenGenerator.class);
}
}

@BeforeEach
public void setUp() throws IOException {
MockitoAnnotations.openMocks(this);
Expand Down Expand Up @@ -250,7 +270,6 @@ public void shouldNotThrowExceptionInAuditServiceIfLauIsDown()
@Test
public void shouldNotThrowExceptionInAuditServiceIfLauSearchIsDownAndRetry()
throws JsonProcessingException, InterruptedException {

final SearchLog searchLog = new SearchLog();
searchLog.setUserId(SEARCH_LOG_USER_ID);
searchLog.setCaseRefs(SEARCH_LOG_CASE_REFS);
Expand Down Expand Up @@ -394,4 +413,45 @@ private long countServeEvents(String pathPrefix) {
.count();
}

@Test
void shouldUseNewTokenOnRetryWithInterceptor() throws Exception {
Mockito.when(authTokenGenerator.generate())
.thenReturn("Bearer originalToken")
.thenReturn("Bearer refreshedToken");

stubFor(WireMock.post(urlEqualTo(ACTION_AUDIT_ENDPOINT))
.willReturn(aResponse().withStatus(AUDIT_UNAUTHORISED_HTTP_STATUS)));

AuditContext auditContext = AuditContext.auditContextWith()
.caseId(CASE_ID)
.auditOperationType(AuditOperationType.CASE_ACCESSED)
.jurisdiction(JURISDICTION)
.caseType(CASE_TYPE)
.httpStatus(200)
.build();

// Act: make call (will retry 3 times)
auditService.audit(auditContext);
waitForPossibleAuditResponse(ACTION_AUDIT_ENDPOINT, 3);

// Assert: all 3 requests
var requests = getAllServeEvents().stream()
.filter(e -> e.getRequest().getUrl().equals(ACTION_AUDIT_ENDPOINT))
.toList();

assertThat(requests.size(), is(3));

var originalRequest = requests.stream()
.filter(r -> r.getRequest().getHeader("ServiceAuthorization").equals("Bearer originalToken"))
.toList();

var retryRequests = requests.stream()
.filter(r -> r.getRequest().getHeader("ServiceAuthorization").equals("Bearer refreshedToken"))
.toList();

assertThat(retryRequests.size(), is(2));
assertThat(originalRequest.size(), is(1));
assertThat(originalRequest.getFirst(), is(notNullValue()));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -94,7 +93,7 @@ void setUp() throws JsonProcessingException {
doReturn("Bearer 1234").when(securityUtils).getServiceAuthorization();
doReturn("http://localhost/caseAction").when(auditCaseRemoteConfiguration).getCaseActionAuditUrl();
doReturn("http://localhost/caseSearch").when(auditCaseRemoteConfiguration).getCaseSearchAuditUrl();
auditCaseRemoteOperation = new AuditCaseRemoteOperation(securityUtils, feignClient,
auditCaseRemoteOperation = new AuditCaseRemoteOperation(feignClient,
auditCaseRemoteConfiguration);
}

Expand All @@ -113,10 +112,10 @@ void shouldPostCaseActionRemoteAuditRequest() {
// Verify asyncRequestService interaction
ArgumentCaptor<CaseActionPostRequest> requestCaptor = ArgumentCaptor.forClass(CaseActionPostRequest.class);
await().atMost(200, MILLISECONDS).untilAsserted(() ->
verify(feignClient).postCaseAction(any(String.class), requestCaptor.capture())
verify(feignClient).postCaseAction(requestCaptor.capture())
);
// Verify headers and endpoint
verify(feignClient).postCaseAction(eq("Bearer 1234"), any(CaseActionPostRequest.class));
verify(feignClient).postCaseAction(any(CaseActionPostRequest.class));
assertThat(auditCaseRemoteConfiguration.getCaseActionAuditUrl(), is(equalTo("http://localhost/caseAction")));

// Assert the captured request
Expand Down Expand Up @@ -147,9 +146,9 @@ void shouldPostCaseSearchRemoteAuditRequest() {
// Verify FeignClient interaction
ArgumentCaptor<CaseSearchPostRequest> requestCaptor = ArgumentCaptor.forClass(CaseSearchPostRequest.class);
await().atMost(200, MILLISECONDS).untilAsserted(() ->
verify(feignClient).postCaseSearch(any(String.class), requestCaptor.capture())
verify(feignClient).postCaseSearch(requestCaptor.capture())
);
verify(feignClient).postCaseSearch(eq("Bearer 1234"), any(CaseSearchPostRequest.class));
verify(feignClient).postCaseSearch(any(CaseSearchPostRequest.class));
assertThat(auditCaseRemoteConfiguration.getCaseSearchAuditUrl(), is(equalTo("http://localhost/caseSearch")));

// Assert the captured request
Expand Down Expand Up @@ -186,13 +185,13 @@ void shouldHandleExceptionDuringPostCaseAction() {

// Simulate exception in FeignClient
doThrow(new RuntimeException("FeignClient error")).when(feignClient)
.postCaseAction(any(String.class), any(CaseActionPostRequest.class));
.postCaseAction(any(CaseActionPostRequest.class));

auditCaseRemoteOperation.postCaseAction(entry, fixedDateTime);

// Verify exception is logged and no further interaction occurs
await().atMost(200, MILLISECONDS).untilAsserted(() ->
verify(feignClient).postCaseAction(any(String.class), any(CaseActionPostRequest.class))
verify(feignClient).postCaseAction(any(CaseActionPostRequest.class))
);
}

Expand All @@ -208,13 +207,13 @@ void shouldHandleExceptionDuringPostSearchAction() {

// Simulate exception in FeignClient
doThrow(new RuntimeException("FeignClient error")).when(feignClient)
.postCaseSearch(any(String.class), any(CaseSearchPostRequest.class));
.postCaseSearch(any(CaseSearchPostRequest.class));

auditCaseRemoteOperation.postCaseSearch(entry, fixedDateTime);

// Verify exception is logged and no further interaction occurs
await().atMost(200, MILLISECONDS).untilAsserted(() ->
verify(feignClient).postCaseSearch(any(String.class), any(CaseSearchPostRequest.class))
verify(feignClient).postCaseSearch(any(CaseSearchPostRequest.class))
);
}

Expand Down