Benchmarking lung integration#

Here we walkthrough applying integration benchmarking metrics on the lung atlas example from the scIB paper. This mirrors the scib-metrics tutorial, but uses scib-rapids for GPU-accelerated computation.

import time

import numpy as np
import pandas as pd
import scanpy as sc

import scib_rapids
from scib_rapids.nearest_neighbors import NeighborsResults, pynndescent

Load and preprocess data#

adata = sc.read(
    "data/lung_atlas.h5ad",
    backup_url="https://figshare.com/ndownloader/files/24539942",
)
adata
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="cell_ranger", batch_key="batch")
sc.tl.pca(adata, n_comps=30, use_highly_variable=True)

We subset to the highly variable genes so that each method has the same input.

adata = adata[:, adata.var.highly_variable].copy()
adata.obsm["Unintegrated"] = adata.obsm["X_pca"]

Run integration methods#

Here we run a few embedding-based methods. By focusing on embedding-based methods, we can substantially reduce the runtime of the benchmarking metrics.

Harmony#

from harmony import harmonize

adata.obsm["Harmony"] = harmonize(adata.obsm["X_pca"], adata.obs, batch_key="batch")

scVI#

import scvi

scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key="batch")
vae = scvi.model.SCVI(adata, gene_likelihood="nb", n_layers=2, n_latent=30)
vae.train()
adata.obsm["scVI"] = vae.get_latent_representation()

scANVI#

lvae = scvi.model.SCANVI.from_scvi_model(
    vae,
    adata=adata,
    labels_key="cell_type",
    unlabeled_category="Unknown",
)
lvae.train(max_epochs=20, n_samples_per_label=100)
adata.obsm["scANVI"] = lvae.get_latent_representation()

Compute all metrics#

We define a helper function that computes all metrics for a given embedding and returns a results dictionary.

def compute_all_metrics(
    adata,
    embedding_key,
    batch_key="batch",
    label_key="cell_type",
    pre_integrated_key="X_pca",
    n_neighbors=90,
):
    """Compute all scib-rapids metrics for a single embedding."""
    X = np.asarray(adata.obsm[embedding_key], dtype=np.float32)
    labels = np.asarray(pd.Categorical(adata.obs[label_key]).codes)
    batch = np.asarray(pd.Categorical(adata.obs[batch_key]).codes)

    # Compute nearest neighbors
    nn = pynndescent(X, n_neighbors=n_neighbors)

    results = {}

    # Bio conservation
    results["Isolated labels"] = scib_rapids.isolated_labels(X, labels, batch)
    kmeans = scib_rapids.nmi_ari_cluster_labels_kmeans(X, labels)
    results["KMeans NMI"] = kmeans["nmi"]
    results["KMeans ARI"] = kmeans["ari"]
    results["Silhouette label"] = scib_rapids.silhouette_label(X, labels)
    results["cLISI"] = scib_rapids.clisi_knn(nn, labels)

    # Batch correction
    results["Silhouette batch"] = scib_rapids.silhouette_batch(X, labels, batch)
    results["iLISI"] = scib_rapids.ilisi_knn(nn, batch)
    results["KBET"] = scib_rapids.kbet_per_label(nn, batch, labels)
    results["Graph connectivity"] = scib_rapids.graph_connectivity(nn, labels)

    # PCR comparison (requires pre-integrated embedding)
    if pre_integrated_key is not None:
        X_pre = np.asarray(adata.obsm[pre_integrated_key], dtype=np.float32)
        results["PCR comparison"] = scib_rapids.pcr_comparison(X_pre, X, batch, categorical=True)
    else:
        results["PCR comparison"] = 0.0

    return results

Perform the benchmark#

Run all metrics on each embedding.

embedding_keys = ["Unintegrated", "Harmony", "scVI", "scANVI"]

all_results = {}
for key in embedding_keys:
    print(f"Computing metrics for {key}...")
    t0 = time.time()
    all_results[key] = compute_all_metrics(adata, key)
    elapsed = time.time() - t0
    print(f"  Done in {elapsed:.1f}s")

Visualize the results#

df = pd.DataFrame(all_results).T
df.index.name = "Embedding"

# Add metric type annotations
bio_metrics = ["Isolated labels", "KMeans NMI", "KMeans ARI", "Silhouette label", "cLISI"]
batch_metrics = ["Silhouette batch", "iLISI", "KBET", "Graph connectivity", "PCR comparison"]

# Compute aggregate scores
df["Bio conservation"] = df[bio_metrics].mean(axis=1)
df["Batch correction"] = df[batch_metrics].mean(axis=1)
df["Total"] = 0.6 * df["Bio conservation"] + 0.4 * df["Batch correction"]

df.round(4)
df[["Bio conservation", "Batch correction", "Total"]].sort_values("Total", ascending=False)