ShapBPT on ImageNet dataset and using resnet50 model

This notebook provides a minimal, self-contained demonstration of the shap_bpt library for pixel-wise explanation of image classifiers. We load a pretrained model (like ResNet-50), define a masking-based black-box prediction function, specify simple background (baseline) images as replacement values, and use shap_bpt to compute Owen values with the BPT method. Each image has its own masking and replacement-background pipeline, but all share the same black-box prediction function, keeping the logic modular and simple. The goal is to clearly illustrate the essential workflow for generating and visualizing explanations, without evaluations.

Imports & library versions

[1]:
# Core scientific / plotting libraries
import cv2
import json
import numpy as np
import matplotlib.pyplot as plt

# PyTorch + torchvision, image processing
import torch
from torchvision import transforms
from skimage.filters import gaussian

# ShapBPT
import shap_bpt as shap_bpt
print('shap_bpt version:',shap_bpt.__version__)
shap_bpt version: 1.0

Device configuration (CPU / GPU / MPS)

[2]:
use_cuda = torch.cuda.is_available()
use_mps = ('mps' in dir(torch.backends)) and torch.backends.mps.is_available()
torch.manual_seed(12345)

if use_cuda:    device = torch.device("cuda")
elif use_mps:   device = torch.device("mps")
else:           device = torch.device("cpu")
print('Using device:', device)
Using device: mps

Utilities

[3]:
def np_softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

Black-box model and preprocessing

[4]:
from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)

# from torchvision.models import vit_b_16
# model = vit_b_16(weights='IMAGENET1K_V1').to(device)

# # # SWIN-ViT from timm
# from torchvision.models import swin_t
# model = swin_t(weights='IMAGENET1K_V1').to(device)

# import timm
# model = timm.create_model('vit_base_patch16_224', pretrained=True).to(device)

# from torchvision.models import vgg16
# model = vgg16(weights='IMAGENET1K_V1').to(device)


model.eval() # we are only doing inference
# Preprocessing
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
preprocess = transforms.Compose( # HWC [0,1] -> CHW normalized tensor
    [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]
)
normalize = transforms.Normalize(mean=mean, std=std)

# Black-box prediction function: takes a batch of tensors and returns logits (NumPy)
def bb_predict(x):
    with torch.no_grad():
        # return np_softmax(model(x), dim=1).cpu().detach().numpy()
        return model(x).cpu().detach().numpy()

Load ImageNet class names

[5]:
# Load ImageNet class names (id -> human-readable label)
with open('imagenet_class_index.json') as file:
    class_names = [v[1] for v in json.load(file).values()]

Helper to prepare one image (image-specific backgrounds & masking)

[6]:
def prepare_image(image_path: str, target_size=(224, 224), seed: int = 0):
    """
    Loads an image, builds image-specific replacement backgrounds and a
    masking-based black-box wrapper nu_masked for shap_bpt.

    Returns a dict with all objects needed for explanation.
    """
    # --- Load and resize the image (BGR -> RGB) ---
    image_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
    if image_bgr is None:
        raise FileNotFoundError(f"Could not read image: {image_path}")
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    image_rgb = cv2.resize(image_rgb, target_size)
    image_np = image_rgb.astype(np.uint8)  # HWC, uint8, [0..255]

    H, W, _ = image_np.shape

    # --- Preprocess for ResNet-50 ---
    image_tensor = preprocess(
        image_np.astype(np.float32) / 255.0
    ).unsqueeze(0).to(device)  # shape (1,3,H,W)

    # --- Image-specific replacement backgrounds ---
    # Black, gray, white
    bkgnd0 = np.full_like(image_np, 0)
    bkgnd1 = np.full_like(image_np, 127)
    bkgnd2 = np.full_like(image_np, 255)

    # Strongly blurred original
    bkgnd3 = (gaussian(image_np, sigma=8, channel_axis=-1) * 255).astype(np.uint8)

    # Smoothed random noise (same size as the image)
    rng = np.random.RandomState(seed)
    noise = rng.normal(loc=128, scale=128, size=image_np.shape)
    noise = np.clip(noise, 0, 255).astype(np.uint8)
    bkgnd4 = (gaussian(noise, sigma=2.0, channel_axis=-1) * 255).astype(np.uint8)

    # Stack backgrounds into one array
    background_image_set = np.stack([bkgnd0, bkgnd1, bkgnd2, bkgnd3, bkgnd4], axis=0)

    # Preprocess all backgrounds with the same transform as the main image
    background_tensors = torch.cat([preprocess(bkgnd.astype(np.float32) / 255.0).unsqueeze(0)
                                    for bkgnd in background_image_set
                                   ], dim=0).to(device)  # shape (B,3,H,W)

    # --- Image-specific masking-based prediction function ---
    def nu_masked(masks: np.ndarray) -> np.ndarray:
        """
        masks: Boolean NumPy array of shape (N, H, W)
               True  = keep original pixel
               False = replace with background pixel

        Returns: (N, num_classes) NumPy array of averaged logits.
        """
        N, Hm, Wm = masks.shape
        assert (Hm, Wm) == (H, W), "Mask size must match image size."

        B = background_tensors.shape[0]

        # Convert masks to a tensor and broadcast to (B*N, 3, H, W)
        masks_t = torch.from_numpy(masks).bool().to(device)  # (N,H,W)
        masks_t = masks_t.view(N, 1, H, W).repeat(1, 3, 1, 1)  # (N,3,H,W)
        masks_t = masks_t.unsqueeze(2).repeat(1, 1, B, 1, 1)   # (N,3,B,H,W)
        masks_t = masks_t.view(B * N, 3, H, W)                 # (B*N,3,H,W)

        # Tile original image and backgrounds to match masks
        Xf = image_tensor.repeat(B * N, 1, 1, 1)     # (B*N,3,H,W)
        Xb = background_tensors.repeat(N, 1, 1, 1)   # (B*N,3,H,W)

        # Apply masks: True -> keep original pixel, False -> background pixel
        X = torch.where(masks_t, Xf, Xb)

        # Use the global, image-agnostic black-box function
        logits = bb_predict(X)  # (B*N, num_classes)

        # Average over backgrounds
        logits = logits.reshape(N, B, -1)  # (N,B,num_classes)
        return logits.mean(axis=1)         # (N,num_classes)

    # --- Predict the class for the full (unmasked) image ---
    logits_full = bb_predict(image_tensor)[0]     # (num_classes,)
    probs_full = np_softmax(logits_full)
    pred_class = int(np.argmax(probs_full))

    print(
        f"{image_path}: predicted as '{class_names[pred_class]}' "
        f"with probability {probs_full[pred_class]:.4f}"
    )

    return {
        "image_path": image_path,
        "image_np": image_np,
        "image_tensor": image_tensor,
        "background_image_set": background_image_set,
        "background_tensors": background_tensors,
        "nu_masked": nu_masked,
        "logits_full": logits_full,
        "probs_full": probs_full,
        "predicted_class": pred_class,
    }

