Collecting large-scale training and validation datasets from the ODC 9470fc05301b4eebadd2323d76b0871c

Background

In many Earth observation workflows, we need to extract satellite measurements from sparsely distributed geographic locations, often represented as points or polygons. These locations may come from field observations, environmental monitoring programs, or labelled training/validation datasets. Unlike typical raster processing tasks (which operate on continuous imagery), these tasks require selectively sampling many small, irregular areas across large spatial and temporal extents.

The collect_training_data function exists to streamline and scale this process within the context of the Open Data Cube (ODC). You should use this function for: * Generating training datasets for machine‑learning models, where each input geometry represents an example associated with a class label. * Gathering validation samples, both for ML workflows and for classical remote sensing algorithms. * Pairing field observations with satellite data, such as matching on-ground measurements of vegetation, soil, biodiversity, water quality, or land condition with satellite-derived features.

Instead of manually writing loops to load data for each location, handle varying geometry shapes, set up complex parallelisations, or manage errors during I/O operations, collect_training_data automates these tasks and provides built‑in support for:

  • parallel processing,

  • zonal statistics over polygons,

  • per-sample time ranges,

  • retry logic for unstable reads on cloud platforms,

  • data cleaning and stacking into a uniform table.

This makes it well‑suited for workflows where the end product is tabular data, such as training/validating for ML models like random forests or gradient boosting models.

In summary, collect_training_data is built for large‑scale sampling of ODC satellite data where locations are numerous and sparse, and the desired output is a structured table of feature values.

A Note on Limitations * Because collect_training_data is designed to produce tabular feature vectors, it is not well‑suited for machine‑learning workflows that require image chips or multi‑dimensional inputs such as convolutional neural networks (CNNs), U‑Nets, or segmentation models that operate directly on spatial image patches. Those workflows require preserving spatial structure, whereas collect_training_data intentionally collapses spatial information into per‑sample feature values. * Large, highly parallelised extraction workflows can be fragile, especially those that run for hours. It can be frustrating to have a long‑running job fail near the end, resulting in no usable output. For very large datasets (roughly more than ~20,000 samples), or for cases where each sample is expensive to load (e.g. pulling several years of satellite observations per geometry), it is often safer and more reliable to break the task into smaller batches and merge the intermediate results afterwards.

Description

In this notebook we will learn how to use the functionalities of collect_training_data by moving through the following sections:

  • Discussing how collect_training_data works and what its outputs are.

  • Discussing how to define a feature_func to pass to collect_training_data.

  • Defining a simple feature_func and running it in serial and parallel.

  • Applying zonal statistics to polygon samples.

  • Defining a complex feature_func and running it in parallel.

  • How to use the time_field parameter to allow loading from different time ranges for every sample.

  • A discussion on retry queues and the parameters available to configure the retry queue.

This notebook is intended as a reference for using the collect_training_data function. If you would like to see an implementation of this function within an end-to-end ML workflow, see the in-depth exploration in Scalable Machine Learning series of notebooks.

Getting started

To run this analysis, run all the cells in the notebook, starting with the “Load packages” cell.

Load packages

[1]:
import datacube
import xarray as xr
import geopandas as gpd
import odc.geo.xr

import sys
sys.path.insert(1, "../Tools/")
from dea_tools.classification import collect_training_data
from dea_tools.datahandling import load_ard, load_reproject

Analysis parameters

  • path: The path to the input vector file from which we will extract training data. A default geojson is provided.

  • field: This is the name of column in your vector attribute table that contains class labels. The class labels must be integers

  • ncpus: The number of cores for parallelism of the data collection

[2]:
path = "../Real_world_examples/Scalable_machine_learning/data/crop_training_WA.geojson"
field = "class"
ncpus = 2

Preview input data

We can load and preview our input data vector using geopandas. The geojson should contain a column with class labels (e.g. ‘class’).

The class labels must be represented by integers

[3]:
# Load input data vector file
input_data = gpd.read_file(path)

# Print data
input_data.head(3)
[3]:
class geometry
0 1 POINT (116.60407 -31.46883)
1 1 POINT (117.03464 -32.40830)
2 1 POINT (117.30838 -32.33747)

