Finetuning or Training Spotiflow on Custom Data#
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "matplotlib",
# "tifffile",
# "spotiflow",
# "tqdm",
# "napari[all]"
# ]
# ///
Overview#
GitHub | Paper | Spotiflow Documentation | Spotiflow API
In this notebook, we will walk through how to finetune a Spotiflow model following the instructions from the Spotiflow example notebook.
This process is very similar to training a model from scratch, but instead of initializing a new model with random weights, we will start with a pretrained model and finetune it on our data.
This is useful when the default models donβt perform well on your data. Finetuning or training from scratch allows Spotiflow to learn directly from your examples, leading to better spot detection accuracy and more relevant results for your experiments.
You can find more details in the Spotiflow documentation for training and finetuning.
π‘ Tip: Spotiflow runs significantly faster on a GPU. It supports both NVIDIA GPUs (CUDA) and Apple Silicon (MPS). If you don't have either, we recommend running this notebook on Google Colab for faster performance.
NVIDIA GPU (CUDA - Windows/Linux)
In order to use Spotiflow in this notebook with an NVIDIA GPU:
you need to have the NVIDIA drivers installed on your system.
you can run
nvidia-smiin the terminal to check your CUDA version (shown in the top-right of the output, e.g.CUDA Version: 13.0.0).update the
# /// scriptblock at the top of this notebook to install the appropriate version of PyTorch with CUDA support (replacecu130with your CUDA version):
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "matplotlib",
# "tifffile",
# "spotiflow",
# "napari[all]",
# "torch",
# "torchvision",
# ]
#
# [tool.uv.sources]
# torch = { index = "pytorch-cu130" }
# torchvision = { index = "pytorch-cu130" }
#
# [[tool.uv.index]]
# name = "pytorch-cu130"
# url = "https://download.pytorch.org/whl/cu130"
# explicit = true
# ///
re-run the notebook using
uvx juv run.
Data#
To go through the training process, you need pairs of images and corresponding spot annotations.
The spot annotations should be .csv files containing the spot coordinates organised in 2 column, x and y for 2D data and z, x and y for 3D data:
x |
y |
|---|---|
100.5 |
200.3 |
150.2 |
250.1 |
or, for 3D data:
z |
x |
y |
|---|---|---|
10.0 |
100.5 |
200.3 |
15.0 |
150.2 |
250.1 |
The image files and their corresponding spot annotation files MUST have the same name and MUST be organized in the same folder. The data should be split into train and val folders, with an optional test folder:
spots_data
βββ train
β βββ img_001.csv
β βββ img_001.tif
β ...
β βββ img_002.csv
β βββ img_002.tif
βββ val
β βββ val_img_001.csv
β βββ val_img_001.tif
β ...
β βββ val_img_002.csv
β βββ val_img_002.tif
βββ test (optional)
βββ test_img_001.csv
βββ test_img_001.tif
...
βββ test_img_002.csv
βββ test_img_002.tif
For this notebook, we will use the Spotiflow Finetuning Dataset suggested in the Spotiflow example notebook (MERFISH dataset from Zhang et al, 2021).
Import Libraries#
import matplotlib.pyplot as plt
import numpy as np
from spotiflow.model import Spotiflow, SpotiflowTrainingConfig
from spotiflow.utils import get_data
Load the Data#
Since the data we will use are organized as described above, to load the data ans split them in training, validation and test sets, we can simply use the get_data function provided by the Spotiflow API, which will automatically look for the .tif and .csv files in the specified folders and create the appropriate data structures for training.
Specify the include_test argument to True if you have a test folder with data that you want to use for testing after training.
data_path = "data/00_spot_detection_spotiflow_finetuning"
train_imgs, train_spots, val_imgs, val_spots, test_imgs, test_spots = get_data(
data_path, include_test=True
)
We can then visualize an example image and its corresponding spot annotations to verify that the data has been loaded correctly.
train_imgs and train_spots (and the corresponding val and test variables) are lists of images and spot coordinates, therefore we can index them to visualize one example.
index = 0
img = train_imgs[index]
spots = train_spots[index]
# set the contrast limits to the 1st and 99.8th percentiles of the image pixel values
plt.imshow(img, cmap="gray", clim=tuple(np.percentile(img, (1, 98))))
plt.scatter(spots[:, 1], spots[:, 0], facecolors="none", edgecolors="green")
plt.axis("off")
plt.title("Example Training Image")
plt.show()
Finetune a Pretrained Model#
We can now finetune a pretrained Spotiflow model on our data.
Load a Pretrained Model#
The first step is to load the pretrained model using the Spotiflow.from_pretrained() method, which allows you to specify the name of the pretrained model you want to load.
For this example, we will finetune the synth_complex model (automatically downloaded if not already present on your system).
Note: If you want to load a model from a path, you can instead use the Spotiflow.from_folder() method and specify the path to the folder where the model is located.
Note: If you want to train a model from scratch instead of finetuning a pretrained model, you can first define a new model configuration using the SpotiflowModelConfig class and then initialize a new model using the SpotiflowTrainingConfig class. See the Spotiflow Training notebook for more details.
model = Spotiflow.from_pretrained("synth_complex")
Prepare the Training Configuration#
The next step is to prepare the training configuration using the SpotiflowTrainingConfig class.
This step is identical to the one used for training a model from scratch.
Here we will only change a few parameters, but you can find the full description of the training configuration in the dropdown below.
Spotiflow SpotiflowTrainingConfig Parameters
Optimization
Parameter |
Default |
What it does |
|---|---|---|
|
|
Learning rate for the AdamW optimizer. Controls how large each weight update is. When finetuning, consider lowering it (e.g. |
|
|
Optimization algorithm. Currently only |
|
|
Number of crops processed simultaneously before each weight update. Larger batches give more stable gradients but use more GPU memory; reduce it if you run out of memory. |
|
|
Number of full passes over the (sampled) training data. More epochs give the model more time to learn, but too many can lead to overfitting. Watch the validation loss: if it starts rising while the training loss keeps falling, youβve trained too long. For a quick tutorial run, a much smaller value (e.g. |
|
|
Number of epochs without validation-loss improvement before the learning rate is automatically reduced ( |
|
|
Number of epochs without validation-loss improvement before training stops early. |
Sampling & Cropping
Parameter |
Default |
What it does |
|---|---|---|
|
|
Side length (in pixels) of the random square crops extracted from each image during training. Internally clamped to fit your images and the networkβs minimum size, so for small images the effective crop may be smaller. |
|
|
Crop size along the Z axis, used only for 3D data. Ignored for 2D. |
|
|
If |
|
|
Number of crops sampled per epoch. |
Loss
Parameter |
Default |
What it does |
|---|---|---|
|
|
Loss function for the spot heatmap. One of |
|
|
Loss function for the stereographic flow regression. Currently only |
|
|
Weight given to positive (spot) pixels in the heatmap loss. Spots are sparse, so positive pixels are up-weighted to counteract the class imbalance. Increase it if the model misses faint spots, decrease it if it over-detects. |
|
|
Number of resolution levels at which the heatmap loss is computed. |
Other
Parameter |
Default |
What it does |
|---|---|---|
|
|
Metadata string recording which pretrained model this one was finetuned from. Stored in the saved config; does not affect training itself. |
|
|
If |
train_config = SpotiflowTrainingConfig(
num_epochs=30,
finetuned_from="synth_complex", # optional, good for keeping track
)
Finetune the Model#
We can now train the model with calling the modelβs fit() method.
We can specify different parameters such as the training and validation data, the training configuration we defined in the previous step, and the path where to save the finetuned model.
model.fit(
train_imgs,
train_spots,
val_imgs,
val_spots,
save_dir="data/00_spot_detection_spotiflow_finetuning/finetuned_model",
train_config=train_config,
)
Evaluate the Finetuned Model#
Now that the model is finetuned, we can evaluate its performance on the test set when compared to the pretrained model.
We first need to load the synth_complex model and then run both this and the finetuned models on the test set to compare their performance.
# load the pretrained model
synth_complex = Spotiflow.from_pretrained("synth_complex")
# run both the pretrained and finetuned models on the test set
img_idx = 3 # index of the test image to evaluate
points_pretrained, _ = synth_complex.predict(test_imgs[img_idx])
points_finetuned, _ = model.predict(test_imgs[img_idx])
Letβs now visualize the predictions of both models to visually compare their performance.
# 2 columns, 1 row
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
clim = tuple(np.percentile(test_imgs[img_idx], (1, 98)))
axes[0].imshow(test_imgs[img_idx], cmap="gray", clim=clim)
axes[0].scatter(
points_pretrained[:, 1],
points_pretrained[:, 0],
facecolors="none",
edgecolors="magenta",
alpha=0.7,
)
axes[0].set_title("Pretrained")
axes[0].axis("off")
axes[1].imshow(test_imgs[img_idx], cmap="gray", clim=clim)
axes[1].scatter(
points_finetuned[:, 1],
points_finetuned[:, 0],
facecolors="none",
edgecolors="green",
alpha=0.7,
)
axes[1].set_title("Finetuned")
axes[1].axis("off")
plt.show()