Skip to main content
Private Preview
The following content is a read-only preview of an executable Jupyter notebook.To run this notebook interactively:
  1. Go to the Wherobots Model Hub.
  2. Select the specific notebook you wish to run.
  3. Click Run Model in Notebook.
This guide shows how to export a PyTorch model, store it in S3, and run it in RasterFlow. By the end, you’ll understand the complete workflow from model creation to scalable raster inference. We’ll use a simple toy model with minimal dependencies to focus on the essential steps: model export and RasterFlow integration.
Note: This notebook requires the Wherobots RasterFlow functionality to be enabled and a GPU runtime selected in Wherobots Cloud.

Model Prerequisites

RasterFlow currently supports PyTorch models in the PT2 format with the following prerequisites:
  • Supported tasks: semantic segmentation and regression
  • Supported models: single image batch input of shape (Batch, Channel, Height, Width), single prediction batch output (Batch, Category, Height, Width).
  • Our prediction workflows assume the model returns the raw PyTorch model output without operations to determine categories, unit conversion, etc.

Resources

For examples of models and scripts you can run in RasterFlow, see our Hugging Face collection. PyTorch PT2 export documentation:

Optional Setup for CPU Runtimes

You can run the following to install torch and torchvision for CPU runtimes to walk through this example. !uv pip install torch==2.8 torchvision --extra-index-url https://download.pytorch.org/whl/cpu

1. Create a Toy Model

We’ll create an example model matching the signature of the Meta/WRI Canopy Height Model: the input is a tensor of image data and the output is a tensor of continuous values (canopy height in meters). The shape of the output is the same as the input.
# Create model and load checkpoint
import torch.nn as nn
import torch

class ExampleModel(nn.Module):
    def forward(self, x):
        predictions = torch.randn(x.shape)
        return predictions
x = torch.randn(1,3,224,224)
print(x.shape)
model = ExampleModel()
result = model(x)
print(result.shape)
With our model defined, let’s now export it.

2. PyTorch 2 Export Formats

PyTorch offers multiple export formats for different use cases: storing weights, training, edge inference, and server inference.

Why not export with the .pth format?

You may be familiar with the checkpoint format saved as .pth:
torch.save(model.state_dict(), "model.pth")
This only stores model weights, not the model structure or execution logic. Loading requires all original dependencies:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load("model.pth", weights_only=True))
Additionally, state_dict exports can’t be optimized for specific accelerators, making deployment difficult.

The .pt2 Format — A Better Alternative

torch.export produces a single artifact for both training and inference with some key benefits:
  1. Device flexibility — Export on CPU, load parameters to GPU at runtime
  2. Accelerator optimization — Compile for faster execution on CUDA, AMD, or Intel GPUs
  3. Standardized metadata — Store hyperparameters, configuration, and transforms alongside the model

We’ll export our ExampleModel in PT2 format. For input preprocessing, we’ll export transforms as an nn.Module in the same archive. Using torch.nn.Sequential simplifies export since its forward method takes a single input argument.
from pathlib import Path
import torchvision.transforms.v2 as T
per_band_mean= [.5]*len(x)
per_band_std = [.1]*len(x)
norm_transform = torch.nn.Sequential(T.Normalize(mean=per_band_mean, std=per_band_std))
print(isinstance(norm_transform, nn.Module))

3. Define Input Shape Constraints

torch.export needs to know the expected input shape. Use Dim objects to specify:
  • Dynamic dimensions — can be any value ≥1
  • Static dimensions — must be a fixed size

FAQ

QuestionAnswer
How do I know what shape to use?Input shape affects runtime performance and accuracy. Check the model creator’s recommendations.
Should I use dynamic for all dimensions?Typically no. Some models have data dependent control flow logic that requires fixed dimensions (e.g., object detection models).
Rule of thumb: Keep batch size dynamic; fix channel, height, and width. Test other dynamic configurations through trial and error to enable flexible channels, height, or width. We’ll export with: dynamic batch size (Denoted by Dim.Auto), 3 channels, 256×256 height/width.
from torch.export import Dim
input_shape_constraints = [-1, 3, 224, 224]
example_input_shape = [2, 3, 224, 224]
example_tensor = torch.randn(*example_input_shape, requires_grad=False)
dims = tuple(Dim.AUTO if dim == -1 else dim for dim in input_shape_constraints)
print(dims)
Set the model to eval mode and configure device/dtype before export:
import inspect
device = torch.device("cuda")
dtype = torch.float32
model.eval()
model = model.to(device).to(dtype)
torch.export needs the forward function’s argument names. Parse them using Python’s inspect module:
model_arg = next(iter(inspect.signature(model.forward).parameters))
print(model_arg)

