Spotiflow in Python#
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "matplotlib",
# "tifffile",
# "spotiflow",
# "tqdm",
# "napari[all]"
# ]
# ///
Overview#
GitHub | Paper | Spotiflow Documentation | Spotiflow API
In this notebook, we’ll see how to run Spotiflow to detect spots on single images or on a folder of images, and how to visualize and save the results.
The images we will use for this section can be downloaded from the Spotiflow Dataset.
Since we will be visualizing points (i.e. the detected spots), to explore something different from what we have done in the previous notebooks, we will use napari to visualize the results (ndv currently does not yet support points directly).
Note: If you want to use napari independently of this notebook, you can quickly launch napari by running uvx "napari[all]" in your terminal (it might take a little while the first time you run this command, but after that it will be very quick).
💡 Tip: Spotiflow runs significantly faster on a GPU. It supports both NVIDIA GPUs (CUDA) and Apple Silicon (MPS). If you don't have either, we recommend running this notebook on Google Colab for faster performance.
NVIDIA GPU (CUDA - Windows/Linux)
In order to use Spotiflow in this notebook with an NVIDIA GPU:
you need to have the NVIDIA drivers installed on your system.
you can run
nvidia-smiin the terminal to check your CUDA version (shown in the top-right of the output, e.g.CUDA Version: 13.0.0).update the
# /// scriptblock at the top of this notebook to install the appropriate version of PyTorch with CUDA support (replacecu130with your CUDA version):
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "matplotlib",
# "tifffile",
# "spotiflow",
# "napari[all]",
# "torch",
# "torchvision",
# ]
#
# [tool.uv.sources]
# torch = { index = "pytorch-cu130" }
# torchvision = { index = "pytorch-cu130" }
#
# [[tool.uv.index]]
# name = "pytorch-cu130"
# url = "https://download.pytorch.org/whl/cu130"
# explicit = true
# ///
re-run the notebook using
uvx juv run.
Import Libraries#
import csv
from pathlib import Path
import matplotlib.pyplot as plt
import napari
import numpy as np
import tifffile
from spotiflow.model import Spotiflow
from tqdm import tqdm
Setup#
These is an helper that saves the results of spot detection in a .csv file in a napari-compatible format, so that we can easily visualize the results in napari later on.
Running Spotiflow on 2D images#
Load the Image#
Since we will be using TIFF files, to load the images, we can use the imread method from the tifffile library.
image_path = "../../_static/images/spotiflow/2d_spots.tif"
image = tifffile.imread(image_path)
print(image.shape)
(6, 1040, 1392)
This is a 6-channel image, we can use the napari viewer to visualize the image.
First we need to create the viewer, and then we can add the image as a layer to the viewer. If you do not close the viewer, you can use the same within the rest of the notebook to visualize the detected spots as well.
viewer = napari.Viewer()
viewer.add_image(image, name="Image")

