Skip to content

Commit

Permalink
Introduce ReusableRaise
Browse files Browse the repository at this point in the history
  • Loading branch information
kyay10 committed Apr 23, 2024
1 parent 5ba0e45 commit 1bf943b
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public inline fun <Error, A, B> fold(
callsInPlace(recover, AT_MOST_ONCE)
callsInPlace(transform, AT_MOST_ONCE)
}
val raise = DefaultRaise(false)
val raise = DefaultRaise<Error>(false)
return try {
val res = block(raise)
raise.complete()
Expand Down Expand Up @@ -201,7 +201,7 @@ public inline fun <Error, A> Raise<Error>.traced(
callsInPlace(block, AT_MOST_ONCE)
callsInPlace(trace, AT_MOST_ONCE)
}
val nested = DefaultRaise(true)
val nested = DefaultRaise<Error>(true)
return try {
block(nested).also { nested.complete() }
} catch (e: Traced) {
Expand All @@ -227,21 +227,20 @@ internal fun Traced.withCause(cause: Traced): Traced =
@PublishedApi
@DelicateRaiseApi
@Suppress("UNCHECKED_CAST")
internal fun <R> CancellationException.raisedOrRethrow(raise: DefaultRaise): R =
internal fun <R> CancellationException.raisedOrRethrow(raise: DefaultRaise<R>): R =
when {
this is RaiseCancellationException && this.raise === raise -> raised as R
else -> throw this
}

/** Serves as both purposes of a scope-reference token, and a default implementation for Raise. */
@PublishedApi
internal class DefaultRaise(@PublishedApi internal val isTraced: Boolean) : Raise<Any?> {
public class DefaultRaise<Error>(@PublishedApi internal val isTraced: Boolean) : Raise<Error> {
private val isActive = AtomicBoolean(true)

@PublishedApi
internal fun complete(): Boolean = isActive.getAndSet(false)
@OptIn(DelicateRaiseApi::class)
override fun raise(r: Any?): Nothing = when {
override fun raise(r: Error): Nothing = when {
isActive.value -> throw if (isTraced) Traced(r, this) else NoTrace(r, this)
else -> throw RaiseLeakedException()
}
Expand All @@ -259,17 +258,18 @@ public annotation class DelicateRaiseApi
@DelicateRaiseApi
public sealed class RaiseCancellationException(
internal val raised: Any?,
internal val raise: Raise<Any?>
internal val raise: Raise<*>
) : CancellationException(RaiseCancellationExceptionCaptured)

@DelicateRaiseApi
@Suppress("EXPECT_ACTUAL_CLASSIFIERS_ARE_IN_BETA_WARNING")
internal expect class NoTrace(raised: Any?, raise: Raise<Any?>) : RaiseCancellationException
internal expect class NoTrace(raised: Any?, raise: Raise<*>) : RaiseCancellationException

@DelicateRaiseApi
internal class Traced(raised: Any?, raise: Raise<Any?>, override val cause: Traced? = null): RaiseCancellationException(raised, raise)
internal class Traced(raised: Any?, raise: Raise<*>, override val cause: Traced? = null): RaiseCancellationException(raised, raise)

private class RaiseLeakedException : IllegalStateException(
@PublishedApi
internal class RaiseLeakedException : IllegalStateException(
"""
'raise' or 'bind' was leaked outside of its context scope.
Make sure all calls to 'raise' and 'bind' occur within the lifecycle of nullable { }, either { } or similar builders.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind.AT_MOST_ONCE
import kotlin.contracts.contract
import kotlin.experimental.ExperimentalTypeInference
import kotlin.jvm.JvmInline
import kotlin.jvm.JvmMultifileClass
import kotlin.jvm.JvmName

Expand Down Expand Up @@ -309,6 +310,43 @@ public interface Raise<in Error> {
map { it.bind() }.toNonEmptySet()
}

@JvmInline public value class ReusableRaise<Error>(public val raise: DefaultRaise<Error>)

@DelicateRaiseApi
@RaiseDSL
public inline fun <Error, A> reusableRaise(
@BuilderInference block: ReusableRaise<Error>.() -> A,
): A {
contract {
callsInPlace(block, AT_MOST_ONCE)
}
val raise = DefaultRaise<Error>(false)
try {
return block(ReusableRaise(raise))
} catch(e: RaiseCancellationException) {
e.raisedOrRethrow(raise)
throw RaiseLeakedException()
} finally {
raise.complete()
}
}

@DelicateRaiseApi
@RaiseDSL
public inline fun <Error, A> ReusableRaise<Error>.recoverReused(
@BuilderInference block: Raise<Error>.() -> A,
recover: (error: Error) -> A
): A {
contract {
callsInPlace(block, AT_MOST_ONCE)
}
return try {
block(raise)
} catch (e: RaiseCancellationException) {
recover(e.raisedOrRethrow(raise))
}
}

/**
* Execute the [Raise] context function resulting in [A] or any _logical error_ of type [Error],
* and recover by providing a transform [Error] into a fallback value of type [A].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,21 +521,22 @@ public inline fun <Error, A> Raise<Error>.forEachAccumulating(
@BuilderInference block: RaiseAccumulate<Error>.(A) -> Unit
): Unit = forEachAccumulatingImpl(iterator, combine) { item, _ -> block(item) }

@OptIn(DelicateRaiseApi::class)
@PublishedApi @JvmSynthetic
internal inline fun <Error, A> Raise<Error>.forEachAccumulatingImpl(
iterator: Iterator<A>,
combine: (Error, Error) -> Error,
@BuilderInference block: RaiseAccumulate<Error>.(item: A, hasErrors: Boolean) -> Unit
) {
): Unit = reusableRaise<NonEmptyList<Error>, _> {
var error: Any? = EmptyValue
for (item in iterator) {
recover({
recoverReused({
block(RaiseAccumulate(this), item, error != EmptyValue)
}) { errors ->
error = combine(error, errors.reduce(combine), combine)
}
}
return if (error === EmptyValue) Unit else raise(unbox<Error>(error))
return if (error === EmptyValue) Unit else this@forEachAccumulatingImpl.raise(unbox<Error>(error))
}

@RaiseDSL
Expand All @@ -560,20 +561,21 @@ public inline fun <Error, A> Raise<NonEmptyList<Error>>.forEachAccumulating(
* Allows to change what to do once the first error is raised.
* Used to provide more performant [mapOrAccumulate].
*/
@OptIn(DelicateRaiseApi::class)
@PublishedApi @JvmSynthetic
internal inline fun <Error, A> Raise<NonEmptyList<Error>>.forEachAccumulatingImpl(
iterator: Iterator<A>,
@BuilderInference block: RaiseAccumulate<Error>.(item: A, hasErrors: Boolean) -> Unit
) {
): Unit = reusableRaise {
val error: MutableList<Error> = mutableListOf()
for (item in iterator) {
recover({
recoverReused({
block(RaiseAccumulate(this), item, error.isNotEmpty())
}) {
error.addAll(it)
}
}
error.toNonEmptyListOrNull()?.let(::raise)
error.toNonEmptyListOrNull()?.let(this@forEachAccumulatingImpl::raise)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import kotlin.coroutines.cancellation.CancellationException
"EXPECT_ACTUAL_CLASSIFIERS_ARE_IN_BETA_WARNING",
"SEALED_INHERITOR_IN_DIFFERENT_MODULE"
)
internal actual class NoTrace actual constructor(raised: Any?, raise: Raise<Any?>) : RaiseCancellationException(raised, raise) {
internal actual class NoTrace actual constructor(raised: Any?, raise: Raise<*>) : RaiseCancellationException(raised, raise) {
override fun fillInStackTrace(): Throwable {
// Prevent Android <= 6.0 bug.
stackTrace = emptyArray()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ package arrow.core.raise
"EXPECT_ACTUAL_CLASSIFIERS_ARE_IN_BETA_WARNING",
"SEALED_INHERITOR_IN_DIFFERENT_MODULE"
)
internal actual class NoTrace actual constructor(raised: Any?, raise: Raise<Any?>) : RaiseCancellationException(raised, raise)
internal actual class NoTrace actual constructor(raised: Any?, raise: Raise<*>) : RaiseCancellationException(raised, raise)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import arrow.core.Either
import arrow.core.NonEmptyList
import arrow.core.raise.Raise
import arrow.core.raise.either
import arrow.core.flattenOrAccumulate
import arrow.core.raise.DelicateRaiseApi
import arrow.core.raise.mapOrAccumulate
import arrow.core.raise.recoverReused
import arrow.core.raise.reusableRaise
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
Expand Down Expand Up @@ -52,68 +55,100 @@ public class ScopedRaiseAccumulate<Error>(
scope: CoroutineScope
) : CoroutineScope by scope, RaiseAccumulate<Error>(raise)

@OptIn(DelicateRaiseApi::class)
public suspend fun <Error, A, B> Iterable<A>.parMapOrAccumulate(
context: CoroutineContext = EmptyCoroutineContext,
concurrency: Int,
combine: (Error, Error) -> Error,
transform: suspend ScopedRaiseAccumulate<Error>.(A) -> B
): Either<Error, List<B>> =
coroutineScope {
val semaphore = Semaphore(concurrency)
map {
async(context) {
either {
semaphore.withPermit {
transform(ScopedRaiseAccumulate(this, this@coroutineScope), it)
either {
coroutineScope {
reusableRaise<NonEmptyList<Error>, _> {
val semaphore = Semaphore(concurrency)
val asyncs = map {
async(context) {
semaphore.withPermit {
transform(ScopedRaiseAccumulate(raise, this@coroutineScope), it)
}
}
}
mapOrAccumulate(asyncs, combine) {
withNel {
recoverReused({ it.await() }, this::raise)
}
}
}
}.awaitAll().flattenOrAccumulate(combine)
}
}

@OptIn(DelicateRaiseApi::class)
public suspend fun <Error, A, B> Iterable<A>.parMapOrAccumulate(
context: CoroutineContext = EmptyCoroutineContext,
combine: (Error, Error) -> Error,
transform: suspend ScopedRaiseAccumulate<Error>.(A) -> B
): Either<Error, List<B>> =
coroutineScope {
map {
async(context) {
either {
transform(ScopedRaiseAccumulate(this, this@coroutineScope), it)
either {
coroutineScope {
reusableRaise<NonEmptyList<Error>, _> {
val asyncs = map {
async(context) {
transform(ScopedRaiseAccumulate(raise, this@coroutineScope), it)
}
}
mapOrAccumulate(asyncs, combine) {
withNel {
recoverReused({ it.await() }, this::raise)
}
}
}
}.awaitAll().flattenOrAccumulate(combine)
}
}

@OptIn(DelicateRaiseApi::class)
public suspend fun <Error, A, B> Iterable<A>.parMapOrAccumulate(
context: CoroutineContext = EmptyCoroutineContext,
concurrency: Int,
transform: suspend ScopedRaiseAccumulate<Error>.(A) -> B
): Either<NonEmptyList<Error>, List<B>> =
coroutineScope {
val semaphore = Semaphore(concurrency)
map {
async(context) {
either {
semaphore.withPermit {
transform(ScopedRaiseAccumulate(this, this@coroutineScope), it)
either {
coroutineScope {
reusableRaise<NonEmptyList<Error>, _> {
val semaphore = Semaphore(concurrency)
val asyncs = map {
async(context) {
semaphore.withPermit {
transform(ScopedRaiseAccumulate(raise, this@coroutineScope), it)
}
}
}
mapOrAccumulate(asyncs) {
withNel {
recoverReused({ it.await() }, this::raise)
}
}
}
}.awaitAll().flattenOrAccumulate()
}
}

@OptIn(DelicateRaiseApi::class)
public suspend fun <Error, A, B> Iterable<A>.parMapOrAccumulate(
context: CoroutineContext = EmptyCoroutineContext,
transform: suspend ScopedRaiseAccumulate<Error>.(A) -> B
): Either<NonEmptyList<Error>, List<B>> =
coroutineScope {
map {
async(context) {
either {
transform(ScopedRaiseAccumulate(this, this@coroutineScope), it)
either {
coroutineScope {
reusableRaise<NonEmptyList<Error>, _> {
val asyncs = map {
async(context) {
transform(ScopedRaiseAccumulate(raise, this@coroutineScope), it)
}
}
mapOrAccumulate(asyncs) {
withNel {
recoverReused({ it.await() }, this::raise)
}
}
}
}.awaitAll().flattenOrAccumulate()
}
}

0 comments on commit 1bf943b

Please sign in to comment.