Quickstart: scib-rapids metrics#

This tutorial demonstrates how to use scib-rapids to compute single-cell integration benchmarking metrics. scib-rapids provides the same metrics as scib-metrics, but uses CuPy/RAPIDS for GPU acceleration instead of JAX.

The API is identical — if you’re familiar with scib-metrics, you can use scib-rapids as a drop-in replacement.

import numpy as np
import scib_rapids
from scib_rapids.nearest_neighbors import NeighborsResults, pynndescent

Generate example data#

We create a simple toy dataset with 1000 cells, 50 features, 5 cell types, and 3 batches. In practice, X would be a PCA embedding from an AnnData object.

rng = np.random.default_rng(42)
n_cells, n_features, n_labels, n_batches = 1000, 50, 5, 3

X = rng.normal(size=(n_cells, n_features)).astype(np.float32)
labels = rng.integers(0, n_labels, size=n_cells)
batch = rng.integers(0, n_batches, size=n_cells)

print(f"X: {X.shape}, labels: {np.unique(labels)}, batches: {np.unique(batch)}")

Compute nearest neighbors#

Several metrics (LISI, kBET, graph connectivity, Leiden clustering) require a precomputed k-nearest neighbor graph. scib-rapids provides a pynndescent wrapper that returns a NeighborsResults dataclass.

nn = pynndescent(X, n_neighbors=30)
print(f"NeighborsResults: indices {nn.indices.shape}, distances {nn.distances.shape}")

Bio-conservation metrics#

These metrics measure how well biological signal (cell type identity) is preserved after integration.

Silhouette scores#

# Average Silhouette Width — measures cell type separation
# Returns a score in [0, 1] (rescaled) where higher = better separation
asw = scib_rapids.silhouette_label(X, labels)
print(f"Silhouette label (ASW): {asw:.4f}")
# Isolated labels — score for cell types present in few batches
iso = scib_rapids.isolated_labels(X, labels, batch)
print(f"Isolated labels: {iso:.4f}")

Clustering-based metrics (NMI / ARI)#

# K-means clustering agreement with true labels
kmeans_result = scib_rapids.nmi_ari_cluster_labels_kmeans(X, labels)
print(f"KMeans NMI: {kmeans_result['nmi']:.4f}")
print(f"KMeans ARI: {kmeans_result['ari']:.4f}")
# Leiden clustering agreement (uses the kNN graph)
leiden_result = scib_rapids.nmi_ari_cluster_labels_leiden(nn, labels)
print(f"Leiden NMI: {leiden_result['nmi']:.4f}")
print(f"Leiden ARI: {leiden_result['ari']:.4f}")

cLISI — cell-type Local Inverse Simpson Index#

# cLISI — how well cell types are preserved in neighborhoods
# Higher = better cell type preservation (scaled to [0, 1])
clisi = scib_rapids.clisi_knn(nn, labels)
print(f"cLISI: {clisi:.4f}")

Batch-correction metrics#

These metrics measure how well batch effects have been removed.

Silhouette batch & BRAS#

# Silhouette batch — measures batch mixing within cell types
sil_batch = scib_rapids.silhouette_batch(X, labels, batch)
print(f"Silhouette batch: {sil_batch:.4f}")

# BRAS — Batch Removal Adapted Silhouette (uses cosine distance, mean_other)
bras_score = scib_rapids.bras(X, labels, batch)
print(f"BRAS: {bras_score:.4f}")

iLISI — integration LISI#

# iLISI — how well batches are mixed in neighborhoods
# Higher = better batch integration (scaled to [0, 1])
ilisi = scib_rapids.ilisi_knn(nn, batch)
print(f"iLISI: {ilisi:.4f}")

kBET#

# kBET — batch effect test via chi-square statistics
# acceptance_rate close to 1 means batches are well mixed
acceptance_rate, test_stats, p_values = scib_rapids.kbet(nn, batch)
print(f"kBET acceptance rate: {acceptance_rate:.4f}")

# Per-label kBET (uses diffusion distances within each cell type)
kbet_score = scib_rapids.kbet_per_label(nn, batch, labels)
print(f"kBET per label: {kbet_score:.4f}")

Graph connectivity#

# Graph connectivity — fraction of cells in the largest connected component
# per cell type subgraph. Higher = better.
gc = scib_rapids.graph_connectivity(nn, labels)
print(f"Graph connectivity: {gc:.4f}")

PCR comparison#

# PCR comparison — reduction in batch variance explained by PCA
# Compares pre- vs post-integration embeddings
X_post = rng.normal(size=X.shape).astype(np.float32)
pcr = scib_rapids.pcr_comparison(X, X_post, batch, categorical=True)
print(f"PCR comparison: {pcr:.4f}")

Summary table#

Collect all metrics into a single table.

import pandas as pd

results = {
    "Metric": [
        "Silhouette label", "Isolated labels",
        "KMeans NMI", "KMeans ARI",
        "Leiden NMI", "Leiden ARI",
        "cLISI",
        "Silhouette batch", "BRAS",
        "iLISI", "kBET", "kBET per label",
        "Graph connectivity", "PCR comparison",
    ],
    "Type": [
        "Bio conservation", "Bio conservation",
        "Bio conservation", "Bio conservation",
        "Bio conservation", "Bio conservation",
        "Bio conservation",
        "Batch correction", "Batch correction",
        "Batch correction", "Batch correction", "Batch correction",
        "Batch correction", "Batch correction",
    ],
    "Score": [
        asw, iso,
        kmeans_result["nmi"], kmeans_result["ari"],
        leiden_result["nmi"], leiden_result["ari"],
        clisi,
        sil_batch, bras_score,
        ilisi, acceptance_rate, kbet_score,
        gc, pcr,
    ],
}

df = pd.DataFrame(results).set_index("Metric")
df["Score"] = df["Score"].map("{:.4f}".format)
df