From 1bf943b4dad6051d129718411252e82e0626f61a Mon Sep 17 00:00:00 2001 From: Youssef Shoaib Date: Tue, 23 Apr 2024 19:07:50 +0100 Subject: [PATCH] Introduce ReusableRaise --- .../kotlin/arrow/core/raise/Fold.kt | 20 ++-- .../kotlin/arrow/core/raise/Raise.kt | 38 ++++++++ .../arrow/core/raise/RaiseAccumulate.kt | 14 +-- .../raise/CancellationExceptionNoTrace.kt | 2 +- .../raise/CancellationExceptionNoTrace.kt | 2 +- .../kotlin/arrow/fx/coroutines/ParMap.kt | 93 +++++++++++++------ 6 files changed, 122 insertions(+), 47 deletions(-) diff --git a/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/Fold.kt b/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/Fold.kt index a3eced1ab62..2133cbf0133 100644 --- a/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/Fold.kt +++ b/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/Fold.kt @@ -138,7 +138,7 @@ public inline fun fold( callsInPlace(recover, AT_MOST_ONCE) callsInPlace(transform, AT_MOST_ONCE) } - val raise = DefaultRaise(false) + val raise = DefaultRaise(false) return try { val res = block(raise) raise.complete() @@ -201,7 +201,7 @@ public inline fun Raise.traced( callsInPlace(block, AT_MOST_ONCE) callsInPlace(trace, AT_MOST_ONCE) } - val nested = DefaultRaise(true) + val nested = DefaultRaise(true) return try { block(nested).also { nested.complete() } } catch (e: Traced) { @@ -227,21 +227,20 @@ internal fun Traced.withCause(cause: Traced): Traced = @PublishedApi @DelicateRaiseApi @Suppress("UNCHECKED_CAST") -internal fun CancellationException.raisedOrRethrow(raise: DefaultRaise): R = +internal fun CancellationException.raisedOrRethrow(raise: DefaultRaise): R = when { this is RaiseCancellationException && this.raise === raise -> raised as R else -> throw this } /** Serves as both purposes of a scope-reference token, and a default implementation for Raise. */ -@PublishedApi -internal class DefaultRaise(@PublishedApi internal val isTraced: Boolean) : Raise { +public class DefaultRaise(@PublishedApi internal val isTraced: Boolean) : Raise { private val isActive = AtomicBoolean(true) @PublishedApi internal fun complete(): Boolean = isActive.getAndSet(false) @OptIn(DelicateRaiseApi::class) - override fun raise(r: Any?): Nothing = when { + override fun raise(r: Error): Nothing = when { isActive.value -> throw if (isTraced) Traced(r, this) else NoTrace(r, this) else -> throw RaiseLeakedException() } @@ -259,17 +258,18 @@ public annotation class DelicateRaiseApi @DelicateRaiseApi public sealed class RaiseCancellationException( internal val raised: Any?, - internal val raise: Raise + internal val raise: Raise<*> ) : CancellationException(RaiseCancellationExceptionCaptured) @DelicateRaiseApi @Suppress("EXPECT_ACTUAL_CLASSIFIERS_ARE_IN_BETA_WARNING") -internal expect class NoTrace(raised: Any?, raise: Raise) : RaiseCancellationException +internal expect class NoTrace(raised: Any?, raise: Raise<*>) : RaiseCancellationException @DelicateRaiseApi -internal class Traced(raised: Any?, raise: Raise, override val cause: Traced? = null): RaiseCancellationException(raised, raise) +internal class Traced(raised: Any?, raise: Raise<*>, override val cause: Traced? = null): RaiseCancellationException(raised, raise) -private class RaiseLeakedException : IllegalStateException( +@PublishedApi +internal class RaiseLeakedException : IllegalStateException( """ 'raise' or 'bind' was leaked outside of its context scope. Make sure all calls to 'raise' and 'bind' occur within the lifecycle of nullable { }, either { } or similar builders. diff --git a/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/Raise.kt b/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/Raise.kt index 429d2ddd12d..8268b440a34 100644 --- a/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/Raise.kt +++ b/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/Raise.kt @@ -16,6 +16,7 @@ import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind.AT_MOST_ONCE import kotlin.contracts.contract import kotlin.experimental.ExperimentalTypeInference +import kotlin.jvm.JvmInline import kotlin.jvm.JvmMultifileClass import kotlin.jvm.JvmName @@ -309,6 +310,43 @@ public interface Raise { map { it.bind() }.toNonEmptySet() } +@JvmInline public value class ReusableRaise(public val raise: DefaultRaise) + +@DelicateRaiseApi +@RaiseDSL +public inline fun reusableRaise( + @BuilderInference block: ReusableRaise.() -> A, +): A { + contract { + callsInPlace(block, AT_MOST_ONCE) + } + val raise = DefaultRaise(false) + try { + return block(ReusableRaise(raise)) + } catch(e: RaiseCancellationException) { + e.raisedOrRethrow(raise) + throw RaiseLeakedException() + } finally { + raise.complete() + } +} + +@DelicateRaiseApi +@RaiseDSL +public inline fun ReusableRaise.recoverReused( + @BuilderInference block: Raise.() -> A, + recover: (error: Error) -> A +): A { + contract { + callsInPlace(block, AT_MOST_ONCE) + } + return try { + block(raise) + } catch (e: RaiseCancellationException) { + recover(e.raisedOrRethrow(raise)) + } +} + /** * Execute the [Raise] context function resulting in [A] or any _logical error_ of type [Error], * and recover by providing a transform [Error] into a fallback value of type [A]. diff --git a/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/RaiseAccumulate.kt b/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/RaiseAccumulate.kt index e5157fe4a4a..829f7666bbd 100644 --- a/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/RaiseAccumulate.kt +++ b/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/raise/RaiseAccumulate.kt @@ -521,21 +521,22 @@ public inline fun Raise.forEachAccumulating( @BuilderInference block: RaiseAccumulate.(A) -> Unit ): Unit = forEachAccumulatingImpl(iterator, combine) { item, _ -> block(item) } +@OptIn(DelicateRaiseApi::class) @PublishedApi @JvmSynthetic internal inline fun Raise.forEachAccumulatingImpl( iterator: Iterator, combine: (Error, Error) -> Error, @BuilderInference block: RaiseAccumulate.(item: A, hasErrors: Boolean) -> Unit -) { +): Unit = reusableRaise, _> { var error: Any? = EmptyValue for (item in iterator) { - recover({ + recoverReused({ block(RaiseAccumulate(this), item, error != EmptyValue) }) { errors -> error = combine(error, errors.reduce(combine), combine) } } - return if (error === EmptyValue) Unit else raise(unbox(error)) + return if (error === EmptyValue) Unit else this@forEachAccumulatingImpl.raise(unbox(error)) } @RaiseDSL @@ -560,20 +561,21 @@ public inline fun Raise>.forEachAccumulating( * Allows to change what to do once the first error is raised. * Used to provide more performant [mapOrAccumulate]. */ +@OptIn(DelicateRaiseApi::class) @PublishedApi @JvmSynthetic internal inline fun Raise>.forEachAccumulatingImpl( iterator: Iterator, @BuilderInference block: RaiseAccumulate.(item: A, hasErrors: Boolean) -> Unit -) { +): Unit = reusableRaise { val error: MutableList = mutableListOf() for (item in iterator) { - recover({ + recoverReused({ block(RaiseAccumulate(this), item, error.isNotEmpty()) }) { error.addAll(it) } } - error.toNonEmptyListOrNull()?.let(::raise) + error.toNonEmptyListOrNull()?.let(this@forEachAccumulatingImpl::raise) } /** diff --git a/arrow-libs/core/arrow-core/src/jvmMain/kotlin/arrow/core/raise/CancellationExceptionNoTrace.kt b/arrow-libs/core/arrow-core/src/jvmMain/kotlin/arrow/core/raise/CancellationExceptionNoTrace.kt index 6b09c1ec66d..5142acb7797 100644 --- a/arrow-libs/core/arrow-core/src/jvmMain/kotlin/arrow/core/raise/CancellationExceptionNoTrace.kt +++ b/arrow-libs/core/arrow-core/src/jvmMain/kotlin/arrow/core/raise/CancellationExceptionNoTrace.kt @@ -10,7 +10,7 @@ import kotlin.coroutines.cancellation.CancellationException "EXPECT_ACTUAL_CLASSIFIERS_ARE_IN_BETA_WARNING", "SEALED_INHERITOR_IN_DIFFERENT_MODULE" ) -internal actual class NoTrace actual constructor(raised: Any?, raise: Raise) : RaiseCancellationException(raised, raise) { +internal actual class NoTrace actual constructor(raised: Any?, raise: Raise<*>) : RaiseCancellationException(raised, raise) { override fun fillInStackTrace(): Throwable { // Prevent Android <= 6.0 bug. stackTrace = emptyArray() diff --git a/arrow-libs/core/arrow-core/src/nonJvmMain/kotlin/arrow/core/raise/CancellationExceptionNoTrace.kt b/arrow-libs/core/arrow-core/src/nonJvmMain/kotlin/arrow/core/raise/CancellationExceptionNoTrace.kt index 83b99822ef5..4bf8c30c4c6 100644 --- a/arrow-libs/core/arrow-core/src/nonJvmMain/kotlin/arrow/core/raise/CancellationExceptionNoTrace.kt +++ b/arrow-libs/core/arrow-core/src/nonJvmMain/kotlin/arrow/core/raise/CancellationExceptionNoTrace.kt @@ -4,4 +4,4 @@ package arrow.core.raise "EXPECT_ACTUAL_CLASSIFIERS_ARE_IN_BETA_WARNING", "SEALED_INHERITOR_IN_DIFFERENT_MODULE" ) -internal actual class NoTrace actual constructor(raised: Any?, raise: Raise) : RaiseCancellationException(raised, raise) +internal actual class NoTrace actual constructor(raised: Any?, raise: Raise<*>) : RaiseCancellationException(raised, raise) diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/ParMap.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/ParMap.kt index b9511ad96c6..e7b8af66342 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/ParMap.kt +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/ParMap.kt @@ -5,7 +5,10 @@ import arrow.core.Either import arrow.core.NonEmptyList import arrow.core.raise.Raise import arrow.core.raise.either -import arrow.core.flattenOrAccumulate +import arrow.core.raise.DelicateRaiseApi +import arrow.core.raise.mapOrAccumulate +import arrow.core.raise.recoverReused +import arrow.core.raise.reusableRaise import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll @@ -52,68 +55,100 @@ public class ScopedRaiseAccumulate( scope: CoroutineScope ) : CoroutineScope by scope, RaiseAccumulate(raise) +@OptIn(DelicateRaiseApi::class) public suspend fun Iterable.parMapOrAccumulate( context: CoroutineContext = EmptyCoroutineContext, concurrency: Int, combine: (Error, Error) -> Error, transform: suspend ScopedRaiseAccumulate.(A) -> B ): Either> = - coroutineScope { - val semaphore = Semaphore(concurrency) - map { - async(context) { - either { - semaphore.withPermit { - transform(ScopedRaiseAccumulate(this, this@coroutineScope), it) + either { + coroutineScope { + reusableRaise, _> { + val semaphore = Semaphore(concurrency) + val asyncs = map { + async(context) { + semaphore.withPermit { + transform(ScopedRaiseAccumulate(raise, this@coroutineScope), it) + } + } + } + mapOrAccumulate(asyncs, combine) { + withNel { + recoverReused({ it.await() }, this::raise) } } } - }.awaitAll().flattenOrAccumulate(combine) + } } +@OptIn(DelicateRaiseApi::class) public suspend fun Iterable.parMapOrAccumulate( context: CoroutineContext = EmptyCoroutineContext, combine: (Error, Error) -> Error, transform: suspend ScopedRaiseAccumulate.(A) -> B ): Either> = - coroutineScope { - map { - async(context) { - either { - transform(ScopedRaiseAccumulate(this, this@coroutineScope), it) + either { + coroutineScope { + reusableRaise, _> { + val asyncs = map { + async(context) { + transform(ScopedRaiseAccumulate(raise, this@coroutineScope), it) + } + } + mapOrAccumulate(asyncs, combine) { + withNel { + recoverReused({ it.await() }, this::raise) + } } } - }.awaitAll().flattenOrAccumulate(combine) + } } +@OptIn(DelicateRaiseApi::class) public suspend fun Iterable.parMapOrAccumulate( context: CoroutineContext = EmptyCoroutineContext, concurrency: Int, transform: suspend ScopedRaiseAccumulate.(A) -> B ): Either, List> = - coroutineScope { - val semaphore = Semaphore(concurrency) - map { - async(context) { - either { - semaphore.withPermit { - transform(ScopedRaiseAccumulate(this, this@coroutineScope), it) + either { + coroutineScope { + reusableRaise, _> { + val semaphore = Semaphore(concurrency) + val asyncs = map { + async(context) { + semaphore.withPermit { + transform(ScopedRaiseAccumulate(raise, this@coroutineScope), it) + } + } + } + mapOrAccumulate(asyncs) { + withNel { + recoverReused({ it.await() }, this::raise) } } } - }.awaitAll().flattenOrAccumulate() + } } +@OptIn(DelicateRaiseApi::class) public suspend fun Iterable.parMapOrAccumulate( context: CoroutineContext = EmptyCoroutineContext, transform: suspend ScopedRaiseAccumulate.(A) -> B ): Either, List> = - coroutineScope { - map { - async(context) { - either { - transform(ScopedRaiseAccumulate(this, this@coroutineScope), it) + either { + coroutineScope { + reusableRaise, _> { + val asyncs = map { + async(context) { + transform(ScopedRaiseAccumulate(raise, this@coroutineScope), it) + } + } + mapOrAccumulate(asyncs) { + withNel { + recoverReused({ it.await() }, this::raise) + } } } - }.awaitAll().flattenOrAccumulate() + } }