Skip to content

Commit

Permalink
Merge pull request #698 from jbouffard/improvement/read-multiplex
Browse files Browse the repository at this point in the history
Read Multiplex Band Support
  • Loading branch information
jbouffard authored Feb 21, 2019
2 parents 36ff624 + cac5647 commit f9ed926
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 37 deletions.
55 changes: 55 additions & 0 deletions docs/guides/core-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,58 @@ intersects it. The higher the ``zindex``, the more priority it has.
# Will always be selected
feature3 = gps.Feature(geometry=geom3, properties=cell_value3)
SourceInfo
-----------

:class:`~geopyspark.geotrellis.SourceInfo` represents a data source and the
information on how that data should be read in.

Reading from Singleband Data Sources
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

For example, suppose that one wants to calculate NDVI for an area, and the bands that
represent the red and near infrared (NIR) values are in two seperate files: ``red_band.tiff``
and ``nir_band.tiff``, respectively. To read in these files as a single ``Tile`` (ie. a
``Tile`` with two bands), we can specify our ``SourceInfo``\s as:

.. code:: python3
source_1 = gps.SourceInfo("/tmp/red_band.tiff", {0: 0})
source_2 = gps.SourceInfo("/tmp/nir_band.tiff" {0: 1})
``source_1`` states that ``Tile``\s created from ``red_band.tiff`` will use the
data from band ``0`` of the source for its band ``0``. Whereas ``Tile``\s created
from ``source_2`` will have the band ``0`` of the source be band ``1`` of the
``Tile``. Thus, when ``red_band.tiff`` and ``nir_band.tiff`` intersect the
same area, the resulting ``Tile``\(s) will have two bands: ``0`` from ``red_band.tiff``
and ``1`` from ``nir_band.tiff``.

Reading from Multiband Data Sources
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

It is also possible to read in individual bands from a multiband data source.
Continuing the example above, suppose one wants to calculate NDVI using
Landsat 8 data where each file contains all eleven bands. In this case,
only the red band (band ``3``) and the NIR band (band ``4``) are of interest.
We can read in just those bands by doing:

.. code:: python3
source = gps.SourceInfo("/tmp/all-landsat-bands.tiff", {3: 0, 4: 1})
The above source will have just bands ``3`` and ``4`` read in, and the resulting
``Tile``\s will just have two bands: the first from ``3`` and the second from
``4``, respectively.

A Note on Missing Data
~~~~~~~~~~~~~~~~~~~~~~~

In the event that data of a specified band does not exist in a region, the
resulting ``Tile``\(s) of that area will have that band be composed of ``NoData``
values.

So if the ``nir_band.tiff`` covers a smaller area than the ``red_band.tiff``,
the ``Tile``\(s) of those uncovered regions will have their band ``1`` be
just ``NoData`` values.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package geopyspark.geotrellis.vlm

import geopyspark.geotrellis.{PartitionStrategy, ProjectedRasterLayer, SpatialTiledRasterLayer}
import geopyspark.geotrellis.{LayoutType => GPSLayoutType, LocalLayout => GPSLocalLayout, GlobalLayout => GPSGlobalLayout, SpatialTiledRasterLayer}
import geopyspark.geotrellis.{PartitionStrategy, ProjectedRasterLayer, SpatialTiledRasterLayer, SpatialPartitioner}
import geopyspark.geotrellis.{LayoutType => GPSLayoutType, LocalLayout => GPSLocalLayout, GlobalLayout => GPSGlobalLayout}

import geopyspark.geotrellis.Constants.{GEOTRELLIS, GDAL}

Expand All @@ -17,7 +17,7 @@ import geotrellis.proj4._
import geotrellis.vector._
import geotrellis.util._

import org.apache.spark.SparkContext
import org.apache.spark.{SparkContext, Partitioner}
import org.apache.spark.rdd.RDD

import scala.collection.JavaConverters._
Expand All @@ -29,17 +29,27 @@ object RasterSource {
layerType: String,
paths: java.util.ArrayList[String],
targetCRS: String,
numPartitions: Integer,
resampleMethod: ResampleMethod,
readMethod: String
): ProjectedRasterLayer =
): ProjectedRasterLayer = {
val scalaPaths: Seq[String] = paths.asScala.toSeq

val partitions =
numPartitions match {
case i: Integer => Some(i.toInt)
case null => None
}

read(
sc,
layerType,
sc.parallelize(paths.asScala),
sc.parallelize(scalaPaths, partitions.getOrElse(scalaPaths.size)),
targetCRS,
resampleMethod,
readMethod
)
}

