Skip to content

Commit

Permalink
Close Both Observations
Browse files Browse the repository at this point in the history
Depending on when a request is cancelled, the before and after observation
starts and stops may be called out of order due to the order in
which their doOnCancel handlers are invoked.

To address this, the before filter-wrapper now always closes both the
before observation and the after observation. Since the before filter-
wrapper wraps the entire request, this ensures that either that was
started is stopped, and either that has not been started yet cannot
inadvertently be started by any unexpected ordering of events that
follows.

Closes spring-projectsgh-14031
  • Loading branch information
jzheaux committed Oct 30, 2023
1 parent aa04ee1 commit 68581e5
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,13 @@ public Observation error(Throwable ex) {

@Override
public void stop() {
this.currentObservation.get().stop();
this.before.stop();
this.after.stop();
}

private void close() {
this.before.close();
this.after.close();
}

@Override
Expand Down Expand Up @@ -357,11 +363,11 @@ public WebFilter wrap(WebFilter filter) {
start();
// @formatter:off
return filter.filter(exchange, chain)
.doOnSuccess((v) -> stop())
.doOnCancel(this::stop)
.doOnSuccess((v) -> close())
.doOnCancel(this::close)
.doOnError((t) -> {
error(t);
stop();
close();
})
.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, this));
// @formatter:on
Expand Down Expand Up @@ -433,6 +439,21 @@ private void stop() {
}
}

private void close() {
try {
this.lock.lock();
if (this.state.compareAndSet(1, 3)) {
this.observation.stop();
}
else {
this.state.set(3);
}
}
finally {
this.lock.unlock();
}
}

}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,98 @@ void decorateWhenNoopThenDoesNotObserve() {
verifyNoInteractions(handler);
}

@Test
void decorateWhenTerminatingFilterThenObserves() {
AccumulatingObservationHandler handler = new AccumulatingObservationHandler();
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
WebFilterChain chain = mock(WebFilterChain.class);
given(chain.filter(any())).willReturn(Mono.error(() -> new Exception("ack")));
WebFilterChain decorated = decorator.decorate(chain,
List.of(new BasicAuthenticationFilter(), new TerminatingFilter()));
Observation http = Observation.start("http", registry).contextualName("http");
try {
decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build()))
.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http))
.block();
}
catch (Exception ex) {
http.error(ex);
}
finally {
http.stop();
}
handler.assertSpanStart(0, "http", null);
handler.assertSpanStart(1, "spring.security.filterchains", "http");
handler.assertSpanStop(2, "security filterchain before");
handler.assertSpanStart(3, "spring.security.filterchains", "http");
handler.assertSpanStop(4, "security filterchain after");
handler.assertSpanStop(5, "http");
}

@Test
void decorateWhenFilterErrorThenStopsObservation() {
AccumulatingObservationHandler handler = new AccumulatingObservationHandler();
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
WebFilterChain chain = mock(WebFilterChain.class);
WebFilterChain decorated = decorator.decorate(chain, List.of(new ErroringFilter()));
Observation http = Observation.start("http", registry).contextualName("http");
try {
decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build()))
.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http))
.block();
}
catch (Exception ex) {
http.error(ex);
}
finally {
http.stop();
}
handler.assertSpanStart(0, "http", null);
handler.assertSpanStart(1, "spring.security.filterchains", "http");
handler.assertSpanError(2);
handler.assertSpanStop(3, "security filterchain before");
handler.assertSpanError(4);
handler.assertSpanStop(5, "http");
}

@Test
void decorateWhenErrorSignalThenStopsObservation() {
AccumulatingObservationHandler handler = new AccumulatingObservationHandler();
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
WebFilterChain chain = mock(WebFilterChain.class);
given(chain.filter(any())).willReturn(Mono.error(() -> new Exception("ack")));
WebFilterChain decorated = decorator.decorate(chain, List.of(new BasicAuthenticationFilter()));
Observation http = Observation.start("http", registry).contextualName("http");
try {
decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build()))
.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http))
.block();
}
catch (Exception ex) {
http.error(ex);
}
finally {
http.stop();
}
handler.assertSpanStart(0, "http", null);
handler.assertSpanStart(1, "spring.security.filterchains", "http");
handler.assertSpanStop(2, "security filterchain before");
handler.assertSpanStart(3, "secured request", "security filterchain before");
handler.assertSpanError(4);
handler.assertSpanStop(5, "secured request");
handler.assertSpanStart(6, "spring.security.filterchains", "http");
handler.assertSpanError(7);
handler.assertSpanStop(8, "security filterchain after");
handler.assertSpanError(9);
handler.assertSpanStop(10, "http");
}

// gh-12849
@Test
void decorateWhenCustomAfterFilterThenObserves() {
Expand Down Expand Up @@ -171,6 +263,24 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {

}

static class ErroringFilter implements WebFilter {

@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return Mono.error(() -> new RuntimeException("ack"));
}

}

static class TerminatingFilter implements WebFilter {

@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return Mono.empty();
}

}

static class AccumulatingObservationHandler implements ObservationHandler<Observation.Context> {

List<Event> contexts = new ArrayList<>();
Expand Down Expand Up @@ -246,6 +356,11 @@ private void assertSpanStop(int index, String name) {
}
}

private void assertSpanError(int index) {
Event event = this.contexts.get(index);
assertThat(event.event).isEqualTo("error");
}

static class Event {

String event;
Expand Down

0 comments on commit 68581e5

Please sign in to comment.