Our point dataset has 430 points. Throughout this demo, we will take a small sample of these points when running functions to speed up the run times.

[4]:
input_data = input_data.sample(n=6, random_state=1).reset_index(drop=True)

Using “collect_training_data”

The function collect_training_data provides a flexible and scalable way to extract samples from an Open Data Cube (ODC) based on geometries stored in a GeoDataFrame. It loops over each input feature (point or polygon), loads the corresponding satellite data defined by a datacube query, and applies a user-defined feature extraction function (feature_func).

The output is a pandas.DataFrame where:

  • each row represents a single extracted sample,

  • the index contains class labels from the vector data,

  • the columns contain the computed feature values (e.g., spectral bands, indices, or other user-defined metrics).

This function can run serially or in parallel, making it well-suited for large datasets containing thousands of sparsely distributed locations.

What is a “feature_func”?

The feature_func parameter is where you define how the satellite data should be processed for each geometry. It is a small (or large if you like!) Python function that receives the dc_query dictionary, loads data from the Open Data Cube, and returns an xarray.Dataset or xarray.DataArray containing the features you want to extract. The feature-func can be as simple or as complex as you like, but it should be a self contained function that does all of the processing you require, and can only accept a dc_query object as an input.

In simple examples, this function might just load a few bands and take a time-average. For example:

def feature_function(query):
    dc = datacube.Datacube(app="feature_layers")
    ds = dc.load(**query)
    return ds.mean("time")

Think of feature_func as the custom feature‑builder, you control what data is loaded and how it is transformed. We will define one below shortly.

First, simple example

Let’s start off by running collect_training_data with a simple example. Firstly, we need to define an input query dictionary for the data we want to load from the datacube and may include:

  • measurements: list of satellite bands or derived products to load

  • resolution: pixel size (e.g. (-30, 30))

  • output_crs: desired projection for loaded data

  • time: an optional global time window (mutually exclusive with time_field)

The following step is then to define a feature_func. In this first example we will load data from the Landsat GeoMAD product, ga_ls8cls9c_gm_cyear_3, and return all the bands except count.

[5]:
# Set up the inputs for the ODC query
time = "2024"
resolution = (-30, 30)
output_crs = "EPSG:3577"

# Generate a datacube query object
dc_query = {
    "time": time,
    "resolution": resolution,
    "output_crs": output_crs,
}
[6]:
def simple_feature_layers(query):

    # Connect to the datacube
    dc = datacube.Datacube(app="custom_feature_layers")

    # Load ls8 geomedian and drop count
    ds = dc.load(
        product="ga_ls8cls9c_gm_cyear_3", skip_broken_datasets=True, **query
    ).drop_vars("count")
    return ds

Now, we can pass this function to collect_training_data. We will run this in serial (with 1 CPU). Note the function will give us a few helpful print statements along the way.

[7]:
df = collect_training_data(
    gdf=input_data,
    dc_query=dc_query,
    feature_func=simple_feature_layers,
    field=field,
    ncpus=1,
)
Collecting training data in serial mode
Removed 0 rows with NaNs or Infs in numeric columns
Output shape: (6, 10)

Now let’s inspect the output. You can see we have returned a pandas.Dataframe where the index is the class label (0’s and 1’s in this case), and columns containing measurements from the GeoMAD product.

[8]:
df
[8]:
nbart_blue nbart_green nbart_red nbart_nir nbart_swir_1 nbart_swir_2 sdev edev bcdev
class
1 1087 1641 2216 3479 3634 2460 0.003306 1809.477051 0.132994
1 478 687 864 1920 2396 1622 0.000531 290.915375 0.036897
1 756 1082 1473 2693 3886 2664 0.001877 503.030487 0.042939
0 401 702 1341 2196 2862 2307 0.000908 307.273621 0.032364
1 868 1262 1666 2764 3488 2450 0.004713 855.537231 0.067494
1 826 1282 1869 2879 3636 2785 0.003180 644.176453 0.048564

Parallelisation

In the case where we have many, many samples to collect, we can choose to run the function in parallel. collect_training_data can automatically parallelise processing across multiple CPUs:

  • If ncpus=1, extraction runs in simple serial mode.

  • If ncpus>1, the function uses multiprocessing to distribute samples across available CPUs (user defines the number of CPUs).

