Benchmarking lung integration#
Here we walkthrough applying integration benchmarking metrics on the lung atlas example from the scIB paper. This mirrors the scib-metrics tutorial, but uses scib-rapids for GPU-accelerated computation.
import time
import numpy as np
import pandas as pd
import scanpy as sc
import scib_rapids
from scib_rapids.nearest_neighbors import NeighborsResults, pynndescent
Load and preprocess data#
adata = sc.read(
"data/lung_atlas.h5ad",
backup_url="https://figshare.com/ndownloader/files/24539942",
)
adata
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="cell_ranger", batch_key="batch")
sc.tl.pca(adata, n_comps=30, use_highly_variable=True)
We subset to the highly variable genes so that each method has the same input.
adata = adata[:, adata.var.highly_variable].copy()
adata.obsm["Unintegrated"] = adata.obsm["X_pca"]
Run integration methods#
Here we run a few embedding-based methods. By focusing on embedding-based methods, we can substantially reduce the runtime of the benchmarking metrics.
Harmony#
from harmony import harmonize
adata.obsm["Harmony"] = harmonize(adata.obsm["X_pca"], adata.obs, batch_key="batch")
scVI#
import scvi
scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key="batch")
vae = scvi.model.SCVI(adata, gene_likelihood="nb", n_layers=2, n_latent=30)
vae.train()
adata.obsm["scVI"] = vae.get_latent_representation()
scANVI#
lvae = scvi.model.SCANVI.from_scvi_model(
vae,
adata=adata,
labels_key="cell_type",
unlabeled_category="Unknown",
)
lvae.train(max_epochs=20, n_samples_per_label=100)
adata.obsm["scANVI"] = lvae.get_latent_representation()
Compute all metrics#
We define a helper function that computes all metrics for a given embedding and returns a results dictionary.
def compute_all_metrics(
adata,
embedding_key,
batch_key="batch",
label_key="cell_type",
pre_integrated_key="X_pca",
n_neighbors=90,
):
"""Compute all scib-rapids metrics for a single embedding."""
X = np.asarray(adata.obsm[embedding_key], dtype=np.float32)
labels = np.asarray(pd.Categorical(adata.obs[label_key]).codes)
batch = np.asarray(pd.Categorical(adata.obs[batch_key]).codes)
# Compute nearest neighbors
nn = pynndescent(X, n_neighbors=n_neighbors)
results = {}
# Bio conservation
results["Isolated labels"] = scib_rapids.isolated_labels(X, labels, batch)
kmeans = scib_rapids.nmi_ari_cluster_labels_kmeans(X, labels)
results["KMeans NMI"] = kmeans["nmi"]
results["KMeans ARI"] = kmeans["ari"]
results["Silhouette label"] = scib_rapids.silhouette_label(X, labels)
results["cLISI"] = scib_rapids.clisi_knn(nn, labels)
# Batch correction
results["Silhouette batch"] = scib_rapids.silhouette_batch(X, labels, batch)
results["iLISI"] = scib_rapids.ilisi_knn(nn, batch)
results["KBET"] = scib_rapids.kbet_per_label(nn, batch, labels)
results["Graph connectivity"] = scib_rapids.graph_connectivity(nn, labels)
# PCR comparison (requires pre-integrated embedding)
if pre_integrated_key is not None:
X_pre = np.asarray(adata.obsm[pre_integrated_key], dtype=np.float32)
results["PCR comparison"] = scib_rapids.pcr_comparison(X_pre, X, batch, categorical=True)
else:
results["PCR comparison"] = 0.0
return results
Perform the benchmark#
Run all metrics on each embedding.
embedding_keys = ["Unintegrated", "Harmony", "scVI", "scANVI"]
all_results = {}
for key in embedding_keys:
print(f"Computing metrics for {key}...")
t0 = time.time()
all_results[key] = compute_all_metrics(adata, key)
elapsed = time.time() - t0
print(f" Done in {elapsed:.1f}s")
Visualize the results#
df = pd.DataFrame(all_results).T
df.index.name = "Embedding"
# Add metric type annotations
bio_metrics = ["Isolated labels", "KMeans NMI", "KMeans ARI", "Silhouette label", "cLISI"]
batch_metrics = ["Silhouette batch", "iLISI", "KBET", "Graph connectivity", "PCR comparison"]
# Compute aggregate scores
df["Bio conservation"] = df[bio_metrics].mean(axis=1)
df["Batch correction"] = df[batch_metrics].mean(axis=1)
df["Total"] = 0.6 * df["Bio conservation"] + 0.4 * df["Batch correction"]
df.round(4)
df[["Bio conservation", "Batch correction", "Total"]].sort_values("Total", ascending=False)