Spotiflow in Python#

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "matplotlib",
#     "tifffile",
#     "spotiflow",
#     "tqdm",
#     "napari[all]"
# ]
# ///

Overview#

GitHub | Paper | Spotiflow Documentation | Spotiflow API

In this notebook, we’ll see how to run Spotiflow to detect spots on single images or on a folder of images, and how to visualize and save the results.

The images we will use for this section can be downloaded from the Spotiflow Dataset.

Since we will be visualizing points (i.e. the detected spots), to explore something different from what we have done in the previous notebooks, we will use napari to visualize the results (ndv currently does not yet support points directly).

Note: If you want to use napari independently of this notebook, you can quickly launch napari by running uvx "napari[all]" in your terminal (it might take a little while the first time you run this command, but after that it will be very quick).

💡 Tip: Spotiflow runs significantly faster on a GPU. It supports both NVIDIA GPUs (CUDA) and Apple Silicon (MPS). If you don't have either, we recommend running this notebook on Google Colab for faster performance.

NVIDIA GPU (CUDA - Windows/Linux)

In order to use Spotiflow in this notebook with an NVIDIA GPU:

  1. you need to have the NVIDIA drivers installed on your system.

  2. you can run nvidia-smi in the terminal to check your CUDA version (shown in the top-right of the output, e.g. CUDA Version: 13.0.0).

  3. update the # /// script block at the top of this notebook to install the appropriate version of PyTorch with CUDA support (replace cu130 with your CUDA version):

    # /// script
    # requires-python = ">=3.12"
    # dependencies = [
    #     "matplotlib",
    #     "tifffile",
    #     "spotiflow",
    #     "napari[all]",
    #     "torch",
    #     "torchvision",
    # ]
    #
    # [tool.uv.sources]
    # torch = { index = "pytorch-cu130" }
    # torchvision = { index = "pytorch-cu130" }
    #
    # [[tool.uv.index]]
    # name = "pytorch-cu130"
    # url = "https://download.pytorch.org/whl/cu130"
    # explicit = true
    # ///
  1. re-run the notebook using uvx juv run.

Import Libraries#

import csv
from pathlib import Path

import matplotlib.pyplot as plt
import napari
import numpy as np
import tifffile
from spotiflow.model import Spotiflow
from tqdm import tqdm

Setup#

These is an helper that saves the results of spot detection in a .csv file in a napari-compatible format, so that we can easily visualize the results in napari later on.

Hide code cell source