Parallel mode is particularly useful for large training datasets where loading thousands of polygons or points can be time-consuming. In this default example we only have access to two CPUs. Our results should be identical to the example above, only it’ll complete in ~half the time (minus some overhead in setting up the parallelisation).

Note: the row positions of the output dataframe might not be identical to the serial version above since the multiprocessing method employed is not guaranteed to return items in order.

[9]:
df = collect_training_data(
    gdf=input_data,
    dc_query=dc_query,
    feature_func=simple_feature_layers,
    field=field,
    ncpus=ncpus  # use both cores on this machine
)

df
Collecting training data in parallel mode
Percentage of possible fails after run 1 = 0.0 %
Removed 0 rows with NaNs or Infs in numeric columns
Output shape: (6, 10)
[9]:
nbart_blue nbart_green nbart_red nbart_nir nbart_swir_1 nbart_swir_2 sdev edev bcdev
class
1 478 687 864 1920 2396 1622 0.000531 290.915375 0.036897
1 -999 1641 2216 3479 3634 2460 0.003306 1809.477051 0.132994
1 756 1082 1473 2693 3886 2664 0.001877 503.030487 0.042939
0 401 702 1341 2196 2862 2307 0.000908 307.273621 0.032364
1 868 1262 1666 2764 3488 2450 0.004713 855.537231 0.067494
1 826 1282 1869 2879 3636 2785 0.003180 644.176453 0.048564

Zonal Statistics

If your input geometries contain polygons, you may optionally compute zonal statistics by setting:

zonal_stats="mean"   # or "median", "max", "min"
  • None (default): returns all pixel values inside each polygon.

  • "mean", "median", "max", "min": returns a single aggregate value per band.

This is useful when training feature sets based on aggregated polygon-level properties rather than per-pixel samples. Below we will convert our point samples to small polygons, and then extract data both with and without applying zonal-stats.

[10]:
# make a copy so we don't alter our original file
input_data_polygon = input_data.copy()

# convert our points to small 60m x 60m polygons
input_data_polygon["geometry"] = (
    input_data["geometry"].to_crs("epsg:3577").buffer(30).envelope.to_crs("epsg:4326")
)

# set the statistics to use for zonal-stats
statistic = "mean"

# update the time variable in the query again
dc_query.update({"time": "2024"})

Our output from the function will again contain six samples because although we provided polygons, we instructed collect_training_data to take a mean of all pixel values returned with each polygon extent.

[11]:
df = collect_training_data(
    gdf=input_data_polygon,
    dc_query=dc_query,
    feature_func=simple_feature_layers,
    field=field,
    zonal_stats=statistic,  # set the zonal stats
    ncpus=ncpus,
)

df
Applying zonal statistic: mean
Collecting training data in parallel mode
Percentage of possible fails after run 1 = 0.0 %
Removed 0 rows with NaNs or Infs in numeric columns
Output shape: (6, 10)
[11]:
nbart_blue nbart_green nbart_red nbart_nir nbart_swir_1 nbart_swir_2 sdev edev bcdev
class
1 556.50 829.25 1088.00 2103.25 2753.75 1921.50 0.000550 314.351196 0.032992
1 940.00 1438.50 1958.75 3073.00 3239.25 2238.25 0.002478 1302.383911 0.103427
1 759.00 1095.25 1493.25 2779.00 3924.75 2666.75 0.001902 555.482605 0.046886
0 409.00 724.00 1386.50 2269.50 2928.75 2373.00 0.000774 286.934326 0.029271
1 896.75 1302.25 1725.25 2832.75 3605.75 2529.75 0.004121 801.337524 0.061265
1 846.75 1318.25 1917.00 2955.00 3693.00 2818.25 0.003616 690.448364 0.049931

Now, let’s try setting zonal_stats=None and see how many samples are returned. After running the code you’ll see in this case we have returned 24 rows because within each polygon we have returned all valid pixel values.

In the output printed below, you’ll also notice that ~24 rows were removed because our samples contained NaN values (probably at the boundaries of the small polygons we created). If we wanted to keep these samples, then we can pass clean=False into collect_training_data. By default, it is set to True.

