Classification
WherobotsAI Raster Inference - Scene Classification¶
This example demonstrates query inference using a classification model with WherobotsAI Raster Inference to identify land cover in satellite imagery. We will use a machine learning model from torchgeo1 trained using imagery from the European Space Agency’s Sentinel-2 satellites. This model categorizes land cover into 10 categories from the EuroSat dataset 2:
- Annual Crop
- Forest
- Herbaceous Vegetation
- Highway
- Industrial Buildings
- Pasture
- Permanent Crop
- Residential Buildings
- River
- SeaLake
Note: This notebook requires the Wherobots Inference functionality to be enabled and a GPU runtime selected in Wherobots Cloud. Please contact us to enable these features.
1: Set up the Wherobots Context¶
import warnings
warnings.filterwarnings('ignore')
from wherobots.inference.data.io import read_raster_table
from sedona.spark import SedonaContext
from pyspark.sql.functions import expr
config = SedonaContext.builder().appName('classification-batch-inference')\
.getOrCreate()
sedona = SedonaContext.create(config)
2: Load satellite imagery¶
Next, we load the satellite imagery that we will be running inference over. These GeoTiff images are loaded as out-db rasters in WherobotsDB, where each row represents a different scene.
tif_folder_path = 's3a://wherobots-examples/data/eurosat_small'
files_df = read_raster_table(tif_folder_path, sedona)
df_raster_input = files_df.withColumn(
"outdb_raster", expr("RS_FromPath(path)")
)
df_raster_input.cache().count()
df_raster_input.show(truncate=False)
3: Run prediction with sedona.sql apis¶
To run predictions we will specify the model we wish to use. Some models are pre-loaded and made available in Wherobots Cloud. Predictions can be run using the Raster Inference SQL function RS_CLASSIFY
or with the Python API.
Here we generate 200 predictions using RS_CLASSIFY
.
%%time
df_raster_input.createOrReplaceTempView("df_raster_input")
model_id = 'landcover-eurosat-sentinel2'
predictions_df = sedona.sql(f"SELECT name, outdb_raster, RS_CLASSIFY('{model_id}', outdb_raster) AS preds FROM df_raster_input")
predictions_df.cache().count()
predictions_df.show(truncate=False)
predictions_df.createOrReplaceTempView("predictions_df")
From the prediction result, we can retrieve the most confidence classification label and it's probability score.
max_predictions_df = sedona.sql(f"SELECT name, outdb_raster, RS_MAX_CONFIDENCE(preds).max_confidence_label, RS_MAX_CONFIDENCE(preds).max_confidence_score FROM predictions_df")
max_predictions_df.show(20, truncate=False)
Raster Inference Python API¶
If you prefer Python, wherobots.inference offers a module for registering the SQL inference functions as python functions. Below we run the same inference as before with RS_CLASSIFY
.
from wherobots.inference.engine.register import create_single_label_classification_udfs
rs_classify, rs_max_confidence = create_single_label_classification_udfs(batch_size = 10, sedona=sedona)
df_predictions = df_raster_input.withColumn("preds", rs_classify(model_id, 'outdb_raster'))
df_predictions.show(1)
from pyspark.sql.functions import col
df_max_predictions = df_predictions.withColumn("max_confidence_temp", rs_max_confidence(col("preds"))) \
.withColumn("max_confidence_label", col("max_confidence_temp.max_confidence_label")) \
.withColumn("max_confidence_score", col("max_confidence_temp.max_confidence_score")) \
.drop("max_confidence_temp", "preds")
df_max_predictions.cache().count()
df_max_predictions.show(2, truncate=False)
Visualize the model predictions and source imagery¶
df_rast = sedona.read.format("binaryFile").option("pathGlobFilter", "*.tif").option("recursiveFileLookup", "true").load(tif_folder_path).selectExpr("RS_FromGeoTiff(content) as raster")
htmlDF = df_max_predictions.selectExpr("RS_Band(outdb_raster, Array(4, 3, 2)) as image_raster", "name", "max_confidence_label")\
.selectExpr("RS_NormalizeAll(image_raster, 1, 65535, True) as image_raster", "name", "max_confidence_label")\
.selectExpr("RS_AsImage(image_raster, 500) as image_raster", "name", "max_confidence_label")
from sedona.raster_utils.SedonaUtils import SedonaUtils
from pyspark.sql.functions import rand
SedonaUtils.display_image(htmlDF.orderBy(rand()).limit(3))
References¶
- Stewart, A. J., Robinson, C., Corley, I. A., Ortiz, A., Lavista Ferres, J. M., & Banerjee, A. (2022). TorchGeo: Deep Learning With Geospatial Data. In Proceedings of the 30th International Conference on Advances in Geographic Information Systems (pp. 1-12). Association for Computing Machinery. https://doi.org/10.1145/3557915.3560953
- Eurosat: A novel dataset and deep learning benchmark for land use and land cover classification. Patrick Helber, Benjamin Bischke, Andreas Dengel, Damian Borth. IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing, 2019.