Skip to content

Commit

Permalink
=str Tweak the stream mapAsyncPartitioned operator
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Oct 9, 2023
1 parent 34815bc commit 1b1f572
Show file tree
Hide file tree
Showing 14 changed files with 421 additions and 299 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ private object MapAsyncPartitionedSpec {
value = i.toString)
}

def extractPartition(e: TestKeyValue): Int =
e.key
val partitioner: TestKeyValue => Int = kv => kv.key

type Operation = TestKeyValue => Future[(Int, String)]

Expand Down Expand Up @@ -125,7 +124,7 @@ class MapAsyncPartitionedSpec

val result =
Source(elements)
.mapAsyncPartitionedUnordered(parallelism = 2)(extractPartition)(blockingOperation)
.mapAsyncPartitionedUnordered(parallelism = 2)(partitioner)(blockingOperation)
.runWith(Sink.seq)
.futureValue
.map(_._2)
Expand All @@ -137,7 +136,7 @@ class MapAsyncPartitionedSpec
forAll(minSuccessful(1000)) { (parallelism: Parallelism, elements: Seq[TestKeyValue]) =>
val result =
Source(elements.toIndexedSeq)
.mapAsyncPartitionedUnordered(parallelism.value)(extractPartition)(asyncOperation)
.mapAsyncPartitionedUnordered(parallelism.value)(partitioner)(asyncOperation)
.runWith(Sink.seq)
.futureValue

Expand All @@ -153,7 +152,7 @@ class MapAsyncPartitionedSpec
val result =
Source
.fromIterator(() => elements.iterator)
.mapAsyncPartitionedUnordered(parallelism = 1)(extractPartition)(asyncOperation)
.mapAsyncPartitionedUnordered(parallelism = 1)(partitioner)(asyncOperation)
.runWith(Sink.seq)
.futureValue

Expand All @@ -169,7 +168,7 @@ class MapAsyncPartitionedSpec
val result =
Source
.fromIterator(() => elements.iterator)
.mapAsyncPartitionedUnordered(parallelism.value)(extractPartition)(blockingOperation)
.mapAsyncPartitionedUnordered(parallelism.value)(partitioner)(blockingOperation)
.runWith(Sink.seq)
.futureValue

Expand Down Expand Up @@ -232,7 +231,7 @@ class MapAsyncPartitionedSpec

val result =
Source(elements)
.mapAsyncPartitionedUnordered(parallelism = 2)(extractPartition)(fun)
.mapAsyncPartitionedUnordered(parallelism = 2)(partitioner)(fun)
.runWith(Sink.seq)
.futureValue

Expand All @@ -244,7 +243,7 @@ class MapAsyncPartitionedSpec
an[IllegalArgumentException] shouldBe thrownBy {
Source(infiniteStream())
.mapAsyncPartitionedUnordered(
parallelism = zeroOrNegativeParallelism)(extractPartition = identity)(f = (_, _) => Future.unit)
parallelism = zeroOrNegativeParallelism)(partitioner = identity)(f = (_, _) => Future.unit)
.runWith(Sink.ignore)
.futureValue
}
Expand Down Expand Up @@ -272,7 +271,7 @@ class MapAsyncPartitionedSpec