Initialize the Model#
To initialize a pre-trained Spotiflow model, we can use the from_pretrained() class method.
Currently the available pre-trained models are:
general: trained on a diverse dataset of spots of different modalities acquired in different microscopes with different settings.hybiss: trained on HybISS data acquired in 3 different microscopessynth_complex: trained on synthetic data, which includes simulations of aberrated spots and fluorescence background.fluo_live: trained on live-cell fluorescence imaging data corresponding to the Telomeres and Terra datasets in the manuscript.synth_3d: trained on synthetic 3D data, which includes simulations of aberrated spots and Z-related artifacts.smfish_3d: fine-tuned from the synth_3d model on smFISH 3D data of Platynereis dumerilii.
For this first example, we’ll use the general model.
Note: If you never used Spotiflow before, the model you will specify in the from_pretrained() method will be downloaded automatically the first time you run this notebook.
model = Spotiflow.from_pretrained("general")
INFO:spotiflow.model.spotiflow:Loading pretrained model: general
Downloading data from https://github.com/weigertlab/spotiflow-models/releases/download/0.6.0/general.zip to /home/runner/.spotiflow/models/general.zip
0/87885382 [..............................] - ETA: 0s
10493952/87885382 [==>...........................] - ETA: 0s
16785408/87885382 [====>.........................] - ETA: 0s
28835840/87885382 [========>.....................] - ETA: 0s
35635200/87885382 [===========>..................] - ETA: 0s
41951232/87885382 [=============>................] - ETA: 0s
50339840/87885382 [================>.............] - ETA: 0s
64258048/87885382 [====================>.........] - ETA: 0s
78692352/87885382 [=========================>....] - ETA: 0s
87885382/87885382 [==============================] - 1s 0us/step
Run Spotiflow#
After initializing the model, we can run Spotiflow on the image using either the predict() or the predict_multichannel() (see the dropdown below) method from the initialized model.
predict_multichannel() Parameters
model.predict_multichannel(
img,
channels=None,
**predict_kwargs,
)
Input
Parameter |
Default |
Description |
|---|---|---|
|
— |
Multi-channel image in channel-last format: |
|
|
An |
|
— |
Any keyword argument accepted by |
predict() keyword arguments
Parameter |
Default |
Description |
|---|---|---|
|
|
Probability threshold for peak detection. If |
|
|
Minimum distance (in pixels) allowed between two detected spots. |
|
|
Rescale factor applied to the image before detection. |
|
|
Whether to refine spot positions to subpixel accuracy using the stereographic flow. If |
|
|
Intensity normalization. |
|
|
Compute device to use: |
Returns
A tuple (points, details):
Output |
Description |
|---|---|
|
numpy array of spot coordinates: |
|
A list with one entry per processed channel, each holding that channel’s |
The difference between the two is:
predict()runs the model on a single-channel image (a 2D(Y, X)image or a 3D(Z, Y, X)stack) and returns the spots detected in it.predict_multichannel()is designed for multi-channel data. It runspredict()on each channel independently and stacks the results, tagging every spot with the channel it was detected in. This is needed because the pre-trained models always work on a single channel at a time.
Since we are using a 6-channel image, we will use predict_multichannel().
⚠️ Spotiflow expects the channel axis to be last (the image must be in (Y, X, C) shape).
⚠️ In our case, the image is stored as (C, Y, X) so we need to move the channel axis to the end before running the model. We can do this easily with `numpy.transpose`.
tr_image = image.transpose(1, 2, 0) # (C, Y, X) -> (Y, X, C)
print(tr_image.shape)
(1040, 1392, 6)
Now we can run Spotiflow on the transposed image using the predict_multichannel().
We need to pass to this function the image and which channel(s) to run the detection on (the channel parameter): we need to skip the first 2 channels (index 0 and 1) since they are the nuclei and the cytoplasm channels.
There could be other useful parameters to use (e.g. prob_thresh, min_distance, subpix), for simplicity we will only pass the two above.
As shown in the dropdown above, predict_multichannel() will return two outputs (a tuple):
points: a numpy array containing the coordinates of the detected spots.details: a list with one entry per channel, each containing the spot-wise details for that channel includingheatmap,intens,prob,flow, andsubpix.
points, details = model.predict_multichannel(tr_image, channels=(2, 3, 4, 5))
INFO:spotiflow.model.spotiflow:Data is assumed to be in channel-last format ((Z)YXC).
Explore, Display and Save the Results#
Points#
Let’s explore the outputs starting from understanding the points.
What is the shape of points?
print(points.shape)
(35103, 3)
What are the coordinates of the first detected spot?
print(points[0])
[1038.40406805 968.14990279 2. ]
Since we used predict_multichannel(), each row of points is in (y, x, channel) format. The last column is therefore the index of the channel the spot was detected in (would have been simply (y, x) if we had used predict() on a single-channel image).
Now let’s visualize the detected spots on top of the image with napari. Note that if you did close the viewer, you need to re-create it and re-add the image together with the points.
Importantly, the points coordinates are in (y, x, channel) format, while napari expects the coordinates to be in (channel, y, x) format, so we need to reorder the columns of points before passing them to napari. We can do this easily with numpy indexing.
# optional if you did close the viewer we previously created
# viewer = napari.Viewer()
# viewer.add_image(image, name="Image")
# we need to reorder the points coordinates as we have them in the (yxc) format and napari needs (cyx)
points_napari = points[:, [2, 0, 1]] # (y, x, channel) -> (channel, y, x)
viewer.add_points(
points_napari,
name="Points",
face_color="green",
border_color="green",
size=4,
symbol="x",
)

We can also save the detected spots for future use as a csv file using the save_points_as_csv function we defined in the Setup section of this notebook.
Note: This function is compatible with napari, so you can directly drag and drop the saved csv file in napari (together with the original multi-channel image) to visualize the detected spots as a points layer: each spot will appear on the channel it was detected in.
save_points_as_csv(points, "2d_points.csv", channel_last=True)
Details#
Let’s have a look at the details output as well and let’s focus in particular on the heatmap and the intensities (intens) of the detected spots.
details is a list with one entry per channel, each containing the spot-wise details for that channel including heatmap, intens, prob, flow, and subpix.
We can first plot the heatmap of the probabilities per pixel by visualizing it with imshow from the ndv library so that we can interactively explore the values of the heatmap by hovering over it with the mouse cursor.
Remember, details is a list, so we need to specify the channel index to visualize the heatmap of a specific channel.
# `ch` is channel index, change it to visualize the heatmap of a different channel.
# note that this is the channel index as passed to predict_multichannel().
ch = 1
# optional if you did close the viewer we previously created
# viewer = napari.Viewer()
viewer.add_image(details[ch].heatmap, name=f"Heatmap (Channel {ch})", colormap="magma")