def save_points_as_csv(points, output_path="points.csv", channel_last=False) -> None:
    """Save points as a napari-compatible CSV (drag-and-drop as Points layer).

    napari maps the CSV columns (axis-0, axis-1, ...) to the layer axes in order.
    `predict_multichannel` returns the channel as the *last* column (e.g. (y, x, channel)),
    so set `channel_last=True` to move it to the front (e.g. (channel, y, x)) and have the
    spots line up with a channel-first (C, ...) image in napari.
    """
    points = np.asarray(points)
    if channel_last:
        # move the last column (channel) to the front
        points = points[:, [-1, *range(points.shape[1] - 1)]]
    ndim = points.shape[1]
    headers = ["index"] + [f"axis-{i}" for i in range(ndim)]
    with open(output_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(headers)
        for i, p in enumerate(points):
            writer.writerow([i, *p])

Running Spotiflow on 2D images#

Load the Image#

Since we will be using TIFF files, to load the images, we can use the imread method from the tifffile library.

image_path = "../../_static/images/spotiflow/2d_spots.tif"
image = tifffile.imread(image_path)

print(image.shape)
(6, 1040, 1392)

This is a 6-channel image, we can use the napari viewer to visualize the image.

First we need to create the viewer, and then we can add the image as a layer to the viewer. If you do not close the viewer, you can use the same within the rest of the notebook to visualize the detected spots as well.

viewer = napari.Viewer()
viewer.add_image(image, name="Image")
napari

Initialize the Model#

To initialize a pre-trained Spotiflow model, we can use the from_pretrained() class method.

Currently the available pre-trained models are:

  • general: trained on a diverse dataset of spots of different modalities acquired in different microscopes with different settings.

  • hybiss: trained on HybISS data acquired in 3 different microscopes

  • synth_complex: trained on synthetic data, which includes simulations of aberrated spots and fluorescence background.

  • fluo_live: trained on live-cell fluorescence imaging data corresponding to the Telomeres and Terra datasets in the manuscript.

  • synth_3d: trained on synthetic 3D data, which includes simulations of aberrated spots and Z-related artifacts.

  • smfish_3d: fine-tuned from the synth_3d model on smFISH 3D data of Platynereis dumerilii.

For this first example, we’ll use the general model.

Note: If you never used Spotiflow before, the model you will specify in the from_pretrained() method will be downloaded automatically the first time you run this notebook.

model = Spotiflow.from_pretrained("general")
INFO:spotiflow.model.spotiflow:Loading pretrained model: general
Downloading data from https://github.com/weigertlab/spotiflow-models/releases/download/0.6.0/general.zip to /home/runner/.spotiflow/models/general.zip

       0/87885382 [..............................] - ETA: 0s

10493952/87885382 [==>...........................] - ETA: 0s
16785408/87885382 [====>.........................] - ETA: 0s
28835840/87885382 [========>.....................] - ETA: 0s
35635200/87885382 [===========>..................] - ETA: 0s

41951232/87885382 [=============>................] - ETA: 0s
50339840/87885382 [================>.............] - ETA: 0s

64258048/87885382 [====================>.........] - ETA: 0s
78692352/87885382 [=========================>....] - ETA: 0s
87885382/87885382 [==============================] - 1s 0us/step

Run Spotiflow#

After initializing the model, we can run Spotiflow on the image using either the predict() or the predict_multichannel() (see the dropdown below) method from the initialized model.

predict_multichannel() Parameters
model.predict_multichannel(
    img,
    channels=None,
    **predict_kwargs,
)

Input

Parameter

Default

Description

img

Multi-channel image in channel-last format: (Y, X, C) for 2D or (Z, Y, X, C) for 3D.

channels

None

An int or a tuple of channel indices to run the detection on. If None, runs on all channels.

**predict_kwargs

Any keyword argument accepted by predict(), forwarded to each channel (see the useful ones below).

predict() keyword arguments

Parameter

Default

Description

prob_thresh

None

Probability threshold for peak detection. If None, uses the model’s optimal value.

min_distance

1

Minimum distance (in pixels) allowed between two detected spots.

scale

None

Rescale factor applied to the image before detection.

subpix

None

Whether to refine spot positions to subpixel accuracy using the stereographic flow. If None, deduced from the model configuration.

normalizer

"auto"

Intensity normalization. "auto" uses percentile-based normalization (p_min=1, p_max=99.8).

device

None

Compute device to use: "auto", "cpu", "cuda" or "mps". If None, inferred from the model location.

Returns

A tuple (points, details):

Output

Description

points

numpy array of spot coordinates: (N, 3) for 2D as (y, x, channel) or (N, 4) for 3D as (z, y, x, channel). The last column is the index of the channel the spot was detected in.

details

A list with one entry per processed channel, each holding that channel’s heatmap (probability heatmap per pixel), intens (intensities of the detected spots), prob (probability of each spot being a true positive), flow (stereographic flow vector field) and subpix (2D local offset vector field).


The difference between the two is:

  • predict() runs the model on a single-channel image (a 2D (Y, X) image or a 3D (Z, Y, X) stack) and returns the spots detected in it.

  • predict_multichannel() is designed for multi-channel data. It runs predict() on each channel independently and stacks the results, tagging every spot with the channel it was detected in. This is needed because the pre-trained models always work on a single channel at a time.

Since we are using a 6-channel image, we will use predict_multichannel().

⚠️ Spotiflow expects the channel axis to be last (the image must be in (Y, X, C) shape).

⚠️ In our case, the image is stored as (C, Y, X) so we need to move the channel axis to the end before running the model. We can do this easily with `numpy.transpose`.

tr_image = image.transpose(1, 2, 0)  # (C, Y, X) -> (Y, X, C)
print(tr_image.shape)
(1040, 1392, 6)

Now we can run Spotiflow on the transposed image using the predict_multichannel().

We need to pass to this function the image and which channel(s) to run the detection on (the channel parameter): we need to skip the first 2 channels (index 0 and 1) since they are the nuclei and the cytoplasm channels.

There could be other useful parameters to use (e.g. prob_thresh, min_distance, subpix), for simplicity we will only pass the two above.

As shown in the dropdown above, predict_multichannel() will return two outputs (a tuple):

  • points: a numpy array containing the coordinates of the detected spots.

  • details: a list with one entry per channel, each containing the spot-wise details for that channel including heatmap, intens, prob, flow, and subpix.

points, details = model.predict_multichannel(tr_image, channels=(2, 3, 4, 5))
INFO:spotiflow.model.spotiflow:Data is assumed to be in channel-last format ((Z)YXC).

Explore, Display and Save the Results#

Points#

Let’s explore the outputs starting from understanding the points.

What is the shape of points?

print(points.shape)
(35103, 3)

What are the coordinates of the first detected spot?

print(points[0])
[1038.40406805  968.14990279    2.        ]

Since we used predict_multichannel(), each row of points is in (y, x, channel) format. The last column is therefore the index of the channel the spot was detected in (would have been simply (y, x) if we had used predict() on a single-channel image).

Now let’s visualize the detected spots on top of the image with napari. Note that if you did close the viewer, you need to re-create it and re-add the image together with the points.

Importantly, the points coordinates are in (y, x, channel) format, while napari expects the coordinates to be in (channel, y, x) format, so we need to reorder the columns of points before passing them to napari. We can do this easily with numpy indexing.

# optional if you did close the viewer we previously created
# viewer = napari.Viewer()
# viewer.add_image(image, name="Image")

# we need to reorder the points coordinates as we have them in the (yxc) format and napari needs (cyx)
points_napari = points[:, [2, 0, 1]]  # (y, x, channel) -> (channel, y, x)
viewer.add_points(
    points_napari,
    name="Points",
    face_color="green",
    border_color="green",
    size=4,
    symbol="x",
)
napari

We can also save the detected spots for future use as a csv file using the save_points_as_csv function we defined in the Setup section of this notebook.

Note: This function is compatible with napari, so you can directly drag and drop the saved csv file in napari (together with the original multi-channel image) to visualize the detected spots as a points layer: each spot will appear on the channel it was detected in.

save_points_as_csv(points, "2d_points.csv", channel_last=True)

Details#

Let’s have a look at the details output as well and let’s focus in particular on the heatmap and the intensities (intens) of the detected spots.

details is a list with one entry per channel, each containing the spot-wise details for that channel including heatmap, intens, prob, flow, and subpix.

We can first plot the heatmap of the probabilities per pixel by visualizing it with imshow from the ndv library so that we can interactively explore the values of the heatmap by hovering over it with the mouse cursor.

Remember, details is a list, so we need to specify the channel index to visualize the heatmap of a specific channel.

# `ch` is channel index, change it to visualize the heatmap of a different channel.
# note that this is the channel index as passed to predict_multichannel().
ch = 1

# optional if you did close the viewer we previously created
# viewer = napari.Viewer()
viewer.add_image(details[ch].heatmap, name=f"Heatmap (Channel {ch})", colormap="magma")
napari

In details[channel].intens we can find the value of the pixel in the original image at the coordinates of each detected spot, which can be used as a measure of the intensity of the detected spots.

We can visualize the distribution of these intensities with a histogram using hist from matplotlib.pyplot.

plt.hist(details[0].intens, bins=15, color="green", alpha=0.5, label="channel 0")
plt.hist(details[1].intens, bins=15, color="magenta", alpha=0.5, label="channel 1")
plt.xlabel("Intensity")
plt.ylabel("Frequency")
plt.legend()
plt.show()
../../_images/6520de5b388aa68db77bf65df38e27036ab1458e1eac5b8f237011269ce4d597.png

Running Spotiflow on a Folder of Images#

To run Spotiflow on a folder of multi-channel images, we can simply loop over the files in the folder, load the images with tifffile, and run predict_multichannel() on each of them. To save the points, we can use the save_points_as_csv function we defined in the Setup section of this notebook.

# Path to the folder containing the images
folder_path = Path("data/00_spot_detection/")  # change this to your folder path

# Get the sorted list of all .tif images in the folder
images_path = sorted(folder_path.glob("*.tif"))

# Initialize the model once before the loop
model = Spotiflow.from_pretrained("general")

# specify the channels you want to process in predict_multichannel()
channels = (2, 3, 4, 5)

# NOTE: tqdm is used to show a progress bar, but you can remove it if you don't want it
for image_path in tqdm(images_path, desc="Processing images"):
    # Load the image
    image = tifffile.imread(image_path)
    # Transpose the image to channel-last format for `predict_multichannel`
    tr_image = image.transpose(1, 2, 0)  # (C, Y, X) -> (Y, X, C)
    # Run Spotiflow on the image
    points, details = model.predict_multichannel(tr_image, channels=channels)
    # Save the points as a CSV file
    output_path = folder_path / f"{image_path.stem}_points.csv"
    save_points_as_csv(points, str(output_path), channel_last=True)

Bonus: Running Spotiflow on 3D images#

Running Spotiflow on 3D images is very similar to running it on 2D images, the only difference is that the input image will be a 3D stack in (Z, Y, X) format for single-channel or (Z, Y, X, C) for multi-channel, and the output points will be in (z, y, x) or (z, y, x, channel) format respectively.

Load and Visualize the Image#

image_path = "../../_static/images/spotiflow/3d_spots.tif"
image_3d = tifffile.imread(image_path)

print(image_3d.shape)
(128, 256, 256)

Let’s use napari again to visualize the 3D image as a stack.

We first need to create the viewer, and then we can add the image as a layer to the viewer.

The scale of the image we are using for this example is zyx=(0.2, 0.1, 0.1) (µm) so we can pass it to add_image() to have the correct physical aspect ratio when visualizing the stack.

viewer = napari.Viewer(ndisplay=3)  # ndisplay=3 to show the volume directly in 3D mode
viewer.add_image(image_3d, name="Image_3D", scale=(0.2, 0.1, 0.1))
napari

Initialize the Model#

For 3D datasets, we need to use a pre-trained model that has been trained on 3D data, either synth_3d or smfish_3d. In this example, we will use the smfish_3d model.

model = Spotiflow.from_pretrained("smfish_3d")
INFO:spotiflow.model.spotiflow:Loading pretrained model: smfish_3d
Downloading data from https://github.com/weigertlab/spotiflow-models/releases/download/0.6.0/smfish_3d.zip to /home/runner/.spotiflow/models/smfish_3d.zip
        0/263106200 [..............................] - ETA: 0s

 10493952/263106200 [>.............................] - ETA: 3s

 16785408/263106200 [>.............................] - ETA: 3s
 20979712/263106200 [=>............................] - ETA: 3s

 25174016/263106200 [=>............................] - ETA: 4s
 31465472/263106200 [==>...........................] - ETA: 3s
 33562624/263106200 [==>...........................] - ETA: 4s

 41951232/263106200 [===>..........................] - ETA: 3s
 52396032/263106200 [====>.........................] - ETA: 3s
 52436992/263106200 [====>.........................] - ETA: 3s
 55975936/263106200 [=====>........................] - ETA: 3s

 62922752/263106200 [======>.......................] - ETA: 3s
 71311360/263106200 [=======>......................] - ETA: 2s
 75505664/263106200 [=======>......................] - ETA: 2s

 83894272/263106200 [========>.....................] - ETA: 2s
 92282880/263106200 [=========>....................] - ETA: 2s
 94380032/263106200 [=========>....................] - ETA: 2s

100671488/263106200 [==========>...................] - ETA: 2s
104865792/263106200 [==========>...................] - ETA: 2s

115351552/263106200 [============>.................] - ETA: 2s
117448704/263106200 [============>.................] - ETA: 2s
125837312/263106200 [=============>................] - ETA: 2s

134225920/263106200 [==============>...............] - ETA: 2s
138125312/263106200 [==============>...............] - ETA: 2s
138420224/263106200 [==============>...............] - ETA: 2s
152190976/263106200 [================>.............] - ETA: 1s

157294592/263106200 [================>.............] - ETA: 1s
163586048/263106200 [=================>............] - ETA: 1s
167780352/263106200 [==================>...........] - ETA: 1s

178266112/263106200 [===================>..........] - ETA: 1s
184557568/263106200 [====================>.........] - ETA: 1s
188751872/263106200 [====================>.........] - ETA: 1s

197140480/263106200 [=====================>........] - ETA: 1s
199237632/263106200 [=====================>........] - ETA: 0s
205529088/263106200 [======================>.......] - ETA: 0s

209723392/263106200 [======================>.......] - ETA: 0s
220209152/263106200 [========================>.....] - ETA: 0s

230694912/263106200 [=========================>....] - ETA: 0s
239083520/263106200 [==========================>...] - ETA: 0s
243277824/263106200 [==========================>...] - ETA: 0s
247472128/263106200 [===========================>..] - ETA: 0s

251666432/263106200 [===========================>..] - ETA: 0s
262152192/263106200 [============================>.] - ETA: 0s
263106200/263106200 [==============================] - 4s 0us/step

Run Spotiflow#

In this example we are using a single-channel z-stack, so we can use the predict() method from the initialized model. We also do not need to transpose the image since it is already in the required (Z, Y, X) format.

points, details = model.predict(image_3d)
INFO:spotiflow.model.spotiflow:Will use device: cpu
INFO:spotiflow.model.spotiflow:Predicting with prob_thresh = [0.4], min_distance = 1
INFO:spotiflow.model.spotiflow:Peak detection mode: fast
INFO:spotiflow.model.spotiflow:Image shape (128, 256, 256)
INFO:spotiflow.model.spotiflow:Predicting with (1, 1, 1) tiles
INFO:spotiflow.model.spotiflow:Normalizing...
INFO:spotiflow.model.spotiflow:Padding to shape (128, 256, 256, 1)
INFO:spotiflow.model.spotiflow:Correcting internal min_distance to 0.5 due to grid: (2, 2, 2).
INFO:spotiflow.model.spotiflow:Found 84 spots
points[0]  # z, y, x
array([114.00123863, 205.02350116, 243.045506  ])

Display and Save the Results#

Points#

As we did for the 2D case, we can visualize the 3D spots on top of the stack with napari.

Note that if you did close the viewer, you need to re-create it and re-add the image together with the points.

The points coordinates are in (z, y, x) format, already in the correct order for napari, so we can directly pass them to add_points() without needing to reorder the columns as we did for the 2D case.

Remember to also pass the scale to add_points() so that the points are shown in the correct physical aspect ratio on top of the stack.

# optional if you did close the viewer we previously created
# viewer = napari.Viewer(ndisplay=3)
# viewer.add_image(image_3d, name="Image_3D", scale=(0.2, 0.1, 0.1))

viewer.add_points(
    points,
    name="Points",
    face_color="green",
    border_color="green",
    size=4,
    symbol="x",
    scale=(0.2, 0.1, 0.1),
)
napari

We can use the save_points_as_csv helper functions we defined in the Setup section to save the detected spots.

save_points_as_csv(points, "path/to/your/folder/points_3d.csv")

Details#

We can also visualize the heatmap of the probabilities per pixel and the distribution of the intensities of the detected spots as we did for 2D images using napari and matplotlib.pyplot.hist() respectively.

Note that the heatmap is half the pixel resolution of the original image, so when visualizing it with napari we need to double the scale we used for the original image to have it correctly overlaid on top of the original image.

# optional if you did close the viewer we previously created
# viewer = napari.Viewer(ndisplay=3)
# viewer.add_image(image_3d, name="Image_3D", scale=(0.2, 0.1, 0.1))

viewer.add_image(
    details.heatmap,
    name="Heatmap_3D",
    colormap="magma",
    scale=(0.2 * 2, 0.1 * 2, 0.1 * 2),
)
napari
plt.hist(details.intens, bins=15)
plt.xlabel("Intensity")
plt.ylabel("Frequency")
plt.show()
../../_images/431860c8dbe5bd4328389806092e29db652e538ec373049786517f7bf550f45a.png