val result =
Source(elements)
.mapAsyncPartitioned(parallelism = 2)(extractPartition)(processElement)
.mapAsyncPartitioned(parallelism = 2)(partitioner)(processElement)
.runWith(Sink.seq)
.futureValue
.map(_._2)
Expand All @@ -289,7 +288,7 @@ class MapAsyncPartitionedSpec
forAll(minSuccessful(1000)) { (parallelism: Parallelism, elements: Seq[TestKeyValue]) =>
val result =
Source(elements.toIndexedSeq)
.mapAsyncPartitioned(parallelism.value)(extractPartition)(asyncOperation)
.mapAsyncPartitioned(parallelism.value)(partitioner)(asyncOperation)
.runWith(Sink.seq)
.futureValue

Expand All @@ -305,7 +304,7 @@ class MapAsyncPartitionedSpec
val result =
Source
.fromIterator(() => elements.iterator)
.mapAsyncPartitioned(parallelism = 1)(extractPartition)(asyncOperation)
.mapAsyncPartitioned(parallelism = 1)(partitioner)(asyncOperation)
.runWith(Sink.seq)
.futureValue

Expand All @@ -321,7 +320,7 @@ class MapAsyncPartitionedSpec
val result =
Source
.fromIterator(() => elements.iterator)
.mapAsyncPartitioned(parallelism.value)(extractPartition)(blockingOperation)
.mapAsyncPartitioned(parallelism.value)(partitioner)(blockingOperation)
.runWith(Sink.seq)
.futureValue

Expand Down Expand Up @@ -384,7 +383,7 @@ class MapAsyncPartitionedSpec

val result =
Source(elements)
.mapAsyncPartitioned(parallelism = 2)(extractPartition)(fun)
.mapAsyncPartitioned(parallelism = 2)(partitioner)(fun)
.runWith(Sink.seq)
.futureValue

Expand All @@ -396,7 +395,7 @@ class MapAsyncPartitionedSpec
an[IllegalArgumentException] shouldBe thrownBy {
Source(infiniteStream())
.mapAsyncPartitioned(
parallelism = zeroOrNegativeParallelism)(extractPartition = identity)(f = (_, _) => Future.unit)
parallelism = zeroOrNegativeParallelism)(partitioner = identity)(f = (_, _) => Future.unit)
.runWith(Sink.ignore)
.futureValue
}
Expand Down
153 changes: 42 additions & 111 deletions stream/src/main/scala/org/apache/pekko/stream/MapAsyncPartitioned.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,102 +23,29 @@ import scala.util.control.{ NoStackTrace, NonFatal }
import scala.util.{ Failure, Success, Try }

import org.apache.pekko
import pekko.annotation.InternalApi
import pekko.dispatch.ExecutionContexts
import pekko.stream.ActorAttributes.SupervisionStrategy
import pekko.stream.Attributes.{ Name, SourceLocation }
import pekko.stream.MapAsyncPartitioned._
import pekko.stream.scaladsl.{ Flow, FlowWithContext, Source, SourceWithContext }
import pekko.stream.stage._
import pekko.util.OptionVal

