Source code for dea_tools.app.imageexport

# -*- coding: utf-8 -*-
"""
Image export widget, which can be used to interactively select and
export satellite imagery from multiple DEA products.
"""

# Import required packages
import datetime
import itertools
import json
from io import BytesIO

import datacube
import geopandas as gpd
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from datacube.utils import masking
from datacube.utils.geometry import Geometry
from ipyleaflet import (
    LayerGroup,
    Marker,
    SearchControl,
    basemap_to_tiles,
    basemaps,
)
from ipywidgets import (
    HTML,
    Button,
    GridspecLayout,
    HBox,
    Layout,
    Output,
    VBox,
)
from skimage import exposure
from skimage.filters import unsharp_mask

from dea_tools.app.widgetconstructors import (
    create_checkbox,
    create_datepicker,
    create_dea_wms_layer,
    create_drawcontrol,
    create_dropdown,
    create_html,
    create_map,
)
from dea_tools.dask import create_local_dask_cluster
from dea_tools.datahandling import xr_pansharpen
from dea_tools.spatial import reverse_geocode

# WMS params and satellite style bands
sat_params = {
    "ga_ls_ard_3": {
        "products": [
            "ga_ls5t_ard_3",
            "ga_ls7e_ard_3",
            "ga_ls8c_ard_3",
            "ga_ls9c_ard_3",
        ],
        "styles": {
            "True colour": ("true_colour", ["nbart_red", "nbart_green", "nbart_blue"]),
            "False colour": (
                "false_colour",
                ["nbart_swir_1", "nbart_nir", "nbart_green"],
            ),
        },
    },
    "ga_s2m_ard_3": {
        "products": ["ga_s2am_ard_3", "ga_s2bm_ard_3", "ga_s2cm_ard_3"],
        "styles": {
            "True colour": ("simple_rgb", ["nbart_red", "nbart_green", "nbart_blue"]),
            "False colour": (
                "infrared_green",
                ["nbart_swir_2", "nbart_nir_1", "nbart_green"],
            ),
        },
    },
}


def make_box_layout():
    return Layout(
        #         border='solid 1px black',
        margin="0px 10px 10px 0px",
        padding="5px 5px 5px 5px",
        width="100%",
        height="100%",
    )


def create_expanded_button(description, button_style):
    return Button(
        description=description,
        button_style=button_style,
        layout=Layout(width="auto", height="auto"),
    )


