Skip to content

Commit

Permalink
Revise all RST_ expressions.
Browse files Browse the repository at this point in the history
Clean up raster disposing.
Clean up duplicated expressions.
  • Loading branch information
milos.colic committed Oct 4, 2023
1 parent ddab8cc commit 43e483f
Show file tree
Hide file tree
Showing 52 changed files with 292 additions and 542 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.databricks.labs.mosaic.core.geometry.linestring.MosaicLineString
import com.databricks.labs.mosaic.core.geometry.point.MosaicPoint
import org.gdal.ogr.ogr
import org.gdal.osr._
import org.gdal.osr.SpatialReference
import org.gdal.osr.osrConstants._
import org.locationtech.proj4j._

import java.util.Locale
Expand Down Expand Up @@ -65,10 +66,10 @@ trait MosaicGeometry extends GeometryWriter with Serializable {
def extent: (Double, Double, Double, Double) = {
val env = envelope
(
env.minMaxCoord("X", "MIN"),
env.minMaxCoord("Y", "MIN"),
env.minMaxCoord("X", "MAX"),
env.minMaxCoord("Y", "MAX")
env.minMaxCoord("X", "MIN"),
env.minMaxCoord("Y", "MIN"),
env.minMaxCoord("X", "MAX"),
env.minMaxCoord("Y", "MAX")
)
}

Expand Down Expand Up @@ -149,6 +150,21 @@ trait MosaicGeometry extends GeometryWriter with Serializable {

def setSpatialReference(srid: Int): Unit

def getSpatialReferenceOSR: SpatialReference = {
val srID = getSpatialReference
if (srID == 0) {
val wsg84 = new SpatialReference()
wsg84.ImportFromEPSG(4326)
wsg84.SetAxisMappingStrategy(OAMS_TRADITIONAL_GIS_ORDER)
wsg84
} else {
val geomCRS = new SpatialReference()
geomCRS.ImportFromEPSG(srID)
geomCRS.SetAxisMappingStrategy(OAMS_TRADITIONAL_GIS_ORDER)
geomCRS
}
}

def hasValidCoords(crsBoundsProvider: CRSBoundsProvider, crsCode: String, which: String): Boolean = {
val crsCodeIn = crsCode.split(":")
val crsBounds = which.toLowerCase(Locale.ROOT) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ object MosaicGeometryCollectionESRI extends GeometryReader {
// POINT by convention, MULTIPOINT are always flattened to POINT in the internal representation
val coordinates = holesRings.head.head.coords
MosaicPointESRI(
new OGCPoint(new Point(coordinates(0), coordinates(1)), spatialReference)
new OGCPoint(new Point(coordinates.head, coordinates(1)), spatialReference)
)
} else {
MosaicGeometryESRI.fromWKT("POINT EMPTY")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ object BalancedSubdivision {

def splitRaster(
mosaicRaster: MosaicRaster,
sizeInMb: Int,
geometryAPI: GeometryAPI,
rasterAPI: RasterAPI
sizeInMb: Int
): immutable.Seq[MosaicRaster] = {
val numSplits = getNumSplits(mosaicRaster, sizeInMb)
val (x, y) = mosaicRaster.getDimensions
val (tileX, tileY) = getTileSize(x, y, numSplits)
ReTile.reTile(mosaicRaster, tileX, tileY, geometryAPI, rasterAPI)
ReTile.reTile(mosaicRaster, tileX, tileY)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.databricks.labs.mosaic.core.raster.operator.retile

import com.databricks.labs.mosaic.core.raster.MosaicRaster
import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate
import com.databricks.labs.mosaic.utils.PathUtils

import scala.collection.immutable

object OverlappingTiles {

def reTile(
raster: MosaicRaster,
tileWidth: Int,
tileHeight: Int,
overlapPercentage: Int
): immutable.Seq[MosaicRaster] = {
val (xSize, ySize) = raster.getDimensions

val overlapWidth = Math.ceil(tileWidth * overlapPercentage / 100.0).toInt
val overlapHeight = Math.ceil(tileHeight * overlapPercentage / 100.0).toInt

val tiles = for (i <- 0 until xSize by (tileWidth - overlapWidth)) yield {
for (j <- 0 until ySize by (tileHeight - overlapHeight)) yield {
val xOff = if (i == 0) i else i - 1
val yOff = if (j == 0) j else j - 1
val width = Math.min(tileWidth, xSize - i) + 1
val height = Math.min(tileHeight, ySize - j) + 1

val uuid = java.util.UUID.randomUUID.toString
val rasterPath = PathUtils.createTmpFilePath(uuid, "tif")

val result = GDALTranslate.executeTranslate(
rasterPath,
isTemp = true,
raster,
command = s"gdal_translate -srcwin $xOff $yOff $width $height"
)

result.flushCache()
}
}

tiles.flatten


}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,28 @@ object RasterTessellate {
val bbox = raster.bbox(geometryAPI, indexSR)
val cells = Mosaic.mosaicFill(bbox, resolution, keepCoreGeom = false, indexSystem, geometryAPI)
val tmpRaster = RasterProject.project(raster, indexSR)
val result = cells

val chips = cells
.map(cell => {
val cellID = cell.cellIdAsLong(indexSystem)
val cellRaster = tmpRaster.getRasterForCell(cellID, indexSystem, geometryAPI)
cellRaster.getRaster.FlushCache()
(
cellRaster.getBandStats.values.map(_("mean")).sum > 0 && !cellRaster.isEmpty,
MosaicRasterChip(cell.index, cellRaster)
)
val isValidCell = indexSystem.isValid(cellID)
if (!isValidCell) {
(false, MosaicRasterChip(cell.index, null))
} else {
val cellRaster = tmpRaster.getRasterForCell(cellID, indexSystem, geometryAPI)
val isValidRaster = cellRaster.getBandStats.values.map(_("mean")).sum > 0 && !cellRaster.isEmpty
(
isValidRaster,
MosaicRasterChip(cell.index, cellRaster)
)
}
})
.filter(_._1)
.map(_._2)

val (result, invalid) = chips.partition(_._1)
invalid.foreach(_._2.raster.destroy())
tmpRaster.destroy()
result

result.map(_._2)
}

}
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package com.databricks.labs.mosaic.core.raster.operator.retile

import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.databricks.labs.mosaic.core.raster.MosaicRaster
import com.databricks.labs.mosaic.core.raster.api.RasterAPI
import com.databricks.labs.mosaic.core.raster.operator.clip.RasterClipByVector
import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate
import com.databricks.labs.mosaic.utils.PathUtils

import scala.collection.immutable

Expand All @@ -12,26 +11,28 @@ object ReTile {
def reTile(
raster: MosaicRaster,
tileWidth: Int,
tileHeight: Int,
geometryAPI: GeometryAPI,
rasterAPI: RasterAPI
tileHeight: Int
): immutable.Seq[MosaicRaster] = {
val (xR, yR) = raster.getDimensions
val xTiles = Math.ceil(xR / tileWidth).toInt
val yTiles = Math.ceil(yR / tileHeight).toInt

val tiles = for (x <- 0 until xTiles; y <- 0 until yTiles) yield {
val xMin = x * tileWidth
val yMin = y * tileHeight
val xMin = if (x == 0) x * tileWidth else x * tileWidth - 1
val yMin = if (y == 0) y * tileHeight else y * tileHeight - 1

val bbox = geometryAPI.createBbox(xMin, yMin, xMin + tileWidth, yMin + tileHeight)
.mapXY((x, y) => rasterAPI.toWorldCoord(raster.getGeoTransform, x.toInt, y.toInt))
val rasterUUID = java.util.UUID.randomUUID.toString
val rasterPath = PathUtils.createTmpFilePath(rasterUUID, "tif")

// buffer bbox by the diagonal size of the raster to ensure we get all the pixels in the tile
val bufferR = raster.pixelDiagSize * 1.01
val bufferedBBox = bbox.buffer(bufferR)
val result = GDALTranslate.executeTranslate(
rasterPath,
isTemp = true,
raster,
command = s"gdal_translate -srcwin $xMin $yMin ${tileWidth + 1} ${tileHeight + 1}"
)

RasterClipByVector.clip(raster, bufferedBBox, raster.getRaster.GetSpatialRef(), geometryAPI)
result.flushCache()
result

}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.databricks.labs.mosaic.datasource.gdal

import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.databricks.labs.mosaic.core.raster.api.RasterAPI.GDAL
import com.databricks.labs.mosaic.core.raster.gdal_raster.{MosaicRasterGDAL, RasterCleaner}
import com.databricks.labs.mosaic.core.raster.operator.retile.BalancedSubdivision
import com.databricks.labs.mosaic.datasource.Utils
Expand Down Expand Up @@ -43,11 +41,10 @@ object ReTileOnRead extends ReadStrategy {
val localCopy = PathUtils.copyToTmp(status.getPath.toString)
val raster = MosaicRasterGDAL.readRaster(localCopy)
val uuid = getUUID(status)
val geometryAPI = GeometryAPI.apply(options.getOrElse("geometry_api", "JTS"))

val size = status.getLen
val numSplits = Math.ceil(size / MB16).toInt
val tiles = BalancedSubdivision.splitRaster(raster, numSplits, geometryAPI, GDAL)
val tiles = BalancedSubdivision.splitRaster(raster, numSplits)

val rows = tiles.map(tile => {
val trimmedSchema = StructType(requiredSchema.filter(field => field.name != RASTER && field.name != LENGTH))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.raster.gdal_raster.RasterCleaner
import com.databricks.labs.mosaic.core.raster.{MosaicRaster, MosaicRasterBand}
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.raster.base.RasterBandExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.types._

/**
Expand All @@ -22,7 +21,13 @@ import org.apache.spark.sql.types._
* Additional arguments for the expression (expressionConfigs).
*/
case class RST_BandMetaData(raster: Expression, band: Expression, expressionConfig: MosaicExpressionConfig)
extends RasterBandExpression[RST_BandMetaData](raster, band, MapType(StringType, StringType), returnsRaster = false, expressionConfig = expressionConfig)
extends RasterBandExpression[RST_BandMetaData](
raster,
band,
MapType(StringType, StringType),
returnsRaster = false,
expressionConfig = expressionConfig
)
with NullIntolerant
with CodegenFallback {

Expand All @@ -35,10 +40,7 @@ case class RST_BandMetaData(raster: Expression, band: Expression, expressionConf
* The band metadata of the band as a map type result.
*/
override def bandTransform(raster: MosaicRaster, band: MosaicRasterBand): Any = {
val metaData = band.metadata
val result = buildMapString(metaData)
RasterCleaner.dispose(raster)
result
buildMapString(band.metadata)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.databricks.labs.mosaic.core.raster.MosaicRaster
import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.types._

/** Returns the world coordinates of the raster (x,y) pixel. */
case class RST_BoundingBox(
raster: Expression,
expressionConfig: MosaicExpressionConfig
) extends RasterExpression[RST_RasterToWorldCoord](raster, BinaryType, returnsRaster = false, expressionConfig = expressionConfig)
with NullIntolerant
with CodegenFallback {

/**
* The function to be overridden by the extending class. It is called when
* the expression is evaluated. It provides the raster to the expression.
* It abstracts spark serialization from the caller.
*
* @param raster
* The raster to be used.
* @return
* The result of the expression.
*/
override def rasterTransform(raster: MosaicRaster): Any = {
val gt = raster.getRaster.GetGeoTransform()
val (originX, originY) = rasterAPI.toWorldCoord(gt, 0, 0)
val (endX, endY) = rasterAPI.toWorldCoord(gt, raster.xSize, raster.ySize)
val geometryAPI = GeometryAPI(expressionConfig.getGeometryAPI)
val bboxPolygon = geometryAPI.geometry(
Seq(
Seq(originX, originY),
Seq(originX, endY),
Seq(endX, endY),
Seq(endX, originY),
Seq(originX, originY)
).map(geometryAPI.fromCoords),
GeometryTypeEnum.POLYGON
)
bboxPolygon.toWKB
}

}

/** Expression info required for the expression registration for spark SQL. */
object RST_BoundingBox extends WithExpressionInfo {

override def name: String = "rst_boundingbox"

override def usage: String =
"""
|_FUNC_(expr1) - Returns the bounding box of the raster.
|""".stripMargin

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(a, b, c);
| (11.2, 12.3)
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[RST_RasterToWorldCoord](3, expressionConfig)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.types.BinaryType
import org.gdal.osr
import org.gdal.osr.SpatialReference

/**
* Returns a set of new rasters with the specified tile size (tileWidth x
Expand All @@ -25,7 +23,7 @@ case class RST_Clip(
rastersExpr,
geometryExpr,
BinaryType,
returnsRaster = false,
returnsRaster = true,
expressionConfig = expressionConfig
)
with NullIntolerant
Expand All @@ -47,22 +45,8 @@ case class RST_Clip(
*/
override def rasterTransform(raster: MosaicRaster, arg1: Any): Any = {
val geometry = geometryAPI.geometry(arg1, geometryExpr.dataType)
val geomCRS =
if (geometry.getSpatialReference == 0) {
val wsg84 = new osr.SpatialReference()
wsg84.ImportFromEPSG(4326)
wsg84.SetAxisMappingStrategy(osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER)
wsg84
}
else {
val geomCRS = new SpatialReference()
geomCRS.ImportFromEPSG(geometry.getSpatialReference)
// debug for this
geomCRS.SetAxisMappingStrategy(osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER)
geomCRS
}
val result = RasterClipByVector.clip(raster, geometry, geomCRS, geometryAPI)
rasterAPI.writeRasters(Seq(result), expressionConfig.getRasterCheckpoint, BinaryType).head
val geomCRS = geometry.getSpatialReferenceOSR
RasterClipByVector.clip(raster, geometry, geomCRS, geometryAPI)
}

}
Expand Down
Loading

0 comments on commit 43e483f

Please sign in to comment.