4. Export to ExportedProgram

Now that we have the following inputs we can export to an ExportedProgram, an in-memory model object that can be saved to a .pt2 archive.
  • model, we defined this toy model earlier as an nn.Module
  • args, a tuple of the arguments to the model’s forward pass function
  • dynamic_shapes, a dict mapping the argument name of the input to the tuple of dimension constraints we created earlier: dims
The ExportedProgram includes the state_dict (weights) and example_inputs (useful for testing).
model_program = torch.export.export(mod=model, args=(example_tensor,), dynamic_shapes={model_arg: dims})
print(model_program.example_inputs[0][0].shape)
Follow the same export steps for the transforms module:
norm_transform.eval()
norm_transform = norm_transform.to(device).to(dtype)
transform_arg = next(iter(inspect.signature(norm_transform.forward).parameters))
transforms_program = torch.export.export(
    mod=norm_transform, args=(example_tensor,), dynamic_shapes={transform_arg: dims}
)

5. Save to .pt2 Archive

Note: torch.export.save saves a single ExportedProgram. We’ll use torch.export.pt2_archive._package.package_pt2 to bundle both the model and transforms into one .pt2 file.
from torch.export.pt2_archive._package import package_pt2
exported_programs = {}
local_model_path = "example.pt2"
exported_programs["model"] = model_program
exported_programs["transforms"] = transforms_program

package_pt2(
    f=local_model_path,
    exported_programs=exported_programs
)

6. Run the Model with RasterFlow

RasterFlow supports both Wherobots Hosted Models and custom models like the one we just exported. Steps:
  1. Upload the .pt2 file to S3 (we’ll use Wherobots Managed Storage)
  2. Define an InferenceConfig to tell RasterFlow how to run the model
Note: You can also load open models stored in pt2 format directly from HuggingFace.
import os
import s3fs

fs = s3fs.S3FileSystem(profile="default")

# Define the destination path on S3
# We use the USER_S3_PATH environment variable to ensure it goes to your personal bucket space
s3_model_path = os.getenv("USER_S3_PATH") + local_model_path
fs.put(local_model_path, s3_model_path)

Build the Model Input

To run our model, we need some input imagery. We’ll test our model on National Agricultural Imagery Program (NAIP) 4 band imagery - red, blue, green, and near infrared. We’ll select an AOI over Nashua, New Hampshire that has some forest canopy for our toy canopy height model.
import wkls
import geopandas as gpd
import os

# Generate a geometry for Nashua, NH using WKLS (https://github.com/wherobots/wkls)
gdf = gpd.read_file(wkls.us.nh.nashua.geojson())

# Save the geometry to a parquet file in the user's S3 path
aoi_path = os.getenv("USER_S3_PATH") + "nashua.parquet"
gdf.to_parquet(aoi_path)
To prepare this imagery, we’ll use RasterFlow to create a mosaic. Mosaics are backed by a cloud native Zarr store that enables accessing spatial subsets, individual bands, and computing on the mosaic with RasterFlow. This workflow takes a few minutes to complete, so you can skip ahead to the next cell where we load the prepared output from the workflow.
Note: We’re using rasterflow_version="v1.43.1" explicitly here. This version is currently being validated before becoming the default in Wherobots Cloud images.
from rasterflow_remote import RasterflowClient
client = RasterflowClient(rasterflow_version="v1.43.1")
mosaic_path = client.build_gti_mosaic(
        gti = "s3://wherobots-examples/rasterflow/indexes/naip_index.parquet",
        aoi = aoi_path,
        bands = ["red", "green", "blue", "nir"],
        location_field = "url",
        crs_epsg = 3857,
        xy_chunksize = 1024,
        query = "res == .6",
        requester_pays = True,
        sort_field = "time",
        resampling = ResamplingMethod.NEAREST,
        nodata= 0.0,
)

Visualize the Input Mosaic

