import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from queue import PriorityQueue
from tqdm.auto import tqdm
import numpy as np
import math, os
######################################################################################################
# Shapley image explanations with data-dependent Binary Partition Trees
######################################################################################################
# Cython implementation of the BPT algorithm
from . import bpt as bpt
######################################################################################################
def mask2image(matrix, color):
h, w = matrix.shape
rgb_image = np.zeros((h, w, len(color)))
# Expand dimensions of color to match the shape of the matrix
color_expanded = np.expand_dims(np.array(color), axis=(0, 1))
# Multiply matrix with color along the third dimension
colored_matrix = matrix[..., np.newaxis] * color_expanded
# Clip values to ensure they are within [0, 1] range
colored_matrix = np.clip(colored_matrix, 0, 1)
return colored_matrix
def hex_to_rgb(value):
value = value.lstrip('#')
lv = len(value)
return tuple(int(value[i:i + lv // 3], 16)/255.0 for i in range(0, lv, lv // 3))
######################################################################################################
from matplotlib.colors import LinearSegmentedColormap
# Custom colormap for Shapley values - similar to 'seismic' but with lighter tones.
shapley_values_colormap = LinearSegmentedColormap.from_list("shapley_values_colormap",
[(0.0, '#0053d1'),
(0.2, '#248df4'),
(0.5, 'white'),
(0.8, '#f23754'),
(1.0, '#cb0021')])
######################################################################################################
class BaseSegment:
def __init__(self, parent=None):
self.parent = parent
def split(self):
raise Exception()
def fill_mask(self, mat, ascend_hier=True):
return
def add_inside_coalition(self, shap_values, contrib):
raise Exception()
def subtract_outside_coalition(self, shap_values, contrib):
raise Exception()
def area(self):
raise Exception()
def plot(self, ax, color=None):
raise Exception()
def contains(self, aa):
raise Exception()
def equals(self, aa):
raise Exception()
######################################################################################################
# A symmetric, disjoint, axis-aligned, hierarchical partition
######################################################################################################
class AxisAlignedSegment(BaseSegment):
def __init__(self, xmin, xmax, ymin, ymax, parent):
super().__init__(parent)
self.xmin = xmin
self.xmax = xmax
self.ymin = ymin
self.ymax = ymax
#@override
def split(self, lparent, rparent):
size_x = self.xmax - self.xmin
size_y = self.ymax - self.ymin
assert size_x>=1 or size_y>=1
lxmin = rxmin = self.xmin
lxmax = rxmax = self.xmax
lymin = rymin = self.ymin
lymax = rymax = self.ymax
if size_x > size_y and size_x > 1: # split over x
lxmax = rxmin = (self.xmin + size_x // 2)
else: # split over y
lymax = rymin = (self.ymin + size_y // 2)
lsg = AxisAlignedSegment(lxmin, lxmax, lymin, lymax, lparent)
rsg = AxisAlignedSegment(rxmin, rxmax, rymin, rymax, rparent)
# print(f'split {self.area()} -> {lsg.area()} + {rsg.area()} {self}')
return (lsg, rsg)
#@override
def fill_mask(self, mat, ascend_hier=True):
mat[self.ymin:self.ymax, self.xmin:self.xmax] = True
if ascend_hier:
self.parent.fill_mask(mat, ascend_hier)
#@override
def add_inside_coalition(self, shap_values, contrib):
contrib = contrib / self.area()
for c in range(len(contrib)):
shap_values[c, self.ymin:self.ymax, self.xmin:self.xmax] += contrib[c]
#@override
def subtract_outside_coalition(self, shap_values, contrib):
if shap_values[0].size==self.area():
return
contrib = contrib / (shap_values[0].size - self.area())
for c in range(len(contrib)):
shap_values[c, :self.ymin, :] -= contrib[c]
shap_values[c, self.ymax:, :] -= contrib[c]
shap_values[c, self.ymin:self.ymax, :self.xmin] -= contrib[c]
shap_values[c, self.ymin:self.ymax, self.xmax:] -= contrib[c]
#@override
def area(self):
return (self.xmax - self.xmin) * (self.ymax - self.ymin)
#@override
def plot(self, ax, color=(.3,.7,1.0)):
ax.add_patch(Rectangle((self.xmin, self.ymin),
self.xmax-self.xmin, self.ymax-self.ymin,
facecolor=color, fill=True, lw=None))
#@override
def contains(self, aa):
return (self.xmin <= aa.xmin and aa.xmax <= self.xmax and
self.ymin <= aa.ymin and aa.ymax <= self.ymax)
#@override
def equals(self, aa):
return (self.xmin == aa.xmin and aa.xmax == self.xmax and
self.ymin == aa.ymin and aa.ymax == self.ymax)
######################################################################################################
# Binary Partition Tree reader (using the code of the AGAT-Team)
######################################################################################################
[docs]
class BPT:
def __init__(self):
self.width = self.height = -1
self.N = self.U = 0
self.pixels = None
self.leaf_idx = None
self.cl_start = self.cl_end = None
self.cl_left = self.cl_right = None
[docs]
def load_from_file(self, bpt_fname):
with open(bpt_fname, 'r') as f:
self.width = int(f.readline())
self.height = int(f.readline())
self.U = int(f.readline())
self.N = int(f.readline())
self.pixels = np.array([int(n) for n in f.readline().split()])
self.leaf_idx = np.array([int(n) for n in f.readline().split()])
self.cl_start = np.array([int(n) for n in f.readline().split()])
self.cl_end = np.array([int(n) for n in f.readline().split()])
self.cl_left = np.array([int(n) for n in f.readline().split()])
self.cl_right = np.array([int(n) for n in f.readline().split()])
[docs]
def from_bpt_builder(self, bpt_builder):
enc = bpt_builder.encode()
(self.width, self.height, self.U, self.N,
self.pixels, self.leaf_idx,
self.cl_start, self.cl_end,
self.cl_left, self.cl_right) = enc
[docs]
def print_tree(self, index=None, lvl=0):
if index is None: index = self.N-1
print(' ' * lvl, end='')
print(f'index={index} ', end='')
if index < self.U: # leaf node
pass
# print(f' pixel {self.pixels[index]}')
else:
s = self.cl_start[ index - self.U ]
e = self.cl_end[ index - self.U ]
l, r = self.cl_left[ index - self.U ], self.cl_right[ index - self.U ]
al = 1 if l < self.U else self.cl_end[ l - self.U ] - self.cl_start[ l - self.U ]
ar = 1 if r < self.U else self.cl_end[ r - self.U ] - self.cl_start[ r - self.U ]
print(f' {e-s} -> {al} + {ar} left={l} right={r}')
self.print_tree(self.cl_left[ index - self.U ], lvl+1)
self.print_tree(self.cl_right[ index - self.U ], lvl+1)
######################################################################################################
def add_noise(img, sigma=1.0, alpha=0.5):
from scipy.ndimage import gaussian_filter
assert 0.0 <= alpha <= 1.0
rndgen = np.random.Generator(np.random.PCG64(1234))
img_noise = rndgen.standard_normal(size=img.shape)*64.0 + 128.0
img_noise = gaussian_filter(img_noise, sigma=1.0)
img = np.clip(img*alpha + img_noise*(1.0-alpha), 0.0, 255.0)
return img
######################################################################################################
def image_rgb2lab(rgb_image):
from skimage.color import rgb2lab
lab_image = rgb2lab(rgb_image)# / 255.0)
# The ranges of Lab values are: L (0:100), a (-128:127), b (-128:127)
lab_image_scaled = (lab_image + [0, 128, 128]) * (255.0/100.0, 255.0/256.0, 255.0/256.0)
return lab_image_scaled.astype(np.uint8)
######################################################################################################
# input image is expected to be of type uint8, with shape H*W*3 or H*W*1
[docs]
def build_bpt_from_image(image, use_lab=True, **kwargs):
if image.dtype!=np.uint8:
raise Exception('Image pixel type is expected to be uint8.')
if len(image.shape)==2:
image = image.reshape((image.shape[0], image.shape[1], 1))
if len(image.shape)!=3:
raise Exception('Image shape is expected to be 3-dimensional.')
if image.shape[2]!=3 and image.shape[2]!=1:
raise Exception('Image is expected to be RGB (H*W*3) or grayscale (H*W*1).')
if use_lab:
image = image_rgb2lab(image)
bpt_builder = bpt.BinaryPartitionTreeBuilder(image=image, **kwargs)
bpt_builder.compute()
bptree = BPT()
bptree.from_bpt_builder(bpt_builder)
del bpt_builder
return bptree
######################################################################################################
# A non-symmetric, disjoint, hierarchical partition of a Binary Partition Tree node
######################################################################################################
class BPT_Segment(BaseSegment):
def __init__(self, bpt, index, parent):
super().__init__(parent)
self.bpt = bpt
self.index = index
#@override
def split(self, lparent, rparent):
if self.area() == 1:
return None
ls = BPT_Segment(self.bpt, self.bpt.cl_left[ self.index - self.bpt.U ], lparent)
rs = BPT_Segment(self.bpt, self.bpt.cl_right[ self.index - self.bpt.U ], rparent)
return (ls, rs)
#@override
def fill_mask(self, mat, ascend_hier=True):
s,e = self.pixels_interval()
mat.ravel()[ self.bpt.pixels[s:e] ] = True
if ascend_hier:
self.parent.fill_mask(mat, ascend_hier)
#@override
def add_inside_coalition(self, shap_values, contrib):
contrib = contrib / self.area()
s,e = self.pixels_interval()
for c in range(len(contrib)):
shap_values[c].ravel()[ self.bpt.pixels[s:e] ] += contrib[c]
#@override
def subtract_outside_coalition(self, shap_values, contrib):
if shap_values[0].size==self.area():
return
contrib = contrib / (shap_values[0].size - self.area())
s,e = self.pixels_interval()
for c in range(len(contrib)):
shap_values[c].ravel()[ self.bpt.pixels[:s] ] -= contrib[c]
shap_values[c].ravel()[ self.bpt.pixels[e:] ] -= contrib[c]
#@override
def plot(self, ax, color=(.3,.7,1.0)):
img = np.zeros((self.bpt.width, self.bpt.height), dtype=np.int8)
self.fill_mask(img, ascend_hier=False)
ax.imshow(mask2image(img, color))
#@override
def area(self):
s,e = self.pixels_interval()
return float(e - s)
#@override
def pixels_interval(self):
if self.index < self.bpt.U: # leaf node
return (self.bpt.leaf_idx[self.index],
self.bpt.leaf_idx[self.index] + 1)
else:
return (self.bpt.cl_start[ self.index - self.bpt.U ],
self.bpt.cl_end[ self.index - self.bpt.U ])
#@override
def contains(self, other):
s1, e1 = self.pixels_interval()
s2, e2 = other.pixels_interval()
return s1 <= s2 and e2 <= e1
#@override
def equals(self, other):
s1, e1 = self.pixels_interval()
s2, e2 = other.pixels_interval()
return s1 == s2 and e2 == e1
######################################################################################################
# A partition of the features, refined by recursive splitting
# A coalition that is part of the global coalition structure
######################################################################################################
class Coalition:
def __init__(self, explainer, segment, f_SuAB, f_S, weight):
self.segment = segment # segment for recursive refinement
self.f_SuAB = f_SuAB # contribution with this coalition AB
self.f_S = f_S # contribution without this coalition AB
self.weight = weight # recursive weight of the Owen formula
# priority to be split for further partition refinements
self.priority = -np.max(np.abs(np.subtract(self.f_SuAB, self.f_S))) * self.weight
if explainer.balance_area:
self.priority *= self.segment.area()
def prepare_split(self, explainer):
# split the current coalition AB into two separate coalitions {A,B}
coS_A, coS_B = self.segment.split(self.segment.parent, self.segment.parent)
coSuB_A, coSuA_B = self.segment.split(coS_B, coS_A) # flip parents
assert self.segment.area() == coS_A.area() + coS_B.area()
# build the new masks
m_SuA, m_SuB = explainer.empty_mask(), explainer.empty_mask()
coS_A.fill_mask(m_SuA)
coS_B.fill_mask(m_SuB)
# [f_SuA, f_SuB] = predictions using masks [m_SuA, m_SuB]
def split_completer(f_SuA, f_SuB):
# generate the four recursive branches
phiSuA_S = Coalition(explainer, coS_A, f_SuA, self.f_S, self.weight/2.0)
phiSuB_S = Coalition(explainer, coS_B, f_SuB, self.f_S, self.weight/2.0)
phiSuAB_SuA = Coalition(explainer, coSuA_B, self.f_SuAB, f_SuA, self.weight/2.0)
phiSuAB_SuB = Coalition(explainer, coSuB_A, self.f_SuAB, f_SuB, self.weight/2.0)
splits = [phiSuA_S, phiSuB_S, phiSuAB_SuA, phiSuAB_SuB]
return splits
return (m_SuA, m_SuB, split_completer)
def plot(self, ax, explainer, color=(1, 0, 0, 1)):
m00 = explainer.empty_mask()
self.segment.parent.fill_mask(m00)
ax.imshow(mask2image(m00, color))
ax.axis('off')
self.segment.plot(ax)
def __lt__(self, other):
return self.priority < other.priority
def get_shapley(self, shap_values):
# compute the weighted marginals and add them to the partition
contrib = (np.subtract(self.f_SuAB, self.f_S) * self.weight)
self.segment.add_inside_coalition(shap_values, contrib)
######################################################################################################
# Explainer object. Implementation of the recursive refinement following Owen formula
######################################################################################################
[docs]
class Explainer:
def __init__(self, fm, image_to_explain, num_explained_classes, balance_area=False, verbose=False):
self.fm = fm # black box predictor with masker
self.image_to_explain = image_to_explain
self.num_explained_classes = num_explained_classes
self.balance_area = balance_area
# foreground prediction (no masking, original input)
ym = self.fm(np.array([np.ones((self.image_to_explain.shape[0],
self.image_to_explain.shape[1]), dtype=np.bool)]))[0]
self.output_indexes = np.flip(np.argsort(ym))[:self.num_explained_classes]
self.base_f_S = np.array([float(ym[i]) for i in self.output_indexes])
# background prediction (everything masked)
ym = self.fm(np.array([np.zeros((self.image_to_explain.shape[0],
self.image_to_explain.shape[1]), dtype=np.bool)]))[0]
self.base_f_0 = np.array([float(ym[i]) for i in self.output_indexes])
self.verbose = verbose
[docs]
def empty_mask(self, dtype=np.bool):
return np.zeros((self.image_to_explain.shape[0],
self.image_to_explain.shape[1]), dtype=dtype)
# get an explanation of the image_to_explain masked by @boolMask
[docs]
def predict_masked(self, masks):
rows = self.fm(np.array(masks))
f = [[float(ym[i]) for i in self.output_indexes] for ym in rows]
return np.array(f)
# get the Owen/Shapley coefficients
[docs]
def explain_instance(self, max_evals, method='BPT', bpt=None,
batch_size=64, verbose_plot=False, pbar=None,
min_area=1, max_weight=None):
assert min_area >= 1
shap_values = np.zeros((self.num_explained_classes,
self.image_to_explain.shape[0],
self.image_to_explain.shape[1]))
if method=='BPT':
if bpt is None:
bpt = build_bpt_from_image(self.image_to_explain)
# assert bpt is not None, 'Expected argoment bpt='
init_coalition = self.init_bpt(bpt)
elif method=='AA':
init_coalition = self.init_axisaligned()
else:
print('Unknown method', method) ; return None
if self.verbose:
pbar = pbar if pbar is not None else tqdm(total=max_evals, disable=False, leave=False)
q = PriorityQueue()
q.put(init_coalition)
eval_count, reached_terminals = 0, 0
while not q.empty():
if eval_count >= max_evals: # no more v(s) budget
while not q.empty():
coalition = q.get()
coalition.get_shapley(shap_values)
break
batch_masks = []
batch_completers = []
batch_owens = []
while not q.empty() and len(batch_masks) < batch_size and \
eval_count + len(batch_masks) < max_evals:
coalition = q.get()
if (coalition.segment.area() <= min_area or
(max_weight is not None and coalition.weight<=max_weight)):
reached_terminals += 1 # do not split further
coalition.get_shapley(shap_values)
else:
(m_SuA, m_SuB, split_completer) = coalition.prepare_split(self)
batch_masks.append(m_SuA)
batch_masks.append(m_SuB)
batch_completers.append(split_completer)
batch_owens.append(coalition)
if len(batch_masks) > 0:
f = self.predict_masked(batch_masks)
eval_count += len(batch_masks)
if self.verbose:
pbar.update(len(batch_masks))
for i in range(len(batch_completers)):
f_SuA, f_SuB = f[i*2], f[i*2 + 1]
splits = batch_completers[i](f_SuA, f_SuB)
for o in splits:
q.put(o)
if verbose_plot:
fig,axes = plt.subplots(1, 5, figsize=(5,1))
batch_owens[i].plot(axes[0], self, alpha=0.5)
for i, s in enumerate(splits):
s.plot(axes[i+1], self)
plt.show()
if self.verbose:
pbar.refresh()
if reached_terminals>0:
print(f'Reached {reached_terminals} terminals.')
return shap_values
[docs]
def init_axisaligned(self):
base = BaseSegment()
s0 = AxisAlignedSegment(0, self.image_to_explain.shape[0],
0, self.image_to_explain.shape[1], base)
return Coalition(self, s0, self.base_f_S, self.base_f_0, 1.0)
[docs]
def init_bpt(self, bpt):
base = BaseSegment()
s0 = BPT_Segment(bpt, bpt.N-1, base)
return Coalition(self, s0, self.base_f_S, self.base_f_0, 1.0)
######################################################################################################
[docs]
def plot_owen_values(explainer, shap_values, class_names, names=None):
shap_values = np.array(shap_values)
if len(shap_values.shape)==3: shap_values = np.array([shap_values])
max_val = np.nanpercentile(np.abs(shap_values.flatten()), 99.9)
num_explained_classes = len(explainer.base_f_S)
num_rows = len(shap_values)
fig,axes = plt.subplots(num_rows+1, num_explained_classes+1,
figsize=(2*(num_explained_classes+1), 2*(num_rows+0.3)),
squeeze=False,
height_ratios=[1]*num_rows + [0.3])
base_image = explainer.image_to_explain
if np.max(base_image)>1: base_image = base_image.astype(np.uint8)
if len(base_image.shape)==2:
base_image = np.stack([base_image, base_image, base_image], axis=-1)
img_grey = (0.2989 * base_image[:, :, 0] +
0.5870 * base_image[:, :, 1] +
0.1140 * base_image[:, :, 2])
# axes[0].set_title(f'real: {class_names[expected_class]}')
for r in range(num_rows):
axes[r,0].imshow(base_image)
for i in range(num_explained_classes):
axes[r,i+1].imshow(img_grey.astype(base_image.dtype), alpha=0.50, cmap='gray')
im=axes[r,i+1].imshow(shap_values[r,i], cmap=shapley_values_colormap, vmin = -max_val, vmax = max_val, alpha=0.80)
if r==0: axes[r,i+1].set_title(f'{class_names[explainer.output_indexes[i]]}', fontsize=10)#+
#f'\n{explainer.base_f_S[i]:.5} to {explainer.base_f_0[i]:.5}')
for jjj in range(num_explained_classes+1): axes[r,jjj].set_xticks([]) ; axes[r,jjj].set_yticks([])
if names is not None:
for r in range(num_rows):
axes[r,0].set_ylabel(names[r])
# Use the last row for the colorbar
for ax in axes[-1,:]:
ax.set_axis_off()
# ax.set_box_aspect(0.1)
cb = fig.colorbar(im, ax=axes[-1,:], label="Shapley/Owen value",
orientation="horizontal", aspect=80, fraction=0.9)#, location='bottom') #, fraction=0.5,
cb.outline.set_visible(False)
fig.subplots_adjust(hspace=0.1, wspace=0.1)
# plt.tight_layout()
plt.show()
######################################################################################################