> ## Documentation Index
> Fetch the complete documentation index at: https://docs.wherobots.com/llms.txt
> Use this file to discover all available pages before exploring further.

# Bringing Your Own PyTorch Model to RasterFlow

<Badge color="purple">Private Preview</Badge>

<Tip>
  The following content is a read-only preview of an executable Jupyter notebook.

  To run this notebook interactively:

  1. Go to [**Wherobots Cloud**](https://cloud.wherobots.com).
  2. Start a runtime.
  3. Open the notebook.
  4. In the Jupyter Launcher:
     1. Click **File > Open Path**.
     2. Paste the following path to access this notebook: `examples/Analyzing_Data/RasterFlow_Bring_Your_Own_Model.ipynb`
     3. Click **Enter**.
</Tip>

This guide shows how to export a PyTorch model, store it in S3, and run it in RasterFlow. By the end, you'll understand the complete workflow from model creation to scalable raster inference.

We'll use a simple toy model with minimal dependencies to focus on the essential steps: **model export** and **RasterFlow integration**.

> **Note:** This notebook requires the Wherobots RasterFlow functionality to be enabled in Wherobots Cloud.

***

## Model Prerequisites

RasterFlow currently supports PyTorch models in the PT2 format with the following prerequisites:

* Supported tasks: semantic segmentation and regression
* Supported models: single image batch input of shape (Batch, Channel, Height, Width), single prediction batch output (Batch, Category, Height, Width).
* Our prediction workflows assume the model returns the raw PyTorch model output without operations to determine categories, unit conversion, etc.

## Resources

For examples of models and scripts you can run in RasterFlow, see our [Hugging Face collection](https://huggingface.co/collections/wherobots/wherobotsai-models).

**PyTorch PT2 export documentation:**

* [Common Challenges and Solutions](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html) — Beginner
* [Export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html) — Advanced (complex models, accelerator optimization)
* [PT2 Archive Format](https://docs.pytorch.org/docs/stable/export/pt2_archive.html)

## Install Dependencies

Run the following cell to install `torch` and `torchvision`. These dependencies are required to build and export the model locally on this CPU instance. The model will execute on a separate GPU managed by RasterFlow.

```python theme={"system"}
!uv pip install torch==2.8 torchvision --extra-index-url https://download.pytorch.org/whl/cpu
```

## 1. Create a Toy Model

We'll create an example model matching the signature of the [Meta/WRI Canopy Height Model](https://huggingface.co/wherobots/meta-chm-v1-pt2): the input is a tensor of image data and the output is a tensor of continuous values (canopy height in meters). The shape of the output is the same as the input.

```python theme={"system"}
# Create model and load checkpoint
import torch.nn as nn
import torch

class ExampleModel(nn.Module):
    def forward(self, x):
        predictions = torch.randn(x.shape)
        return predictions
```

```python theme={"system"}
x = torch.randn(1,3,224,224)
print(x.shape)
```

```python theme={"system"}
model = ExampleModel()
result = model(x)
print(result.shape)
```

With our model defined, let's now export it.

## 2. PyTorch 2 Export Formats

PyTorch offers multiple export formats for different use cases: storing weights, training, edge inference, and server inference.

### Why not export with the `.pth` format?

You may be familiar with the checkpoint format saved as `.pth`:

```python theme={"system"}
torch.save(model.state_dict(), "model.pth")
```

This only stores model weights, not the model structure or execution logic. Loading requires all original dependencies:

```python theme={"system"}
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load("model.pth", weights_only=True))
```

Additionally, `state_dict` exports can't be optimized for specific accelerators, making deployment difficult.

### The `.pt2` Format — A Better Alternative

`torch.export` produces a single artifact for both training and inference with some key benefits:

1. **Device flexibility** — Export on CPU, load parameters to GPU at runtime
2. **Accelerator optimization** — Compile for faster execution on CUDA, AMD, or Intel GPUs
3. **Standardized metadata** — Store hyperparameters, configuration, and transforms alongside the model

***

We'll export our `ExampleModel` in PT2 format. For input preprocessing, we'll export transforms as an `nn.Module` in the same archive. Using `torch.nn.Sequential` simplifies export since its `forward` method takes a single `input` argument.

```python theme={"system"}
from pathlib import Path
import torchvision.transforms.v2 as T
```

```python theme={"system"}
per_band_mean= [.5]*len(x)
per_band_std = [.1]*len(x)
norm_transform = torch.nn.Sequential(T.Normalize(mean=per_band_mean, std=per_band_std))
print(isinstance(norm_transform, nn.Module))
```

## 3. Define Input Shape Constraints

`torch.export` needs to know the expected input shape. Use `Dim` objects to specify:

* **Dynamic** dimensions — can be any value ≥1
* **Static** dimensions — must be a fixed size

#### FAQ

| Question                                 | Answer                                                                                                                           |
| ---------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------- |
| How do I know what shape to use?         | Input shape affects runtime performance and accuracy. Check the model creator's recommendations.                                 |
| Should I use dynamic for all dimensions? | Typically no. Some models have data dependent control flow logic that requires fixed dimensions (e.g., object detection models). |

**Rule of thumb:** Keep batch size dynamic; fix channel, height, and width. Test other dynamic configurations through trial and error to enable flexible channels, height, or width.

We'll export with: **dynamic batch size** (Denoted by `Dim.Auto`), **3 channels**, **256×256 height/width**.

```python theme={"system"}
from torch.export import Dim
input_shape_constraints = [-1, 3, 224, 224]
example_input_shape = [2, 3, 224, 224]
example_tensor = torch.randn(*example_input_shape, requires_grad=False)
dims = tuple(Dim.AUTO if dim == -1 else dim for dim in input_shape_constraints)
print(dims)
```

Set the model to `eval` mode and configure device/dtype before export:

```python theme={"system"}
import inspect
device = torch.device("cpu")
dtype = torch.float32
model.eval()
model = model.to(device).to(dtype)
```

`torch.export` needs the forward function's argument names. Parse them using Python's `inspect` module:

```python theme={"system"}
model_arg = next(iter(inspect.signature(model.forward).parameters))
print(model_arg)
```

## 4. Export to ExportedProgram

Now that we have the following inputs we can export to an `ExportedProgram`, an in-memory model object that can be saved to a `.pt2` archive.

* `model`, we defined this toy model earlier as an nn.Module
* `args`, a tuple of the arguments to the model's forward pass function
* `dynamic_shapes`, a dict mapping the argument name of the input to the tuple of dimension constraints we created earlier: `dims`

The `ExportedProgram` includes the `state_dict` (weights) and `example_inputs` (useful for testing).

```python theme={"system"}
model_program = torch.export.export(mod=model, args=(example_tensor,), dynamic_shapes={model_arg: dims})
print(model_program.example_inputs[0][0].shape)
```

Follow the same export steps for the transforms module:

```python theme={"system"}
norm_transform.eval()
norm_transform = norm_transform.to(device).to(dtype)
transform_arg = next(iter(inspect.signature(norm_transform.forward).parameters))
```

```python theme={"system"}
transforms_program = torch.export.export(
    mod=norm_transform, args=(example_tensor,), dynamic_shapes={transform_arg: dims}
)
```

## 5. Save to `.pt2` Archive

> **Note:** `torch.export.save` saves a single ExportedProgram. We'll use `torch.export.pt2_archive._package.package_pt2` to bundle both the model and transforms into one `.pt2` file.

```python theme={"system"}
from torch.export.pt2_archive._package import package_pt2
exported_programs = {}
local_model_path = "example.pt2"
exported_programs["model"] = model_program
exported_programs["transforms"] = transforms_program

package_pt2(
    f=local_model_path,
    exported_programs=exported_programs
)

```

## 6. Run the Model with RasterFlow

RasterFlow supports both Wherobots Hosted Models and custom models like the one we just exported.

**Steps:**

1. Upload the `.pt2` file to S3 (we'll use Wherobots Managed Storage)
2. Define an `InferenceConfig` to tell RasterFlow how to run the model

Note: You can also load open models stored in pt2 format directly from HuggingFace.

```python theme={"system"}
import os
import s3fs

fs = s3fs.S3FileSystem(profile="default")

# Define the destination path on S3
# We use the USER_S3_PATH environment variable to ensure it goes to your personal bucket space
s3_model_path = os.getenv("USER_S3_PATH") + local_model_path
fs.put(local_model_path, s3_model_path)
```

## Build the Model Input

To run our model, we need some input imagery. We'll test our model on National Agricultural Imagery Program (NAIP) 4 band imagery - red, blue, green, and near infrared. We'll select an AOI over Nashua, New Hampshire that has some forest canopy for our toy canopy height model.

```python theme={"system"}
import wkls
import geopandas as gpd
import os

# Generate a geometry for Nashua, NH using WKLS (https://github.com/wherobots/wkls)
gdf = gpd.read_file(wkls.us.nh.nashua.geojson())

# Save the geometry to a parquet file in the user's S3 path
aoi_path = os.getenv("USER_S3_PATH") + "nashua.parquet"
gdf.to_parquet(aoi_path)
```

To prepare this imagery, we'll use RasterFlow to create a mosaic. Mosaics are backed by a cloud native Zarr store that enables accessing spatial subsets, individual bands, and computing on the mosaic with RasterFlow.

This workflow takes a few minutes to complete, so you can skip ahead to the next cell where we load the prepared output from the workflow.

```python theme={"system"}
from rasterflow_remote import RasterflowClient
client = RasterflowClient()
```

```python theme={"system"}
mosaic_index = client.build_gti_mosaic(
        gti = "s3://wherobots-examples/rasterflow/indexes/naip_index.parquet",
        aoi = aoi_path,
        bands = ["red", "green", "blue", "nir"],
        location_field = "url",
        time_column = "year",
        crs_epsg = 3857,
        xy_chunksize = 1024,
        query = "res == .6",
        requester_pays = True,
        sort_field = "time",
        resampling = ResamplingMethod.NEAREST,
        nodata= 0.0,
)
mosaic_store = mosaic_index.first_row_mosaic
mosaic_index.mosaics
```

## Visualizing outputs

If RasterFlow is enabled for your organization, you can visualize the Zarr, GeoParquet, and other geospatial outputs using [cloud.wherobots.com/map](https://cloud.wherobots.com/map).

With our mosaic, we are now ready to run model prediction on the mosaic with RasterFlow.

We'll use the [`predict_mosaic`](https://docs.wherobots.com/reference/rasterflow/client#predict_mosaic) method to run our model. `predict_mosaic` leverages RasterFlow's powerful inference engine that scales from small to global scale areas of interest.

The inputs to this method are our input store we want to run prediction on, and our InferenceConfig object we created earlier.

## Defining the InferenceConfig

With the model on S3, define the inference job configuration.

> **Note:** Wherobots Hosted Models come with preconfigured `ModelRecipes`— this step is only needed for custom models.

See the [InferenceConfig documentation](https://docs.wherobots.com/reference/rasterflow/data-models#inferenceconfig) for parameter details.

```python theme={"system"}
from dataclasses import asdict
from rasterflow_remote.data_models import InferenceConfig, MergeModeEnum, MosaicToMosaicActorEnum, ResamplingMethod

custom_inference_config = InferenceConfig(
    model_path = s3_model_path,
    actor = MosaicToMosaicActorEnum.REGRESSION_PYTORCH,
    patch_size = 224,
    clip_size = 28,
    device = "cuda",
    features = ["red", "green", "blue"],
    labels = ["canopy_height"],
    max_batch_size=64,
    merge_mode = MergeModeEnum.WEIGHTED_AVERAGE
    
)
```

> **Note:** This step will take approximately 10 minutes to complete when run for the first time

```python theme={"system"}
predict_mosaic_output = client.predict_mosaic(
        store=mosaic_store,
        **asdict(custom_inference_config)
)

predict_mosaic_output.mosaics
```

```python theme={"system"}
# This example AOI has one geometry, so we select the first (only) mosaic location.
predict_store = predict_mosaic_output.first_row_mosaic

predict_store
```

We can inspect the result of `predict_mosaic` and confirm that it is the same shape as our input store.

```python theme={"system"}
import xarray as xr
import s3fs
import zarr

predict_fs = s3fs.S3FileSystem(profile="default", asynchronous=True)
predict_zstore = zarr.storage.FsspecStore(predict_fs, path=predict_store[5:])
predict_ds = xr.open_zarr(predict_zstore)
```

```python theme={"system"}
mosaic_fs = s3fs.S3FileSystem(profile="default", asynchronous=True)
mosaic_zstore = zarr.storage.FsspecStore(mosaic_fs, path=mosaic_store[5:])
xr.open_zarr(mosaic_zstore)['variables']
```

```python theme={"system"}
predict_ds['variables']
```

## 7. What We Learned

Congratulations on completing this guide! You've learned how to:

* **Create a PyTorch model** compatible with RasterFlow's expected input/output signature (Batch, Channel, Height, Width)
* **Understand PT2 vs .pth formats** — PT2 bundles weights, structure, and execution logic into a single deployable artifact
* **Define input shape constraints** using `torch.export.Dim` for dynamic batch sizes and fixed spatial dimensions
* **Export models and transforms** together into a single `.pt2` archive using `package_pt2`
* **Upload custom models to S3** using Wherobots Managed Storage
* **Build GTI mosaics** from imagery indexes with `build_gti_mosaic`
* **Run scalable inference** using `InferenceConfig` and `predict_mosaic`

### Next Steps

**Explore more RasterFlow examples:**

* [RasterFlow Canopy Height Model](/tutorials/example-notebooks/rasterflow-chm) — Run the real Meta/WRI canopy height model at scale
* [RasterFlow ChesapeakeRSC Rural Road Segmentation](/tutorials/example-notebooks/rasterflow-chesapeake) — Semantic segmentation for rural road detection
* [RasterFlow Tile2Net](/tutorials/example-notebooks/rasterflow-tile2net) — Extract road networks from satellite imagery

**Learn more:**

* [RasterFlow Documentation](https://docs.wherobots.com/develop/rasterflow/index) — Full API reference
* [Wherobots AI Models on Hugging Face](https://huggingface.co/collections/wherobots/wherobotsai-models) — Pre-trained models ready to use
* [InferenceConfig Reference](https://docs.wherobots.com/reference/rasterflow/data-models#inferenceconfig) — All configuration options for custom models