def read(
sc: SparkContext,
Expand All @@ -51,7 +61,7 @@ object RasterSource {
): ProjectedRasterLayer = {
val rasterSourceRDD: RDD[RasterSource] =
(readMethod match {
case GEOTRELLIS => rdd.map { GeoTiffRasterSource(_): RasterSource }
case GEOTRELLIS => rdd.map { new GeoTiffRasterSource(_): RasterSource }
case GDAL => rdd.map { GDALRasterSource(_): RasterSource }
}).cache()

Expand Down Expand Up @@ -82,18 +92,28 @@ object RasterSource {
paths: java.util.ArrayList[String],
layoutType: GPSLayoutType,
targetCRS: String,
numPartitions: Integer,
resampleMethod: ResampleMethod,
readMethod: String
): SpatialTiledRasterLayer =
): SpatialTiledRasterLayer = {
val scalaPaths: Seq[String] = paths.asScala.toSeq

val partitions =
numPartitions match {
case i: Integer => Some(i.toInt)
case null => None
}

readToLayout(
sc,
layerType,
sc.parallelize(paths.asScala),
sc.parallelize(scalaPaths, partitions.getOrElse(scalaPaths.size)),
layoutType,
targetCRS,
resampleMethod,
readMethod
)
}

def readToLayout(
sc: SparkContext,
Expand Down Expand Up @@ -147,7 +167,91 @@ object RasterSource {
val contextRDD: MultibandTileLayerRDD[SpatialKey] =
ContextRDD(tiledRDD, tileLayerMetadata)

SpatialTiledRasterLayer(zoom.toInt, contextRDD)
SpatialTiledRasterLayer(zoom, contextRDD)
}

def readOrderedToLayout(
sc: SparkContext,
paths: java.util.ArrayList[SourceInfo],
layoutType: GPSLayoutType,
targetCRS: String,
numPartitions: Integer,
resampleMethod: ResampleMethod,
partitionStrategy: PartitionStrategy,
readMethod: String
): SpatialTiledRasterLayer = {
val scalaSources: Seq[SourceInfo] = paths.asScala.toSeq

val partitions =
numPartitions match {
case i: Integer => Some(i.toInt)
case null => None
}

val partitioner: Option[Partitioner] =
partitionStrategy match {
case ps: PartitionStrategy => ps.producePartitioner(partitions.getOrElse(scalaSources.size))
case null => None
}

val crs: Option[CRS] =
targetCRS match {
case crs: String => Some(CRS.fromString(crs))
case null => None
}

val transformSource: String => RasterSource =
readMethod match {
case GEOTRELLIS =>
crs match {
case Some(projection) =>
(path: String) => GeoTiffRasterSource(path).reproject(projection, resampleMethod)
case None =>
(path: String) => GeoTiffRasterSource(path)
}
case GDAL =>
crs match {
case Some(projection) =>
(path: String) => GDALRasterSource(path).reproject(projection, resampleMethod)
case None =>
(path: String) => GDALRasterSource(path)
}
}

val sourceInfoRDD: RDD[SourceInfo] =
sc.parallelize(scalaSources, partitions.getOrElse(scalaSources.size))

val readingSourcesRDD: RDD[ReadingSource] =
sourceInfoRDD.map { source =>
val rasterSource = transformSource(source.source)

ReadingSource(rasterSource, source.sourceToTargetBand)
}

val sourcesRDD: RDD[RasterSource] = readingSourcesRDD.map { _.source }

val rasterSummary: RasterSummary = geotrellis.contrib.vlm.RasterSummary.fromRDD(sourcesRDD)

val LayoutLevel(zoom, layout) =
layoutType match {
case global: GPSGlobalLayout =>
val scheme = ZoomedLayoutScheme(rasterSummary.crs, global.tileSize)
scheme.levelForZoom(global.zoom)
case local: GPSLocalLayout =>
val scheme = FloatingLayoutScheme(local.tileCols, local.tileRows)
rasterSummary.levelFor(scheme)
}

val resampledSourcesRDD: RDD[ReadingSource] =
readingSourcesRDD.map { source =>
val resampledSource: RasterSource = source.source.resampleToGrid(layout, resampleMethod)

source.copy(source = resampledSource)
}

val result = RasterSourceRDD.read(resampledSourcesRDD, layout, partitioner)(sc)

SpatialTiledRasterLayer(zoom, result)
}

implicit def gps2VLM(layoutType: GPSLayoutType): LayoutType =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package geopyspark.geotrellis.vlm

import scala.collection.JavaConverters._


case class SourceInfo(
source: String,
sourceToTargetBand: Map[Int, Int]
) extends Serializable


object SourceInfo {
def apply(source: String, sourceBand: Int, targetBand: Int): SourceInfo =
SourceInfo(source, Map(sourceBand -> targetBand))

def apply(source: String, targetBand: Int): SourceInfo =
SourceInfo(source, 0, targetBand)

def apply(source: String, javaMap: java.util.HashMap[Int, Int]): SourceInfo =
SourceInfo(source, javaMap.asScala.toMap)
}
3 changes: 2 additions & 1 deletion geopyspark-backend/project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ resolvers ++= Seq(
)

addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.0")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.5")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9")

addSbtPlugin("org.foundweekends" % "sbt-bintray" % "0.5.2")
addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.1.0-M10")
3 changes: 3 additions & 0 deletions geopyspark-backend/project/protoc.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.15")

libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.7.0"
29 changes: 28 additions & 1 deletion geopyspark/geotrellis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,33 @@ def __new__(cls, time_unit, num_partitions=None, bits=8, time_resolution=None):
return super(cls, SpaceTimePartitionStrategy).__new__(cls, time_unit, num_partitions, bits, time_resolution)


class SourceInfo(namedtuple("SourceInfo", "source source_to_target_band")):
"""Represents a data source and how its bands should be formatted when being read
in.
When two or more sources of data cover the same area, a single ``Tile`` will be created
that contains the bands specified by ``source_to_target_band``.
Args:
source (str): The path to the data source to be read.
source_to_target_band ({int: int}): A ``{int: int}`` that maps each band from the
source to the target band of the output.
For example, ``{0: 2}`` specifies that band ``0`` of the source be band ``2``
for the ``Tile``\s that were created from that source.
Attributes:
source (str): The path to the data source to be read.
source_to_target_band ({int: int}): A ``{int: int}`` that maps each band from the
source to the target band of the output.
For example, ``{0: 2}`` specifies that band ``0`` of the source be band ``2``
for the ``Tile``\s that were created from that source.
"""

__slots__ = []


class Feature(namedtuple("Feature", "geometry properties")):
"""Represents a geometry that is derived from an OSM Element with that Element's associated metadata.
Expand Down Expand Up @@ -834,7 +861,7 @@ def __str__(self):
__all__ = ["Tile", "Extent", "ProjectedExtent", "TemporalProjectedExtent", "SpatialKey", "SpaceTimeKey",
"Metadata", "TileLayout", "GlobalLayout", "LocalLayout", "LayoutDefinition", "Bounds", "RasterizerOptions",
"zfactor_lat_lng_calculator", "zfactor_calculator", "HashPartitionStrategy", "SpatialPartitionStrategy",
"SpaceTimePartitionStrategy", "Feature", "CellValue"]
"SpaceTimePartitionStrategy", "Feature", "CellValue", "SourceInfo"]

from . import catalog
from . import color
Expand Down
14 changes: 13 additions & 1 deletion geopyspark/geotrellis/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
LayoutDefinition,
HashPartitionStrategy,
SpatialPartitionStrategy,
SpaceTimePartitionStrategy)
SpaceTimePartitionStrategy,
SourceInfo)


from geopyspark.geotrellis.constants import ResampleMethod
Expand Down Expand Up @@ -148,6 +149,16 @@ def convert(self, obj, gateway_client):
return ScalaTemporalStrategy.apply(obj.num_partitions, obj.bits, scala_time_unit, scala_time_resolution)


class SourceInfoConverter(object):
def can_convert(self, object):
return isinstance(object, SourceInfo)

def convert(self, obj, gateway_client):
ScalaSourceInfo = JavaClass("geopyspark.geotrellis.vlm.SourceInfo", gateway_client)

return ScalaSourceInfo.apply(obj.source, obj.source_to_target_band)


register_input_converter(CellTypeConverter(), prepend=True)
register_input_converter(RasterizerOptionsConverter(), prepend=True)
register_input_converter(LayoutTypeConverter(), prepend=True)
Expand All @@ -156,3 +167,4 @@ def convert(self, obj, gateway_client):
register_input_converter(HashPartitionStrategyConverter(), prepend=True)
register_input_converter(SpatialPartitionStrategyConverter(), prepend=True)
register_input_converter(SpaceTimePartitionStrategyConverter(), prepend=True)
register_input_converter(SourceInfoConverter(), prepend=True)
Loading

0 comments on commit f9ed926

Please sign in to comment.