# dea_datahandling.py
"""
Loading and manipulating Digital Earth Australia products and data
using the Open Data Cube and xarray.
License: The code in this notebook is licensed under the Apache License,
Version 2.0 (https://www.apache.org/licenses/LICENSE-2.0). Digital Earth
Australia data is licensed under the Creative Commons by Attribution 4.0
license (https://creativecommons.org/licenses/by/4.0/).
Contact: If you need assistance, please post a question on the Open Data
Cube Discord chat (https://discord.com/invite/4hhBQVas5U) or on the GIS Stack
Exchange (https://gis.stackexchange.com/questions/ask?tags=open-data-cube)
using the `open-data-cube` tag (you can view previously asked questions
here: https://gis.stackexchange.com/questions/tagged/open-data-cube).
If you would like to report an issue with this script, you can file one
on GitHub (https://github.com/GeoscienceAustralia/dea-notebooks/issues/new).
Last modified: Feb 2024
"""
import datetime
# Import required packages
import os
import warnings
import zipfile
import requests
from collections import Counter
import rioxarray
import numpy as np
import pandas as pd
import xarray as xr
import sklearn.decomposition
from scipy.ndimage import binary_dilation
from skimage.color import hsv2rgb, rgb2hsv
from skimage.exposure import match_histograms
import odc.geo.xr
import odc.algo
from odc.algo import mask_cleanup
from datacube.utils.dates import normalise_dt
def _dc_query_only(**kw):
"""
Remove load-only datacube parameters, the rest can be
passed to Query/dc.find_datasets.
Returns
-------
dict of query parameters
"""
def _impl(
measurements=None,
output_crs=None,
resolution=None,
resampling=None,
skip_broken_datasets=None,
dask_chunks=None,
fuse_func=None,
align=None,
datasets=None,
progress_cbk=None,
group_by=None,
**query,
):
return query
return _impl(**kw)
def _common_bands(dc, products):
"""
Takes a list of products and returns a list of measurements/bands
that are present in all products
Returns
-------
List of band names
"""
common = None
bands = None
for p in products:
p = dc.index.products.get_by_name(p)
if common is None:
common = set(p.measurements)
bands = list(p.measurements)
else:
common = common.intersection(set(p.measurements))
return [band for band in bands if band in common]
[docs]
def load_ard(
dc,
products=None,
cloud_mask="fmask",
min_gooddata=0.00,
mask_pixel_quality=True,
mask_filters=None,
mask_contiguity=False,
fmask_categories=["valid", "snow", "water"],
s2cloudless_categories=["valid"],
ls7_slc_off=True,
dtype="auto",
predicate=None,
**kwargs,
):
"""
Load multiple Geoscience Australia Landsat or Sentinel 2
Collection 3 products (e.g. Landsat 5, 7, 8, 9; Sentinel 2A and 2B),
optionally apply pixel quality/cloud masking and contiguity masks,
and drop time steps that contain greater than a minimum proportion
of good quality (e.g. non-cloudy or shadowed) pixels.
The function supports loading the following Landsat products:
* ga_ls5t_ard_3
* ga_ls7e_ard_3
* ga_ls8c_ard_3
* ga_ls9c_ard_3
And Sentinel-2 products:
* ga_s2am_ard_3
* ga_s2bm_ard_3
Cloud masking can be performed using the Fmask (Function of Mask)
cloud mask for Landsat and Sentinel-2, and the s2cloudless
(Sentinel Hub cloud detector for Sentinel-2 imagery) cloud mask for
Sentinel-2.
Last modified: June 2023
Parameters
----------
dc : datacube Datacube object
The Datacube to connect to, i.e. ``dc = datacube.Datacube()``.
This allows you to also use development datacubes if required.
products : list
A list of product names to load. Valid options are
['ga_ls5t_ard_3', 'ga_ls7e_ard_3', 'ga_ls8c_ard_3', 'ga_ls9c_ard_3']
for Landsat, ['ga_s2am_ard_3', 'ga_s2bm_ard_3'] for Sentinel 2.
cloud_mask : string, optional
The cloud mask used by the function. This is used for both
masking out poor quality pixels (e.g. clouds) if
``mask_pixel_quality=True``, and for calculating the
``min_gooddata`` percentage when dropping cloudy or low quality
satellite observations. Two cloud masks are supported:
* 'fmask': (default; available for Landsat, Sentinel-2)
* 's2cloudless' (Sentinel-2 only)
min_gooddata : float, optional
The minimum percentage of good quality pixels required for a
satellite observation to be loaded. Defaults to 0.00 which will
return all observations regardless of pixel quality (set to e.g.
0.99 to return only observations with more than 99% good quality
pixels).
mask_pixel_quality : str or bool, optional
Whether to mask out poor quality (e.g. cloudy) pixels by setting
them as nodata. Depending on the choice of cloud mask, the
function will identify good quality pixels using the categories
passed to the ``fmask_categories`` or ``s2cloudless_categories``
params. Set to False to turn off pixel quality masking completely.
Poor quality pixels will be set to NaN (and convert all data to
`float32`) if ``dtype='auto'``, or be set to the data's native
nodata value (usually -999) if ``dtype='native'`` (see 'dtype'
below for more details).
mask_filters : iterable of tuples, optional
Iterable tuples of morphological operations - ("<operation>", <radius>)
to apply to the inverted pixel quality mask, where:
operation: string; one of these morphological operations:
* ``'dilation'`` = Expands poor quality pixels/clouds outwards
* ``'erosion'`` = Shrinks poor quality pixels/clouds inwards
* ``'closing'`` = Remove small holes in clouds by expanding
then shrinking poor quality pixels
* ``'opening'`` = Remove small or narrow clouds by shrinking
then expanding poor quality pixels
radius: int
e.g. ``mask_filters=[('erosion', 5), ("opening", 2), ("dilation", 2)]``
mask_contiguity : str or bool, optional
Whether to mask out pixels that are missing data in any band
(i.e. "non-contiguous" pixels). This can be important for
generating clean composite datasets. The default of False will
not apply any contiguity mask.
If loading NBART data, set:
* ``mask_contiguity='nbart'`` (or ``mask_contiguity=True``)
If loading NBAR data, specify:
* ``mask_contiguity='nbar'``
Non-contiguous pixels will be set to NaN if `dtype='auto'`, or
set to the data's native nodata value if `dtype='native'` (see
'dtype' below).
fmask_categories : list, optional
A list of Fmask cloud mask categories to consider as good
quality pixels when calculating `min_gooddata` and when masking
data by pixel quality if ``mask_pixel_quality=True``.
The default is ``['valid', 'snow', 'water']``; all other Fmask
categories ('cloud', 'shadow', 'nodata') will be treated as low
quality pixels. Choose from: 'nodata', 'valid', 'cloud',
'shadow', 'snow', and 'water'.
s2cloudless_categories : list, optional
A list of s2cloudless cloud mask categories to consider as good
quality pixels when calculating `min_gooddata` and when masking
data by pixel quality if ``mask_pixel_quality=True``. The default
is `['valid']`; all other s2cloudless categories ('cloud',
'nodata') will be treated as low quality pixels. Choose from:
'nodata', 'valid', or 'cloud'.
ls7_slc_off : bool, optional
An optional boolean indicating whether to include data from
after the Landsat 7 SLC failure (i.e. SLC-off). Defaults to
True, which keeps all Landsat 7 observations > May 31 2003.
dtype : string, optional
Controls the data type/dtype that layers are coerced to after
loading. Valid values: 'native', 'auto', and 'float{16|32|64}'.
When 'auto' is used, the data will be converted to `float32`
if masking is used, otherwise data will be returned in the
native data type of the data. Be aware that if data is loaded
in its native dtype, nodata and masked pixels will be returned
with the data's native nodata value (typically -999), not NaN.
predicate : function, optional
DEPRECATED: Please use `dataset_predicate` instead.
An optional function that can be passed in to restrict the datasets that
are loaded. A predicate function should take a
`datacube.model.Dataset` object as an input (i.e. as returned
from `dc.find_datasets`), and return a boolean. For example,
a predicate function could be used to return True for only
datasets acquired in January: `dataset.time.begin.month == 1`
**kwargs :
A set of keyword arguments to `dc.load` that define the
spatiotemporal query and load parameters used to extract data.
Keyword arguments can either be listed directly in the
`load_ard` call like any other parameter (e.g.
`measurements=['nbart_red']`), or by passing in a query kwarg
dictionary (e.g. `**query`). Keywords can include `measurements`,
`x`, `y`, `time`, `resolution`, `resampling`, `group_by`, `crs`;
see the `dc.load` documentation for all possible options:
https://datacube-core.readthedocs.io/en/latest/api/indexed-data/generate/datacube.Datacube.load.html
Returns
-------
combined_ds : xarray.Dataset
An xarray.Dataset containing only satellite observations with
a proportion of good quality pixels greater than `min_gooddata`.
Notes
-----
The `load_ard` function builds on the Open Data Cube's native `dc.load`
function by adding the ability to load multiple satellite data
products at once, and automatically apply cloud masking and filtering.
For loading non-satellite data products (e.g. DEA Water Observations),
use `dc.load` instead.
"""
#########
# Setup #
#########
# Verify that products were provided
if not products:
raise ValueError(
"Please provide a list of product names to load data from. "
"Valid options are: ['ga_ls5t_ard_3', 'ga_ls7e_ard_3', "
"'ga_ls8c_ard_3', 'ga_ls9c_ard_3'] for Landsat, and "
"['ga_s2am_ard_3', 'ga_s2bm_ard_3'] for Sentinel 2."
)
# Determine whether products are all Landsat, all S2, or mixed
elif all(["ls" in product for product in products]):
product_type = "ls"
elif all(["s2" in product for product in products]):
product_type = "s2"
else:
product_type = "mixed"
warnings.warn(
"You have selected a combination of Landsat and Sentinel-2 "
"products. This can produce unexpected results as these "
"products use the same names for different spectral bands "
"(e.g. Landsat and Sentinel-2's 'nbart_swir_2'); use with "
"caution."
)
# Set contiguity band depending on `mask_contiguity`;
# "oa_nbart_contiguity" if True, False or "nbart",
# "oa_nbar_contiguity" if "nbar"
if mask_contiguity in (True, False, "nbart"):
contiguity_band = "oa_nbart_contiguity"
elif mask_contiguity == "nbar":
contiguity_band = "oa_nbar_contiguity"
else:
raise ValueError(
f"Unsupported value '{mask_contiguity}' passed to "
"`mask_contiguity`. Please provide either 'nbart', 'nbar', "
"True, or False."
)
# Set pixel quality (PQ) band depending on `cloud_mask`
if cloud_mask == "fmask":
pq_band = "oa_fmask"
pq_categories = fmask_categories
elif cloud_mask == "s2cloudless":
pq_band = "oa_s2cloudless_mask"
pq_categories = s2cloudless_categories
# Raise error if s2cloudless is requested for Landsat products
if product_type in ["ls", "mixed"]:
raise ValueError(
"The 's2cloudless' cloud mask is not available for "
"Landsat products. Please set `mask_pixel_quality` to "
"'fmask' or False."
)
else:
raise ValueError(
f"Unsupported value '{cloud_mask}' passed to "
"`cloud_mask`. Please provide either 'fmask', "
"'s2cloudless', True, or False."
)
# To ensure that the categorical PQ/contiguity masking bands are
# loaded using nearest neighbour resampling, we need to add these to
# the resampling kwarg if it exists and is not "nearest".
# This only applies if a string resampling method is supplied;
# if a resampling dictionary (e.g. `resampling={'*': 'bilinear',
# 'oa_fmask': 'mode'}` is passed instead we assume the user wants
# to select custom resampling methods for each of their bands.
resampling = kwargs.get("resampling", None)
if isinstance(resampling, str) and resampling not in (None, "nearest"):
kwargs["resampling"] = {
"*": resampling,
pq_band: "nearest",
contiguity_band: "nearest",
}
# We extract and deal with `dask_chunks` separately as every
# function call uses dask internally regardless of whether the user
# sets `dask_chunks` themselves
dask_chunks = kwargs.pop("dask_chunks", None)
# Create a list of requested measurements so that we can eventually
# return only the measurements the user orignally asked for
requested_measurements = kwargs.pop("measurements", None)
# Copy our measurements list so we can temporarily append extra PQ
# and/or contiguity masking bands when loading our data
measurements = requested_measurements.copy() if requested_measurements else None
# Deal with "load all" case: pick a set of bands that are common
# across requested products
if measurements is None:
measurements = _common_bands(dc, products)
# Deal with edge case where user supplies alias for PQ/contiguity
# by stripping PQ/contiguity masks of their "oa_" prefix
else:
contiguity_band = (
contiguity_band.replace("oa_", "")
if contiguity_band.replace("oa_", "") in measurements
else contiguity_band
)
pq_band = (
pq_band.replace("oa_", "")
if pq_band.replace("oa_", "") in measurements
else pq_band
)
# If `measurements` are specified but do not include PQ or
# contiguity variables, add these to `measurements`
if pq_band not in measurements:
measurements.append(pq_band)
if mask_contiguity and contiguity_band not in measurements:
measurements.append(contiguity_band)
# Get list of data and mask bands so that we can later exclude
# mask bands from being masked themselves
data_bands = [
band for band in measurements if band not in (pq_band, contiguity_band)
]
mask_bands = [band for band in measurements if band not in data_bands]
#################
# Find datasets #
#################
# Pull out query params only to pass to dc.find_datasets
query = _dc_query_only(**kwargs)
# If predicate is specified, use this function to filter the list
# of datasets prior to load
if predicate:
print(
"The 'predicate' parameter will be deprecated in future "
"versions of this function as this functionality has now "
"been added to Datacube itself. Please use "
"`dataset_predicate=...` instead."
)
query["dataset_predicate"] = predicate
# Extract list of datasets for each product using query params
dataset_list = []
# Get list of datasets for each product
print("Finding datasets")
for product in products:
# Obtain list of datasets for product
print(
f" {product} (ignoring SLC-off observations)"
if not ls7_slc_off and product == "ga_ls7e_ard_3"
else f" {product}"
)
datasets = dc.find_datasets(product=product, **query)
# Remove Landsat 7 SLC-off observations if ls7_slc_off=False
if not ls7_slc_off and product == "ga_ls7e_ard_3":
datasets = [
i
for i in datasets
if normalise_dt(i.time.begin) < datetime.datetime(2003, 5, 31)
]
# Add any returned datasets to list
dataset_list.extend(datasets)
# Raise exception if no datasets are returned
if len(dataset_list) == 0:
raise ValueError(
"No data available for query: ensure that "
"the products specified have data for the "
"time and location requested"
)
#############
# Load data #
#############
# Note we always load using dask here so that we can lazy load data
# before filtering by `min_gooddata`
ds = dc.load(
datasets=dataset_list,
measurements=measurements,
dask_chunks={} if dask_chunks is None else dask_chunks,
**kwargs,
)
####################
# Filter good data #
####################
# Calculate pixel quality mask
pq_mask = odc.algo.fmask_to_bool(ds[pq_band], categories=pq_categories)
# The good data percentage calculation has to load all pixel quality
# data, which can be slow. If the user has chosen no filtering
# by using the default `min_gooddata = 0`, we can skip this step
# completely to save processing time
if min_gooddata > 0.0:
# Compute good data for each observation as % of total pixels
print(f"Counting good quality pixels for each time step using {cloud_mask}")
data_perc = pq_mask.sum(axis=[1, 2], dtype="int32") / (
pq_mask.shape[1] * pq_mask.shape[2]
)
keep = (data_perc >= min_gooddata).persist()
# Filter by `min_gooddata` to drop low quality observations
total_obs = len(ds.time)
ds = ds.sel(time=keep)
pq_mask = pq_mask.sel(time=keep)
print(
f"Filtering to {len(ds.time)} out of {total_obs} "
f"time steps with at least {min_gooddata:.1%} "
f"good quality pixels"
)
# Morphological filtering on cloud masks
if (mask_filters is not None) & (mask_pixel_quality != False):
print(f"Applying morphological filters to pixel quality mask: {mask_filters}")
pq_mask = ~mask_cleanup(~pq_mask, mask_filters=mask_filters)
warnings.warn(
"As of `dea_tools` v0.3.0, pixel quality masks are "
"inverted before being passed to `mask_filters` (i.e. so "
"that good quality/clear pixels are False and poor quality "
"pixels/clouds are True). This means that 'dilation' will "
"now expand cloudy pixels, rather than shrink them as in "
"previous versions."
)
###############
# Apply masks #
###############
# Create a combined mask to hold both pixel quality and contiguity.
# This is more efficient than creating multiple dask tasks for
# similar masking operations.
mask = None
# Add pixel quality mask to combined mask
if mask_pixel_quality:
print(f"Applying {cloud_mask} pixel quality/cloud mask")
mask = pq_mask
# Add contiguity mask to combined mask
if mask_contiguity:
print(f"Applying contiguity mask ({contiguity_band})")
cont_mask = ds[contiguity_band] == 1
# If mask already has data if mask_pixel_quality == True,
# multiply with cont_mask to perform a logical 'or' operation
# (keeping only pixels good in both)
mask = cont_mask if mask is None else mask * cont_mask
# Split into data/masks bands, as conversion to float and masking
# should only be applied to data bands
ds_data = ds[data_bands]
ds_masks = ds[mask_bands]
# Mask data if either of the above masks were generated
if mask is not None:
ds_data = odc.algo.keep_good_only(ds_data, where=mask)
# Automatically set dtype to either native or float32 depending
# on whether masking was requested
if dtype == "auto":
dtype = "native" if mask is None else "float32"
# Set nodata values using odc.algo tools to reduce peak memory
# use when converting data dtype
if dtype != "native":
ds_data = odc.algo.to_float(ds_data, dtype=dtype)
# Put data and mask bands back together
attrs = ds.attrs
ds = xr.merge([ds_data, ds_masks])
ds.attrs.update(attrs)
###############
# Return data #
###############
# Drop bands not originally requested by user
if requested_measurements:
ds = ds[requested_measurements]
# If user supplied `dask_chunks`, return data as a dask array
# without actually loading it into memory
if dask_chunks is not None:
print(f"Returning {len(ds.time)} time steps as a dask array")
return ds
else:
print(f"Loading {len(ds.time)} time steps")
return ds.compute()
[docs]
def mostcommon_crs(dc, product, query):
"""
Takes a given query and returns the most common CRS for observations
returned for that spatial extent. This can be useful when your study
area lies on the boundary of two UTM zones, forcing you to decide
which CRS to use for your `output_crs` in `dc.load`.
Parameters
----------
dc : datacube Datacube object
The Datacube to connect to, i.e. `dc = datacube.Datacube()`.
This allows you to also use development datacubes if required.
product : str
A product name (or list of product names) to load CRSs from.
query : dict
A datacube query including x, y and time range to assess for the
most common CRS
Returns
-------
epsg_string : str
An EPSG string giving the most common CRS from all datasets
returned by the query above
"""
# Find list of datasets matching query for either product or
# list of products
if isinstance(product, list):
matching_datasets = []
for i in product:
matching_datasets.extend(dc.find_datasets(product=i, **query))
else:
matching_datasets = dc.find_datasets(product=product, **query)
# Extract all CRSs
crs_list = [str(i.crs) for i in matching_datasets]
# If CRSs are returned
if len(crs_list) > 0:
# Identify most common CRS
crs_counts = Counter(crs_list)
crs_mostcommon = crs_counts.most_common(1)[0][0]
# Warn user if multiple CRSs are encountered
if len(crs_counts.keys()) > 1:
warnings.warn(
f"Multiple UTM zones {list(crs_counts.keys())} "
f"were returned for this query. Defaulting to "
f"the most common zone: {crs_mostcommon}",
UserWarning,
)
return crs_mostcommon
else:
raise ValueError(
f"No CRS was returned as no data was found for "
f"the supplied product ({product}) and query. "
f"Please ensure that data is available for "
f"{product} for the spatial extents and time "
f"period specified in the query (e.g. by using "
f"the Data Cube Explorer for this datacube "
f"instance)."
)
[docs]
def download_unzip(url, output_dir=None, remove_zip=True):
"""
Downloads and unzips a .zip file from an external URL to a local
directory.
Parameters
----------
url : str
A string giving a URL path to the zip file you wish to download
and unzip
output_dir : str, optional
An optional string giving the directory to unzip files into.
Defaults to None, which will unzip files in the current working
directory
remove_zip : bool, optional
An optional boolean indicating whether to remove the downloaded
.zip file after files are unzipped. Defaults to True, which will
delete the .zip file.
"""
# Get basename for zip file
zip_name = os.path.basename(url)
# Raise exception if the file is not of type .zip
if not zip_name.endswith(".zip"):
raise ValueError(
f"The URL provided does not point to a .zip "
f"file (e.g. {zip_name}). Please specify a "
f"URL path to a valid .zip file"
)
# Download zip file
print(f"Downloading {zip_name}")
r = requests.get(url)
with open(zip_name, "wb") as f:
f.write(r.content)
# Extract into output_dir
with zipfile.ZipFile(zip_name, "r") as zip_ref:
zip_ref.extractall(output_dir)
print(
f"Unzipping output files to: "
f"{output_dir if output_dir else os.getcwd()}"
)
# Optionally cleanup
if remove_zip:
os.remove(zip_name)
[docs]
def wofs_fuser(dest, src):
"""
Fuse two WOfS water measurements represented as ``ndarray``s.
Note: this is a copy of the function located here:
https://github.com/GeoscienceAustralia/digitalearthau/blob/develop/digitalearthau/utils.py
"""
empty = (dest & 1).astype(bool)
both = ~empty & ~((src & 1).astype(bool))
dest[empty] = src[empty]
dest[both] |= src[both]
[docs]
def dilate(array, dilation=10, invert=True):
"""
Dilate a binary array by a specified nummber of pixels using a
disk-like radial dilation.
By default, invalid (e.g. False or 0) values are dilated. This is
suitable for applications such as cloud masking (e.g. creating a
buffer around cloudy or shadowed pixels). This functionality can
be reversed by specifying `invert=False`.
Parameters
----------
array : array
The binary array to dilate.
dilation : int, optional
An optional integer specifying the number of pixels to dilate
by. Defaults to 10, which will dilate `array` by 10 pixels.
invert : bool, optional
An optional boolean specifying whether to invert the binary
array prior to dilation. The default is True, which dilates the
invalid values in the array (e.g. False or 0 values).
Returns
-------
An array of the same shape as `array`, with valid data pixels
dilated by the number of pixels specified by `dilation`.
"""
y, x = np.ogrid[
-dilation : (dilation + 1),
-dilation : (dilation + 1),
]
# disk-like radial dilation
kernel = (x * x) + (y * y) <= (dilation + 0.5) ** 2
# If invert=True, invert True values to False etc
if invert:
array = ~array
return ~binary_dilation(
array.astype(bool), structure=kernel.reshape((1,) + kernel.shape)
)
[docs]
def paths_to_datetimeindex(paths, string_slice=(0, 10)):
"""
Helper function to generate a Pandas datetimeindex object
from dates contained in a file path string.
Parameters
----------
paths : list of strings
A list of file path strings that will be used to extract times
string_slice : tuple
An optional tuple giving the start and stop position that
contains the time information in the provided paths. These are
applied to the basename (i.e. file name) in each path, not the
path itself. Defaults to (0, 10).
Returns
-------
datetime : pandas.DatetimeIndex
A pandas.DatetimeIndex object containing a 'datetime64[ns]' derived
from the file paths provided by `paths`.
"""
date_strings = [os.path.basename(i)[slice(*string_slice)] for i in paths]
return pd.to_datetime(date_strings)
def _select_along_axis(values, idx, axis):
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
sl = other_ind[:axis] + (idx,) + other_ind[axis:]
return values[sl]
[docs]
def first(array: xr.DataArray, dim: str, index_name: str = None) -> xr.DataArray:
"""
Finds the first occuring non-null value along the given dimension.
Parameters
----------
array : xr.DataArray
The array to search.
dim : str
The name of the dimension to reduce by finding the first
non-null value.
Returns
-------
reduced : xr.DataArray
An array of the first non-null values.
The `dim` dimension will be removed, and replaced with a coord
of the same name, containing the value of that dimension where
the last value was found.
"""
axis = array.get_axis_num(dim)
idx_first = np.argmax(~pd.isnull(array), axis=axis)
reduced = array.reduce(_select_along_axis, idx=idx_first, axis=axis)
reduced[dim] = array[dim].isel({dim: xr.DataArray(idx_first, dims=reduced.dims)})
if index_name is not None:
reduced[index_name] = xr.DataArray(idx_first, dims=reduced.dims)
return reduced
[docs]
def last(array: xr.DataArray, dim: str, index_name: str = None) -> xr.DataArray:
"""
Finds the last occuring non-null value along the given dimension.
Parameters
----------
array : xr.DataArray
The array to search.
dim : str
The name of the dimension to reduce by finding the last non-null
value.
index_name : str, optional
If given, the name of a coordinate to be added containing the
index of where on the dimension the nearest value was found.
Returns
-------
reduced : xr.DataArray
An array of the last non-null values.
The `dim` dimension will be removed, and replaced with a coord
of the same name, containing the value of that dimension where
the last value was found.
"""
axis = array.get_axis_num(dim)
rev = (slice(None),) * axis + (slice(None, None, -1),)
idx_last = -1 - np.argmax(~pd.isnull(array)[rev], axis=axis)
reduced = array.reduce(_select_along_axis, idx=idx_last, axis=axis)
reduced[dim] = array[dim].isel({dim: xr.DataArray(idx_last, dims=reduced.dims)})
if index_name is not None:
reduced[index_name] = xr.DataArray(idx_last, dims=reduced.dims)
return reduced
[docs]
def nearest(
array: xr.DataArray, dim: str, target, index_name: str = None
) -> xr.DataArray:
"""
Finds the nearest values to a target label along the given
dimension, for all other dimensions.
E.g. For a DataArray with dimensions ('time', 'x', 'y')
nearest_array = nearest(array, 'time', '2017-03-12')
will return an array with the dimensions ('x', 'y'), with non-null
values found closest for each (x, y) pixel to that location along
the time dimension.
The returned array will include the 'time' coordinate for each x,y
pixel that the nearest value was found.
Parameters
----------
array : xr.DataArray
The array to search.
dim : str
The name of the dimension to look for the target label.
target : same type as array[dim]
The value to look up along the given dimension.
index_name : str, optional
If given, the name of a coordinate to be added containing the
index of where on the dimension the nearest value was found.
Returns
-------
nearest_array : xr.DataArray
An array of the nearest non-null values to the target label.
The `dim` dimension will be removed, and replaced with a coord
of the same name, containing the value of that dimension closest
to the given target label.
"""
before_target = slice(None, target)
after_target = slice(target, None)
da_before = array.sel({dim: before_target})
da_after = array.sel({dim: after_target})
da_before = last(da_before, dim, index_name) if da_before[dim].shape[0] else None
da_after = first(da_after, dim, index_name) if da_after[dim].shape[0] else None
if da_before is None and da_after is not None:
return da_after
if da_after is None and da_before is not None:
return da_before
target = array[dim].dtype.type(target)
is_before_closer = abs(target - da_before[dim]) < abs(target - da_after[dim])
nearest_array = xr.where(is_before_closer, da_before, da_after, keep_attrs=True)
nearest_array[dim] = xr.where(
is_before_closer, da_before[dim], da_after[dim], keep_attrs=True
)
if index_name is not None:
nearest_array[index_name] = xr.where(
is_before_closer,
da_before[index_name],
da_after[index_name],
keep_attrs=True,
)
return nearest_array
[docs]
def parallel_apply(ds, dim, func, use_threads=False, *args, **kwargs):
"""
Applies a custom function in parallel along the dimension of an
xarray.Dataset or xarray.DataArray.
The function can be any function that can be applied to an
individual xarray.Dataset or xarray.DataArray (e.g. data for a
single timestep). The function should also return data in
xarray.Dataset or xarray.DataArray format.
This function is useful as a simple method for parallising code
that cannot easily be parallised using Dask.
Parameters
----------
ds : xarray.Dataset or xarray.DataArray
xarray data with a dimension `dim` to apply the custom function
along.
dim : string
The dimension along which the custom function will be applied.
func : function
The function that will be applied in parallel to each array
along dimension `dim`. The first argument passed to this
function should be the array along `dim`.
use_threads : bool, optional
Whether to use threads instead of processes for parallelisation.
Defaults to False, which means it'll use multi-processing.
In brief, the difference between threads and processes is that threads
share memory, while processes have separate memory.
*args :
Any number of arguments that will be passed to `func`.
**kwargs :
Any number of keyword arguments that will be passed to `func`.
Returns
-------
xarray.Dataset
A concatenated dataset containing an output for each array
along the input `dim` dimension.
"""
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from functools import partial
from itertools import repeat
from tqdm import tqdm
# Use threads or processes
if use_threads:
Executor = ThreadPoolExecutor
else:
Executor = ProcessPoolExecutor
with Executor as executor:
# Update func to add kwargs
func = partial(func, **kwargs)
# Apply func in parallel
groups = [group for (i, group) in ds.groupby(dim)]
to_iterate = (groups, *(repeat(i, len(groups)) for i in args))
out_list = list(tqdm(executor.map(func, *to_iterate), total=len(groups)))
# Combine to match the original dataset
return xr.concat(out_list, dim=ds[dim])
def _apply_weights(da, band_weights):
"""
Apply weights from a dictionary to the bands of a
multispectral xarray.DataArray. Raises a ValueError if any
bands in `da` are not present in the `band_weights` dictionary.
Parameters
----------
da : xarray.DataArray object
DataArray containing multispectral data. The dataarray
should contain a "variable" dimension that corresponds
to the different bands of the data.
band_weights : dict
Mapping of band names to weights to be applied. The keys
of the dictionary should be the names of the bands in the
"variable" dimension of `da`, and the values should be
the weights to be applied to each band.
Returns
-------
xarray.DataArray object
DataArray with weights applied to the bands.
"""
# Identify any bands without weights, and raise an
# error if they exist
bands_without_weights = set(da["variable"].values) - set(band_weights.keys())
if len(bands_without_weights) > 0:
raise ValueError(
f"The following multispectral bands are missing from the "
f"`band_weights` dictionary: {bands_without_weights}.\n"
f"Ensure that weights are supplied for all multispectral "
f"bands in `ds`, or set `band_weights=None`."
)
# Create xr.DataArray with weights for each variable
# along the "variable" dimension
weights_da = xr.DataArray(
data=list(band_weights.values()),
coords={"variable": list(band_weights.keys())},
dims="variable",
)
# Apply weights
return da.weighted(weights_da)
def _brovey_pansharpen(ds, pan_band, band_weights=None):
"""
Perform pansharpening on multiple timesteps of a multispectral
dataset using the Brovey transform (with optional per-band weights).
Source: https://pro.arcgis.com/en/pro-app/latest/help/analysis/
raster-functions/fundamentals-of-pan-sharpening-pro.htm
Parameters
----------
ds : xarray.Dataset
Dataset containing multispectral and panchromatic bands.
pan_band : str
Name of the panchromatic band in the dataset.
band_weights : dict, optional
Mapping of band names to weights to be applied to each band when
calculating the sum of all multispectral bands. The keys of
the dictionary should be the names of the bands, and the values
should be the weights to apply to each band, e.g.:
``{"nbart_red": 0.4, "nbart_green": 0.4, "nbart_blue": 0.2}``.
The default accounts for Landsat 8 and 9's pan band only
partially overlapping with the blue band; this may not be
suitable for all applications. Setting `band_weights=None`
will use a simple unweighted sum.
Returns
-------
ds_pansharpened : xarray.Dataset
Pansharpened dataset with the same dimensions as the input dataset.
"""
# Create new dataarrays with and without pan band
da_nopan = ds.drop(pan_band).to_array()
da_pan = ds[pan_band]
# Calculate weighted sum
if band_weights is not None:
da_total = _apply_weights(da_nopan, band_weights).sum(dim="variable")
else:
da_total = da_nopan.sum(dim="variable")
# Perform Brovey Transform in form of: band / total * panchromatic
da_pansharpened = da_nopan / da_total * da_pan
ds_pansharpened = da_pansharpened.to_dataset("variable")
return ds_pansharpened
def _esri_pansharpen(ds, pan_band, band_weights=None):
"""
Perform pansharpening on multiple timesteps of a multispectral
dataset using the ESRI transform (with optional per-band weights).
Source: https://pro.arcgis.com/en/pro-app/latest/help/analysis/
raster-functions/fundamentals-of-pan-sharpening-pro.htm
Parameters
----------
ds : xarray.Dataset
Dataset containing multispectral and panchromatic bands.
pan_band : str
Name of the panchromatic band in the dataset.
band_weights : dict, optional
Mapping of band names to weights to be applied to each band when
calculating the mean of all multispectral bands. The keys of
the dictionary should be the names of the bands, and the values
should be the weights to apply to each band, e.g.:
``{"nbart_red": 0.4, "nbart_green": 0.4, "nbart_blue": 0.2}``.
The default accounts for Landsat 8 and 9's pan band only
partially overlapping with the blue band; this may not be
suitable for all applications. Setting `band_weights=None`
will use a simple unweighted mean.
Returns
-------
ds_pansharpened : xarray.Dataset
Pansharpened dataset with the same dimensions as the input dataset.
"""
# Create new dataarrays with and without pan band
da_nopan = ds.drop(pan_band).to_array()
da_pan = ds[pan_band]
# Calculate weighted sum
if band_weights is not None:
da_mean = _apply_weights(da_nopan, band_weights).mean(dim="variable")
else:
da_mean = da_nopan.mean(dim="variable")
# Calculate adjustment and apply to multispectral bands
adj = da_pan - da_mean
da_pansharpened = da_nopan + adj
ds_pansharpened = da_pansharpened.to_dataset("variable")
return ds_pansharpened
def _simple_mean_pansharpen(ds, pan_band):
"""
Perform pansharpening on multiple timesteps of a multispectral
dataset using the Simple Mean transform.
Source: https://pro.arcgis.com/en/pro-app/latest/help/analysis/
raster-functions/fundamentals-of-pan-sharpening-pro.htm
Parameters
----------
ds : xarray.Dataset
Dataset containing multispectral and panchromatic bands.
pan_band : str
Name of the panchromatic band in the dataset.
Returns
-------
ds_pansharpened : xarray.Dataset
Pansharpened dataset with the same dimensions as the input dataset.
"""
# Create new dataarrays with and without pan band
ds_nopan = ds.drop(pan_band)
da_pan = ds[pan_band]
# Take mean of pan band and RGBs
ds_pansharpened = (ds_nopan + da_pan) / 2.0
return ds_pansharpened
def _hsv_timestep_pansharpen(ds_i, pan_band):
"""
Perform pansharpening on a single timestep of a multispectral
dataset using the Hue Saturation Value (HSV) transform.
Parameters
----------
ds : xarray.Dataset
Dataset containing multispectral and panchromatic bands.
pan_band : str
Name of the panchromatic band in the dataset.
Returns
-------
ds_pansharpened : xarray.Dataset
Pansharpened dataset with the same dimensions as the input dataset.
"""
# Squeeze out any single dimensions
ds_i = ds_i.squeeze()
# Convert to an xr.DataArray and move "variable" to end
da_i = ds_i.to_array().transpose(..., "variable")
# Create new dataarrays with and without pan band
da_i_nopan = da_i.drop(pan_band, dim="variable")
da_i_pan = da_i.sel(variable=pan_band)
# Convert to HSV colour space
hsv = rgb2hsv(da_i_nopan)
# Replace value (lightness) channel with pan band data
hsv[:, :, 2] = da_i_pan.values
# Convert back to RGB colour space
pansharped_array = hsv2rgb(hsv)
# Add back into original array, reshape and return dataframe
da_i_nopan[:] = pansharped_array
ds_pansharpened = da_i_nopan.to_dataset("variable")
return ds_pansharpened
def _pca_timestep_pansharpen(ds_i, pan_band, pca_rescaling="histogram"):
"""
Perform pansharpening on a single timestep of a multispectral
dataset using the principal component analysis (PCA) transform.
Parameters
----------
ds : xarray.Dataset
Dataset containing multispectral and panchromatic bands.
pan_band : str
Name of the panchromatic band in the dataset.
pca_rescaling : str, optional
Method to use for rescaling pan band to more closely match the
distribution of values in the first PCA component. "simple"
scales the pan band values to more closely match the first PCA
component by subtracting the mean of the pan band values from
each value, scaling the resulting values by the ratio of the
standard deviations of the first PCA component and the pan band,
and adding back the mean of the first PCA component.
"histogram" uses a histogram matching technique to adjust the
pan band values so that the resulting histogram more closely
matches the histogram of the first PCA component.
Returns
-------
ds_pansharpened : xarray.Dataset
Pansharpened dataset with the same dimensions as the input dataset.
"""
# Squeeze out any single dimensions
ds_i = ds_i.squeeze()
# Reshape to 2D by stacking x and y dimensions to prepare it
# as an input to PCA. Drop NA rows as these are not supported
# by `pca.fit_transform`.
da_2d = (
ds_i.to_array()
.stack(pixel=("y", "x"))
.transpose("pixel", "variable")
.dropna(dim="pixel")
)
# Create new dataarrays with and without pan band
da_2d_nopan = da_2d.drop(pan_band, dim="variable")
da_2d_pan = da_2d.sel(variable=pan_band)
# Apply PCA transformation
pca = sklearn.decomposition.PCA()
pca_array = pca.fit_transform(da_2d_nopan)
# Rescale pan band to more closely match the first PCA component
if pca_rescaling == "simple":
pca_array[:, 0] = (da_2d_pan.values - da_2d_pan.values.mean()) * (
pca_array[:, 0].std() / da_2d_pan.values.std()
) + pca_array[:, 0].mean()
elif pca_rescaling == "histogram":
pca_array[:, 0] = match_histograms(da_2d_pan.values, pca_array[:, 0])
# Apply reverse PCA transform to restore multispectral array
pansharped_array = pca.inverse_transform(pca_array)
# Add back into original array, reshape and return dataframe
da_2d_nopan[:] = pansharped_array
ds_pansharpened = da_2d_nopan.unstack("pixel").to_dataset("variable")
return ds_pansharpened
[docs]
def xr_pansharpen(
ds,
transform,
pan_band="nbart_panchromatic",
return_pan=False,
output_dtype=None,
parallelise=False,
band_weights={"nbart_red": 0.4, "nbart_green": 0.4, "nbart_blue": 0.2},
pca_rescaling="histogram",
):
"""
Apply pan-sharpening to multispectral satellite data with one
or more timesteps. The following pansharpening transforms are
currently supported:
- Brovey ("brovey"), with optional band weighting
- ESRI ("esri"), with optional band weighting
- Simple mean ("simple mean")
- PCA ("pca")
- HSV ("hsv"), similar to IHS
Note: Pan-sharpening transforms do not necessarily maintain
the spectral integrity of the input satellite data, and may
be more suitable for visualisation than quantitative work.
Parameters
----------
ds : xarray.Dataset
An xarrray dataset containing the three input multispectral
bands, and a panchromatic band. This dataset should have
already been resampled to the spatial resolution of the
panchromatic band (15 m for Landsat). Due to differences in
the electromagnetic spectrum covered by the panchromatic band,
Landsat 8 and 9 data should be supplied with 'blue', 'green',
and 'red' multispectral bands, while Landsat 7 should be
supplied with 'green', 'red' and 'NIR'.
transform : string
The pansharpening transform to apply to the data. Valid options
include "brovey", "esri", "simple mean", "pca", "hsv".
pan_band : string, optional
The name of the panchromatic band that will be used to
pansharpen the multispectral data.
return_pan : bool, optional
Whether to return the panchromatic band in the output dataset.
Defaults to False.
output_dtype : string or numpy.dtype, optional
The dtype used for the output values. Defaults to the input
dtype of the multispectral bands in `ds`.
parallelise: bool, optional
Whether to parallelise transformations across multiple cores.
Used for PCA and HSV transforms that are applied to each
timestep in `ds` individually; defaults to False.
band_weights : dict, optional
Used for the Brovey and ESRI transforms. Mapping of band
names to weights to be applied to each band when calculating
the sum (Brovey) or mean (ESRI) of all multispectral bands.
The keys of the dictionary should be the names of the bands,
and the values should be the weights to apply to each band, e.g.:
``{"nbart_red": 0.4, "nbart_green": 0.4, "nbart_blue": 0.2}``.
The default accounts for Landsat 8 and 9's pan band only
partially overlapping with the blue band; this may not be
suitable for all applications. Setting `band_weights=None`
will use a simple unweighted sum (for the Brovey transform)
or unweighted mean (for the ESRI transform).
pca_rescaling : str, optional
Used for the PCA transform. The method to use for rescaling
pan band to more closely match the distribution of values
in the first PCA component. "simple" scales the pan band
values to more closely match the first PCA component by
subtracting the mean of the pan band values from each value,
scaling the resulting values by the ratio of the standard
deviations of the first PCA component and the pan band, and
adding back the mean of the first PCA component.
"histogram" uses a histogram matching technique to adjust the
pan band values so that the resulting histogram more closely
matches the histogram of the first PCA component.
Returns
-------
ds_pansharpened : xarray.Dataset
An xarrray dataset containing the three pansharpened input
multispectral bands and optionally the panchromatic band
(if `return_pan=True`).
"""
# Assert whether pan band exists in the dataset
if pan_band not in ds.data_vars:
raise ValueError(
f"The specified panchromatic band '{pan_band}' cannot be found in `ds`. "
f"Specify a panchromatic band name that exists in the dataset using `pan_band=...`."
)
# Assert whether exactly three multispectral bands are included in `ds`
n_multi = len(ds.drop(pan_band).data_vars)
if n_multi != 3:
raise ValueError(
f"`ds` should contain exactly three multispectral bands (not "
f"including the panchromatic band). However, {n_multi} "
f"multispectral bands were found: {list(ds.drop(pan_band).data_vars)}. "
)
# Define dict linking functions to each transform
transform_dict = {
"brovey": _brovey_pansharpen,
"esri": _esri_pansharpen,
"simple mean": _simple_mean_pansharpen,
"pca": _pca_timestep_pansharpen,
"hsv": _hsv_timestep_pansharpen,
}
# If Brovey, ESRI or Simple Mean pansharpening is specified, apply to
# entire `xr.Dataset` in one go (with optional weights for Brovey, ESRI)
if transform in ("brovey", "esri", "simple mean"):
print(f"Applying {transform.capitalize()} pansharpening")
extra_params = (
{"band_weights": band_weights} if transform in ("brovey", "esri") else {}
)
ds_pansharpened = transform_dict[transform](
ds,
pan_band=pan_band,
**extra_params,
)
# Otherwise, apply PCA or HSV pansharpening to each
# timestep in the `xr.Dataset` using `.apply`
elif transform in ("pca", "hsv"):
extra_params = {"pca_rescaling": pca_rescaling} if transform == "pca" else {}
# Apply pansharpening to all timesteps in data in parallel
if ("time" in ds.dims) and parallelise:
print(f"Applying {transform.upper()} pansharpening in parallel")
ds_pansharpened = parallel_apply(
ds,
"time",
transform_dict[transform],
pan_band,
*extra_params.values(), # TODO: Update once `parallel_apply` supports kwargs
)
# Apply pansharpening to all timesteps in data sequentially
elif ("time" in ds.dims) and not parallelise:
print(f"Applying {transform.upper()} pansharpening")
ds_pansharpened = ds.groupby("time").apply(
transform_dict[transform],
pan_band=pan_band,
**extra_params,
)
# Otherwise, apply func directly if only one timestep
else:
print(f"Applying {transform.upper()} pansharpening")
ds_pansharpened = transform_dict[transform](
ds, pan_band=pan_band, **extra_params
)
else:
raise ValueError(
f"Unsupported value '{transform}' passed to `method`. Please "
f"provide one of {list(transform_dict.keys())}."
)
# Optionally insert pan band back into dataset
if return_pan:
ds_pansharpened[pan_band] = ds[pan_band]
# Return data in original or requested dtype
return ds_pansharpened.astype(
ds.to_array().dtype if output_dtype is None else output_dtype
)
[docs]
def load_reproject(
path,
how,
resolution="auto",
tight=False,
resampling="nearest",
chunks={"x": 2048, "y": 2048},
bands=None,
masked=True,
reproject_kwds=None,
**kwargs,
):
"""
Load and reproject part of a raster dataset into a given GeoBox or
custom CRS/resolution.
Parameters
----------
path : str
Path to the raster dataset to be loaded and reprojected.
how : GeoBox, str or int
How to reproject the raster. Can be a GeoBox or a CRS (e.g.
"ESPG:XXXX" string or integer).
resolution : str or int, optional
The resolution to reproject the raster dataset into if `how` is
a CRS, by default "auto". Supports:
- "same" use exactly the same resolution as the input raster
- "fit" use center pixel to determine required scale change
- "auto" uses the same resolution on the output if CRS units
are the same between source and destination; otherwise "fit"
- Else, a specific resolution in the units of the output crs
tight : bool, optional
By default output pixel grid is adjusted to align pixel edges
to X/Y axis, suppling tight=True produces an unaligned geobox.
resampling : str, optional
Resampling method to use when reprojecting data, by default
"nearest", supports all standard GDAL options ("average",
"bilinear", "min", "max", "cubic" etc).
chunks : dict, optional
The size of the Dask chunks to load the data with, by default
{"x": 2048, "y": 2048}.
bands : str or list, optional
Bands to optionally filter to when loading data.
masked : bool, optional
Whether to mask the data by its nodata value, by default True.
reproject_kwds : dict, optional
Additional keyword arguments to pass to the `.odc.reproject()`
method, by default None.
**kwargs : dict
Additional keyword arguments to be passed to the
`rioxarray.open_rasterio` function.
Returns
-------
xarray.Dataset
The reprojected raster dataset.
"""
# Use empty kwds if not provided
reproject_kwds = {} if reproject_kwds is None else reproject_kwds
# Load data with rasterio
da = rioxarray.open_rasterio(
filename=path,
masked=masked,
chunks=chunks,
**kwargs,
)
# Optionally filter to bands
if bands is not None:
da = da.sel(band=bands)
# Reproject into GeoBox
da = da.odc.reproject(
how=how,
resolution=resolution,
tight=tight,
resampling=resampling,
dst_nodata=np.NaN if masked else None,
**reproject_kwds,
)
# Squeeze if only one band
da = da.squeeze()
return da