[docs] def update_map_layers(self): """ Updates map to add new DEA layers, styles or basemap when selected using menu options. Triggers data reload by resetting load params and output arrays. """ # Clear data load params to trigger data re-load self.rgb_array = None self.sensor = None self.load_params = None self.query_params = None # Clear all layers and add basemap self.map_layers.clear_layers() self.map_layers.add_layer(self.basemap) # Get style name for specific satellite sensor style = sat_params[self.dealayer]["styles"][self.style][0] # Add DEA layers over the top of the basemap dea_layer = create_dea_wms_layer(self.dealayer, self.date, styles=style) self.map_layers.add_layer(dea_layer)
def extract_data(self): # Connect to datacube database dc = datacube.Datacube(app="Exporting satellite images") # Configure local dask cluster client = create_local_dask_cluster(return_client=True, display_client=True) # Convert to geopolygon geopolygon = Geometry(geom=self.gdf_drawn.geometry[0], crs=self.gdf_drawn.crs) # Create query after adjusting interval time to UTC by # adding a UTC offset of -10 hours. This results issues # on the east coast of Australia where satelite overpasses # can occur on either side of 24:00 hours UTC start_date = np.datetime64(self.date) - np.timedelta64(10, "h") end_date = np.datetime64(self.date) + np.timedelta64(14, "h") self.query_params = { "time": (str(start_date), str(end_date)), "geopolygon": geopolygon, } # Find matching datasets dss = [dc.find_datasets(product=i, **self.query_params) for i in sat_params[self.dealayer]["products"]] dss = list(itertools.chain.from_iterable(dss)) # If data is found if len(dss) > 0: # Get CRS crs = str(dss[0].crs) # Get sensor (try/except to account for different S2 NRT metadata) try: sensor = dss[0].metadata_doc["properties"]["eo:platform"].capitalize() except: sensor = dss[0].metadata_doc["platform"]["code"].capitalize() self.sensor = sensor[0:-1].replace("_", "-") + sensor[-1].capitalize() # Meets pansharpening requirements can_pansharpen = self.style == "True colour" and self.sensor in [ "Landsat-7", "Landsat-8", "Landsat-9", ] # Set up load params if self.pansharpen and can_pansharpen: self.load_params = { "measurements": sat_params[self.dealayer]["styles"][self.style][1] + ["nbart_panchromatic"], "resolution": (-self.resolution, self.resolution), "align": (7.5, 7.5), "output_crs": crs, } else: # Use resolution if provided, otherwise use default if self.resolution: sat_params[self.dealayer]["resolution"] = ( -self.resolution, self.resolution, ) self.load_params = { "measurements": sat_params[self.dealayer]["styles"][self.style][1], "resolution": (-self.resolution, self.resolution), "output_crs": crs, "skip_broken_datasets": True, } # Load data from datasets print(f"Loading {self.sensor} satellite data") ds = dc.load( datasets=dss, resampling="bilinear", group_by="solar_day", dask_chunks={"time": 1, "x": 2048, "y": 2048}, **self.load_params, **self.query_params, ) ds = masking.mask_invalid_data(ds) # Create plain numpy array, optionally after pansharpening if self.pansharpen and can_pansharpen: # Perform Brovey pan-sharpening and return numpy.array print(f"Pansharpening {self.sensor} image to 15 m resolution") rgb_array = xr_pansharpen(ds, transform="brovey").to_array().squeeze("time").values # If pansharpening is requested but not possible, deactivate # pansharpening and reset to 30 m resolution elif self.pansharpen and not can_pansharpen: print("\nUnable to pansharpen; reverting to 30 m resolution") self.checkbox_pansharpen.value = False self.text_resolution.disabled = False self.text_resolution.value = 30 rgb_array = ds.isel(time=0).to_array().values else: rgb_array = ds.isel(time=0).to_array().values # Transpose numpy array rgb_array = np.transpose(rgb_array, axes=[1, 2, 0]) # Else if no data is returned, return None else: rgb_array = None # Close down the dask client client.close() return rgb_array def plot_data(self, fname): # Data to plot to_plot = self.rgb_array # If percentile stretch is supplied, calculate vmin and vmax # from percentiles if self.percentile_stretch: vmin, vmax = np.nanpercentile(to_plot, self.percentile_stretch) else: vmin, vmax = self.vmin, self.vmax # Raise by power to dampen bright features and enhance dark. # Raise vmin and vmax by same amount to ensure proper stretch if self.power < 1.0: with self.status_info: print(f"\nApplying power transformation ({self.power})") to_plot = to_plot**self.power vmin, vmax = vmin**self.power, vmax**self.power # Rescale/stretch imagery between vmin and vmax to_plot = exposure.rescale_intensity(to_plot.astype(float), in_range=(vmin, vmax), out_range=(0.0, 1.0)) # Unsharp mask if self.unsharp_mask: with self.status_info: print( f"\nApplying unsharp masking with {self.unsharp_mask_radius} " f"radius and {self.unsharp_mask_amount} amount" ) to_plot = unsharp_mask(to_plot, radius=self.unsharp_mask_radius, amount=self.unsharp_mask_amount) # Create figure with aspect ratio of data fig = plt.figure(dpi=100) fig.set_size_inches(10, 10 / (to_plot.shape[1] / to_plot.shape[0])) # Remove axes to plot just array data ax = plt.Axes( fig, [0.0, 0.0, 1.0, 1.0], ) ax.set_axis_off() fig.add_axes(ax) # Add data to plot ax.imshow(to_plot) # If a min DPI is specified and image is less than DPI if (self.dpi > 0) and (to_plot.shape[1] < self.dpi * 10): # Export figure to file using exact DPI with self.status_info: print(f"\nExporting image at {self.dpi} DPI") fig.savefig(fname.replace("resolution", f"resolution, {self.dpi} DPI"), dpi=self.dpi) # If no minumum DPI is specified, export raw array data in native # resolution else: plt.imsave(fname=fname, arr=np.ascontiguousarray(to_plot), format=self.output_format) # Add plot preview below map and finish plt.show() with self.status_info: print(f"\nImage successfully exported to:\n{fname}.")
[docs] class imageexport_app(HBox): def __init__(self): super().__init__() ###################### # INITIAL ATTRIBUTES # ###################### # Basemap self.basemap_list = [ ("Open Street Map", basemap_to_tiles(basemaps.OpenStreetMap.Mapnik)), ("ESRI World Imagery", basemap_to_tiles(basemaps.Esri.WorldImagery)), ] self.basemap = self.basemap_list[0][1] # Satellite data using yesterday's date date = datetime.datetime.today() date = datetime.datetime(year=date.year, month=date.month, day=date.day - 1) self.date = date.strftime("%Y-%m-%d") self.dealayer_list = [ ("Landsat", "ga_ls_ard_3"), ("Sentinel-2", "ga_s2m_ard_3"), ] self.dealayer = self.dealayer_list[0][1] # Styles self.styles_list = ["True colour", "False colour"] self.style = self.styles_list[0] # Analysis params self.resolution = 30 self.pansharpen = False self.standardise_name = False self.vmin = 50 self.vmax = 3000 self.percentile_stretch = None # (1, 99) self.power = 1.0 self.output_list = [("JPG", "jpg"), ("PNG", "png")] self.output_format = self.output_list[0][1] self.unsharp_mask = False self.unsharp_mask_radius = 20 self.unsharp_mask_amount = 0.3 self.max_size = False self.dpi = 0 # Drawing params self.target = None self.action = None self.gdf_drawn = None # Data load params self.rgb_array = None self.sensor = None self.load_params = None self.query_params = None ################## # HEADER FOR APP # ################## # Create the Header widget header_title_text = "<h3>Digital Earth Australia satellite image export</h3>" instruction_text = ( "<p>Select the desired satellite data, imagery " "date and image style, zoom in until satellite " "imagery appears on the map, then draw a " "rectangle to select an area of imagery to " "export as a high-resolution image file.</p>" ) self.header = create_html(f"{header_title_text}{instruction_text}") self.header.layout = make_box_layout() ##################################### # HANDLER FUNCTION FOR DRAW CONTROL # ##################################### # Define the action to take once something is drawn on the map def update_geojson(target, action, geo_json): # Get data from action self.action = action # Clear data load params to trigger data re-load self.rgb_array = None self.sensor = None self.load_params = None self.query_params = None # Convert data to geopandas json_data = json.dumps(geo_json) binary_data = json_data.encode() io = BytesIO(binary_data) io.seek(0) gdf = gpd.read_file(io) gdf.crs = "EPSG:4326" # Convert to Albers and compute area gdf_drawn_albers = gdf.copy().to_crs("EPSG:3577") m2_per_km2 = 10**6 area = gdf_drawn_albers.area.values[0] / m2_per_km2 polyarea_label = "Total area of satellite data to extract" polyarea_text = f"<b>{polyarea_label}</b>: {area:.2f} km<sup>2</sup>" # Test area size if self.max_size: confirmation_text = ( '<span style="color: #33cc33"> ' "<b>(Overriding maximum size limit; use with caution as may lead to memory issues)</b></span>" ) self.header.value = header_title_text + instruction_text + polyarea_text + confirmation_text self.gdf_drawn = gdf elif area <= 10000: confirmation_text = ( '<span style="color: #33cc33"> <b>(Area to extract falls within recommended limit)</b></span>' ) self.header.value = header_title_text + instruction_text + polyarea_text + confirmation_text self.gdf_drawn = gdf else: warning_text = ( '<span style="color: #ff5050"> ' "<b>(Area to extract is too large, " "please select an area less than 10000 " "km<sup>2)</b></span>" ) self.header.value = header_title_text + instruction_text + polyarea_text + warning_text self.gdf_drawn = None ########################### # WIDGETS FOR APP OUTPUTS # ########################### self.status_info = Output(layout=make_box_layout()) self.output_plot = Output(layout=make_box_layout()) ######################################### # MAP WIDGET, DRAWING TOOLS, WMS LAYERS # ######################################### # Create drawing tools desired_drawtools = ["rectangle"] draw_control = create_drawcontrol(desired_drawtools) # Begin by displaying an empty layer group, and update the group with desired WMS on interaction. self.map_layers = LayerGroup(layers=()) self.map_layers.name = "Map Overlays" # Create map widget self.m = create_map(map_center=(-28, 135), zoom_level=4) self.m.layout = make_box_layout() # Add tools to map widget self.m.add_control(draw_control) self.m.add_control( SearchControl( position="topleft", url="https://nominatim.openstreetmap.org/search?format=json&q={s}", zoom=13, # 'Village / Suburb' level zoom marker=Marker(draggable=False), ) ) self.m.add_layer(self.map_layers) # Update all maps to starting defaults update_map_layers(self) ############################ # WIDGETS FOR APP CONTROLS # ############################ # Create parameter widgets dropdown_basemap = create_dropdown(self.basemap_list, self.basemap_list[0][1]) dropdown_dealayer = create_dropdown(self.dealayer_list, self.dealayer_list[0][1]) dropdown_output = create_dropdown(self.output_list, self.output_list[0][1]) date_picker = create_datepicker(value=date) dropdown_styles = create_dropdown(self.styles_list, self.styles_list[0]) slider_abs = widgets.IntRangeSlider( value=[50, 3000], min=0, max=10000, step=25, description="", layout={"width": "85%"}, ) run_button = create_expanded_button("Export imagery", "info") # Expandable advanced section text_resolution = widgets.FloatText( value=30, description="", layout={"width": "100%", "margin": "0px", "padding": "0px"}, ) checkbox_pansharpen = create_checkbox(self.pansharpen, "Pansharpen Landsat") slider_power = widgets.FloatSlider( value=1.0, min=0.01, max=1.0, step=0.01, description="", layout={"width": "85%"}, ) checkbox_unsharp_mask = create_checkbox(self.unsharp_mask, "Enable", layout={"width": "100%"}) text_unsharp_mask_radius = widgets.FloatText( value=20, step=1, description="Radius", layout={ "width": "100%", "margin": "0px", "padding": "0px", "display": "none", }, ) text_unsharp_mask_amount = widgets.FloatText( value=0.3, step=0.1, description="Amount", layout={ "width": "100%", "margin": "0px", "padding": "0px", "display": "none", }, ) checkbox_max_size = create_checkbox(self.max_size, "Enable") text_dpi = widgets.IntText(value=0, description="", step=50, layout={"width": "85%"}) html_dpi = HTML("</br>Minimum DPI for image export</br>(100 DPI = 1000 pixels wide):") expand_box = widgets.VBox( [ HTML("Resolution (metres):"), text_resolution, checkbox_pansharpen, HTML("</br>Apply power transformation to darken bright features:"), slider_power, HTML("</br>Apply unsharp masking to sharpen image:"), checkbox_unsharp_mask, text_unsharp_mask_radius, text_unsharp_mask_amount, HTML("</br>Override maximum size limit: (use with caution; may cause memory issues/crashes)"), checkbox_max_size, html_dpi, text_dpi, ], layout={"overflow": "hidden"}, ) expand = widgets.Accordion(children=[expand_box], selected_index=None) expand.set_title(0, "Advanced") # Add specific dialogs to class so they can be modified self.text_resolution = text_resolution self.checkbox_pansharpen = checkbox_pansharpen self.text_unsharp_mask_radius = text_unsharp_mask_radius self.text_unsharp_mask_amount = text_unsharp_mask_amount self.html_dpi = html_dpi #################################### # UPDATE FUNCTIONS FOR EACH WIDGET # #################################### # Run update functions whenever various widgets are changed. date_picker.observe(self.update_date, "value") dropdown_basemap.observe(self.update_basemap, "value") dropdown_dealayer.observe(self.update_dealayer, "value") dropdown_styles.observe(self.update_styles, "value") dropdown_output.observe(self.update_output, "value") run_button.on_click(self.run_app) draw_control.on_draw(update_geojson) slider_abs.observe(self.update_slider_abs, "value") # Advanced params text_resolution.observe(self.update_text_resolution, "value") checkbox_pansharpen.observe(self.update_checkbox_pansharpen, "value") slider_power.observe(self.update_slider_power, "value") checkbox_unsharp_mask.observe(self.update_checkbox_unsharp_mask, "value") text_unsharp_mask_radius.observe(self.update_text_unsharp_mask_radius, "value") text_unsharp_mask_amount.observe(self.update_text_unsharp_mask_amount, "value") checkbox_max_size.observe(self.update_checkbox_max_size, "value") text_dpi.observe(self.update_dpi, "value") ################################## # COLLECTION OF ALL APP CONTROLS # ################################## parameter_selection = VBox([ HTML("<b>Date:</b>"), date_picker, HTML("<b>Satellite imagery:</b>"), dropdown_dealayer, HTML("<b>Style:</b>"), dropdown_styles, HTML("<b>Colour stretch:</b>"), slider_abs, HTML("<b>Output file format:</b>"), dropdown_output, HTML("</br>"), expand, ]) map_selection = VBox([ HTML("</br><b>Map overlay:</b>"), dropdown_basemap, ]) parameter_selection.layout = make_box_layout() map_selection.layout = make_box_layout() ############################### # SPECIFICATION OF APP LAYOUT # ############################### # 0 1 2 3 4 5 6 7 8 9 # --------------------------------------------- # 0 | Header | Map sel. | # |-------------------------------------------| # 1 | Params | | # 2 | | | # 3 | | | # 4 | | Map | # 5 | | | # |--------| | # 6 | Run | | # |-------------------------------------------| # 7 | Status info | Figure/output | # 8 | | | # 9 | | | # 10 | | | # 11 --------------------------------------------- # Create the layout #[rowspan, colspan] grid = GridspecLayout(12, 10, height="1400px", width="auto") # Header and controls grid[0, :8] = self.header grid[0, 8:] = map_selection grid[1:6, 0:2] = parameter_selection grid[6, 0:2] = run_button # Status info, map and plot grid[1:7, 2:] = self.m # map grid[7:, 0:4] = self.status_info grid[7:, 4:] = self.output_plot # Display using HBox children attribute self.children = [grid] ###################################### # DEFINITION OF ALL UPDATE FUNCTIONS # ###################################### # Update date def update_date(self, change): self.date = str(change.new) update_map_layers(self) # Update colour stretch def update_slider_abs(self, change): self.vmin, self.vmax = change.new # Update power transform def update_slider_power(self, change): self.power = change.new # Enable pansharpening and reset/deactivate resolution def update_checkbox_pansharpen(self, change): self.pansharpen = change.new # Override default resolution if pansharpening is specified; # disable input if so if change.new: self.text_resolution.value = 15 self.text_resolution.disabled = True else: self.text_resolution.value = 30 self.text_resolution.disabled = False # Enable unsharp masking and show/hide custom params def update_checkbox_unsharp_mask(self, change): self.unsharp_mask = change.new # Show unsharp masking params in menu if activated if change.new: self.text_unsharp_mask_radius.layout.display = "block" self.text_unsharp_mask_amount.layout.display = "block" else: self.text_unsharp_mask_radius.layout.display = "none" self.text_unsharp_mask_amount.layout.display = "none" # Change unsharp masking radius def update_text_unsharp_mask_radius(self, change): self.unsharp_mask_radius = change.new # Change unsharp masking amount def update_text_unsharp_mask_amount(self, change): self.unsharp_mask_amount = change.new # Override max size limit def update_checkbox_max_size(self, change): self.max_size = change.new # Override min DPI def update_dpi(self, change): self.dpi = change.new # Update DPI helper text to give output resolution self.html_dpi.value = ( f"</br>Minimum DPI for image export</br>({change.new} DPI = {change.new * 10} pixels wide):" ) # Update resolution def update_text_resolution(self, change): self.resolution = change.new # Clear data load params to trigger data re-load self.rgb_array = None self.sensor = None self.load_params = None self.query_params = None # Change layers shown on the map def update_dealayer(self, change): self.dealayer = change.new if change.new == "ga_ls_ard_3": self.text_resolution.value = 30 self.checkbox_pansharpen.disabled = False if self.pansharpen: self.text_resolution.value = 15 self.text_resolution.disabled = True else: self.text_resolution.value = 10 self.checkbox_pansharpen.disabled = True self.text_resolution.disabled = False update_map_layers(self) # Update basemap def update_basemap(self, change): self.basemap = change.new update_map_layers(self) # Set imagery style def update_styles(self, change): self.style = change.new update_map_layers(self) # Set output file format def update_output(self, change): self.output_format = change.new def run_app(self, change): # Clear progress bar and output areas before running self.status_info.clear_output() self.output_plot.clear_output() # Verify that polygon was drawn if self.gdf_drawn is not None: with self.status_info: # Load data and add to attribute if self.rgb_array is None: self.rgb_array = extract_data(self) else: print("Using previously loaded data") if self.rgb_array is not None: with self.status_info: # Create unique file name centre_coords = self.gdf_drawn.geometry[0].centroid.coords[0][::-1] site = reverse_geocode(coords=centre_coords) fname = ( f"{self.sensor} - {self.date} - {site} - {self.style}, " f"{self.resolution:.0f} m resolution.{self.output_format}" ) # Remove spaces and commas if requested if self.standardise_name: fname = fname.replace(" - ", "_").replace(", ", "-").replace(" ", "-").lower() print(f"\nExporting image for {site}.\nThis may take several minutes...") ############ # Plotting # ############ with self.output_plot: plot_data(self, fname) else: with self.status_info: print( "No satellite data found in the selected area. " "Please select a new rectangle over an area with " "satellite imagery visible on the map." ) else: with self.status_info: print('Please draw a valid rectangle on the map, then press "Export imagery"')