diff --git a/danesfield/gdal_utils.py b/danesfield/gdal_utils.py index dfafe2522..194435673 100644 --- a/danesfield/gdal_utils.py +++ b/danesfield/gdal_utils.py @@ -11,7 +11,7 @@ import re import ogr import osr - +from vtk.util import numpy_support def gdal_bounding_box(raster, outProj=None): """ @@ -148,3 +148,18 @@ def read_offset(fileName, offset): if match: for i in range(3): offset[i] = float(match.group(1+i)) + +def vtk_to_numpy_order(aFlatVtk, dimensions): + ''' + Convert a 2D array from VTK order to numpy order + ''' + # VTK to numpy + aFlat = numpy_support.vtk_to_numpy(aFlatVtk) + # VTK X,Y corresponds to numpy cols,rows. VTK stores as + # in Fortran order. + aTranspose = numpy.reshape(aFlat, dimensions, "F") + # changes from cols, rows to rows,cols. + a = numpy.transpose(aTranspose) + # numpy rows increase as you go down, Y for VTK images increases as you go up + a = numpy.flip(a, 0) + return a diff --git a/danesfield/ortho.py b/danesfield/ortho.py index 0958e3dbb..869a492df 100644 --- a/danesfield/ortho.py +++ b/danesfield/ortho.py @@ -31,7 +31,7 @@ def circ_structure(n): def orthorectify(args_source_image, args_dsm, args_destination_image, args_occlusion_thresh=1.0, args_denoise_radius=2, - args_raytheon_rpc=None, args_dtm=None): + args_raytheon_rpc=None, args_dtm=None, args_convert_to_latlon=True): """ Orthorectify an image given the DSM @@ -39,12 +39,14 @@ def orthorectify(args_source_image, args_dsm, args_destination_image, source_image: Source image file name dsm: Digital surface model (DSM) image file name destination_image: Orthorectified image file name - occlusion-thresh: Threshold on height difference for detecting + occlusion_thresh: Threshold on height difference for detecting and masking occluded regions (in meters) - denoise-radius: Apply morphological operations with this radius + denoise_radius: Apply morphological operations with this radius to the DSM reduce speckled noise - raytheon-rpc: Raytheon RPC file name. If not provided + raytheon_rpc: Raytheon RPC file name or model object. If not provided the RPC is read from the source_image + dtm: Optional DTM parameter used to replace nodata areas in the + orthorectified image Returns: COMPLETE_DSM_INTERSECTION = 0 @@ -60,9 +62,12 @@ def orthorectify(args_source_image, args_dsm, args_destination_image, sourceBand = sourceImage.GetRasterBand(1) if (args_raytheon_rpc): - # read the RPC from raytheon file - print("Reading RPC from Raytheon file: {}".format(args_raytheon_rpc)) - model = raytheon_rpc.read_raytheon_rpc_file(args_raytheon_rpc) + if (isinstance(args_raytheon_rpc, str)): + # read the RPC from raytheon file + print("Reading RPC from Raytheon file: {}".format(args_raytheon_rpc)) + model = raytheon_rpc.read_raytheon_rpc_file(args_raytheon_rpc) + else: + model = args_raytheon_rpc else: # read the RPC from RPC Metadata in the image file print("Reading RPC Metadata from {}".format(args_source_image)) @@ -157,12 +162,13 @@ def orthorectify(args_source_image, args_dsm, args_destination_image, print("Driver {} does not supports Create().".format(driver)) return ERROR - # convert coordinates to Long/Lat - srs = osr.SpatialReference(wkt=projection) - proj_srs = srs.ExportToProj4() - inProj = pyproj.Proj(proj_srs) - outProj = pyproj.Proj('+proj=longlat +datum=WGS84') - arrayX, arrayY = pyproj.transform(inProj, outProj, arrayX, arrayY) + if (args_convert_to_latlon): + # convert coordinates to Long/Lat + srs = osr.SpatialReference(wkt=projection) + proj_srs = srs.ExportToProj4() + inProj = pyproj.Proj(proj_srs) + outProj = pyproj.Proj('+proj=longlat +datum=WGS84') + arrayX, arrayY = pyproj.transform(inProj, outProj, arrayX, arrayY) # Sort the points by height so that higher points project last if (args_occlusion_thresh > 0): @@ -240,10 +246,11 @@ def orthorectify(args_source_image, args_dsm, args_destination_image, print("Processing band {} ...".format(bandIndex)) sourceBand = sourceImage.GetRasterBand(bandIndex) nodata_value = sourceBand.GetNoDataValue() - # for now use zero as a no-data value if one is not specified + # for now use -9999 as a no-data value if one is not specified # it would probably be better to add a mask (alpha) band instead if nodata_value is None: - nodata_value = 0 + nodata_value = -9999 + print("nodata: {}".format(nodata_value)) if numpy.any(cropSize < 1): # read one value for data type sourceRaster = sourceBand.ReadAsArray( diff --git a/tools/buildings_to_dsm.py b/tools/buildings_to_dsm.py index 94f53ad8e..c91909a11 100644 --- a/tools/buildings_to_dsm.py +++ b/tools/buildings_to_dsm.py @@ -197,10 +197,10 @@ def main(args): buildingsScalarRange = p2cBuildings.GetOutput().GetCellData().GetScalars().GetRange() if (args.debug): - polyWriter = vtk.vtkXMLPolyDataWriter() - polyWriter.SetFileName("p2c.vtp") - polyWriter.SetInputConnection(p2cBuildings.GetOutputPort()) - polyWriter.Write() + writer = vtk.vtkXMLPolyDataWriter() + writer.SetFileName("p2c.vtp") + writer.SetInputConnection(p2cBuildings.GetOutputPort()) + writer.Write() buildingsMapper = vtk.vtkPolyDataMapper() buildingsMapper.SetInputDataObject(p2cBuildings.GetOutput()) @@ -217,7 +217,26 @@ def main(args): dtmReader = vtk.vtkGDALRasterReader() dtmReader.SetFileName(args.input_dtm) dtmReader.Update() - dtmVtk = dtmReader.GetOutput() + dtmVtkCell = dtmReader.GetOutput() + + # convert from cell to point data + origin = dtmVtkCell.GetOrigin() + spacing = dtmVtkCell.GetSpacing() + dims = list(dtmVtkCell.GetDimensions()) + dims[0] = dims[0] - 1 + dims[1] = dims[1] - 1 + dtmVtk = vtk.vtkUniformGrid() + dtmVtk.SetDimensions(dims) + dtmVtk.SetOrigin(origin) + dtmVtk.SetSpacing(spacing) + data = dtmVtkCell.GetCellData().GetScalars() + dtmVtk.GetPointData().SetScalars(data) + + if (args.debug): + writer = vtk.vtkXMLImageDataWriter() + writer.SetFileName("dtm.vti") + writer.SetInputDataObject(dtmVtk) + writer.Write() # Convert the terrain into a polydata. surface = vtk.vtkImageDataGeometryFilter() @@ -236,6 +255,12 @@ def main(args): warp.Update() dsmScalarRange = warp.GetOutput().GetPointData().GetScalars().GetRange() + if (args.debug): + writer = vtk.vtkXMLPolyDataWriter() + writer.SetFileName("dtm_warped.vtp") + writer.SetInputConnection(warp.GetOutputPort()) + writer.Write() + dtmMapper = vtk.vtkPolyDataMapper() dtmMapper.SetInputConnection(warp.GetOutputPort()) dtmActor = vtk.vtkActor() @@ -264,6 +289,10 @@ def main(args): ren.RemoveActor(dtmActor) renWin.Render() + # iren = vtk.vtkRenderWindowInteractor() + # iren.SetRenderWindow(renWin) + # iren.Start(); + windowToImageFilter = vtk.vtkWindowToImageFilter() windowToImageFilter.SetInput(renWin) windowToImageFilter.SetInputBufferTypeToRGBA() @@ -310,15 +339,8 @@ def main(args): valuePass.ReleaseGraphicsResources(renWin) print("Writing the DSM ...") - elevationFlat = numpy_support.vtk_to_numpy(elevationFlatVtk) - # VTK X,Y corresponds to numpy cols,rows. VTK stores arrays - # in Fortran order. - elevationTranspose = numpy.reshape( - elevationFlat, [dtm.RasterXSize, dtm.RasterYSize], "F") - # changes from cols, rows to rows,cols. - elevation = numpy.transpose(elevationTranspose) - # numpy rows increase as you go down, Y for VTK images increases as you go up - elevation = numpy.flip(elevation, 0) + elevation = gdal_utils.vtk_to_numpy_order(elevationFlatVtk, + [dtm.RasterXSize, dtm.RasterYSize]) if args.buildings_only: dsmElevation = elevation else: diff --git a/tools/render_shadows.py b/tools/render_shadows.py new file mode 100644 index 000000000..33e357199 --- /dev/null +++ b/tools/render_shadows.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python + +import argparse +from danesfield import gdal_utils +from danesfield import ortho +import gdal +import logging +import numpy +import os +import vtk +from vtk.numpy_interface import dataset_adapter as dsa +from vtk.util import vtkAlgorithm as vta + + +class ParallelProjectionModel(object): + '''Implements a parallel projection: (lon, lat, alt) ==> (lon, lat) + ''' + def __init__(self, m, dims): + '''Initialize the class from a 16 element row order matrix. + ''' + self.mat = numpy.array([[m[0], m[1], m[2], m[3]], + [m[4], m[5], m[6], m[7]], + [m[8], m[9], m[10], m[11]], + [m[12], m[13], m[14], m[15]]]) + self.dims = dims + + def project(self, point): + '''Project a long, lat, elev point into image coordinates using parallel projection + This function can also project an (n,3) matrix where each row of the + matrix is a point to project. The result is an (n,2) matrix of image + coordinates. + ''' + pointCount = point.shape[0] + ones = numpy.repeat(1, pointCount) + ones = numpy.reshape(ones, (pointCount, 1)) + point = numpy.hstack((point, ones)) + viewPoint = numpy.dot(point, self.mat) + viewPoint = numpy.delete(viewPoint, [2, 3], axis=1) + imagePoint = (viewPoint + 1.0) / 2.0 * self.dims + imagePoint[:, 1] = self.dims[1] - imagePoint[:, 1] + return imagePoint + + +class FillNoData(vta.VTKPythonAlgorithmBase): + ''' Algorithm with a DSM and DTM as inputs, DTM is optional. It produces + an output the same as the DSM with the NoData values replaced with correspoinding + values from the DTM. + ''' + def __init__(self): + vta.VTKPythonAlgorithmBase.__init__(self, nInputPorts=2, outputType='vtkUniformGrid') + + def RequestInformation(self, request, inInfo, outInfo): + origin = inInfo[0].GetInformationObject(0).Get(vtk.vtkDataObject.ORIGIN()) + spacing = inInfo[0].GetInformationObject(0).Get(vtk.vtkDataObject.SPACING()) + outInfo.GetInformationObject(0).Set( + vtk.vtkDataObject.ORIGIN(), origin, 3) + outInfo.GetInformationObject(0).Set( + vtk.vtkDataObject.SPACING(), spacing, 3) + return 1 + + def FillInputPortInformation(self, port, info): + """Sets the required input type to InputType.""" + info.Set(vtk.vtkAlgorithm.INPUT_REQUIRED_DATA_TYPE(), self.InputType) + if (port == 1): + info.Set(vtk.vtkAlgorithm.INPUT_IS_OPTIONAL(), 1) + return 1 + + def RequestData(self, request, inInfo, outInfo): + i0 = self.GetInputData(inInfo, 0, 0) + dsm = dsa.WrapDataObject(i0) + dtmInfo = inInfo[1].GetInformationObject(0) + if (dtmInfo): + i1 = self.GetInputData(inInfo, 1, 0) + dtm = dsa.WrapDataObject(i1) + o = self.GetOutputData(outInfo, 0) + o.ShallowCopy(i0) + output = dsa.WrapDataObject(o) + + dsmData = dsm.CellData["Elevation"] + if (dtmInfo): + dtmData = dtm.CellData["Elevation"] + ghostArray = dsm.CellData["vtkGhostType"] + blankedIndex = ghostArray == 32 + dsmData[blankedIndex] = dtmData[blankedIndex] + output.VTKObject.GetCellData().RemoveArray("vtkGhostType") + output.CellData.append(dsmData, "Elevation") + return 1 + + +def main(args): + parser = argparse.ArgumentParser( + description='Render a shadow mask from a sun position (stored in the input_image), ' + 'a DSM and an optional DTM') + parser.add_argument("input_image", help="Source image with sun position") + parser.add_argument("dsm", help="Digital surface model (DSM) image") + parser.add_argument("output_image", help="Image with shadow mask") + parser.add_argument("--dtm", type=str, + help="Optional DTM parameter used to fill nodata areas " + "in the dsm") + parser.add_argument("--render_png", action="store_true", + help="Do not save shadow mask, render a PNG instead.") + parser.add_argument("--debug", action="store_true", + help="Save intermediate results") + args = parser.parse_args(args) + + sourceImage = gdal_utils.gdal_open(args.input_image) + metaData = sourceImage.GetMetadata() + azimuth = float(metaData["NITF_CSEXRA_SUN_AZIMUTH"]) + elevation = float(metaData["NITF_CSEXRA_SUN_ELEVATION"]) + print("azimuth = {}, elevation = {}".format(azimuth, elevation)) + sourceImage = None + + dsm = vtk.vtkGDALRasterReader() + dsm.SetFileName(args.dsm) + + fillNoData = FillNoData() + fillNoData.SetInputConnection(0, dsm.GetOutputPort()) + if (args.dtm): + dtm = vtk.vtkGDALRasterReader() + dtm.SetFileName(args.dtm) + fillNoData.SetInputConnection(1, dtm.GetOutputPort()) + + cellToPoint = vtk.vtkCellDataToPointData() + cellToPoint.SetInputConnection(fillNoData.GetOutputPort()) + + warp = vtk.vtkWarpScalar() + warp.SetInputConnection(cellToPoint.GetOutputPort()) + warp.Update() + warpOutput = warp.GetOutput() + + scalarRange = warpOutput.GetPointData().GetScalars().GetRange() + warpBounds = warpOutput.GetBounds() + dims = warpOutput.GetDimensions() + + if (args.debug): + writerVts = vtk.vtkXMLStructuredGridWriter() + writerVts.SetFileName("warp.vts") + writerVts.SetInputDataObject(warpOutput) + writerVts.Write() + + print("warpBounds: {}".format(warpBounds)) + print("scalarRange: {}".format(scalarRange)) + print("dims: {}".format(dims)) + + # vtkValuePass works only with vtkPolyData + warpSurface = vtk.vtkDataSetSurfaceFilter() + warpSurface.SetInputConnection(warp.GetOutputPort()) + warpSurface.Update() + + ren = vtk.vtkRenderer() + renWin = vtk.vtkRenderWindow() + # VTK specifies dimensions in points, GDAL in cells + renWin.SetSize(dims[0] - 1, dims[1] - 1) + renWin.AddRenderer(ren) + + warpMapper = vtk.vtkPolyDataMapper() + warpMapper.SetInputConnection(warpSurface.GetOutputPort()) + warpActor = vtk.vtkActor() + warpActor.SetMapper(warpMapper) + ren.AddActor(warpActor) + + camera = ren.GetActiveCamera() + camera.SetViewUp(0, 1, 0) + camera.ParallelProjectionOn() + camera.Roll(azimuth - 180) + camera.Elevation(elevation - 90) + ren.ResetCamera() + + if (args.render_png): + print("Render into a PNG ...") + lut = vtk.vtkColorTransferFunction() + lut.AddRGBPoint(scalarRange[0], 0.23, 0.30, 0.75) + lut.AddRGBPoint((scalarRange[0] + scalarRange[1]) / 2, 0.86, 0.86, 0.86) + lut.AddRGBPoint(scalarRange[1], 0.70, 0.02, 0.15) + warpMapper.SetLookupTable(lut) + warpMapper.SetColorModeToMapScalars() + + iren = vtk.vtkRenderWindowInteractor() + iren.SetRenderWindow(renWin) + iren.Initialize() + renWin.Render() + + windowToImageFilter = vtk.vtkWindowToImageFilter() + windowToImageFilter.SetInput(renWin) + windowToImageFilter.SetInputBufferTypeToRGBA() + windowToImageFilter.ReadFrontBufferOff() + windowToImageFilter.Update() + + writerPng = vtk.vtkPNGWriter() + writerPng.SetFileName(args.output_image + ".png") + writerPng.SetInputConnection(windowToImageFilter.GetOutputPort()) + writerPng.Write() + + iren.Start() + else: + print("Render into a floating point buffer ...") + renWin.OffScreenRenderingOn() + arrayName = "Elevation" + valuePass = vtk.vtkValuePass() + valuePass.SetRenderingMode(vtk.vtkValuePass.FLOATING_POINT) + # use the default scalar for point data + valuePass.SetInputComponentToProcess(0) + valuePass.SetInputArrayToProcess(vtk.VTK_SCALAR_MODE_USE_POINT_FIELD_DATA, + arrayName) + passes = vtk.vtkRenderPassCollection() + passes.AddItem(valuePass) + sequence = vtk.vtkSequencePass() + sequence.SetPasses(passes) + cameraPass = vtk.vtkCameraPass() + cameraPass.SetDelegatePass(sequence) + ren.SetPass(cameraPass) + renWin.Render() + elevationFlatVtk = valuePass.GetFloatImageDataArray(ren) + elevationFlatVtk.SetName("Elevation") + valuePass.ReleaseGraphicsResources(renWin) + + # GDAL dataset + heightMapFileName = "_temp_heightMap.tif" + elevation = gdal_utils.vtk_to_numpy_order(elevationFlatVtk, + [dims[0] - 1, dims[1] - 1]) + driver = gdal.GetDriverByName("GTiff") + heightMap = driver.Create( + heightMapFileName, xsize=dims[0] - 1, ysize=dims[1] - 1, + bands=1, eType=gdal.GDT_Float32, options=["COMPRESS=DEFLATE"]) + heightMap.GetRasterBand(1).WriteArray(elevation) + hasNoData = 0 + nodata = dsm.GetInvalidValue(0, hasNoData) + if (not args.dtm and hasNoData): + heightMap.GetRasterBand(1).SetNoDataValue(nodata) + corners = [[0, 0], [dims[0] - 1, 0], + [dims[0] - 1, dims[1] - 1], + [0, dims[1] - 1]] + gcps = [] + for corner in corners: + worldPoint = [0.0, 0.0, 0.0] + warpOutput.GetPoint(corner[0], corner[1], 0, worldPoint) + ren.SetWorldPoint(worldPoint[0], worldPoint[1], worldPoint[2], 1.0) + ren.WorldToView() + viewPoint = numpy.array(ren.GetViewPoint()) + pixelCoord = (viewPoint + 1.0) / 2.0 * numpy.array(dims) + pixelCoord[1] = dims[1] - pixelCoord[1] + gcp = gdal.GCP(worldPoint[0], worldPoint[1], worldPoint[2], + pixelCoord[0], pixelCoord[1]) + gcps.append(gcp) + wkt = dsm.GetProjectionWKT() + heightMap.SetGCPs(gcps, wkt) + heightMap = None + + if (args.debug): + # VTK dataset + heightMapVtk = vtk.vtkImageData() + heightMapVtk.SetDimensions(dims[0], dims[1], 1) + heightMapVtk.SetOrigin(dsm.GetOutput().GetOrigin()) + heightMapVtk.SetSpacing(dsm.GetOutput().GetSpacing()) + heightMapVtk.GetCellData().SetScalars(elevationFlatVtk) + + writerVti = vtk.vtkXMLImageDataWriter() + writerVti.SetFileName("heightMap.vti") + writerVti.SetInputDataObject(heightMapVtk) + writerVti.Write() + + # create the parallel projection model + mat = [ren.GetActiveCamera().GetCompositeProjectionTransformMatrix( + ren.GetTiledAspectRatio(), 0, 1).GetElement(i, j) + for j in range(0, 4) + for i in range(0, 4)] + model = ParallelProjectionModel(mat, numpy.array(dims[0:2])) + ortho.orthorectify(heightMapFileName, args.dsm, args.output_image, + 1.0, 2.0, model, args.dtm, False) + if (not args.debug): + os.remove(heightMapFileName) + + +if __name__ == '__main__': + import sys + try: + main(sys.argv[1:]) + except Exception as e: + logging.exception(e) + sys.exit(1)