Machine learning with the Open Data Cube 
Sign up to the DEA Sandbox to run this notebook interactively from a browser
Compatibility: Notebook currently compatible with both the
NCI
andDEA Sandbox
environmentsProducts used: ga_ls8cls9c_gm_cyear_3
Special requirements: A shapefile of labelled data in shapefile format is required to use this notebook. An example dataset is provided.
Prerequisites: A basic understanding of supervised learning techniques is required. Introduction to statistical learning is a useful resource to begin with - it can be downloaded for free here. The Scikit-learn documentation provides information on the available models and their parameters.
Description
This notebook demonstrates a potential workflow using functions from the dea_tools.classification script to implement a supervised learning landcover classifier within the ODC (Open Data Cube) framework.
This example predicts a single class of cultivated / agricultural areas. The notebook demonstrates how to:
Extract the desired ODC data for each labelled area (this becomes our training dataset).
Train a simple decision tree model and adjust parameters.
Evaluate the output of the classification using quantitative metrics and qualitative tools.
Predict the extent of cropping using trained model on new data.
This notebook is intended as a quick reference for machine learning on the ODC. For a more in depth exploration please use the Scalable Machine Learning series of notebooks.
Getting started
To run this analysis, run all the cells in the notebook, starting with the “Load packages” cell.
Load packages
Import Python packages that are used for the analysis.
[1]:
%matplotlib inline
import os
import datacube
import pydotplus
import numpy as np
import xarray as xr
import geopandas as gpd
from io import StringIO
from sklearn import tree
from sklearn import model_selection
from IPython.display import Image
from odc.io.cgroups import get_cpu_quota
from sklearn.metrics import accuracy_score
from odc.geo.xr import write_cog
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, '../Tools/')
from dea_tools.classification import collect_training_data, predict_xr
import warnings
warnings.filterwarnings("ignore")
Connect to the datacube
Connect to the datacube so we can access DEA data.
[2]:
dc = datacube.Datacube(app='Machine_learning_with_ODC')
Analysis parameters
path
: The path to the input shapefile. A default shapefile is provided.field
: This is the name of column in your shapefile attribute table that contains the class labelstime
: The time range you wish to extract data for, typically the same date the labels were created.zonal_stats
: This is an option to calculate the'mean'
,'median'
, or'std'
of the pixel values within each polygon feature, setting it toNone
will result in all pixels being extracted.resolution
: The spatial resolution, in metres, to resample the satellite data too e.g. if working with Landsat data, then this should be(-30, 30)
output_crs
: The coordinate reference system for the data you are querying.ncpus
: Set this value to > 1 to parallize the collection of training data. eg.npus=8
If running the notebook for the first time, keep the default settings below. This will demonstrate how the analysis works and provide meaningful results.
[3]:
path = '../Real_world_examples/Scalable_machine_learning/data/crop_training_WA.geojson'
field = 'class'
time = ('2015')
resolution = (-30, 30)
output_crs = 'EPSG:3577'
Automatically detect the number of cpus
[4]:
if get_cpu_quota() is not None:
ncpus = round(get_cpu_quota())
else:
ncpus = os.cpu_count()
print(f"ncpus = {ncpus}")
ncpus = 2
Preview input data and study area
We can load and preview our input data geojson using geopandas
. The geojson should contain a column with class labels (e.g. class
below). These labels will be used to train our model, importantly, the class labels should be integers.
[5]:
# Load input data shapefile
input_data = gpd.read_file(path)
# Down sample the number of input data polygons to speed up the analysis
input_data = input_data.sample(frac=0.25, random_state=1).reset_index(drop=True)
# Plot first five rows
input_data.head()
[5]:
class | geometry | |
---|---|---|
0 | 1 | POINT (116.58083 -31.42120) |
1 | 1 | POINT (118.05265 -30.23842) |
2 | 1 | POINT (117.45115 -32.71308) |
3 | 0 | POINT (126.01858 -28.75520) |
4 | 1 | POINT (117.67845 -32.78576) |
The data can also be explored using the interactive map below. Hover over each individual feature to see a print-out of its unique class label number above the map.
[6]:
# Plot training data in an interactive map
input_data.explore(column=field, legend=False)
[6]:
Extract training data
To train our model, we need to obtain satellite data that corresponds with the labelled input data locations above. The function below takes our polygons containing class labels and extracts the specified product within these areas into a single array.
The following function is passed to collect_training_data
. It extracts bands from the ga_ls8cls9c_gm_cyear_3 product as feature layers along side our labelled data so we can train a supervised model. The feature function can be modified to extract different combinations of features within the datacube, it is one of the important parts to experiment with when generating your own model.
You can find a more detailed description of the
collect_training_data
function and its attributes in the Scalable Machine Learning series.
[7]:
# Generate a new datacube query object
query = {
'time': time,
'resolution': resolution,
'output_crs': output_crs,
'group_by': 'solar_day',
}
[8]:
def custom_function(query):
# Initialise datacube
dc = datacube.Datacube(app='custom_feature_layers')
# Load data using query
result = dc.load(product='ga_ls8cls9c_gm_cyear_3', **query)
return result
Note: The following cell can take several minutes to run. The class labels will be contained in the first column of the output array called model_input, the corresponding variable names captured in the list column_names
[9]:
column_names, model_input = collect_training_data(
gdf=input_data,
dc_query=query,
ncpus=ncpus,
feature_func=custom_function,
field=field
)
Collecting training data in parallel mode
Percentage of possible fails after run 1 = 0.0 %
Removed 0 rows wth NaNs &/or Infs
Output shape: (108, 11)
Extract testing data
So that we can access the accuracy of our classification, we split our data into training and testing data. 80% is used for training with 20% held back for testing. When splitting our data, we stratify the training data by the distributions of class membership. This sampling method leads to a similar distribution of class membership in the training data.
[10]:
# Split into training and testing data
model_train, model_test = model_selection.train_test_split(
model_input, stratify=model_input[:, 0], train_size=0.8, random_state=0)
print("Train shape:", model_train.shape)
print("Test shape:", model_test.shape)
Train shape: (86, 11)
Test shape: (22, 11)
Model preparation
This section automatically creates a list of variable names and their respective indices for each of the training data variables.
[11]:
# Select the variables we want to use to train our model
model_variables = column_names[1:]
# Extract relevant indices from the processed geojson
model_col_indices = [
column_names.index(var_name) for var_name in model_variables
]
A decision tree model is chosen as it is one of the simplest supervised machine learning models we can implement.
Its strengths are its explainability and cheap computational cost.
Parameter tuning can be conducted in the model initialisation below - details on how the different parameters will affect the model are here.
[12]:
# Initialise model
model = tree.DecisionTreeClassifier(max_depth=10, random_state=1)
Train model
The model is fitted / trained using the prepared training data. The fitting process uses the decision tree approach to create a generalised representation of reality based on the training data. This fitted / trained model can then be used to predict which class new data belongs to.
[13]:
# Train model
model.fit(model_train[:, model_col_indices], model_train[:, 0])
[13]:
DecisionTreeClassifier(max_depth=10, random_state=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(max_depth=10, random_state=1)
Analyse results
Feature importance
The decision tree classifier allows us to inspect the feature importance of each input variable. Feature importance represents the relative contribution of each variable in predicting the desired landcover class. When summed, the importance of all variables should add up to 1.0.
[14]:
# This shows the feature importance of the input features for predicting the class labels provided
fig,ax=plt.subplots(1,1, figsize=(12,3))
ax.bar(x=model_variables, height=model.feature_importances_)
ax.grid(alpha=0.5);

This decision tree representation visualises the trained model. Here we can see that the model decides which landcover class to assign based on the value of the important variables in the plot above.
The gini value shown in the tree below represents the decrease in node impurity. This can also be understood as how heterogeneous the labels are (small values indicating better results). This metric is used by the decision tree to determine how to split the data into smaller groups.
[15]:
# Prepare a dictionary of class names
class_names = {
1: 'Cropping',
0: 'Not Cropping'
}
# Get list of unique classes in model
class_codes = np.unique(model_train[:, 0])
class_names_in_model = [class_names[k] for k in class_codes]
# Plot decision tree
dot_data = StringIO()
tree.export_graphviz(model,
out_file=dot_data,
feature_names=model_variables,
class_names=class_names_in_model,
filled=True,
rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
[15]:

Accuracy
We can use the 20% sample of test data we partitioned earlier to test the accuracy of the trained model on this new, “unseen” data.
An accuracy value of 1.0 indicates that the model was able to correctly predict 100% of the classes in the test data.
[16]:
predictions = model.predict(model_test[:, model_col_indices])
accuracy_score(predictions, model_test[:, 0])
[16]:
0.8636363636363636
Prediction
Now that we have a trained model, we can load new data and use the predict_xr
function to predict landcover classes.
The trained model can technically be used to classify any dataset or product with the same bands as the data originally used to train the data. However, it is typically highly advisable to classify data from the same product that the data was originally trained on.
[17]:
# provide some lat/lon extents
xmin, ymin, xmax, ymax = 116.8208, -32.9305, 117.2258, -32.6851
# Set up the query parameters
query = {
'time': time,
'x': (xmin, xmax),
'y': (ymin, ymax),
'crs': 'EPSG:4326',
'resolution': resolution
}
# Use custom function to generate input data
ds = custom_function(query)
Once the data has been loaded, we can classify it using DEA predict_xr function:
[18]:
# Predict landcover using the trained model
predicted = predict_xr(model, ds, clean=True)
predicting...
Plotting
To qualitatively evaluate how well the classification performed, we can plot the classifed/predicted data next to our input satellite imagery.
Note: The output below is unlikely to be optimal the first time the classification is run. The model training process is one of experimentation and assumption checking that occurs in an iterative cycle - you may need to revisit the steps above and make changes to model parameters or input training data until you achieve a satisfactory result.
[19]:
# Set up plot
fig, ax = plt.subplots(1, 2, figsize=(10, 5), layout='constrained')
# Plot classified image
predicted.Predictions.plot(ax=ax[0],
cmap='Greens',
add_labels=False,
add_colorbar=False)
# Plot true colour image
(ds[['nbart_red', 'nbart_green', 'nbart_blue']]
.squeeze('time')
.to_array()
.plot.imshow(ax=ax[1], robust=True, add_labels=False))
ax[1].get_yaxis().set_visible(False)
ax[0].set_yticklabels([])
ax[0].set_xticklabels([])
ax[1].set_yticklabels([])
ax[1].set_xticklabels([])
# Add plot titles
ax[0].set_title('Classified data')
ax[1].set_title('True colour image');

Exporting classification
We can now export the predicted landcover out to a GeoTIFF .tif
file. This file can be loaded into GIS software (e.g. QGIS, ArcMap) to be inspected more closely.
[20]:
# Write the predicted data out to a GeoTIFF
write_cog(predicted.Predictions,
'predicted.tif',
overwrite=True);
Additional information
License: The code in this notebook is licensed under the Apache License, Version 2.0. Digital Earth Australia data is licensed under the Creative Commons by Attribution 4.0 license.
Contact: If you need assistance, please post a question on the Open Data Cube Discord chat or on the GIS Stack Exchange using the open-data-cube
tag (you can view previously asked questions here). If you would like to report an issue with this notebook, you can file one on
GitHub.
Last modified: September 2025
Compatible datacube version:
[21]:
print(datacube.__version__)
1.8.19
Tags
Tags: NCI compatible, sandbox compatible, landsat 8, annual geomedian, predict_xr, get_training_data_for_shp, machine learning, decision tree, classification tools