[12]:
df = collect_training_data(
    gdf=input_data_polygon,
    dc_query=dc_query,
    feature_func=simple_feature_layers,
    field=field,
    zonal_stats=None,  # This is the default behaviour
    ncpus=ncpus,
)

df
Collecting training data in parallel mode
Percentage of possible fails after run 1 = 0.0 %
Removed 24 rows with NaNs or Infs in numeric columns
Output shape: (24, 10)
[12]:
nbart_blue nbart_green nbart_red nbart_nir nbart_swir_1 nbart_swir_2 sdev edev bcdev
class
1 475.0 683.0 861.0 1874.0 2393.0 1662.0 0.000684 295.805664 0.035670
1 478.0 687.0 864.0 1920.0 2396.0 1622.0 0.000531 290.915375 0.036897
1 647.0 992.0 1353.0 2345.0 3195.0 2279.0 0.000485 359.263489 0.030462
1 626.0 955.0 1274.0 2274.0 3031.0 2123.0 0.000501 311.420197 0.028938
1 768.0 1200.0 1639.0 2559.0 2751.0 1947.0 0.001015 558.416321 0.054844
1 872.0 1349.0 1855.0 2932.0 3107.0 2167.0 0.002552 1202.241821 0.099947
1 1033.0 1564.0 2125.0 3322.0 3465.0 2379.0 0.003039 1639.400513 0.125921
1 1087.0 1641.0 2216.0 3479.0 3634.0 2460.0 0.003306 1809.477051 0.132994
1 756.0 1082.0 1473.0 2693.0 3886.0 2664.0 0.001877 503.030487 0.042939
1 780.0 1118.0 1518.0 2786.0 4019.0 2774.0 0.001819 567.402405 0.046460
1 708.0 1030.0 1422.0 2723.0 3806.0 2530.0 0.001940 509.284027 0.045060
1 792.0 1151.0 1560.0 2914.0 3988.0 2699.0 0.001971 642.213501 0.053087
0 419.0 750.0 1427.0 2355.0 3000.0 2444.0 0.000743 282.387909 0.027148
0 415.0 747.0 1441.0 2354.0 3005.0 2449.0 0.000707 257.812927 0.025330
0 401.0 702.0 1341.0 2196.0 2862.0 2307.0 0.000908 307.273621 0.032364
0 401.0 697.0 1337.0 2173.0 2848.0 2292.0 0.000736 300.262817 0.032242
1 914.0 1327.0 1764.0 2880.0 3685.0 2585.0 0.003619 745.375244 0.054979
1 923.0 1340.0 1784.0 2882.0 3710.0 2586.0 0.003211 728.596924 0.054985
1 882.0 1280.0 1687.0 2805.0 3540.0 2498.0 0.004941 875.840637 0.067602
1 868.0 1262.0 1666.0 2764.0 3488.0 2450.0 0.004713 855.537231 0.067494
1 871.0 1358.0 1969.0 3042.0 3759.0 2858.0 0.003902 770.102539 0.053305
1 843.0 1306.0 1888.0 2891.0 3615.0 2766.0 0.003421 630.546143 0.047134
1 847.0 1327.0 1942.0 3008.0 3762.0 2864.0 0.003963 716.968201 0.050719
1 826.0 1282.0 1869.0 2879.0 3636.0 2785.0 0.003180 644.176453 0.048564

A more complicated example

Below, we will define a more complicated feature_func than the simple example shown above. We will load: * A time series of satellite bands from Landsat 8, * Append the PV bands from the fractional cover percentiles product, * Additionally add an external Digital Elevation Model (DEM) dataset that is stored as a GeoTIFF on a DEA public data bucket (this could equally be a local GeoTIFF). We will use the function dea-tools.datahandling.load_reproject for this, which is built to load and reproject part of a raster dataset into a given GeoBox or custom CRS/resolution.

Note, sometimes when using load_reproject with many CPUs, it can become memory inefficient. If you encounter this, consider instead using rio_slurp_xarray, which can be more memory efficient but does not provide dask-backed reprojection. Drop-in replacement code for our example is shown below:

