Skip to content

Commit

Permalink
Add python signatures.
Browse files Browse the repository at this point in the history
Fix index id types.
  • Loading branch information
milos.colic committed Oct 12, 2023
1 parent 30f8b31 commit 2a7c34c
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 22 deletions.
74 changes: 65 additions & 9 deletions python/mosaic/api/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

__all__ = [
"rst_bandmetadata",
"rst_boundingbox",
"rst_clip",
"rst_georeference",
"rst_getsubdataset",
"rst_height",
"rst_isempty",
"rst_memsize",
Expand All @@ -31,7 +33,6 @@
"rst_rastertoworldcoordx",
"rst_rastertoworldcoordy",
"rst_retile",
"rst_gridtiles",
"rst_rotation",
"rst_scalex",
"rst_scaley",
Expand All @@ -42,6 +43,8 @@
"rst_summary",
"rst_subdivide",
"rst_tessellate",
"rst_tile",
"rst_tryopen",
"rst_upperleftx",
"rst_upperlefty",
"rst_width",
Expand Down Expand Up @@ -75,6 +78,27 @@ def rst_bandmetadata(raster: ColumnOrName, band: ColumnOrName) -> Column:
)


def rst_boundingbox(raster: ColumnOrName) -> Column:
"""
Returns the bounding box of the raster as a WKT polygon.
Parameters
----------
raster : Column (StringType)
Path to the raster file.
Returns
-------
Column (StringType)
A WKT polygon representing the bounding box of the raster.
"""
return config.mosaic_context.invoke_function(
"rst_boundingbox",
pyspark_to_java_column(raster)
)


def rst_clip(raster: ColumnOrName, geometry: ColumnOrName) -> Column:
"""
Clips the raster to the given geometry.
Expand Down Expand Up @@ -129,6 +153,31 @@ def rst_georeference(raster: ColumnOrName) -> Column:
)


def rst_getsubdataset(raster: ColumnOrName, subdataset: ColumnOrName) -> Column:
"""
Returns the subdataset of the raster.
The subdataset is the path to the subdataset of the raster.
Parameters
----------
raster : Column (StringType)
Path to the raster file.
subdataset : Column (IntegerType)
The index of the subdataset to get.
Returns
-------
Column (StringType)
The path to the subdataset.
"""
return config.mosaic_context.invoke_function(
"rst_getsubdataset",
pyspark_to_java_column(raster),
pyspark_to_java_column(subdataset)
)


def rst_height(raster: ColumnOrName) -> Column:
"""
Parameters
Expand Down Expand Up @@ -561,14 +610,6 @@ def rst_retile(raster: ColumnOrName, tileWidth: ColumnOrName, tileHeight: Column
)


def rst_gridtiles(raster: ColumnOrName, resolution: ColumnOrName) -> Column:
return config.mosaic_context.invoke_function(
"rst_gridtiles",
pyspark_to_java_column(raster),
pyspark_to_java_column(resolution)
)


def rst_rotation(raster: ColumnOrName) -> Column:
"""
Computes the rotation of the raster in degrees.
Expand Down Expand Up @@ -769,6 +810,21 @@ def rst_tessellate(raster: ColumnOrName, resolution: ColumnOrName) -> Column:
)


def rst_tile(raster: ColumnOrName, sizeInMB: ColumnOrName) -> Column:
"""
Tiles the raster into tiles of the given size.
:param raster:
:param sizeInMB:
:return:
"""

return config.mosaic_context.invoke_function(
"rst_tile",
pyspark_to_java_column(raster),
pyspark_to_java_column(sizeInMB)
)


def rst_tryopen(raster: ColumnOrName) -> Column:
"""
Tries to open the raster and returns a flag indicating if the raster can be opened.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ object RasterTessellate {
val isValidRaster = cellRaster.getBandStats.values.map(_("mean")).sum > 0 && !cellRaster.isEmpty
(
isValidRaster,
MosaicRasterTile(cell.index, cellRaster, raster.getParentPath, raster.getDriversShortName)
MosaicRasterTile(cell.index, cellRaster, raster.getParentPath, raster.getDriversShortName)
)
}
})

val (result, invalid) = chips.partition(_._1)
invalid.foreach(_._2.raster.destroy())
invalid.flatMap(t => Option(t._2.raster)).foreach(_.destroy())
tmpRaster.destroy()