Prepare all images

[7]:
# List of image filenames we want to explain
image_paths = [
    "flamingo4.png",
    "bird4.png",
    "acoustic_guitar.png",
]

# Prepare each image (image-specific backgrounds and masking function)
image_data_list = [prepare_image(path) for path in image_paths]
flamingo4.png: predicted as 'flamingo' with probability 0.5045
bird4.png: predicted as 'indigo_bunting' with probability 0.4445
acoustic_guitar.png: predicted as 'acoustic_guitar' with probability 0.3564

Display all input images together

[8]:
# Show each image (left column) together with its replacement backgrounds (right columns)
num_images = len(image_data_list)
num_bkgnds = 5  # we defined exactly 5 backgrounds earlier

fig, axes = plt.subplots(
    num_images,
    1 + num_bkgnds,
    figsize=(3 * (1 + num_bkgnds), 3 * num_images)
)

# Ensure axes is 2D even when num_images == 1
if num_images == 1:
    axes = np.expand_dims(axes, axis=0)

for row_idx, data in enumerate(image_data_list):
    # Original image
    img = data["image_np"]
    cls_id = data["predicted_class"]
    axes[row_idx, 0].imshow(img)
    axes[row_idx, 0].set_title(
        f"{data['image_path']}\nPredicted class: {class_names[cls_id]}"
    )
    axes[row_idx, 0].set_xticks([])
    axes[row_idx, 0].set_yticks([])

    # Replacement backgrounds
    bkgnd_imgs = data["background_image_set"]  # shape (5, H, W, 3)

    for col_idx in range(num_bkgnds):
        axes[row_idx, col_idx + 1].imshow(bkgnd_imgs[col_idx].astype(np.uint8))
        axes[row_idx, col_idx + 1].set_title(f"Replacement {col_idx}")
        axes[row_idx, col_idx + 1].set_xticks([])
        axes[row_idx, col_idx + 1].set_yticks([])

plt.tight_layout()
plt.show()
../_images/notebooks_ImageNET_BPT_17_0.png

Configure explanation parameters

[9]:
# How many black-box evaluations we are willing to pay for the explanation
MAX_EVALS_BUDGET = 100

# Batch size for evaluations within shap_bpt
batch_size = 4

# Number of top classes to explain (here: explain top-4 classes)
num_explained_classes = 4

Run shap_bpt with method = “BPT” (Binary Partition Tree)

[10]:
explainers = []
shap_values_list = []

for data in image_data_list:
    print(f"\n=== Explaining {data['image_path']} ===")

    # Create a shap_bpt explainer for this specific image
    explainer = shap_bpt.Explainer(
        data["nu_masked"],
        data["image_np"],
        num_explained_classes=num_explained_classes,
        verbose=True,
    )

    # Compute Owen values using the BPT method
    shap_values = explainer.explain_instance(
        MAX_EVALS_BUDGET,
        method="BPT",
        batch_size=batch_size,
    )

    explainers.append(explainer)
    shap_values_list.append(shap_values)

    # Sanity check: Owen values sum to nu(S) − nu(∅)
    print('Expected Shapley explanation: ', explainer.base_f_S[0] - explainer.base_f_0[0])
    print('Computed Shapley explanation: ', np.sum(shap_values[0]))


=== Explaining flamingo4.png ===
Expected Shapley explanation:  6.482598960399628
Computed Shapley explanation:  6.482598960399629

=== Explaining bird4.png ===
Expected Shapley explanation:  6.265199840068817
Computed Shapley explanation:  6.265199840068817

=== Explaining acoustic_guitar.png ===
Expected Shapley explanation:  5.936330497264862
Computed Shapley explanation:  5.936330497264861

Visualize the Owen values

[11]:
# Plot the Owen value explanations for each image in turn.
# Each call produces its own figure; they will appear one after another.

for data, explainer, shap_values in zip(image_data_list, explainers, shap_values_list):
    print(f"\nPlotting explanations for {data['image_path']}")
    shap_bpt.plot_owen_values(explainer, shap_values, class_names)


Plotting explanations for flamingo4.png
../_images/notebooks_ImageNET_BPT_23_1.png

Plotting explanations for bird4.png
../_images/notebooks_ImageNET_BPT_23_3.png

Plotting explanations for acoustic_guitar.png
../_images/notebooks_ImageNET_BPT_23_5.png
[ ]: