Source code for dea_tools.app.imageexport

# -*- coding: utf-8 -*-
"""
Image export widget, which can be used to interactively select and 
export satellite imagery from multiple DEA products.
"""

# Import required packages
import fiona
import sys
import datacube
import warnings
import matplotlib.pyplot as plt
from datacube.utils.geometry import CRS
from ipyleaflet import (
    WMSLayer,
    basemaps,
    basemap_to_tiles,
    Map,
    DrawControl,
    WidgetControl,
    SearchControl,
    Marker,
    LayerGroup,
    LayersControl,
    GeoData,
)
from traitlets import Unicode
from ipywidgets import (
    GridspecLayout,
    Button,
    Layout,
    HBox,
    VBox,
    HTML,
    Output,
)
import json
import itertools
import numpy as np
import geopandas as gpd
from io import BytesIO
import ipywidgets as widgets
import datetime
from skimage import exposure
from skimage.filters import unsharp_mask

from datacube.utils import masking
from datacube.utils.geometry import Geometry
import dea_tools.app.widgetconstructors as deawidgets
from dea_tools.dask import create_local_dask_cluster
from dea_tools.spatial import reverse_geocode
from dea_tools.datahandling import xr_pansharpen


# WMS params and satellite style bands
sat_params = {
    "ga_ls_ard_3": {
        "products": [
            "ga_ls5t_ard_3",
            "ga_ls7e_ard_3",
            "ga_ls8c_ard_3",
            "ga_ls9c_ard_3",
        ],
        "styles": {
            "True colour": ("true_colour", ["nbart_red", "nbart_green", "nbart_blue"]),
            "False colour": (
                "false_colour",
                ["nbart_swir_1", "nbart_nir", "nbart_green"],
            ),
        },
    },
    "ga_s2m_ard_3": {
        "products": ["ga_s2am_ard_3", "ga_s2bm_ard_3"],
        "styles": {
            "True colour": ("simple_rgb", ["nbart_red", "nbart_green", "nbart_blue"]),
            "False colour": (
                "infrared_green",
                ["nbart_swir_2", "nbart_nir_1", "nbart_green"],
            ),
        },
    },
}


def make_box_layout():
    return Layout(
        #         border='solid 1px black',
        margin="0px 10px 10px 0px",
        padding="5px 5px 5px 5px",
        width="100%",
        height="100%",
    )


def create_expanded_button(description, button_style):
    return Button(
        description=description,
        button_style=button_style,
        layout=Layout(width="auto", height="auto"),
    )