result.map(_._2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ case class MosaicRasterTile(
def isEmpty: Boolean = Option(raster).forall(_.isEmpty)

def formatCellId(indexSystem: IndexSystem): MosaicRasterTile = {
if (Option(index).isEmpty) return this
(indexSystem.getCellIdDataType, index) match {
case (_: LongType, Left(_)) => this
case (_: StringType, Right(_)) => this
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.databricks.labs.mosaic.datasource.gdal

import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory}
import com.databricks.labs.mosaic.core.raster.MosaicRaster
import com.databricks.labs.mosaic.core.raster.api.RasterAPI
import com.databricks.labs.mosaic.core.raster.gdal_raster.{MosaicRasterGDAL, RasterCleaner}
import com.databricks.labs.mosaic.core.raster.operator.retile.BalancedSubdivision
Expand Down Expand Up @@ -46,15 +47,10 @@ object ReTileOnRead extends ReadStrategy {
rasterAPI: RasterAPI
): Iterator[InternalRow] = {
val inPath = status.getPath.toString
val localCopy = PathUtils.copyToTmp(inPath)
val driverShortName = MosaicRasterGDAL.indentifyDriver(localCopy)
val raster = MosaicRasterGDAL.readRaster(localCopy, inPath, driverShortName)
val uuid = getUUID(status)

val sizeInMB = options.getOrElse("sizeInMB", "16").toInt

val inTile = MosaicRasterTile(null, raster, inPath, driverShortName)
val tiles = BalancedSubdivision.splitRaster(inTile, sizeInMB)
val (raster, tiles) = localSubdivide(inPath, sizeInMB)

val rows = tiles.map(tile => {
val trimmedSchema = StructType(requiredSchema.filter(field => field.name != TILE))
Expand All @@ -72,7 +68,7 @@ object ReTileOnRead extends ReadStrategy {
case other => throw new RuntimeException(s"Unsupported field name: $other")
}
// Writing to bytes is destructive so we delay reading content and content length until the last possible moment
val row = Utils.createRow(fields ++ Seq(tile.serialize(rasterAPI)))
val row = Utils.createRow(fields ++ Seq(tile.formatCellId(indexSystem).serialize(rasterAPI)))
RasterCleaner.dispose(tile)
row
})
Expand All @@ -81,4 +77,13 @@ object ReTileOnRead extends ReadStrategy {
rows.iterator
}

def localSubdivide(inPath: String, sizeInMB: Int): (MosaicRaster, Seq[MosaicRasterTile]) = {
val localCopy = PathUtils.copyToTmp(inPath)
val driverShortName = MosaicRasterGDAL.indentifyDriver(localCopy)
val raster = MosaicRasterGDAL.readRaster(localCopy, inPath, driverShortName)
val inTile = MosaicRasterTile(null, raster, inPath, driverShortName)
val tiles = BalancedSubdivision.splitRaster(inTile, sizeInMB)
(raster, tiles)
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.index.IndexSystemFactory
import com.databricks.labs.mosaic.core.raster.api.RasterAPI
import com.databricks.labs.mosaic.core.raster.gdal_raster.RasterCleaner
import com.databricks.labs.mosaic.core.raster.operator.merge.MergeRasters
Expand Down Expand Up @@ -78,6 +79,7 @@ case class RST_MergeAgg(
val driver = tiles.head.driver

val result = MosaicRasterTile(idx, merged, parentPath, driver)
.formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem))
.serialize(rasterAPI, BinaryType, expressionConfig.getRasterCheckpoint)

tiles.foreach(RasterCleaner.dispose)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory}
import com.databricks.labs.mosaic.core.raster.api.RasterAPI
import com.databricks.labs.mosaic.core.raster.gdal_raster.RasterCleaner
import com.databricks.labs.mosaic.core.types.RasterTileType
import com.databricks.labs.mosaic.datasource.gdal.ReTileOnRead
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, NullIntolerant}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String

/**
* Returns a set of new rasters with the specified tile size (tileWidth x
* tileHeight).
*/
case class RST_Tile(
rasterPathExpr: Expression,
sizeInMB: Expression,
expressionConfig: MosaicExpressionConfig
) extends CollectionGenerator
with Serializable
with NullIntolerant
with CodegenFallback {

override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType)

val uuid: String = java.util.UUID.randomUUID().toString.replace("-", "_")

protected val rasterAPI: RasterAPI = RasterAPI(expressionConfig.getRasterAPI)
rasterAPI.enable()
protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI)

protected val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)