In details[channel].intens we can find the value of the pixel in the original image at the coordinates of each detected spot, which can be used as a measure of the intensity of the detected spots.
We can visualize the distribution of these intensities with a histogram using hist from matplotlib.pyplot.
plt.hist(details[0].intens, bins=15, color="green", alpha=0.5, label="channel 0")
plt.hist(details[1].intens, bins=15, color="magenta", alpha=0.5, label="channel 1")
plt.xlabel("Intensity")
plt.ylabel("Frequency")
plt.legend()
plt.show()
Running Spotiflow on a Folder of Images#
To run Spotiflow on a folder of multi-channel images, we can simply loop over the files in the folder, load the images with tifffile, and run predict_multichannel() on each of them. To save the points, we can use the save_points_as_csv function we defined in the Setup section of this notebook.
# Path to the folder containing the images
folder_path = Path("data/00_spot_detection/") # change this to your folder path
# Get the sorted list of all .tif images in the folder
images_path = sorted(folder_path.glob("*.tif"))
# Initialize the model once before the loop
model = Spotiflow.from_pretrained("general")
# specify the channels you want to process in predict_multichannel()
channels = (2, 3, 4, 5)
# NOTE: tqdm is used to show a progress bar, but you can remove it if you don't want it
for image_path in tqdm(images_path, desc="Processing images"):
# Load the image
image = tifffile.imread(image_path)
# Transpose the image to channel-last format for `predict_multichannel`
tr_image = image.transpose(1, 2, 0) # (C, Y, X) -> (Y, X, C)
# Run Spotiflow on the image
points, details = model.predict_multichannel(tr_image, channels=channels)
# Save the points as a CSV file
output_path = folder_path / f"{image_path.stem}_points.csv"
save_points_as_csv(points, str(output_path), channel_last=True)
Bonus: Running Spotiflow on 3D images#
Running Spotiflow on 3D images is very similar to running it on 2D images, the only difference is that the input image will be a 3D stack in (Z, Y, X) format for single-channel or (Z, Y, X, C) for multi-channel, and the output points will be in (z, y, x) or (z, y, x, channel) format respectively.
Load and Visualize the Image#
image_path = "../../_static/images/spotiflow/3d_spots.tif"
image_3d = tifffile.imread(image_path)
print(image_3d.shape)
(128, 256, 256)
Let’s use napari again to visualize the 3D image as a stack.
We first need to create the viewer, and then we can add the image as a layer to the viewer.
The scale of the image we are using for this example is zyx=(0.2, 0.1, 0.1) (µm) so we can pass it to add_image() to have the correct physical aspect ratio when visualizing the stack.
viewer = napari.Viewer(ndisplay=3) # ndisplay=3 to show the volume directly in 3D mode
viewer.add_image(image_3d, name="Image_3D", scale=(0.2, 0.1, 0.1))

