Segmentation
Before you start¶
This example demonstrates query inference using a segmentation model with Raster Inference to identify solar farms in satellite imagery. We will use a machine learning model from Satlas 1 which was trained using imagery from the European Space Agency’s Sentinel-2 satellites.
This is a read-only preview of this notebook.
To execute the cells in this Jupyter Notebook, do the following:
- Login to Wherobots Cloud.
- Start a GPU-Optimized runtime instance.
- We recommend using a Tiny GPU-Optimized runtime.
- Open a notebook.
- Open the
examples/python/wherobots-ai/gpu/segmentation.ipynb
notebook path.
For more information on starting and using notebooks, see the following Wherobots Documentation:
Access a GPU-Optimized runtime¶
This notebook requires a GPU-Optimized runtime. For more information on GPU Optimized runtimes, see Runtime types.
To access this runtime category, do the following:
- Sign up for a paid Wherobots Organization Edition (Professional or Enterprise).
- Submit a Compute Request for a GPU-Optimized runtime.
Step 1: Set Up The WherobotsDB Context¶
import warnings
warnings.filterwarnings('ignore')
from wherobots.inference.data.io import read_raster_table
from sedona.spark import SedonaContext
from pyspark.sql.functions import expr
config = SedonaContext.builder().appName('segmentation-batch-inference')\
.getOrCreate()
sedona = SedonaContext.create(config)
# Uncomment the line that sets the `user_mlm_uri` variable and include the path to your MLM JSON to use your own model.
# Learn more about bringing your own model see [Bring your own model](https://docs.wherobots.com/latest/tutorials/wherobotsai/wherobots-inference/bring-your-own-model/) in the Wherobots Documentation.
# user_mlm_uri = [PATH-TO-MLM-JSON]
Step 2: Load Satellite Imagery¶
Next, we load the satellite imagery that we will be running inference over. These GeoTiff images are loaded as out-db rasters in WherobotsDB, where each row represents a different scene.
tif_folder_path = 's3a://wherobots-benchmark-prod/data/ml/satlas/'
files_df = read_raster_table(tif_folder_path, sedona, limit=400)
df_raster_input = files_df.withColumn(
"outdb_raster", expr("RS_FromPath(path)")
)
df_raster_input.cache().count()
df_raster_input.show(truncate=False)
df_raster_input.createOrReplaceTempView("df_raster_input")
Step 3: Run Predictions And Visualize Results¶
To run predictions we will specify the model we wish to use. Some models are pre-loaded and made available in Wherobots Cloud. We can also load our own models. Predictions can be run with the Raster Inference SQL function RS_Segment
or the Python API.
Here we generate 400 raster predictions using RS_Segment
.
model_id = 'solar-satlas-sentinel2'
predictions_df = sedona.sql(f"""
SELECT
outdb_raster,
segment_result.*
FROM (
SELECT
outdb_raster,
RS_SEGMENT('{model_id}', outdb_raster) AS segment_result
FROM
df_raster_input
) AS segment_fields
""")
predictions_df.cache().count()
predictions_df.show()
predictions_df.createOrReplaceTempView("predictions")
You can specify your own model instead of using one of our hosted models via the model_id
variable. To do so, replace the model_id
variable with the s3 uri pointing to your Machine Learning Model Extension (MLM) metadata json. Then pass that as an argument to RS_SEGMENT
.
For example:
user_mlm_uri = 's3://wherobots-modelhub-prod/professional/semantic-segmentation/solar-satlas-sentinel2/model-metadata.json'
predictions_df = sedona.sql(f"SELECT name, outdb_raster, RS_DETECT_BBOXES('{user_mlm_uri}', outdb_raster) AS preds FROM df_raster_input")
Learn more about bringing your own model see Bring your own model in the Wherobots Documentation.
Now that we've generated predictions using our model over our satellite imagery, we can use the RS_Segment_To_Geoms
function to extract the geometries indicating the model has identified as possible solar farms. we'll specify the following:
- a raster column to use for georeferencing our results
- the prediction result from the previous step
- our category label "1" returned by the model representing Solar Farms and the class map to use for assigning labels to the prediction
- a confidence threshold between 0 and 1.
df_multipolys = sedona.sql("""
WITH t AS (
SELECT RS_SEGMENT_TO_GEOMS(outdb_raster, confidence_array, array(1), class_map, 0.65) result
FROM predictions
)
SELECT result.* FROM t
""")
df_multipolys.cache().count()
df_multipolys.show()
df_multipolys.createOrReplaceTempView("multipolygon_predictions")
Since we ran inference across the state of Arizona, many scenes don't contain solar farms and don't have positive detections. Let's filter out scenes without segmentation detections so that we can plot the results.
df_merged_predictions = sedona.sql("""
SELECT
element_at(class_name, 1) AS class_name,
cast(element_at(average_pixel_confidence_score, 1) AS double) AS average_pixel_confidence_score,
ST_Collect(geometry) AS merged_geom
FROM
multipolygon_predictions
""")
This leaves us with a few predicted solar farm polygons for our 300 satellite image samples.
df_filtered_predictions = df_merged_predictions.filter("ST_IsEmpty(merged_geom) = False")
df_filtered_predictions.cache().count()
df_filtered_predictions.show()
We'll plot these with SedonaKepler. Compare the satellite basemap with the predictions and see if there's a match!
from sedona.maps.SedonaKepler import SedonaKepler
config = {
'version': 'v1',
'config': {
'mapStyle': {
'styleType': 'dark',
'topLayerGroups': {},
'visibleLayerGroups': {},
'mapStyles': {}
},
}
}
map = SedonaKepler.create_map(config=config)
SedonaKepler.add_df(map, df=df_filtered_predictions, name="Solar Farm Detections")
map
wherobots.inference Python API¶
If you prefer python, wherobots.inference offers a module for registering the SQL inference functions as python functions. Below we run the same inference as before with RS_SEGMENT.
from wherobots.inference.engine.register import create_semantic_segmentation_udfs
from pyspark.sql.functions import col
rs_segment = create_semantic_segmentation_udfs(batch_size = 10, sedona=sedona)
df = df_raster_input.withColumn("segment_result", rs_segment(model_id, col("outdb_raster"))).select(
"outdb_raster",
col("segment_result.confidence_array").alias("confidence_array"),
col("segment_result.class_map").alias("class_map")
)
df.show(3)
References¶
- Bastani, Favyen, Wolters, Piper, Gupta, Ritwik, Ferdinando, Joe, and Kembhavi, Aniruddha. "SatlasPretrain: A Large-Scale Dataset for Remote Sensing Image Understanding." arXiv preprint arXiv:2211.15660 (2023). https://doi.org/10.48550/arXiv.2211.15660