protected val cellIdDataType: DataType = indexSystem.getCellIdDataType

override def position: Boolean = false

override def inline: Boolean = false

override def children: Seq[Expression] = Seq(rasterPathExpr, sizeInMB)

override def elementSchema: StructType = StructType(Array(StructField("tile", dataType)))

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val path = rasterPathExpr.eval(input).asInstanceOf[UTF8String].toString
val targetSize = sizeInMB.eval(input).asInstanceOf[Int]
val (raster, tiles) = ReTileOnRead.localSubdivide(path, targetSize)
val rows = tiles.map(_.formatCellId(indexSystem).serialize(rasterAPI))
tiles.foreach(RasterCleaner.dispose)
RasterCleaner.dispose(raster)
rows.map(row => InternalRow.fromSeq(Seq(row)))
}

override def makeCopy(newArgs: Array[AnyRef]): Expression =
GenericExpressionFactory.makeCopyImpl[RST_Tile](this, newArgs, children.length, expressionConfig)

override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = makeCopy(newChildren.toArray)

}

/** Expression info required for the expression registration for spark SQL. */
object RST_Tile extends WithExpressionInfo {

override def name: String = "rst_tile"

override def usage: String =
"""
|_FUNC_(expr1) - Returns a set of new rasters with the specified tile size (tileWidth x tileHeight).
|""".stripMargin

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(a, b);
| /path/to/raster_tile_1.tif
| /path/to/raster_tile_2.tif
| /path/to/raster_tile_3.tif
| ...
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[RST_ReTile](3, expressionConfig)
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.databricks.labs.mosaic.expressions.raster.base

import com.databricks.labs.mosaic.core.index.IndexSystemFactory
import com.databricks.labs.mosaic.core.raster.api.RasterAPI
import com.databricks.labs.mosaic.core.raster.gdal_raster.RasterCleaner
import com.databricks.labs.mosaic.core.types.RasterTileType
Expand All @@ -20,7 +21,9 @@ trait RasterExpressionSerialization {
val tile = data.asInstanceOf[MosaicRasterTile]
val checkpoint = expressionConfig.getRasterCheckpoint
val rasterType = outputDataType.asInstanceOf[RasterTileType].rasterType
val result = tile.serialize(rasterAPI, rasterType, checkpoint)
val result = tile
.formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem))
.serialize(rasterAPI, rasterType, checkpoint)
RasterCleaner.dispose(tile)
result
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ abstract class RasterGeneratorExpression[T <: Expression: ClassTag](
val generatedRasters = rasterGenerator(tile)

// Writing rasters disposes of the written raster
val rows = generatedRasters.map(_.serialize(rasterAPI))
val rows = generatedRasters.map(_.formatCellId(indexSystem).serialize(rasterAPI))
generatedRasters.foreach(RasterCleaner.dispose)
RasterCleaner.dispose(tile)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag](
)
val inResolution: Int = indexSystem.getResolution(resolutionExpr.eval(input))
val generatedChips = rasterGenerator(tile, inResolution)
.map(chip => chip.formatCellId(indexSystem))

val rows = generatedChips.map(chip => InternalRow.fromSeq(Seq(chip.serialize(rasterAPI))))
val rows = generatedChips
.map(chip => InternalRow.fromSeq(Seq(chip.formatCellId(indexSystem).serialize(rasterAPI))))

RasterCleaner.dispose(tile)
generatedChips.foreach(chip => RasterCleaner.dispose(chip.raster))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP

/** RasterAPI dependent functions */
mosaicRegistry.registerExpression[RST_BandMetaData](expressionConfig)
mosaicRegistry.registerExpression[RST_BoundingBox](expressionConfig)
mosaicRegistry.registerExpression[RST_Clip](expressionConfig)
mosaicRegistry.registerExpression[RST_GeoReference](expressionConfig)
mosaicRegistry.registerExpression[RST_GetSubdataset](expressionConfig)
Expand Down Expand Up @@ -287,6 +288,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP
mosaicRegistry.registerExpression[RST_Subdatasets](expressionConfig)
mosaicRegistry.registerExpression[RST_Summary](expressionConfig)
mosaicRegistry.registerExpression[RST_Tessellate](expressionConfig)
mosaicRegistry.registerExpression[RST_Tile](expressionConfig)
mosaicRegistry.registerExpression[RST_TryOpen](expressionConfig)
mosaicRegistry.registerExpression[RST_Subdivide](expressionConfig)
mosaicRegistry.registerExpression[RST_UpperLeftX](expressionConfig)
Expand Down Expand Up @@ -605,11 +607,14 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP
ColumnAdapter(RST_BandMetaData(raster.expr, lit(band).expr, expressionConfig))
def rst_bandmetadata(raster: String, band: Int): Column =
ColumnAdapter(RST_BandMetaData(lit(raster).expr, lit(band).expr, expressionConfig))
def rst_boundbox(raster: Column): Column = ColumnAdapter(RST_BoundingBox(raster.expr, expressionConfig))
def rst_clip(raster: Column, geometry: Column): Column = ColumnAdapter(RST_Clip(raster.expr, geometry.expr, expressionConfig))
def rst_georeference(raster: Column): Column = ColumnAdapter(RST_GeoReference(raster.expr, expressionConfig))
def rst_georeference(raster: String): Column = ColumnAdapter(RST_GeoReference(lit(raster).expr, expressionConfig))
def rst_getsubdataset(raster: Column, subdatasetName: Column): Column =
ColumnAdapter(RST_GetSubdataset(raster.expr, subdatasetName.expr, expressionConfig))
def rst_getsubdataset(raster: Column, subdatasetName: String): Column =
ColumnAdapter(RST_GetSubdataset(raster.expr, lit(subdatasetName).expr, expressionConfig))
def rst_height(raster: Column): Column = ColumnAdapter(RST_Height(raster.expr, expressionConfig))
def rst_height(raster: String): Column = ColumnAdapter(RST_Height(lit(raster).expr, expressionConfig))
def rst_isempty(raster: Column): Column = ColumnAdapter(RST_IsEmpty(raster.expr, expressionConfig))
Expand Down Expand Up @@ -696,6 +701,12 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP
ColumnAdapter(RST_Tessellate(col(raster).expr, resolution.expr, expressionConfig))
def rst_tessellate(raster: Column, resolution: Int): Column =
ColumnAdapter(RST_Tessellate(raster.expr, lit(resolution).expr, expressionConfig))
def rst_tile(raster: Column, sizeInMB: Column): Column =
ColumnAdapter(RST_Tile(raster.expr, sizeInMB.expr, expressionConfig))
def rst_tile(raster: Column, sizeInMB: Int): Column =
ColumnAdapter(RST_Tile(raster.expr, lit(sizeInMB).expr, expressionConfig))
def rst_tile(raster: String): Column =
ColumnAdapter(RST_Tile(lit(raster).expr, lit(256).expr, expressionConfig))
def rst_tryopen(raster: Column): Column = ColumnAdapter(RST_TryOpen(raster.expr, expressionConfig))
def rst_subdivide(raster: Column, sizeInMB: Column): Column =
ColumnAdapter(RST_Subdivide(raster.expr, sizeInMB.expr, expressionConfig))
Expand Down Expand Up @@ -982,7 +993,7 @@ object MosaicContext extends Logging {
if (!isML && !isPhoton) {
// Print out the warnings both to the log and to the console
logWarning("DEPRECATION WARNING: Mosaic is not supported on the selected Databricks Runtime")
logWarning("DEPRECATION WARNING: Mosaic will stop working on this cluster after v0.3.x.")
logWarning("DEPRECATION WARNING: Mosaic will stop working on this cluster after v0.3.x.")
logWarning("Please use a Databricks Photon-enabled Runtime (for performance benefits) or Runtime ML (for spatial AI benefits).")
println("DEPRECATION WARNING: Mosaic is not supported on the selected Databricks Runtime")
println("DEPRECATION WARNING: Mosaic will stop working on this cluster after v0.3.x.")
Expand Down

0 comments on commit 2a7c34c

Please sign in to comment.