diff --git a/spring-web/src/main/kotlin/org/springframework/web/server/CoWebExceptionHandler.kt b/spring-web/src/main/kotlin/org/springframework/web/server/CoWebExceptionHandler.kt new file mode 100644 index 000000000000..ac70d4c66630 --- /dev/null +++ b/spring-web/src/main/kotlin/org/springframework/web/server/CoWebExceptionHandler.kt @@ -0,0 +1,15 @@ +package org.springframework.web.server + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.reactor.mono +import reactor.core.publisher.Mono +import kotlin.coroutines.CoroutineContext + +abstract class CoWebExceptionHandler : WebExceptionHandler { + final override fun handle(exchange: ServerWebExchange, ex: Throwable): Mono { + val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext? + return mono(context ?: Dispatchers.Unconfined) { coHandle(exchange, ex) }.then() + } + + protected abstract suspend fun coHandle(exchange: ServerWebExchange, ex: Throwable) +} diff --git a/spring-web/src/test/kotlin/org/springframework/web/server/CoWebExceptionHandlerTests.kt b/spring-web/src/test/kotlin/org/springframework/web/server/CoWebExceptionHandlerTests.kt new file mode 100644 index 000000000000..f5d1f30534c8 --- /dev/null +++ b/spring-web/src/test/kotlin/org/springframework/web/server/CoWebExceptionHandlerTests.kt @@ -0,0 +1,28 @@ +package org.springframework.web.server + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest +import org.springframework.web.testfixture.server.MockServerWebExchange +import reactor.test.StepVerifier + +class CoWebExceptionHandlerTest { + @Test + fun handle() { + val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com")) + val ex = RuntimeException() + + val handler = MyCoWebExceptionHandler() + val result = handler.handle(exchange, ex) + + StepVerifier.create(result).verifyComplete() + + assertThat(exchange.attributes["foo"]).isEqualTo("bar") + } +} + +private class MyCoWebExceptionHandler : CoWebExceptionHandler() { + override suspend fun coHandle(exchange: ServerWebExchange, ex: Throwable) { + exchange.attributes["foo"] = "bar" + } +}