Bring Your Own Model
WherobotsAI Raster Inference - Bring Your Own Model¶
This example shows how to bring your own model to Raster Inference. We will start with a machine learning model from Satlas 1 and describe it using the Machine Learning Model STAC Extension (MLM). After we have an MLM JSON file describing our model on S3, we can pass a URI to the JSON file to RS_* inference functions to run the model on Raster Inference.
Note: This notebook requires the Wherobots Inference functionality to be enabled and a GPU runtime selected in Wherobots Cloud. Please contact us to enable these features.
Step 1: Creating the MLM metadata with stac_model
¶
The MLM specification is what we use to define the model inference requirements. These include descriptions of model inputs (their shape, role, and preprocessing steps), the categories associated with a model, the model task, and the location of the model asset. Wherobots maintains the stac-model
python library which can be used to create and validate that metadata complies with the MLM specification requirements.
Below we will break down the steps involved to fill out and validate the MLM fields using stac-model
. But first we will save the model artifact to our s3 user path that is referred by the MLM metadata so that we can later run inference with it.
import os
import fsspec
user_uri = os.getenv("USER_S3_PATH")
original_model_uri = "s3://wherobots-modelhub-prod/professional/semantic-segmentation/solar-satlas-sentinel2/inductor/gpu/aot_inductor_gpu_tensor_cores.zip"
user_model_uri = f"{user_uri}aot_inductor_gpu_tensor_cores.zip"
user_mlm_uri = f"{user_uri}model-metadata.json"
fs = fsspec.filesystem('s3')
fs.copy(original_model_uri, user_model_uri)
The main library we will use to construct the metadata is stac-model
, which implements validation for the MLM fields, which begin with mlm:
.
We also use pystac
to combine MLM metadata with other STAC core and STAC extension metadata. For a primer on STAC and STAC extensions, check out https://stac-extensions.github.io/, which also lists the many extensions out there for describing different kinds of spatio-temporal data.
The other libraries, shapely
and dateutil
are briefly used to format geometry and time metadata in our metadata JSON.
import pystac
import shapely
from dateutil.parser import parse as parse_dt
from stac_model.base import ProcessingExpression
from stac_model.input import InputStructure, ModelInput, MLMStatistic
from stac_model.output import MLMClassification, ModelOutput, ModelResult
from stac_model.schema import MLModelExtension, MLModelProperties
The InputStructure object describes the shape of a tensor/array input to a model's predict function. This describes the shape of the input after all data processing steps have been applied to the original data input.
For our example SATLAS model, the expected input structure is a tensor with a flexible batch size (-1
), 9
bands multiplied by 4
time steps to form a single dimension, and 1024
height and width.
input_array = InputStructure(
shape=[-1, 9 * 4, 1024, 1024], dim_order=["batch", "channel", "height", "width"], data_type="float32"
)
The MLMStatistic object describes the statistics for the input to a model's prediction function. This describes statistics for normalizing the model input to the data range and data distribution expected by the model. Some models may only need the data range adjusted and the model itself will handle normalizing the inputs to a given distribution without preprocessing. Others will adjust both the range and distribution. The field norm_type
indicates how inputs need to be normalized given the statistics.
Wherobots Raster Inference functions expect that the statistics are applied to the band dimension of an overhead imagery input. band_names
should be specified according to the STAC Collection that the model input expects. For interoperability with WherobotsAI Raster Inference, use strings for band names rather than the Model Band Object.
band_names = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B11", "B12"]*4
stats = [
MLMStatistic(maximum=255, minimum=0)
for band in band_names
]
With the band names, input structure (input_array
), and statistics, we have a few other fields to fill out to fully describe our Model Input. norm_by_channel
specifies if Statistics should be applied per channel/band, or if Statistics should be applied to the input_array as a whole. If it is set to True, Statistics must be the same length as the band dimension of the input_array
. norm_type
specifies the normalization equation to apply to the input_array
and Statistics
. The resize_type
indicates how to convert data samples to the size expected by the model's predict function, which is specified by the InputStructure object.
For processing operations that cannot be described by the normalization and resize_type fields, the pre_processing_function
can be used to either link to documentation or specify the mathematical expresson to apply to model inputs before prediction.
model_input = ModelInput(
name="9 Band Sentinel-2 4 Time Step Series Batch",
bands=band_names,
input=input_array,
norm_by_channel=True,
norm_type="min-max",
resize_type="crop",
statistics=stats,
pre_processing_function=ProcessingExpression(format="documentation-link", expression="https://github.com/allenai/satlas/blob/main/CustomInference.md#sentinel-2-inference-example"),
)
Similar to the Input Structure object, we need to describe the model output with the Result Structure object. -1
denotes a flexible batch size, 1
refers to the category dimension (in this case the model predicts only solar farm/not solar farm), and 1024
refers to our height and width dimensions again. dim_order
enumerates the meaning of each shape element.
confidence = ModelResult(
shape=[-1, 1, 1024, 1024],
dim_order=["batch", "category", "height", "width"],
data_type="float32"
)
We use the STAC Classification Extension to describe categories predicted by the model. The only required fields are an integer value to represent the category and the name of the category.
The Model Output Object ties together the classes, task, and any post processing functions that need to be applied to the Result Structure Object. Options for tasks are specified in the Tasks Enum. WherobotsAI Raster Inference currently supports scene-classification
, object-detection
, and semantic-segmentation
.
class_map = {"Solar Farm": 1,}
class_objects = [
MLMClassification(value=class_value, name=class_name)
for class_name, class_value in class_map.items()
]
model_output = ModelOutput(
name="confidence array",
tasks={"semantic-segmentation"},
result=confidence,
classes=class_objects,
post_processing_function=None,
)
To describe the actual model file we load and run in Raster Inference, we use the STAC Core spec for Asset Objects. The MLM spec adds additional fields and roles for different asset types. To use the MLM you are required to specify the mlm:model
asset that points to a model file and where it is hosted (href
). The Artifact Enum specifies the type of the model file. Currently Pytorch torch.jit.script
and torch.compile
models are supported by Raster Inference.
assets = {
"model": pystac.Asset(
title="AOTInductor model exported from private, edited, hard fork of Satlas github repo.",
description=(
"A Swin Transformer backbone with a U-net head trained on the 9-band Sentinel-2 Top of Atmosphere product."
),
href=user_model_uri,
media_type="application/zip; application=pytorch",
roles=[
"mlm:model",
"data"
],
extra_fields={"mlm_artifact_type": "torch.compile",}
),
"source_code": pystac.Asset(
title="Model implementation.",
description="Source code to export the model.",
href="https://github.com/wherobots/modelhub/blob/main/model-forge/satlas/solar/export.py",
media_type="text/x-python",
roles=[
"mlm:model",
"code"
]
)
}
After specifying our model input, model output, and assets, we can assemble this info in the top level Item Properties. Note that the required fields of the spec are mlm:name
, mlm:architecture
, mlm:tasks
, mlm:input
, and mlm:output
. Additionally, WherobotsAI Raster Inference requires the following fields and options:
- Only one value for
tasks
is supported. framework="pytorch"
is currently requiredframework_version="2.3.0+cu121"
is recommended as this is the default version installed in our GPU Runtimes. You can still install a different version during Notebook Instance setup if needed, however Raster Inference expects to run on features availabler in Pytorch 2.3.batch_size_suggestion
is recommended. If not specified, Raster Inference defaults to a batch size of 10, or the batch size that is set in the sedona configurationwherobots.inference.args
. For example:
config = (
SedonaContext.builder()
.appName("raster-inference")
.config("spark.wherobots.inference.args", "10") # sets the batch size for RS_ inference functions to 10
accelerator
is recommended. If not set, we assume CUDA is available.accelerator_constrained
is recommended to indicate that a model must run on the accelerator.
ml_model_meta = MLModelProperties(
name="Satlas Solar Farm Segmentation",
architecture="Swin Transformer V2 with U-Net head",
tasks={"semantic-segmentation"},
framework="pytorch",
framework_version="2.3.0+cu121",
batch_size_suggestion=10,
accelerator="cuda",
accelerator_constrained=True,
accelerator_summary="It is necessary to use GPU since it was compiled for NVIDIA Ampere and newer architectures with AOTInductor and the computational demands of the model.",
input=[model_input],
output=[model_output],
)
A requirement of describing the model with the MLM is specifying it's spatial and temporal relevance. These fields are not used by Raster Inference currently but can be useful for search and discovery. These fields must have a value to comply with STAC.
After assembling all of our model metadata we now need to create a STAC item with pystac
where we will insert our MLM Extension metadata.
start_datetime_str = "1900-01-01"
end_datetime_str = "9999-01-01"
start_datetime = parse_dt(start_datetime_str).isoformat() + "Z"
end_datetime = parse_dt(end_datetime_str).isoformat() + "Z"
bbox = [
-7.882190080512502,
37.13739173208318,
27.911651652899923,
58.21798141355221
]
geometry = shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__
item_name = "item_solar_satlas_sentinel2"
item = pystac.Item(
id=item_name,
geometry=geometry,
bbox=bbox,
datetime=None,
properties={
"start_datetime": start_datetime,
"end_datetime": end_datetime,
"description": (
"Sourced from satlas source code released by Allen AI under Apache 2.0"
),
},
assets=assets,
)
We add a link to the source dataset that the model was trained on and should be inferenced on. If the model is trained on multiple datasets, multiple links can be added with the DERIVED_FROM
relation type. We also add a self referential link to the Item to aid in search and discovery.
item.add_link(
pystac.Link(
target="https://earth-search.aws.element84.com/v1/collections/sentinel-2-l1c",
rel=pystac.RelType.DERIVED_FROM,
media_type=pystac.MediaType.JSON,
)
)
item.set_self_href(user_mlm_uri)
Finally we add our extension metadata to the item we created with pystac
. Using the .ext()
method we produce an item that has all of the methods from pystac
as well as custom MLM methods that are needed to correctly format and validate MLM metadata with .apply()
.
item_mlm = MLModelExtension.ext(item, add_if_missing=True)
item_mlm.apply(ml_model_meta.model_dump(by_alias=True, exclude_unset=False, exclude_defaults=True))
This can now be saved to a JSON file and copied to the user_mlm_uri
path we specified on s3. We will also need to copy our model artifact path to the correct link we specified in the model asset object in order to use the model in Raster Inference.
import json
with open("model-metadata.json", "w") as json_file:
json.dump(item_mlm.item.to_dict(), json_file, indent=4)
fs.put("model-metadata.json", user_mlm_uri)
Step 1: Set Up The WherobotsDB Context¶
import warnings
warnings.filterwarnings('ignore')
from wherobots.inference.data.io import read_raster_table
from sedona.spark import SedonaContext
from pyspark.sql.functions import expr
config = SedonaContext.builder().appName('segmentation-batch-inference')\
.getOrCreate()
sedona = SedonaContext.create(config)
2: Load Satellite Imagery¶
Next, we load the satellite imagery that we will be running inference over. These GeoTiff images are loaded as out-db rasters in WherobotsDB, where each row represents a different scene.
tif_folder_path = 's3a://wherobots-benchmark-prod/data/ml/satlas/'
files_df = read_raster_table(tif_folder_path, sedona, limit=400)
df_raster_input = files_df.withColumn(
"outdb_raster", expr("RS_FromPath(path)")
)
df_raster_input.cache().count()
df_raster_input.show(truncate=False)
df_raster_input.createOrReplaceTempView("df_raster_input")
3: Run Predictions And Visualize Results¶
To run predictions we will specify the MLM model metadata file we saved to user_mlm_uri
. Predictions can be run with the Raster Inference SQL function RS_Segment
or the Python API.
Here we generate 400 raster predictions using RS_Segment
.
predictions_df = sedona.sql(f"""
SELECT
outdb_raster,
segment_result.*
FROM (
SELECT
outdb_raster,
RS_SEGMENT('{user_mlm_uri}', outdb_raster) AS segment_result
FROM
df_raster_input
) AS segment_fields
""")
predictions_df.cache().count()
predictions_df.show()
predictions_df.createOrReplaceTempView("predictions")
Now that we've generated predictions using our model over our satellite imagery, we can use the RS_Segment_To_Geoms
function to extract the geometries indicating the model has identified as possible solar farms. we'll specify the following:
- a raster column to use for georeferencing our results
- the prediction result from the previous step
- our category label "1" returned by the model representing Solar Farms and the class map to use for assigning labels to the prediction
- a confidence threshold between 0 and 1.
df_multipolys = sedona.sql("""
WITH t AS (
SELECT RS_SEGMENT_TO_GEOMS(outdb_raster, confidence_array, array(1), class_map, 0.65) result
FROM predictions
)
SELECT result.* FROM t
""")
df_multipolys.cache().count()
df_multipolys.show()
df_multipolys.createOrReplaceTempView("multipolygon_predictions")
Since we ran inference across the state of Arizona, many scenes don't contain solar farms and don't have positive detections. Let's filter out scenes without segmentation detections so that we can plot the results.
df_merged_predictions = sedona.sql("""
SELECT
element_at(class_name, 1) AS class_name,
cast(element_at(average_pixel_confidence_score, 1) AS double) AS average_pixel_confidence_score,
ST_Collect(geometry) AS merged_geom
FROM
multipolygon_predictions
""")
This leaves us with a few predicted solar farm polygons for our 300 satellite image samples.
df_filtered_predictions = df_merged_predictions.filter("ST_IsEmpty(merged_geom) = False")
df_filtered_predictions.cache().count()
df_filtered_predictions.show()
We'll plot these with SedonaKepler. Compare the satellite basemap with the predictions and see if there's a match!
from sedona.maps.SedonaKepler import SedonaKepler
config = {
'version': 'v1',
'config': {
'mapStyle': {
'styleType': 'dark',
'topLayerGroups': {},
'visibleLayerGroups': {},
'mapStyles': {}
},
}
}
map = SedonaKepler.create_map(config=config)
SedonaKepler.add_df(map, df=df_filtered_predictions, name="Solar Farm Detections")
map
wherobots.inference Python API¶
If you prefer python, wherobots.inference offers a module for registering the SQL inference functions as python functions. Below we run the same inference as before with RS_SEGMENT.
from wherobots.inference.engine.register import create_semantic_segmentation_udfs
from pyspark.sql.functions import col
rs_segment = create_semantic_segmentation_udfs(batch_size = 10, sedona=sedona)
df = df_raster_input.withColumn("segment_result", rs_segment(user_mlm_uri, col("outdb_raster"))).select(
"outdb_raster",
col("segment_result.confidence_array").alias("confidence_array"),
col("segment_result.class_map").alias("class_map")
)
df.show(3)
References¶
- Bastani, Favyen, Wolters, Piper, Gupta, Ritwik, Ferdinando, Joe, and Kembhavi, Aniruddha. "SatlasPretrain: A Large-Scale Dataset for Remote Sensing Image Understanding." arXiv preprint arXiv:2211.15660 (2023). https://doi.org/10.48550/arXiv.2211.15660