We will use hvplot and datashader to visualize a small subset of the mosaic’s red band.
# Import libraries for visualization and coordinate transformation
import hvplot.xarray
import xarray as xr
import s3fs 
import zarr
from pyproj import Transformer
from holoviews.element.tiles import EsriImagery 

# Open the Zarr store
mosaic_path = "s3://wherobots-examples/rasterflow/mosaics/nashua.zarr"
fs = s3fs.S3FileSystem(profile="default", asynchronous=True, anon=True)
zstore = zarr.storage.FsspecStore(fs, path=mosaic_path)
ds = xr.open_zarr(zstore)
# Create a transformer to convert from lat/lon to meters
transformer = Transformer.from_crs("EPSG:4326", "EPSG:3857", always_xy=True)

# Transform bounding box coordinates from lat/lon to meters
min_lon, min_lat, max_lon, max_lat = gdf.total_bounds
(min_x, max_x), (min_y, max_y) = transformer.transform(
    [min_lon, max_lon], 
    [min_lat, max_lat]
)

# Select the red band and slice the dataset to the bounding box
# y=slice(max_y, min_y) handles the standard "North-to-South" image orientation
ds_subset = ds.sel(band="red",
    x=slice(min_x, max_x), 
    y=slice(max_y, min_y) 
)

# Select the first time step and extract the variables array
arr_subset = ds_subset.isel(time=0)["variables"]

# Create a base map layer using Esri satellite imagery
base_map = EsriImagery()

# Create an overlay layer from the model outputs with hvplot
output_layer = arr_subset.hvplot(
    x = "x",
    y = "y",
    geo = True,           # Enable geographic plotting
    dynamic = True,       # Enable dynamic rendering for interactivity
    rasterize = True,     # Use datashader for efficient rendering of large datasets
    cmap = "viridis",     # Color map for visualization
    aspect = "equal",     # Maintain equal aspect ratio
    title = "Nashua, NH NAIP Red Band" 
).opts(
    width = 600, 
    height = 600,
    alpha = 0.7           # Set transparency to see the base map underneath
)

# Combine the base map and output layer
final_plot = base_map * output_layer
final_plot
With our mosaic, we are now ready to run model prediction on the mosaic with RasterFlow. We’ll use the predict_mosaic method to run our model. predict_mosaic leverages RasterFlow’s powerful inference engine that scales from small to global scale areas of interest. The inputs to this method are our input store we want to run prediction on, and our InferenceConfig object we created earlier.

Defining the InferenceConfig

With the model on S3, define the inference job configuration.
Note: Wherobots Hosted Models come with preconfigured ModelRecipes— this step is only needed for custom models.
See the InferenceConfig documentation for parameter details.
from dataclasses import asdict
from rasterflow_remote.data_models import InferenceConfig, InferenceActorEnum, MergeModeEnum, ResamplingMethod

custom_inference_config = InferenceConfig(
    model_path = s3_model_path,
    actor = InferenceActorEnum.REGRESSION_PYTORCH,
    patch_size = 224,
    clip_size = 28,
    device = "cuda",
    features = ["red", "green", "blue"],
    labels = ["canopy_height"],
    max_batch_size=64,
    merge_mode = MergeModeEnum.WEIGHTED_AVERAGE
    
)
Note: This step will take approximately 10 minutes to complete when run for the first time
predict_mosaic_path = client.predict_mosaic(
        store=mosaic_path,
        **asdict(custom_inference_config)
)
We can inspect the result of predict_mosaic and confirm that it is the same shape as our input store.
predict_ds = xr.open_zarr(predict_mosaic_path)
ds['variables']
predict_ds['variables']

7. What We Learned

Congratulations on completing this guide! You’ve learned how to:
  • Create a PyTorch model compatible with RasterFlow’s expected input/output signature (Batch, Channel, Height, Width)
  • Understand PT2 vs .pth formats — PT2 bundles weights, structure, and execution logic into a single deployable artifact
  • Define input shape constraints using torch.export.Dim for dynamic batch sizes and fixed spatial dimensions
  • Export models and transforms together into a single .pt2 archive using package_pt2
  • Upload custom models to S3 using Wherobots Managed Storage
  • Build GTI mosaics from imagery indexes with build_gti_mosaic
  • Run scalable inference using InferenceConfig and predict_mosaic

Next Steps

Explore more RasterFlow examples: Learn more: