Skip to content

Commit

Permalink
Fix missing spatial reference when generating index geom.
Browse files Browse the repository at this point in the history
Fix SR transformations in MosaicGeometry.
Add NDVI and Clip expressions.
Fix projection issues in RST_Tessellate.
  • Loading branch information
milos.colic committed Sep 8, 2023
1 parent f220d75 commit cd47656
Show file tree
Hide file tree
Showing 16 changed files with 404 additions and 59 deletions.
57 changes: 57 additions & 0 deletions python/mosaic/api/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

__all__ = [
"rst_bandmetadata",
"rst_clip",
"rst_georeference",
"rst_height",
"rst_isempty",
Expand All @@ -18,6 +19,7 @@
"rst_merge",
"rst_mergebands",
"rst_numbands",
"rst_ndvi",
"rst_pixelheight",
"rst_pixelwidth",
"rst_rastertogridavg",
Expand Down Expand Up @@ -73,6 +75,32 @@ def rst_bandmetadata(raster: ColumnOrName, band: ColumnOrName) -> Column:
)


def rst_clip(raster: ColumnOrName, geometry: ColumnOrName) -> Column:
"""
Clips the raster to the given geometry.
The result is the path to the clipped raster.
The result is stored in the checkpoint directory.
Parameters
----------
raster : Column (StringType)
Path to the raster file.
geometry : Column (StringType)
The geometry to clip the raster to.
Returns
-------
Column (StringType)
The path to the clipped raster.
"""
return config.mosaic_context.invoke_function(
"rst_clip",
pyspark_to_java_column(raster),
pyspark_to_java_column(geometry)
)