from datacube.testutils.io import rio_slurp_xarray
path_url = "https://dea-public-data.s3-ap-southeast-2.amazonaws.com/projects/elevation/ga_srtm_dem1sv1_0/dems1sv1_0.tif"
dem = rio_slurp_xarray(
    path_url,
    geobox=ds.odc.geobox,
    resampling="nearest",
).to_dataset(name="DEM")
[13]:
def complex_feature_layers(query):

    # Connect to the datacube
    dc = datacube.Datacube(app="custom_feature_layers")

    # load landsat 8 time series
    ds = load_ard(
        dc=dc,
        products=["ga_ls8c_ard_3"],
        measurements=["nbart_green", "nbart_red", "nbart_blue"],
        mask_contiguity=True,
        mask_pixel_quality=True,
        verbose=False,
        skip_broken_datasets=True,
        group_by="solar_day",
        **query
    )

    # take the mean of the time series
    ds = ds.mean("time")

    # Add Fractional cover percentiles
    fc = dc.load(
        product="ga_ls_fc_pc_cyear_3",
        measurements=["pv_pc_10", "pv_pc_50", "pv_pc_90"],  # only the PV band
        like=ds,  # will match LS8 extent
        time=query.get("time"),  # use time if in query
    )

    # add a DEM. use dea-tools.datahandling.load_reproject  for this.
    path_url = "https://dea-public-data.s3-ap-southeast-2.amazonaws.com/projects/elevation/ga_srtm_dem1sv1_0/dems1sv1_0.tif"
    dem = load_reproject(
        path=path_url, how=ds.odc.geobox, resampling="nearest"
    ).drop_vars('band').to_dataset(name="DEM")

    # Merge results into single dataset
    result = xr.merge([ds, fc, dem], compat="override")

    return result

Now we can run collect_training_data again with the more complex function, note we are updating the dc_query['time'] object to load just the first month of 2024 so the function runs faster for this example.

[14]:
# update the time variable in the query
dc_query.update({"time": ("2024-01-01", "2024-01-31")})

df = collect_training_data(
    gdf=input_data_polygon, # use our polygon example again
    dc_query=dc_query,
    feature_func=complex_feature_layers,
    zonal_stats=statistic,
    field=field,
    ncpus=ncpus,  # use both cores on this machine
)

df
Applying zonal statistic: mean
Collecting training data in parallel mode
Percentage of possible fails after run 1 = 0.0 %
Removed 0 rows with NaNs or Infs in numeric columns
Output shape: (6, 8)
[14]:
nbart_green nbart_red nbart_blue pv_pc_10 pv_pc_50 pv_pc_90 DEM
class
1 1848.1250 2617.2500 1173.1250 4.00 8.00 68.75 287.490784
1 833.7500 1121.7500 547.8750 19.25 24.75 33.25 443.540558
1 1218.6250 1735.3750 844.0000 11.00 19.00 71.50 328.243011
1 1442.1250 2013.0000 984.8750 5.50 9.50 94.00 353.246704
0 634.2500 1317.2500 349.5000 7.00 17.25 22.75 286.581696
1 1557.6875 2337.0625 962.9375 6.00 9.25 79.75 191.174728

Using the “time_field” Parameter

The time_field parameter allows each geometry in your GeoDataFrame to be sampled from a different time range. This is useful when your training points or polygons correspond to events that occurred at different dates, such as water observations, fires, crop growth stages, land‑use changes, or from field campaigns.

Normally, you specify a single time range in the dc_query dictionary (e.g., "time": ("2020-01-01", "2020-12-31")), and that same period is used for every sample. However, in many real datasets, each training point has its own timestamp. By setting, for example:

time_field="date"

collect_training_data will:

  1. Read the timestamp from the specified column in the GeoDataFrame.

  2. Use that date (or date range) to load data for that individual sample.

  3. Apply the feature_func to data from the correct time window.

The values in time_field must be in a format accepted by datacube.load(), such as:

  • A single date string "2021-05-03"

  • A tuple defining a range ("2021-05-01", "2021-05-10")

Let’s set up an example of this below where each sample will load from a different year.