[docs] def update_map_layers(self): """ Updates map to add new DEA layers, styles or basemap when selected using menu options. Triggers data reload by resetting load params and output arrays. """ # Clear data load params to trigger data re-load self.rgb_array = None self.sensor = None self.load_params = None self.query_params = None # Clear all layers and add basemap self.map_layers.clear_layers() self.map_layers.add_layer(self.basemap) # Get style name for specific satellite sensor style = sat_params[self.dealayer]["styles"][self.style][0] # Add DEA layers over the top of the basemap dea_layer = deawidgets.create_dea_wms_layer(self.dealayer, self.date, styles=style) self.map_layers.add_layer(dea_layer)
def extract_data(self): # Connect to datacube database dc = datacube.Datacube(app="Exporting satellite images") # Configure local dask cluster client = create_local_dask_cluster(return_client=True, display_client=True) # Convert to geopolygon geopolygon = Geometry(geom=self.gdf_drawn.geometry[0], crs=self.gdf_drawn.crs) # Create query after adjusting interval time to UTC by # adding a UTC offset of -10 hours. This results issues # on the east coast of Australia where satelite overpasses # can occur on either side of 24:00 hours UTC start_date = np.datetime64(self.date) - np.timedelta64(10, "h") end_date = np.datetime64(self.date) + np.timedelta64(14, "h") self.query_params = { "time": (str(start_date), str(end_date)), "geopolygon": geopolygon, } # Find matching datasets dss = [ dc.find_datasets(product=i, **self.query_params) for i in sat_params[self.dealayer]["products"] ] dss = list(itertools.chain.from_iterable(dss)) # If data is found if len(dss) > 0: # Get CRS crs = str(dss[0].crs) # Get sensor (try/except to account for different S2 NRT metadata) try: sensor = dss[0].metadata_doc["properties"]["eo:platform"].capitalize() except: sensor = dss[0].metadata_doc["platform"]["code"].capitalize() self.sensor = sensor[0:-1].replace("_", "-") + sensor[-1].capitalize() # Meets pansharpening requirements can_pansharpen = self.style == "True colour" and self.sensor in [ "Landsat-7", "Landsat-8", "Landsat-9", ] # Set up load params if self.pansharpen and can_pansharpen: self.load_params = { "measurements": sat_params[self.dealayer]["styles"][self.style][1] + ["nbart_panchromatic"], "resolution": (-self.resolution, self.resolution), "align": (7.5, 7.5), "output_crs": crs, } else: # Use resolution if provided, otherwise use default if self.resolution: sat_params[self.dealayer]["resolution"] = ( -self.resolution, self.resolution, ) self.load_params = { "measurements": sat_params[self.dealayer]["styles"][self.style][1], "resolution": (-self.resolution, self.resolution), "output_crs": crs, "skip_broken_datasets": True, } # Load data from datasets print(f"Loading {self.sensor} satellite data") ds = dc.load( datasets=dss, resampling="bilinear", group_by="solar_day", dask_chunks={"time": 1, "x": 2048, "y": 2048}, **self.load_params, **self.query_params, ) ds = masking.mask_invalid_data(ds) # Create plain numpy array, optionally after pansharpening if self.pansharpen and can_pansharpen: # Perform Brovey pan-sharpening and return numpy.array print(f"Pansharpening {self.sensor} image to 15 m resolution") rgb_array = ( xr_pansharpen(ds, transform="brovey").to_array().squeeze("time").values ) # If pansharpening is requested but not possible, deactivate # pansharpening and reset to 30 m resolution elif self.pansharpen and not can_pansharpen: print("\nUnable to pansharpen; reverting to 30 m resolution") self.checkbox_pansharpen.value = False self.text_resolution.disabled = False self.text_resolution.value = 30 rgb_array = ds.isel(time=0).to_array().values else: rgb_array = ds.isel(time=0).to_array().values # Transpose numpy array rgb_array = np.transpose(rgb_array, axes=[1, 2, 0]) # Else if no data is returned, return None else: rgb_array = None # Close down the dask client client.close() return rgb_array def plot_data(self, fname): # Data to plot to_plot = self.rgb_array # If percentile stretch is supplied, calculate vmin and vmax # from percentiles if self.percentile_stretch: vmin, vmax = np.nanpercentile(to_plot, self.percentile_stretch) else: vmin, vmax = self.vmin, self.vmax # Raise by power to dampen bright features and enhance dark. # Raise vmin and vmax by same amount to ensure proper stretch if self.power < 1.0: with self.status_info: print(f"\nApplying power transformation ({self.power})") to_plot = to_plot ** self.power vmin, vmax = vmin ** self.power, vmax ** self.power # Rescale/stretch imagery between vmin and vmax to_plot = exposure.rescale_intensity( to_plot.astype(float), in_range=(vmin, vmax), out_range=(0.0, 1.0) ) # Unsharp mask if self.unsharp_mask: with self.status_info: print( f"\nApplying unsharp masking with {self.unsharp_mask_radius} " f"radius and {self.unsharp_mask_amount} amount" ) to_plot = unsharp_mask( to_plot, radius=self.unsharp_mask_radius, amount=self.unsharp_mask_amount ) # Create figure with aspect ratio of data fig = plt.figure(dpi=100) fig.set_size_inches(10, 10 / (to_plot.shape[1] / to_plot.shape[0])) # Remove axes to plot just array data ax = plt.Axes( fig, [0.0, 0.0, 1.0, 1.0], ) ax.set_axis_off() fig.add_axes(ax) # Add data to plot ax.imshow(to_plot) # If a min DPI is specified and image is less than DPI if (self.dpi > 0) and (to_plot.shape[1] < self.dpi * 10): # Export figure to file using exact DPI with self.status_info: print(f"\nExporting image at {self.dpi} DPI") fig.savefig( fname.replace("resolution", f"resolution, {self.dpi} DPI"), dpi=self.dpi ) # If no minumum DPI is specified, export raw array data in native # resolution else: plt.imsave( fname=fname, arr=np.ascontiguousarray(to_plot), format=self.output_format ) # Add plot preview below map and finish plt.show() with self.status_info: print(f"\nImage successfully exported to:\n{fname}.")
[docs] class imageexport_app(HBox): def __init__(self): super().__init__() ###################### # INITIAL ATTRIBUTES # ###################### # Basemap self.basemap_list = [ ("Open Street Map", basemap_to_tiles(basemaps.OpenStreetMap.Mapnik)), ("ESRI World Imagery", basemap_to_tiles(basemaps.Esri.WorldImagery)), ] self.basemap = self.basemap_list[0][1] # Satellite data using yesterday's date date = datetime.datetime.today() date = datetime.datetime(year=date.year, month=date.month, day=date.day - 1) self.date = date.strftime("%Y-%m-%d") self.dealayer_list = [ ("Landsat", "ga_ls_ard_3"), ("Sentinel-2", "ga_s2m_ard_3"), ] self.dealayer = self.dealayer_list[0][1] # Styles self.styles_list = ["True colour", "False colour"] self.style = self.styles_list[0] # Analysis params self.resolution = 30 self.pansharpen = False self.standardise_name = False self.vmin = 50 self.vmax = 3000 self.percentile_stretch = None # (1, 99) self.power = 1.0 self.output_list = [("JPG", "jpg"), ("PNG", "png")] self.output_format = self.output_list[0][1] self.unsharp_mask = False self.unsharp_mask_radius = 20 self.unsharp_mask_amount = 0.3 self.max_size = False self.dpi = 0 # Drawing params self.target = None self.action = None self.gdf_drawn = None # Data load params self.rgb_array = None self.sensor = None self.load_params = None self.query_params = None ################## # HEADER FOR APP # ################## # Create the Header widget header_title_text = "<h3>Digital Earth Australia satellite image export</h3>" instruction_text = ( "<p>Select the desired satellite data, imagery " "date and image style, zoom in until satellite " "imagery appears on the map, then draw a " "rectangle to select an area of imagery to " "export as a high-resolution image file.</p>" ) self.header = deawidgets.create_html(f"{header_title_text}{instruction_text}") self.header.layout = make_box_layout() ##################################### # HANDLER FUNCTION FOR DRAW CONTROL # ##################################### # Define the action to take once something is drawn on the map def update_geojson(target, action, geo_json): # Get data from action self.action = action # Clear data load params to trigger data re-load self.rgb_array = None self.sensor = None self.load_params = None self.query_params = None # Convert data to geopandas json_data = json.dumps(geo_json) binary_data = json_data.encode() io = BytesIO(binary_data) io.seek(0) gdf = gpd.read_file(io) gdf.crs = "EPSG:4326" # Convert to Albers and compute area gdf_drawn_albers = gdf.copy().to_crs("EPSG:3577") m2_per_km2 = 10 ** 6 area = gdf_drawn_albers.area.values[0] / m2_per_km2 polyarea_label = "Total area of satellite data to extract" polyarea_text = f"<b>{polyarea_label}</b>: {area:.2f} km<sup>2</sup>" # Test area size if self.max_size: confirmation_text = ( '<span style="color: #33cc33"> ' "<b>(Overriding maximum size limit; use with caution as may lead to memory issues)</b></span>" ) self.header.value = ( header_title_text + instruction_text + polyarea_text + confirmation_text ) self.gdf_drawn = gdf elif area <= 10000: confirmation_text = ( '<span style="color: #33cc33"> ' "<b>(Area to extract falls within " "recommended limit)</b></span>" ) self.header.value = ( header_title_text + instruction_text + polyarea_text + confirmation_text ) self.gdf_drawn = gdf else: warning_text = ( '<span style="color: #ff5050"> ' "<b>(Area to extract is too large, " "please select an area less than 10000 " "km<sup>2)</b></span>" ) self.header.value = ( header_title_text + instruction_text + polyarea_text + warning_text ) self.gdf_drawn = None ########################### # WIDGETS FOR APP OUTPUTS # ########################### self.status_info = Output(layout=make_box_layout()) self.output_plot = Output(layout=make_box_layout()) ######################################### # MAP WIDGET, DRAWING TOOLS, WMS LAYERS # ######################################### # Create drawing tools desired_drawtools = ["rectangle"] draw_control = deawidgets.create_drawcontrol(desired_drawtools) # Begin by displaying an empty layer group, and update the group with desired WMS on interaction. self.map_layers = LayerGroup(layers=()) self.map_layers.name = "Map Overlays" # Create map widget self.m = deawidgets.create_map(map_center=(-28, 135), zoom_level=4) self.m.layout = make_box_layout() # Add tools to map widget self.m.add_control(draw_control) self.m.add_control(SearchControl( position="topleft", url='https://nominatim.openstreetmap.org/search?format=json&q={s}', zoom=13, # 'Village / Suburb' level zoom marker=Marker(draggable=False) )) self.m.add_layer(self.map_layers) # Update all maps to starting defaults update_map_layers(self) ############################ # WIDGETS FOR APP CONTROLS # ############################ # Create parameter widgets dropdown_basemap = deawidgets.create_dropdown( self.basemap_list, self.basemap_list[0][1] ) dropdown_dealayer = deawidgets.create_dropdown( self.dealayer_list, self.dealayer_list[0][1] ) dropdown_output = deawidgets.create_dropdown( self.output_list, self.output_list[0][1] ) date_picker = deawidgets.create_datepicker(value=date) dropdown_styles = deawidgets.create_dropdown( self.styles_list, self.styles_list[0] ) slider_abs = widgets.IntRangeSlider( value=[50, 3000], min=0, max=10000, step=25, description="", layout={"width": "85%"}, ) run_button = create_expanded_button("Export imagery", "info") # Expandable advanced section text_resolution = widgets.FloatText( value=30, description="", layout={"width": "100%", "margin": "0px", "padding": "0px"}, ) checkbox_pansharpen = deawidgets.create_checkbox( self.pansharpen, "Pansharpen Landsat" ) slider_power = widgets.FloatSlider( value=1.0, min=0.01, max=1.0, step=0.01, description="", layout={"width": "85%"}, ) checkbox_unsharp_mask = deawidgets.create_checkbox( self.unsharp_mask, "Enable", layout={"width": "100%"} ) text_unsharp_mask_radius = widgets.FloatText( value=20, step=1, description="Radius", layout={ "width": "100%", "margin": "0px", "padding": "0px", "display": "none", }, ) text_unsharp_mask_amount = widgets.FloatText( value=0.3, step=0.1, description="Amount", layout={ "width": "100%", "margin": "0px", "padding": "0px", "display": "none", }, ) checkbox_max_size = deawidgets.create_checkbox(self.max_size, "Enable") text_dpi = widgets.IntText( value=0, description="", step=50, layout={"width": "85%"} ) html_dpi = HTML( "</br>Minimum DPI for image export</br>(100 DPI = 1000 pixels wide):" ) expand_box = widgets.VBox( [ HTML("Resolution (metres):"), text_resolution, checkbox_pansharpen, HTML("</br>Apply power transformation to darken bright features:"), slider_power, HTML("</br>Apply unsharp masking to sharpen image:"), checkbox_unsharp_mask, text_unsharp_mask_radius, text_unsharp_mask_amount, HTML( "</br>Override maximum size limit: (use with caution; may cause memory issues/crashes)" ), checkbox_max_size, html_dpi, text_dpi, ], layout={"overflow": "hidden"}, ) expand = widgets.Accordion(children=[expand_box], selected_index=None) expand.set_title(0, "Advanced") # Add specific dialogs to class so they can be modified self.text_resolution = text_resolution self.checkbox_pansharpen = checkbox_pansharpen self.text_unsharp_mask_radius = text_unsharp_mask_radius self.text_unsharp_mask_amount = text_unsharp_mask_amount self.html_dpi = html_dpi #################################### # UPDATE FUNCTIONS FOR EACH WIDGET # #################################### # Run update functions whenever various widgets are changed. date_picker.observe(self.update_date, "value") dropdown_basemap.observe(self.update_basemap, "value") dropdown_dealayer.observe(self.update_dealayer, "value") dropdown_styles.observe(self.update_styles, "value") dropdown_output.observe(self.update_output, "value") run_button.on_click(self.run_app) draw_control.on_draw(update_geojson) slider_abs.observe(self.update_slider_abs, "value") # Advanced params text_resolution.observe(self.update_text_resolution, "value") checkbox_pansharpen.observe(self.update_checkbox_pansharpen, "value") slider_power.observe(self.update_slider_power, "value") checkbox_unsharp_mask.observe(self.update_checkbox_unsharp_mask, "value") text_unsharp_mask_radius.observe(self.update_text_unsharp_mask_radius, "value") text_unsharp_mask_amount.observe(self.update_text_unsharp_mask_amount, "value") checkbox_max_size.observe(self.update_checkbox_max_size, "value") text_dpi.observe(self.update_dpi, "value") ################################## # COLLECTION OF ALL APP CONTROLS # ################################## parameter_selection = VBox( [ HTML("<b>Date:</b>"), date_picker, HTML("<b>Satellite imagery:</b>"), dropdown_dealayer, HTML("<b>Style:</b>"), dropdown_styles, HTML("<b>Colour stretch:</b>"), slider_abs, HTML("<b>Output file format:</b>"), dropdown_output, HTML("</br>"), expand, ] ) map_selection = VBox( [ HTML("</br><b>Map overlay:</b>"), dropdown_basemap, ] ) parameter_selection.layout = make_box_layout() map_selection.layout = make_box_layout() ############################### # SPECIFICATION OF APP LAYOUT # ############################### # 0 1 2 3 4 5 6 7 8 9 # --------------------------------------------- # 0 | Header | Map sel. | # |-------------------------------------------| # 1 | Params | | # 2 | | | # 3 | | | # 4 | | Map | # 5 | | | # |--------| | # 6 | Run | | # |-------------------------------------------| # 7 | Status info | Figure/output | # 8 | | | # 9 | | | # 10 | | | # 11 --------------------------------------------- # Create the layout #[rowspan, colspan] grid = GridspecLayout(12, 10, height="1400px", width="auto") # Header and controls grid[0, :8] = self.header grid[0, 8:] = map_selection grid[1:6, 0:2] = parameter_selection grid[6, 0:2] = run_button # Status info, map and plot grid[1:7, 2:] = self.m # map grid[7:, 0:4] = self.status_info grid[7:, 4:] = self.output_plot # Display using HBox children attribute self.children = [grid] ###################################### # DEFINITION OF ALL UPDATE FUNCTIONS # ###################################### # Update date def update_date(self, change): self.date = str(change.new) update_map_layers(self) # Update colour stretch def update_slider_abs(self, change): self.vmin, self.vmax = change.new # Update power transform def update_slider_power(self, change): self.power = change.new # Enable pansharpening and reset/deactivate resolution def update_checkbox_pansharpen(self, change): self.pansharpen = change.new # Override default resolution if pansharpening is specified; # disable input if so if change.new: self.text_resolution.value = 15 self.text_resolution.disabled = True else: self.text_resolution.value = 30 self.text_resolution.disabled = False # Enable unsharp masking and show/hide custom params def update_checkbox_unsharp_mask(self, change): self.unsharp_mask = change.new # Show unsharp masking params in menu if activated if change.new: self.text_unsharp_mask_radius.layout.display = "block" self.text_unsharp_mask_amount.layout.display = "block" else: self.text_unsharp_mask_radius.layout.display = "none" self.text_unsharp_mask_amount.layout.display = "none" # Change unsharp masking radius def update_text_unsharp_mask_radius(self, change): self.unsharp_mask_radius = change.new # Change unsharp masking amount def update_text_unsharp_mask_amount(self, change): self.unsharp_mask_amount = change.new # Override max size limit def update_checkbox_max_size(self, change): self.max_size = change.new # Override min DPI def update_dpi(self, change): self.dpi = change.new # Update DPI helper text to give output resolution self.html_dpi.value = ( f"</br>Minimum DPI for image export</br>" f"({change.new} DPI = {change.new * 10} " f"pixels wide):" ) # Update resolution def update_text_resolution(self, change): self.resolution = change.new # Clear data load params to trigger data re-load self.rgb_array = None self.sensor = None self.load_params = None self.query_params = None # Change layers shown on the map def update_dealayer(self, change): self.dealayer = change.new if change.new == "ga_ls_ard_3": self.text_resolution.value = 30 self.checkbox_pansharpen.disabled = False if self.pansharpen: self.text_resolution.value = 15 self.text_resolution.disabled = True else: self.text_resolution.value = 10 self.checkbox_pansharpen.disabled = True self.text_resolution.disabled = False update_map_layers(self) # Update basemap def update_basemap(self, change): self.basemap = change.new update_map_layers(self) # Set imagery style def update_styles(self, change): self.style = change.new update_map_layers(self) # Set output file format def update_output(self, change): self.output_format = change.new def run_app(self, change): # Clear progress bar and output areas before running self.status_info.clear_output() self.output_plot.clear_output() # Verify that polygon was drawn if self.gdf_drawn is not None: with self.status_info: # Load data and add to attribute if self.rgb_array is None: self.rgb_array = extract_data(self) else: print("Using previously loaded data") if self.rgb_array is not None: with self.status_info: # Create unique file name centre_coords = self.gdf_drawn.geometry[0].centroid.coords[0][::-1] site = reverse_geocode(coords=centre_coords) fname = ( f"{self.sensor} - {self.date} - {site} - {self.style}, " f"{self.resolution:.0f} m resolution.{self.output_format}" ) # Remove spaces and commas if requested if self.standardise_name: fname = ( fname.replace(" - ", "_") .replace(", ", "-") .replace(" ", "-") .lower() ) print( f"\nExporting image for {site}.\nThis may take several minutes..." ) ############ # Plotting # ############ with self.output_plot: plot_data(self, fname) else: with self.status_info: print( "No satellite data found in the selected area. " "Please select a new rectangle over an area with " "satellite imagery visible on the map." ) else: with self.status_info: print( 'Please draw a valid rectangle on the map, then press "Export imagery"' )