Quickstart: scib-rapids metrics#
This tutorial demonstrates how to use scib-rapids to compute single-cell integration benchmarking metrics. scib-rapids provides the same metrics as scib-metrics, but uses CuPy/RAPIDS for GPU acceleration instead of JAX.
The API is identical — if you’re familiar with scib-metrics, you can use scib-rapids as a drop-in replacement.
import numpy as np
import scib_rapids
from scib_rapids.nearest_neighbors import NeighborsResults, pynndescent
Generate example data#
We create a simple toy dataset with 1000 cells, 50 features, 5 cell types, and 3 batches. In practice, X would be a PCA embedding from an AnnData object.
rng = np.random.default_rng(42)
n_cells, n_features, n_labels, n_batches = 1000, 50, 5, 3
X = rng.normal(size=(n_cells, n_features)).astype(np.float32)
labels = rng.integers(0, n_labels, size=n_cells)
batch = rng.integers(0, n_batches, size=n_cells)
print(f"X: {X.shape}, labels: {np.unique(labels)}, batches: {np.unique(batch)}")
Compute nearest neighbors#
Several metrics (LISI, kBET, graph connectivity, Leiden clustering) require a precomputed k-nearest neighbor graph. scib-rapids provides a pynndescent wrapper that returns a NeighborsResults dataclass.
nn = pynndescent(X, n_neighbors=30)
print(f"NeighborsResults: indices {nn.indices.shape}, distances {nn.distances.shape}")
Bio-conservation metrics#
These metrics measure how well biological signal (cell type identity) is preserved after integration.
Silhouette scores#
# Average Silhouette Width — measures cell type separation
# Returns a score in [0, 1] (rescaled) where higher = better separation
asw = scib_rapids.silhouette_label(X, labels)
print(f"Silhouette label (ASW): {asw:.4f}")
# Isolated labels — score for cell types present in few batches
iso = scib_rapids.isolated_labels(X, labels, batch)
print(f"Isolated labels: {iso:.4f}")
Clustering-based metrics (NMI / ARI)#
# K-means clustering agreement with true labels
kmeans_result = scib_rapids.nmi_ari_cluster_labels_kmeans(X, labels)
print(f"KMeans NMI: {kmeans_result['nmi']:.4f}")
print(f"KMeans ARI: {kmeans_result['ari']:.4f}")
# Leiden clustering agreement (uses the kNN graph)
leiden_result = scib_rapids.nmi_ari_cluster_labels_leiden(nn, labels)
print(f"Leiden NMI: {leiden_result['nmi']:.4f}")
print(f"Leiden ARI: {leiden_result['ari']:.4f}")
cLISI — cell-type Local Inverse Simpson Index#
# cLISI — how well cell types are preserved in neighborhoods
# Higher = better cell type preservation (scaled to [0, 1])
clisi = scib_rapids.clisi_knn(nn, labels)
print(f"cLISI: {clisi:.4f}")
Batch-correction metrics#
These metrics measure how well batch effects have been removed.
Silhouette batch & BRAS#
# Silhouette batch — measures batch mixing within cell types
sil_batch = scib_rapids.silhouette_batch(X, labels, batch)
print(f"Silhouette batch: {sil_batch:.4f}")
# BRAS — Batch Removal Adapted Silhouette (uses cosine distance, mean_other)
bras_score = scib_rapids.bras(X, labels, batch)
print(f"BRAS: {bras_score:.4f}")
iLISI — integration LISI#
# iLISI — how well batches are mixed in neighborhoods
# Higher = better batch integration (scaled to [0, 1])
ilisi = scib_rapids.ilisi_knn(nn, batch)
print(f"iLISI: {ilisi:.4f}")
kBET#
# kBET — batch effect test via chi-square statistics
# acceptance_rate close to 1 means batches are well mixed
acceptance_rate, test_stats, p_values = scib_rapids.kbet(nn, batch)
print(f"kBET acceptance rate: {acceptance_rate:.4f}")
# Per-label kBET (uses diffusion distances within each cell type)
kbet_score = scib_rapids.kbet_per_label(nn, batch, labels)
print(f"kBET per label: {kbet_score:.4f}")
Graph connectivity#
# Graph connectivity — fraction of cells in the largest connected component
# per cell type subgraph. Higher = better.
gc = scib_rapids.graph_connectivity(nn, labels)
print(f"Graph connectivity: {gc:.4f}")
PCR comparison#
# PCR comparison — reduction in batch variance explained by PCA
# Compares pre- vs post-integration embeddings
X_post = rng.normal(size=X.shape).astype(np.float32)
pcr = scib_rapids.pcr_comparison(X, X_post, batch, categorical=True)
print(f"PCR comparison: {pcr:.4f}")
Summary table#
Collect all metrics into a single table.
import pandas as pd
results = {
"Metric": [
"Silhouette label", "Isolated labels",
"KMeans NMI", "KMeans ARI",
"Leiden NMI", "Leiden ARI",
"cLISI",
"Silhouette batch", "BRAS",
"iLISI", "kBET", "kBET per label",
"Graph connectivity", "PCR comparison",
],
"Type": [
"Bio conservation", "Bio conservation",
"Bio conservation", "Bio conservation",
"Bio conservation", "Bio conservation",
"Bio conservation",
"Batch correction", "Batch correction",
"Batch correction", "Batch correction", "Batch correction",
"Batch correction", "Batch correction",
],
"Score": [
asw, iso,
kmeans_result["nmi"], kmeans_result["ari"],
leiden_result["nmi"], leiden_result["ari"],
clisi,
sil_batch, bras_score,
ilisi, acceptance_rate, kbet_score,
gc, pcr,
],
}
df = pd.DataFrame(results).set_index("Metric")
df["Score"] = df["Score"].map("{:.4f}".format)
df