/**
* Internal API
*/
@InternalApi
private[stream] object MapAsyncPartitioned {
private val NotYetThere = Failure(new Exception with NoStackTrace)

private def extractPartitionWithCtx[In, Ctx, Partition](extract: In => Partition)(tuple: (In, Ctx)): Partition =
extract(tuple._1)

private def fWithCtx[In, Out, Ctx, Partition](f: (In, Partition) => Future[Out])(tuple: (In, Ctx),
partition: Partition): Future[(Out, Ctx)] =
f(tuple._1, partition).map(_ -> tuple._2)(ExecutionContexts.parasitic)

def mapSourceOrdered[In, Out, Partition, Mat](source: Source[In, Mat], parallelism: Int)(
extractPartition: In => Partition)(
f: (In, Partition) => Future[Out]): Source[Out, Mat] =
source.via(new MapAsyncPartitioned[In, Out, Partition](orderedOutput = true, parallelism, extractPartition, f))

def mapSourceUnordered[In, Out, Partition, Mat](source: Source[In, Mat], parallelism: Int)(
extractPartition: In => Partition)(
f: (In, Partition) => Future[Out]): Source[Out, Mat] =
source.via(new MapAsyncPartitioned[In, Out, Partition](orderedOutput = false, parallelism, extractPartition, f))

def mapSourceWithContextOrdered[In, Ctx, T, Partition, Mat](flow: SourceWithContext[In, Ctx, Mat], parallelism: Int)(
extractPartition: In => Partition)(
f: (In, Partition) => Future[T]): SourceWithContext[T, Ctx, Mat] =
flow.via(
new MapAsyncPartitioned[(In, Ctx), (T, Ctx), Partition](
orderedOutput = true,
parallelism,
extractPartitionWithCtx(extractPartition),
fWithCtx[In, T, Ctx, Partition](f)))

def mapSourceWithContextUnordered[In, Ctx, T, Partition, Mat](flow: SourceWithContext[In, Ctx, Mat],
parallelism: Int)(extractPartition: In => Partition)(
f: (In, Partition) => Future[T]): SourceWithContext[T, Ctx, Mat] =
flow.via(
new MapAsyncPartitioned[(In, Ctx), (T, Ctx), Partition](
orderedOutput = false,
parallelism,
extractPartitionWithCtx(extractPartition),
fWithCtx[In, T, Ctx, Partition](f)))

def mapFlowOrdered[In, Out, T, Partition, Mat](flow: Flow[In, Out, Mat], parallelism: Int)(
extractPartition: Out => Partition)(
f: (Out, Partition) => Future[T]): Flow[In, T, Mat] =
flow.via(new MapAsyncPartitioned[Out, T, Partition](orderedOutput = true, parallelism, extractPartition,
f))

def mapFlowUnordered[In, Out, T, Partition, Mat](flow: Flow[In, Out, Mat], parallelism: Int)(
extractPartition: Out => Partition)(
f: (Out, Partition) => Future[T]): Flow[In, T, Mat] =
flow.via(new MapAsyncPartitioned[Out, T, Partition](orderedOutput = false, parallelism,
extractPartition, f))

def mapFlowWithContextOrdered[In, Out, CtxIn, CtxOut, T, Partition, Mat](
flow: FlowWithContext[In, CtxIn, Out, CtxOut, Mat], parallelism: Int)(
extractPartition: Out => Partition)(
f: (Out, Partition) => Future[T]): FlowWithContext[In, CtxIn, T, CtxOut, Mat] =
flow.via(
new MapAsyncPartitioned[(Out, CtxOut), (T, CtxOut), Partition](
orderedOutput = true,
parallelism,
extractPartitionWithCtx(extractPartition),
fWithCtx[Out, T, CtxOut, Partition](f)))

def mapFlowWithContextUnordered[In, Out, CtxIn, CtxOut, T, Partition, Mat](
flow: FlowWithContext[In, CtxIn, Out, CtxOut, Mat], parallelism: Int)(extractPartition: Out => Partition)(
f: (Out, Partition) => Future[T]): FlowWithContext[In, CtxIn, T, CtxOut, Mat] =
flow.via(
new MapAsyncPartitioned[(Out, CtxOut), (T, CtxOut), Partition](
orderedOutput = false,
parallelism,
extractPartitionWithCtx(extractPartition),
fWithCtx[Out, T, CtxOut, Partition](f)))

private[stream] val NotYetThere: Failure[Nothing] = Failure(new Exception with NoStackTrace)

private[stream] final class Holder[In, Out](
val in: In,
var out: Try[Out],
callback: AsyncCallback[Holder[In, Out]]) extends (Try[Out] => Unit) {

// To support both fail-fast when the supervision directive is Stop
// and not calling the decider multiple times (#23888) we need to cache the decider result and re-use that
private var cachedSupervisionDirective: Option[Supervision.Directive] = None
private final class Holder[In, Out](val in: In, var out: Try[Out], val cb: AsyncCallback[Holder[In, Out]]) extends (
Try[Out] => Unit) {
private var cachedSupervisionDirective: OptionVal[Supervision.Directive] = OptionVal.None

def supervisionDirectiveFor(decider: Supervision.Decider, ex: Throwable): Supervision.Directive = {
cachedSupervisionDirective match {
case Some(d) => d
case OptionVal.Some(d) => d
case _ =>
val d = decider(ex)
cachedSupervisionDirective = Some(d)
cachedSupervisionDirective = OptionVal.Some(d)
d
}
}
Expand All @@ -128,27 +55,32 @@ private[stream] object MapAsyncPartitioned {

override def apply(t: Try[Out]): Unit = {
setOut(t)
callback.invoke(this)
cb.invoke(this)
}

override def toString = s"Holder($in, $out)"
}
}

private[stream] class MapAsyncPartitioned[In, Out, Partition](
orderedOutput: Boolean,
/**
* Internal API
*/
@InternalApi
private[stream] final class MapAsyncPartitioned[In, Out, Partition](
parallelism: Int,
extractPartition: In => Partition,
orderedOutput: Boolean,
partitioner: In => Partition,
f: (In, Partition) => Future[Out]) extends GraphStage[FlowShape[In, Out]] {
require(parallelism >= 1, "parallelism must be at least 1")
require(partitioner != null, "partitioner function should not be null")
require(f != null, "f function should not be null.")
import MapAsyncPartitioned._

if (parallelism < 1) throw new IllegalArgumentException("parallelism must be at least 1")

private val in = Inlet[In]("MapAsyncPartitionOrdered.in")
private val out = Outlet[Out]("MapAsyncPartitionOrdered.out")
private val in = Inlet[In]("MapAsyncPartitioned.in")
private val out = Outlet[Out]("MapAsyncPartitioned.out")

override val shape: FlowShape[In, Out] = FlowShape(in, out)

override def initialAttributes: Attributes =
Attributes(Name("MapAsyncPartitionOrdered")) and SourceLocation.forLambda(f)

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
private val contextPropagation = pekko.stream.impl.ContextPropagation()
Expand Down Expand Up @@ -191,13 +123,12 @@ private[stream] class MapAsyncPartitioned[In, Out, Partition](
buffer = mutable.Queue()
}

override def onPull(): Unit =
pushNextIfPossible()
override def onPull(): Unit = pushNextIfPossible()

override def onPush(): Unit = {
try {
val element = grab(in)
val partition = extractPartition(element)
val partition = partitioner(element)

val wrappedInput = new Contextual(
contextPropagation.currentContext(),
Expand All @@ -217,8 +148,7 @@ private[stream] class MapAsyncPartitioned[In, Out, Partition](
pullIfNeeded()
}

override def onUpstreamFinish(): Unit =
if (idle()) completeStage()
override def onUpstreamFinish(): Unit = if (idle()) completeStage()

private def processElement(partition: Partition, wrappedInput: Contextual[Holder[In, Out]]): Unit = {
import wrappedInput.{ element => holder }
Expand Down Expand Up @@ -289,7 +219,7 @@ private[stream] class MapAsyncPartitioned[In, Out, Partition](
buffer = buffer.filter { case (partition, wrappedInput) =>
import wrappedInput.{ element => holder }

if ((holder.out eq MapAsyncPartitioned.NotYetThere) || !isAvailable(out)) {
if ((holder.out eq NotYetThere) || !isAvailable(out)) {
true
} else {
partitionsInProgress -= partition
Expand Down Expand Up @@ -321,12 +251,14 @@ private[stream] class MapAsyncPartitioned[In, Out, Partition](
}

private def drainQueue(): Unit = {
buffer.foreach {
case (partition, wrappedInput) =>
if (canStartNextElement(partition)) {
wrappedInput.resume()
processElement(partition, wrappedInput)
}
if (buffer.nonEmpty) {
buffer.foreach {
case (partition, wrappedInput) =>
if (canStartNextElement(partition)) {
wrappedInput.resume()
processElement(partition, wrappedInput)
}
}
}
}

Expand All @@ -335,11 +267,10 @@ private[stream] class MapAsyncPartitioned[In, Out, Partition](
else if (buffer.size < parallelism && !hasBeenPulled(in)) tryPull(in)
// else already pulled and waiting for next element

private def idle(): Boolean =
buffer.isEmpty
private def idle(): Boolean = buffer.isEmpty

private def canStartNextElement(partition: Partition): Boolean =
!partitionsInProgress(partition) && partitionsInProgress.size < parallelism
!partitionsInProgress.contains(partition) && partitionsInProgress.size < parallelism

setHandlers(in, out, this)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import pekko.stream.Attributes._
val mapError = name("mapError")
val mapAsync = name("mapAsync")
val mapAsyncUnordered = name("mapAsyncUnordered")
val mapAsyncPartition = name("mapAsyncPartition")
val mapAsyncPartitionUnordered = name("mapAsyncPartitionUnordered")
val ask = name("ask")
val grouped = name("grouped")
val groupedWithin = name("groupedWithin")
Expand Down
Loading

0 comments on commit 1b1f572

Please sign in to comment.