Initialize the Model#
For 3D datasets, we need to use a pre-trained model that has been trained on 3D data, either synth_3d or smfish_3d. In this example, we will use the smfish_3d model.
model = Spotiflow.from_pretrained("smfish_3d")
INFO:spotiflow.model.spotiflow:Loading pretrained model: smfish_3d
Downloading data from https://github.com/weigertlab/spotiflow-models/releases/download/0.6.0/smfish_3d.zip to /home/runner/.spotiflow/models/smfish_3d.zip
0/263106200 [..............................] - ETA: 0s
10493952/263106200 [>.............................] - ETA: 3s
16785408/263106200 [>.............................] - ETA: 3s
20979712/263106200 [=>............................] - ETA: 3s
25174016/263106200 [=>............................] - ETA: 4s
31465472/263106200 [==>...........................] - ETA: 3s
33562624/263106200 [==>...........................] - ETA: 4s
41951232/263106200 [===>..........................] - ETA: 3s
52396032/263106200 [====>.........................] - ETA: 3s
52436992/263106200 [====>.........................] - ETA: 3s
55975936/263106200 [=====>........................] - ETA: 3s
62922752/263106200 [======>.......................] - ETA: 3s
71311360/263106200 [=======>......................] - ETA: 2s
75505664/263106200 [=======>......................] - ETA: 2s
83894272/263106200 [========>.....................] - ETA: 2s
92282880/263106200 [=========>....................] - ETA: 2s
94380032/263106200 [=========>....................] - ETA: 2s
100671488/263106200 [==========>...................] - ETA: 2s
104865792/263106200 [==========>...................] - ETA: 2s
115351552/263106200 [============>.................] - ETA: 2s
117448704/263106200 [============>.................] - ETA: 2s
125837312/263106200 [=============>................] - ETA: 2s
134225920/263106200 [==============>...............] - ETA: 2s
138125312/263106200 [==============>...............] - ETA: 2s
138420224/263106200 [==============>...............] - ETA: 2s
152190976/263106200 [================>.............] - ETA: 1s
157294592/263106200 [================>.............] - ETA: 1s
163586048/263106200 [=================>............] - ETA: 1s
167780352/263106200 [==================>...........] - ETA: 1s
178266112/263106200 [===================>..........] - ETA: 1s
184557568/263106200 [====================>.........] - ETA: 1s
188751872/263106200 [====================>.........] - ETA: 1s
197140480/263106200 [=====================>........] - ETA: 1s
199237632/263106200 [=====================>........] - ETA: 0s
205529088/263106200 [======================>.......] - ETA: 0s
209723392/263106200 [======================>.......] - ETA: 0s
220209152/263106200 [========================>.....] - ETA: 0s
230694912/263106200 [=========================>....] - ETA: 0s
239083520/263106200 [==========================>...] - ETA: 0s
243277824/263106200 [==========================>...] - ETA: 0s
247472128/263106200 [===========================>..] - ETA: 0s
251666432/263106200 [===========================>..] - ETA: 0s
262152192/263106200 [============================>.] - ETA: 0s
263106200/263106200 [==============================] - 4s 0us/step
Run Spotiflow#
In this example we are using a single-channel z-stack, so we can use the predict() method from the initialized model. We also do not need to transpose the image since it is already in the required (Z, Y, X) format.
points, details = model.predict(image_3d)
INFO:spotiflow.model.spotiflow:Will use device: cpu
INFO:spotiflow.model.spotiflow:Predicting with prob_thresh = [0.4], min_distance = 1
INFO:spotiflow.model.spotiflow:Peak detection mode: fast
INFO:spotiflow.model.spotiflow:Image shape (128, 256, 256)
INFO:spotiflow.model.spotiflow:Predicting with (1, 1, 1) tiles
INFO:spotiflow.model.spotiflow:Normalizing...
INFO:spotiflow.model.spotiflow:Padding to shape (128, 256, 256, 1)
INFO:spotiflow.model.spotiflow:Correcting internal min_distance to 0.5 due to grid: (2, 2, 2).
INFO:spotiflow.model.spotiflow:Found 84 spots
points[0] # z, y, x
array([114.00123863, 205.02350116, 243.045506 ])
Display and Save the Results#
Points#
As we did for the 2D case, we can visualize the 3D spots on top of the stack with napari.
Note that if you did close the viewer, you need to re-create it and re-add the image together with the points.
The points coordinates are in (z, y, x) format, already in the correct order for napari, so we can directly pass them to add_points() without needing to reorder the columns as we did for the 2D case.
Remember to also pass the scale to add_points() so that the points are shown in the correct physical aspect ratio on top of the stack.
# optional if you did close the viewer we previously created
# viewer = napari.Viewer(ndisplay=3)
# viewer.add_image(image_3d, name="Image_3D", scale=(0.2, 0.1, 0.1))
viewer.add_points(
points,
name="Points",
face_color="green",
border_color="green",
size=4,
symbol="x",
scale=(0.2, 0.1, 0.1),
)

We can use the save_points_as_csv helper functions we defined in the Setup section to save the detected spots.
save_points_as_csv(points, "path/to/your/folder/points_3d.csv")
Details#
We can also visualize the heatmap of the probabilities per pixel and the distribution of the intensities of the detected spots as we did for 2D images using napari and matplotlib.pyplot.hist() respectively.
Note that the heatmap is half the pixel resolution of the original image, so when visualizing it with napari we need to double the scale we used for the original image to have it correctly overlaid on top of the original image.
# optional if you did close the viewer we previously created
# viewer = napari.Viewer(ndisplay=3)
# viewer.add_image(image_3d, name="Image_3D", scale=(0.2, 0.1, 0.1))
viewer.add_image(
details.heatmap,
name="Heatmap_3D",
colormap="magma",
scale=(0.2 * 2, 0.1 * 2, 0.1 * 2),
)

plt.hist(details.intens, bins=15)
plt.xlabel("Intensity")
plt.ylabel("Frequency")
plt.show()