# -*- coding: utf-8 -*-
"""
Image export widget, which can be used to interactively select and
export satellite imagery from multiple DEA products.
"""
# Import required packages
import datetime
import itertools
import json
from io import BytesIO
import datacube
import geopandas as gpd
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from datacube.utils import masking
from datacube.utils.geometry import Geometry
from ipyleaflet import (
LayerGroup,
Marker,
SearchControl,
basemap_to_tiles,
basemaps,
)
from ipywidgets import (
HTML,
Button,
GridspecLayout,
HBox,
Layout,
Output,
VBox,
)
from skimage import exposure
from skimage.filters import unsharp_mask
from dea_tools.app.widgetconstructors import (
create_checkbox,
create_datepicker,
create_dea_wms_layer,
create_drawcontrol,
create_dropdown,
create_html,
create_map,
)
from dea_tools.dask import create_local_dask_cluster
from dea_tools.datahandling import xr_pansharpen
from dea_tools.spatial import reverse_geocode
# WMS params and satellite style bands
sat_params = {
"ga_ls_ard_3": {
"products": [
"ga_ls5t_ard_3",
"ga_ls7e_ard_3",
"ga_ls8c_ard_3",
"ga_ls9c_ard_3",
],
"styles": {
"True colour": ("true_colour", ["nbart_red", "nbart_green", "nbart_blue"]),
"False colour": (
"false_colour",
["nbart_swir_1", "nbart_nir", "nbart_green"],
),
},
},
"ga_s2m_ard_3": {
"products": ["ga_s2am_ard_3", "ga_s2bm_ard_3", "ga_s2cm_ard_3"],
"styles": {
"True colour": ("simple_rgb", ["nbart_red", "nbart_green", "nbart_blue"]),
"False colour": (
"infrared_green",
["nbart_swir_2", "nbart_nir_1", "nbart_green"],
),
},
},
}
def make_box_layout():
return Layout(
# border='solid 1px black',
margin="0px 10px 10px 0px",
padding="5px 5px 5px 5px",
width="100%",
height="100%",
)
def create_expanded_button(description, button_style):
return Button(
description=description,
button_style=button_style,
layout=Layout(width="auto", height="auto"),
)
[docs]
def update_map_layers(self):
"""
Updates map to add new DEA layers, styles or basemap when selected
using menu options. Triggers data reload by resetting load params
and output arrays.
"""
# Clear data load params to trigger data re-load
self.rgb_array = None
self.sensor = None
self.load_params = None
self.query_params = None
# Clear all layers and add basemap
self.map_layers.clear_layers()
self.map_layers.add_layer(self.basemap)
# Get style name for specific satellite sensor
style = sat_params[self.dealayer]["styles"][self.style][0]
# Add DEA layers over the top of the basemap
dea_layer = create_dea_wms_layer(self.dealayer, self.date, styles=style)
self.map_layers.add_layer(dea_layer)
def extract_data(self):
# Connect to datacube database
dc = datacube.Datacube(app="Exporting satellite images")
# Configure local dask cluster
client = create_local_dask_cluster(return_client=True, display_client=True)
# Convert to geopolygon
geopolygon = Geometry(geom=self.gdf_drawn.geometry[0], crs=self.gdf_drawn.crs)
# Create query after adjusting interval time to UTC by
# adding a UTC offset of -10 hours. This results issues
# on the east coast of Australia where satelite overpasses
# can occur on either side of 24:00 hours UTC
start_date = np.datetime64(self.date) - np.timedelta64(10, "h")
end_date = np.datetime64(self.date) + np.timedelta64(14, "h")
self.query_params = {
"time": (str(start_date), str(end_date)),
"geopolygon": geopolygon,
}
# Find matching datasets
dss = [dc.find_datasets(product=i, **self.query_params) for i in sat_params[self.dealayer]["products"]]
dss = list(itertools.chain.from_iterable(dss))
# If data is found
if len(dss) > 0:
# Get CRS
crs = str(dss[0].crs)
# Get sensor (try/except to account for different S2 NRT metadata)
try:
sensor = dss[0].metadata_doc["properties"]["eo:platform"].capitalize()
except:
sensor = dss[0].metadata_doc["platform"]["code"].capitalize()
self.sensor = sensor[0:-1].replace("_", "-") + sensor[-1].capitalize()
# Meets pansharpening requirements
can_pansharpen = self.style == "True colour" and self.sensor in [
"Landsat-7",
"Landsat-8",
"Landsat-9",
]
# Set up load params
if self.pansharpen and can_pansharpen:
self.load_params = {
"measurements": sat_params[self.dealayer]["styles"][self.style][1] + ["nbart_panchromatic"],
"resolution": (-self.resolution, self.resolution),
"align": (7.5, 7.5),
"output_crs": crs,
}
else:
# Use resolution if provided, otherwise use default
if self.resolution:
sat_params[self.dealayer]["resolution"] = (
-self.resolution,
self.resolution,
)
self.load_params = {
"measurements": sat_params[self.dealayer]["styles"][self.style][1],
"resolution": (-self.resolution, self.resolution),
"output_crs": crs,
"skip_broken_datasets": True,
}
# Load data from datasets
print(f"Loading {self.sensor} satellite data")
ds = dc.load(
datasets=dss,
resampling="bilinear",
group_by="solar_day",
dask_chunks={"time": 1, "x": 2048, "y": 2048},
**self.load_params,
**self.query_params,
)
ds = masking.mask_invalid_data(ds)
# Create plain numpy array, optionally after pansharpening
if self.pansharpen and can_pansharpen:
# Perform Brovey pan-sharpening and return numpy.array
print(f"Pansharpening {self.sensor} image to 15 m resolution")
rgb_array = xr_pansharpen(ds, transform="brovey").to_array().squeeze("time").values
# If pansharpening is requested but not possible, deactivate
# pansharpening and reset to 30 m resolution
elif self.pansharpen and not can_pansharpen:
print("\nUnable to pansharpen; reverting to 30 m resolution")
self.checkbox_pansharpen.value = False
self.text_resolution.disabled = False
self.text_resolution.value = 30
rgb_array = ds.isel(time=0).to_array().values
else:
rgb_array = ds.isel(time=0).to_array().values
# Transpose numpy array
rgb_array = np.transpose(rgb_array, axes=[1, 2, 0])
# Else if no data is returned, return None
else:
rgb_array = None
# Close down the dask client
client.close()
return rgb_array
def plot_data(self, fname):
# Data to plot
to_plot = self.rgb_array
# If percentile stretch is supplied, calculate vmin and vmax
# from percentiles
if self.percentile_stretch:
vmin, vmax = np.nanpercentile(to_plot, self.percentile_stretch)
else:
vmin, vmax = self.vmin, self.vmax
# Raise by power to dampen bright features and enhance dark.
# Raise vmin and vmax by same amount to ensure proper stretch
if self.power < 1.0:
with self.status_info:
print(f"\nApplying power transformation ({self.power})")
to_plot = to_plot**self.power
vmin, vmax = vmin**self.power, vmax**self.power
# Rescale/stretch imagery between vmin and vmax
to_plot = exposure.rescale_intensity(to_plot.astype(float), in_range=(vmin, vmax), out_range=(0.0, 1.0))
# Unsharp mask
if self.unsharp_mask:
with self.status_info:
print(
f"\nApplying unsharp masking with {self.unsharp_mask_radius} "
f"radius and {self.unsharp_mask_amount} amount"
)
to_plot = unsharp_mask(to_plot, radius=self.unsharp_mask_radius, amount=self.unsharp_mask_amount)
# Create figure with aspect ratio of data
fig = plt.figure(dpi=100)
fig.set_size_inches(10, 10 / (to_plot.shape[1] / to_plot.shape[0]))
# Remove axes to plot just array data
ax = plt.Axes(
fig,
[0.0, 0.0, 1.0, 1.0],
)
ax.set_axis_off()
fig.add_axes(ax)
# Add data to plot
ax.imshow(to_plot)
# If a min DPI is specified and image is less than DPI
if (self.dpi > 0) and (to_plot.shape[1] < self.dpi * 10):
# Export figure to file using exact DPI
with self.status_info:
print(f"\nExporting image at {self.dpi} DPI")
fig.savefig(fname.replace("resolution", f"resolution, {self.dpi} DPI"), dpi=self.dpi)
# If no minumum DPI is specified, export raw array data in native
# resolution
else:
plt.imsave(fname=fname, arr=np.ascontiguousarray(to_plot), format=self.output_format)
# Add plot preview below map and finish
plt.show()
with self.status_info:
print(f"\nImage successfully exported to:\n{fname}.")
[docs]
class imageexport_app(HBox):
def __init__(self):
super().__init__()
######################
# INITIAL ATTRIBUTES #
######################
# Basemap
self.basemap_list = [
("Open Street Map", basemap_to_tiles(basemaps.OpenStreetMap.Mapnik)),
("ESRI World Imagery", basemap_to_tiles(basemaps.Esri.WorldImagery)),
]
self.basemap = self.basemap_list[0][1]
# Satellite data using yesterday's date
date = datetime.datetime.today()
date = datetime.datetime(year=date.year, month=date.month, day=date.day - 1)
self.date = date.strftime("%Y-%m-%d")
self.dealayer_list = [
("Landsat", "ga_ls_ard_3"),
("Sentinel-2", "ga_s2m_ard_3"),
]
self.dealayer = self.dealayer_list[0][1]
# Styles
self.styles_list = ["True colour", "False colour"]
self.style = self.styles_list[0]
# Analysis params
self.resolution = 30
self.pansharpen = False
self.standardise_name = False
self.vmin = 50
self.vmax = 3000
self.percentile_stretch = None # (1, 99)
self.power = 1.0
self.output_list = [("JPG", "jpg"), ("PNG", "png")]
self.output_format = self.output_list[0][1]
self.unsharp_mask = False
self.unsharp_mask_radius = 20
self.unsharp_mask_amount = 0.3
self.max_size = False
self.dpi = 0
# Drawing params
self.target = None
self.action = None
self.gdf_drawn = None
# Data load params
self.rgb_array = None
self.sensor = None
self.load_params = None
self.query_params = None
##################
# HEADER FOR APP #
##################
# Create the Header widget
header_title_text = "<h3>Digital Earth Australia satellite image export</h3>"
instruction_text = (
"<p>Select the desired satellite data, imagery "
"date and image style, zoom in until satellite "
"imagery appears on the map, then draw a "
"rectangle to select an area of imagery to "
"export as a high-resolution image file.</p>"
)
self.header = create_html(f"{header_title_text}{instruction_text}")
self.header.layout = make_box_layout()
#####################################
# HANDLER FUNCTION FOR DRAW CONTROL #
#####################################
# Define the action to take once something is drawn on the map
def update_geojson(target, action, geo_json):
# Get data from action
self.action = action
# Clear data load params to trigger data re-load
self.rgb_array = None
self.sensor = None
self.load_params = None
self.query_params = None
# Convert data to geopandas
json_data = json.dumps(geo_json)
binary_data = json_data.encode()
io = BytesIO(binary_data)
io.seek(0)
gdf = gpd.read_file(io)
gdf.crs = "EPSG:4326"
# Convert to Albers and compute area
gdf_drawn_albers = gdf.copy().to_crs("EPSG:3577")
m2_per_km2 = 10**6
area = gdf_drawn_albers.area.values[0] / m2_per_km2
polyarea_label = "Total area of satellite data to extract"
polyarea_text = f"<b>{polyarea_label}</b>: {area:.2f} km<sup>2</sup>"
# Test area size
if self.max_size:
confirmation_text = (
'<span style="color: #33cc33"> '
"<b>(Overriding maximum size limit; use with caution as may lead to memory issues)</b></span>"
)
self.header.value = header_title_text + instruction_text + polyarea_text + confirmation_text
self.gdf_drawn = gdf
elif area <= 10000:
confirmation_text = (
'<span style="color: #33cc33"> <b>(Area to extract falls within recommended limit)</b></span>'
)
self.header.value = header_title_text + instruction_text + polyarea_text + confirmation_text
self.gdf_drawn = gdf
else:
warning_text = (
'<span style="color: #ff5050"> '
"<b>(Area to extract is too large, "
"please select an area less than 10000 "
"km<sup>2)</b></span>"
)
self.header.value = header_title_text + instruction_text + polyarea_text + warning_text
self.gdf_drawn = None
###########################
# WIDGETS FOR APP OUTPUTS #
###########################
self.status_info = Output(layout=make_box_layout())
self.output_plot = Output(layout=make_box_layout())
#########################################
# MAP WIDGET, DRAWING TOOLS, WMS LAYERS #
#########################################
# Create drawing tools
desired_drawtools = ["rectangle"]
draw_control = create_drawcontrol(desired_drawtools)
# Begin by displaying an empty layer group, and update the group with desired WMS on interaction.
self.map_layers = LayerGroup(layers=())
self.map_layers.name = "Map Overlays"
# Create map widget
self.m = create_map(map_center=(-28, 135), zoom_level=4)
self.m.layout = make_box_layout()
# Add tools to map widget
self.m.add_control(draw_control)
self.m.add_control(
SearchControl(
position="topleft",
url="https://nominatim.openstreetmap.org/search?format=json&q={s}",
zoom=13, # 'Village / Suburb' level zoom
marker=Marker(draggable=False),
)
)
self.m.add_layer(self.map_layers)
# Update all maps to starting defaults
update_map_layers(self)
############################
# WIDGETS FOR APP CONTROLS #
############################
# Create parameter widgets
dropdown_basemap = create_dropdown(self.basemap_list, self.basemap_list[0][1])
dropdown_dealayer = create_dropdown(self.dealayer_list, self.dealayer_list[0][1])
dropdown_output = create_dropdown(self.output_list, self.output_list[0][1])
date_picker = create_datepicker(value=date)
dropdown_styles = create_dropdown(self.styles_list, self.styles_list[0])
slider_abs = widgets.IntRangeSlider(
value=[50, 3000],
min=0,
max=10000,
step=25,
description="",
layout={"width": "85%"},
)
run_button = create_expanded_button("Export imagery", "info")
# Expandable advanced section
text_resolution = widgets.FloatText(
value=30,
description="",
layout={"width": "100%", "margin": "0px", "padding": "0px"},
)
checkbox_pansharpen = create_checkbox(self.pansharpen, "Pansharpen Landsat")
slider_power = widgets.FloatSlider(
value=1.0,
min=0.01,
max=1.0,
step=0.01,
description="",
layout={"width": "85%"},
)
checkbox_unsharp_mask = create_checkbox(self.unsharp_mask, "Enable", layout={"width": "100%"})
text_unsharp_mask_radius = widgets.FloatText(
value=20,
step=1,
description="Radius",
layout={
"width": "100%",
"margin": "0px",
"padding": "0px",
"display": "none",
},
)
text_unsharp_mask_amount = widgets.FloatText(
value=0.3,
step=0.1,
description="Amount",
layout={
"width": "100%",
"margin": "0px",
"padding": "0px",
"display": "none",
},
)
checkbox_max_size = create_checkbox(self.max_size, "Enable")
text_dpi = widgets.IntText(value=0, description="", step=50, layout={"width": "85%"})
html_dpi = HTML("</br>Minimum DPI for image export</br>(100 DPI = 1000 pixels wide):")
expand_box = widgets.VBox(
[
HTML("Resolution (metres):"),
text_resolution,
checkbox_pansharpen,
HTML("</br>Apply power transformation to darken bright features:"),
slider_power,
HTML("</br>Apply unsharp masking to sharpen image:"),
checkbox_unsharp_mask,
text_unsharp_mask_radius,
text_unsharp_mask_amount,
HTML("</br>Override maximum size limit: (use with caution; may cause memory issues/crashes)"),
checkbox_max_size,
html_dpi,
text_dpi,
],
layout={"overflow": "hidden"},
)
expand = widgets.Accordion(children=[expand_box], selected_index=None)
expand.set_title(0, "Advanced")
# Add specific dialogs to class so they can be modified
self.text_resolution = text_resolution
self.checkbox_pansharpen = checkbox_pansharpen
self.text_unsharp_mask_radius = text_unsharp_mask_radius
self.text_unsharp_mask_amount = text_unsharp_mask_amount
self.html_dpi = html_dpi
####################################
# UPDATE FUNCTIONS FOR EACH WIDGET #
####################################
# Run update functions whenever various widgets are changed.
date_picker.observe(self.update_date, "value")
dropdown_basemap.observe(self.update_basemap, "value")
dropdown_dealayer.observe(self.update_dealayer, "value")
dropdown_styles.observe(self.update_styles, "value")
dropdown_output.observe(self.update_output, "value")
run_button.on_click(self.run_app)
draw_control.on_draw(update_geojson)
slider_abs.observe(self.update_slider_abs, "value")
# Advanced params
text_resolution.observe(self.update_text_resolution, "value")
checkbox_pansharpen.observe(self.update_checkbox_pansharpen, "value")
slider_power.observe(self.update_slider_power, "value")
checkbox_unsharp_mask.observe(self.update_checkbox_unsharp_mask, "value")
text_unsharp_mask_radius.observe(self.update_text_unsharp_mask_radius, "value")
text_unsharp_mask_amount.observe(self.update_text_unsharp_mask_amount, "value")
checkbox_max_size.observe(self.update_checkbox_max_size, "value")
text_dpi.observe(self.update_dpi, "value")
##################################
# COLLECTION OF ALL APP CONTROLS #
##################################
parameter_selection = VBox([
HTML("<b>Date:</b>"),
date_picker,
HTML("<b>Satellite imagery:</b>"),
dropdown_dealayer,
HTML("<b>Style:</b>"),
dropdown_styles,
HTML("<b>Colour stretch:</b>"),
slider_abs,
HTML("<b>Output file format:</b>"),
dropdown_output,
HTML("</br>"),
expand,
])
map_selection = VBox([
HTML("</br><b>Map overlay:</b>"),
dropdown_basemap,
])
parameter_selection.layout = make_box_layout()
map_selection.layout = make_box_layout()
###############################
# SPECIFICATION OF APP LAYOUT #
###############################
# 0 1 2 3 4 5 6 7 8 9
# ---------------------------------------------
# 0 | Header | Map sel. |
# |-------------------------------------------|
# 1 | Params | |
# 2 | | |
# 3 | | |
# 4 | | Map |
# 5 | | |
# |--------| |
# 6 | Run | |
# |-------------------------------------------|
# 7 | Status info | Figure/output |
# 8 | | |
# 9 | | |
# 10 | | |
# 11 ---------------------------------------------
# Create the layout #[rowspan, colspan]
grid = GridspecLayout(12, 10, height="1400px", width="auto")
# Header and controls
grid[0, :8] = self.header
grid[0, 8:] = map_selection
grid[1:6, 0:2] = parameter_selection
grid[6, 0:2] = run_button
# Status info, map and plot
grid[1:7, 2:] = self.m # map
grid[7:, 0:4] = self.status_info
grid[7:, 4:] = self.output_plot
# Display using HBox children attribute
self.children = [grid]
######################################
# DEFINITION OF ALL UPDATE FUNCTIONS #
######################################
# Update date
def update_date(self, change):
self.date = str(change.new)
update_map_layers(self)
# Update colour stretch
def update_slider_abs(self, change):
self.vmin, self.vmax = change.new
# Update power transform
def update_slider_power(self, change):
self.power = change.new
# Enable pansharpening and reset/deactivate resolution
def update_checkbox_pansharpen(self, change):
self.pansharpen = change.new
# Override default resolution if pansharpening is specified;
# disable input if so
if change.new:
self.text_resolution.value = 15
self.text_resolution.disabled = True
else:
self.text_resolution.value = 30
self.text_resolution.disabled = False
# Enable unsharp masking and show/hide custom params
def update_checkbox_unsharp_mask(self, change):
self.unsharp_mask = change.new
# Show unsharp masking params in menu if activated
if change.new:
self.text_unsharp_mask_radius.layout.display = "block"
self.text_unsharp_mask_amount.layout.display = "block"
else:
self.text_unsharp_mask_radius.layout.display = "none"
self.text_unsharp_mask_amount.layout.display = "none"
# Change unsharp masking radius
def update_text_unsharp_mask_radius(self, change):
self.unsharp_mask_radius = change.new
# Change unsharp masking amount
def update_text_unsharp_mask_amount(self, change):
self.unsharp_mask_amount = change.new
# Override max size limit
def update_checkbox_max_size(self, change):
self.max_size = change.new
# Override min DPI
def update_dpi(self, change):
self.dpi = change.new
# Update DPI helper text to give output resolution
self.html_dpi.value = (
f"</br>Minimum DPI for image export</br>({change.new} DPI = {change.new * 10} pixels wide):"
)
# Update resolution
def update_text_resolution(self, change):
self.resolution = change.new
# Clear data load params to trigger data re-load
self.rgb_array = None
self.sensor = None
self.load_params = None
self.query_params = None
# Change layers shown on the map
def update_dealayer(self, change):
self.dealayer = change.new
if change.new == "ga_ls_ard_3":
self.text_resolution.value = 30
self.checkbox_pansharpen.disabled = False
if self.pansharpen:
self.text_resolution.value = 15
self.text_resolution.disabled = True
else:
self.text_resolution.value = 10
self.checkbox_pansharpen.disabled = True
self.text_resolution.disabled = False
update_map_layers(self)
# Update basemap
def update_basemap(self, change):
self.basemap = change.new
update_map_layers(self)
# Set imagery style
def update_styles(self, change):
self.style = change.new
update_map_layers(self)
# Set output file format
def update_output(self, change):
self.output_format = change.new
def run_app(self, change):
# Clear progress bar and output areas before running
self.status_info.clear_output()
self.output_plot.clear_output()
# Verify that polygon was drawn
if self.gdf_drawn is not None:
with self.status_info:
# Load data and add to attribute
if self.rgb_array is None:
self.rgb_array = extract_data(self)
else:
print("Using previously loaded data")
if self.rgb_array is not None:
with self.status_info:
# Create unique file name
centre_coords = self.gdf_drawn.geometry[0].centroid.coords[0][::-1]
site = reverse_geocode(coords=centre_coords)
fname = (
f"{self.sensor} - {self.date} - {site} - {self.style}, "
f"{self.resolution:.0f} m resolution.{self.output_format}"
)
# Remove spaces and commas if requested
if self.standardise_name:
fname = fname.replace(" - ", "_").replace(", ", "-").replace(" ", "-").lower()
print(f"\nExporting image for {site}.\nThis may take several minutes...")
############
# Plotting #
############
with self.output_plot:
plot_data(self, fname)
else:
with self.status_info:
print(
"No satellite data found in the selected area. "
"Please select a new rectangle over an area with "
"satellite imagery visible on the map."
)
else:
with self.status_info:
print('Please draw a valid rectangle on the map, then press "Export imagery"')