def rst_georeference(raster: ColumnOrName) -> Column:
"""
Returns GeoTransform of the raster as a GT array of doubles.
Expand Down Expand Up @@ -242,6 +270,35 @@ def rst_numbands(raster: ColumnOrName) -> Column:
)


def rst_ndvi(raster: ColumnOrName, band1: ColumnOrName, band2: ColumnOrName) -> Column:
"""
Computes the NDVI of the raster.
The result is the path to the NDVI raster.
The result is stored in the checkpoint directory.
Parameters
----------
raster : Column (StringType)
Path to the raster file.
band1 : Column (IntegerType)
The first band index.
band2 : Column (IntegerType)
The second band index.
Returns
-------
Column (StringType)
The path to the NDVI raster.
"""
return config.mosaic_context.invoke_function(
"rst_ndvi",
pyspark_to_java_column(raster),
pyspark_to_java_column(band1),
pyspark_to_java_column(band2)
)


def rst_pixelheight(raster: ColumnOrName) -> Column:
"""
Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ trait MosaicGeometry extends GeometryWriter with Serializable {
def osrTransformCRS(srcSR: SpatialReference, destSR: SpatialReference, geometryAPI: GeometryAPI): MosaicGeometry = {
if (srcSR.IsSame(destSR) == 1) return this
val ogcGeometry = ogr.CreateGeometryFromWkb(this.toWKB)
val transform = new CoordinateTransformation(srcSR, destSR)
ogcGeometry.Transform(transform)
ogcGeometry.AssignSpatialReference(srcSR)
ogcGeometry.TransformTo(destSR)
val mosaicGeometry = geometryAPI.geometry(ogcGeometry.ExportToWkb, "WKB")
mosaicGeometry
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,20 @@ object BNGIndexSystem extends IndexSystem(StringType) with Serializable {
*/
val letterMap: Seq[Seq[String]] =
Seq(
Seq("SV", "SW", "SX", "SY", "SZ", "TV", "TW"),
Seq("SQ", "SR", "SS", "ST", "SU", "TQ", "TR"),
Seq("SL", "SM", "SN", "SO", "SP", "TL", "TM"),
Seq("SF", "SG", "SH", "SJ", "SK", "TF", "TG"),
Seq("SA", "SB", "SC", "SD", "SE", "TA", "TB"),
Seq("NV", "NW", "NX", "NY", "NZ", "OV", "OW"),
Seq("NQ", "NR", "NS", "NT", "NU", "OQ", "OR"),
Seq("NL", "NM", "NN", "NO", "NP", "OL", "OM"),
Seq("NF", "NG", "NH", "NJ", "NK", "OF", "OG"),
Seq("NA", "NB", "NC", "ND", "NE", "OA", "OB"),
Seq("HV", "HW", "HX", "HY", "SZ", "JV", "JW"),
Seq("HQ", "HR", "HS", "HT", "HU", "JQ", "JR"),
Seq("HL", "HM", "HN", "HO", "HP", "JL", "JM")
Seq("SV", "SW", "SX", "SY", "SZ", "TV", "TW", "TX"),
Seq("SQ", "SR", "SS", "ST", "SU", "TQ", "TR", "TS"),
Seq("SL", "SM", "SN", "SO", "SP", "TL", "TM", "TN"),
Seq("SF", "SG", "SH", "SJ", "SK", "TF", "TG", "TH"),
Seq("SA", "SB", "SC", "SD", "SE", "TA", "TB", "TC"),
Seq("NV", "NW", "NX", "NY", "NZ", "OV", "OW", "OX"),
Seq("NQ", "NR", "NS", "NT", "NU", "OQ", "OR", "OS"),
Seq("NL", "NM", "NN", "NO", "NP", "OL", "OM", "ON"),
Seq("NF", "NG", "NH", "NJ", "NK", "OF", "OG", "OH"),
Seq("NA", "NB", "NC", "ND", "NE", "OA", "OB", "OC"),
Seq("HV", "HW", "HX", "HY", "SZ", "JV", "JW", "JX"),
Seq("HQ", "HR", "HS", "HT", "HU", "JQ", "JR", "JS"),
Seq("HL", "HM", "HN", "HO", "HP", "JL", "JM", "JN"),
Seq("HF", "HG", "HH", "HJ", "HK", "JF", "JG", "JH")
)

/**
Expand Down Expand Up @@ -257,7 +258,7 @@ object BNGIndexSystem extends IndexSystem(StringType) with Serializable {
* @return
* Boolean representing validity.
*/
def isValid(index: Long): Boolean = {
override def isValid(index: Long): Boolean = {
val digits = indexDigits(index)
val xLetterIndex = digits.slice(3, 5).mkString.toInt
val yLetterIndex = digits.slice(1, 3).mkString.toInt
Expand Down Expand Up @@ -437,7 +438,9 @@ object BNGIndexSystem extends IndexSystem(StringType) with Serializable {
val p2 = geometryAPI.fromCoords(Seq(x + edgeSize, y))
val p3 = geometryAPI.fromCoords(Seq(x + edgeSize, y + edgeSize))
val p4 = geometryAPI.fromCoords(Seq(x, y + edgeSize))
geometryAPI.geometry(Seq(p1, p2, p3, p4, p1), POLYGON)
val geom = geometryAPI.geometry(Seq(p1, p2, p3, p4, p1), POLYGON)
geom.setSpatialReference(this.crsID)
geom
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,12 @@ object H3IndexSystem extends IndexSystem(LongType) with Serializable {
override def indexToGeometry(index: String, geometryAPI: GeometryAPI): MosaicGeometry = {
val boundary = h3.h3ToGeoBoundary(index).asScala
val extended = boundary ++ List(boundary.head)
geometryAPI.geometry(
val geom = geometryAPI.geometry(
extended.map(p => geometryAPI.fromGeoCoord(Coordinates(p.lat, p.lng))),
POLYGON
)
geom.setSpatialReference(crsID)
geom
}

override def format(id: Long): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import org.gdal.osr.SpatialReference
*/
abstract class IndexSystem(var cellIdType: DataType) extends Serializable {

// Passthrough if not redefined
def isValid(cellID: Long): Boolean = true

def crsID: Int

def osrSpatialRef: SpatialReference = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ abstract class MosaicRaster(
def asTemp: MosaicRaster

def flushCache(): MosaicRaster = {
this.getRaster.FlushCache()
if (Option(getRaster).isDefined) {
getRaster.FlushCache()
}
this.destroy()
this.refresh()
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,21 @@ abstract class RasterAPI(reader: RasterReader) extends Serializable {

def writeRasters(generatedRasters: Seq[MosaicRaster], checkpointPath: String, rasterDT: DataType): Seq[Any] = {
generatedRasters.map(raster =>
rasterDT match {
case StringType =>
val extension = raster.getRaster.GetDriver().GetMetadataItem("DMD_EXTENSION")
val writePath = s"$checkpointPath/${raster.uuid}.$extension"
val outPath = raster.writeToPath(writePath)
RasterCleaner.dispose(raster)
UTF8String.fromString(outPath)
case BinaryType =>
val bytes = raster.writeToBytes()
RasterCleaner.dispose(raster)
bytes
if (Option(raster).isDefined) {
rasterDT match {
case StringType =>
val extension = raster.getRaster.GetDriver().GetMetadataItem("DMD_EXTENSION")
val writePath = s"$checkpointPath/${raster.uuid}.$extension"
val outPath = raster.writeToPath(writePath)
RasterCleaner.dispose(raster)
UTF8String.fromString(outPath)
case BinaryType =>
val bytes = raster.writeToBytes()
RasterCleaner.dispose(raster)
bytes
}
} else {
null
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class MosaicRasterGDAL(_uuid: Long, var raster: Dataset, path: String, isTemp: B
val max = Array.ofDim[Double](1)
val mean = Array.ofDim[Double](1)
val stddev = Array.ofDim[Double](1)
band.GetStatistics(true, false, min, max, mean, stddev)
band.GetStatistics(true, true, min, max, mean, stddev)
i -> Map(
"min" -> min(0),
"max" -> max(0),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.databricks.labs.mosaic.core.raster.operator

import com.databricks.labs.mosaic.core.raster.MosaicRaster
import com.databricks.labs.mosaic.core.raster.gdal_raster.MosaicRasterGDAL
import com.databricks.labs.mosaic.utils.PathUtils

object NDVI {

def emptyCopy(raster: MosaicRaster, path: String): MosaicRaster = {
val driver = raster.getRaster.GetDriver()
val newRaster = driver.Create(path, raster.xSize, raster.ySize, raster.numBands, raster.getRaster.GetRasterBand(1).getDataType)
newRaster.SetGeoTransform(raster.getRaster.GetGeoTransform)
newRaster.SetProjection(raster.getRaster.GetProjection)
MosaicRasterGDAL(newRaster, path, isTemp = true)
}

def compute(raster: MosaicRaster, redIndex: Int, nirIndex: Int): MosaicRaster = {

val redBand = raster.getRaster.GetRasterBand(redIndex)
val nirBand = raster.getRaster.GetRasterBand(nirIndex)

val numLines = redBand.GetYSize
val lineSize = redBand.GetXSize

val ndviPath = PathUtils.createTmpFilePath(raster.uuid.toString, raster.getExtension)
val ndviRaster = emptyCopy(raster, ndviPath)

var outputLine: Array[Double] = null
var redScanline: Array[Double] = null
var nirScanline: Array[Double] = null
val dataType = org.gdal.gdalconst.gdalconstConstants.GDT_Float64
for (line <- Range(0, numLines)) {
redScanline = Array.fill[Double](lineSize)(0.0)
nirScanline = Array.fill[Double](lineSize)(0.0)
redBand.ReadRaster(0, line, lineSize, 1, dataType, redScanline)
nirBand.ReadRaster(0, line, lineSize, 1, dataType, nirScanline)

outputLine = redScanline.zip(nirScanline).map { case (red, nir) =>
if (red + nir == 0) 0.0
else (nir - red) / (red + nir)
}
ndviRaster.getRaster
.GetRasterBand(1)
.WriteRaster(0, line, lineSize, 1, dataType, outputLine.array)
}
outputLine = null
redScanline = null
nirScanline = null

ndviRaster.flushCache()

ndviRaster
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.databricks.labs.mosaic.core.raster.operator.proj

import com.databricks.labs.mosaic.core.raster.MosaicRaster
import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALWarp
import com.databricks.labs.mosaic.utils.PathUtils
import org.gdal.osr.SpatialReference

object RasterProject {

def project(raster: MosaicRaster, destCRS: SpatialReference): MosaicRaster = {
val outShortName = raster.getRaster.GetDriver().getShortName

val resultFileName = PathUtils.createTmpFilePath(raster.uuid.toString, raster.getExtension)

// Note that Null is the right value here
val authName = destCRS.GetAuthorityName(null)
val authCode = destCRS.GetAuthorityCode(null)

val result = GDALWarp.executeWarp(
resultFileName,
isTemp = true,
Seq(raster),
command = s"gdalwarp -of $outShortName -t_srs $authName:$authCode -r cubic -overwrite -co COMPRESS=PACKBITS"
)

result
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,33 @@ import com.databricks.labs.mosaic.core.Mosaic
import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.databricks.labs.mosaic.core.index.IndexSystem
import com.databricks.labs.mosaic.core.raster.MosaicRaster
import com.databricks.labs.mosaic.core.raster.operator.proj.RasterProject
import com.databricks.labs.mosaic.core.types.model.MosaicRasterChip

object RasterTessellate {

def tessellate(raster: MosaicRaster, resolution: Int, indexSystem: IndexSystem, geometryAPI: GeometryAPI): Seq[MosaicRasterChip] = {

Array.fill(10)(1.0)

val indexSR = indexSystem.osrSpatialRef
val bbox = raster.bbox(geometryAPI, indexSR)
val cells = Mosaic.mosaicFill(bbox, resolution, keepCoreGeom = false, indexSystem, geometryAPI)
cells
val tmpRaster = RasterProject.project(raster, indexSR)
val result = cells
.map(cell => {
val cellID = cell.cellIdAsLong(indexSystem)
val cellRaster = raster.getRasterForCell(cellID, indexSystem, geometryAPI)
val cellRaster = tmpRaster.getRasterForCell(cellID, indexSystem, geometryAPI)
cellRaster.getRaster.FlushCache()
(
cellRaster.getBands.exists { band =>
band.values.count(_ != band.noDataValue) > 0 &&
band.maskValues.count(_ > 0) > 0
} && !cellRaster.isEmpty,
cellRaster.getBandStats.values.map(_("mean")).sum > 0 && !cellRaster.isEmpty,
MosaicRasterChip(cell.index, cellRaster)
)
})
.filter(_._1)
.map(_._2)

tmpRaster.destroy()
result
}

}
Loading

0 comments on commit cd47656

Please sign in to comment.