[15]:
# add a date column to our geopandas dataframe
dates = ["2019", "2020", "2021", "2022", "2023", "2024"]
input_data["date"] = dates

# we also need to remove 'time' from the dc_query
# if we don't, the function will return a ValueError
dc_query.pop("time", None)
[15]:
('2024-01-01', '2024-01-31')

Now run collect_training_data with the time_field set.

Note: * We can also set return_time_coords=True which will append a new column to our output that contains the time-stamp of each sample. This could be useful, for example, if you’d like to do post-processing of the datasets. * Additionally, we will also demonstate returning the x,y locations of each sample using the return_coords parameter. If True, additional columns (x_coord, y_coord) are added for each extracted point/pixel.

[16]:
df = collect_training_data(
    gdf=input_data,
    dc_query=dc_query,
    feature_func=simple_feature_layers,
    field=field,
    time_field="date",  # column with our per-row time ranges
    return_time_coords=True,  # these params will return new columns!
    return_coords=True,
    ncpus=ncpus,
)

df
Collecting training data in parallel mode
Percentage of possible fails after run 1 = 0.0 %
Removed 0 rows with NaNs or Infs in numeric columns
Output shape: (6, 13)
[16]:
nbart_blue nbart_green nbart_red nbart_nir nbart_swir_1 nbart_swir_2 sdev edev bcdev x_coord y_coord time_coord
class
1 520 790 1053 2046 2745 1888 0.000210 133.539413 0.017182 -1325115.0 -3362445.0 2020-07-01 23:59:59.999999
1 1014 1611 2291 3622 3616 2336 0.001784 1365.965332 0.089852 -1448355.0 -3510075.0 2019-07-02 11:59:59.999999
0 423 727 1400 2111 2962 2423 0.000149 183.791794 0.019001 -577005.0 -3137235.0 2022-07-02 11:59:59.999999
1 727 1007 1302 2062 3499 2648 0.001103 617.994568 0.063858 -1350555.0 -3644055.0 2021-07-02 11:59:59.999999
1 826 1282 1869 2879 3636 2785 0.003180 644.176453 0.048564 -1508655.0 -3451545.0 2024-07-01 23:59:59.999999
1 843 1256 1697 2877 3320 2293 0.011416 2073.405518 0.160282 -1328625.0 -3649755.0 2023-07-02 11:59:59.999999

Retry Queue: Purpose and Key Parameters

When running collect_training_data in parallel, occasional read failures can occur, especially when loading many samples from cloud-backed storage such as S3. These failures typically result in missing (NaN) values in the extracted features. To make the process more robust, the function includes a retry queue that automatically re‑attempts samples that appear to have failed.

Three parameters control this behaviour:

  • ``fail_ratio``: Determines when an individual sample is considered a failure. For example, with fail_ratio=0.5 (default), a sample is flagged as failed if more than 50% of its feature values are missing.

  • ``fail_threshold``: Sets the acceptable fraction of failed samples in the whole dataset. If the proportion of failed samples is greater than this threshold, those samples are placed back into the queue for reprocessing. This prevents the function from retrying unnecessarily when only a very small number of failures occur. Default is 0.05 (i.e. 5 % of samples)

  • ``max_retries`` :Limits how many times the function will attempt to recollect failed samples. Once this limit is reached - or the failure rate falls below fail_threshold - the retry process stops. Default is 2.

Most of the time you can leave the parameters as their defaults.

Next Steps

The collect_training_data function is used in two further examples within DEA-Notebooks, both of which explore its use within a ML workflow: * Machine Learning with ODC * Within the Scalable Machine Learning notebook series.


Additional information

License: The code in this notebook is licensed under the Apache License, Version 2.0. Digital Earth Australia data is licensed under the Creative Commons by Attribution 4.0 license.

Contact: If you need assistance, please post a question on the Open Data Cube Discord chat or on the GIS Stack Exchange using the open-data-cube tag (you can view previously asked questions here). If you would like to report an issue with this notebook, you can file one on GitHub.

Last modified: March 2026

Compatible datacube version:

[17]:
print(datacube.__version__)
1.8.19

Tags

Tags Landsat 8 geomedian, Landsat 8 TMAD, machine learning, collect_training_data, Fractional Cover