Generate the interactive GIF pictures of the recursive splitting of the BPT/AA structure of an input imageΒΆ
[1]:
import numpy as np
import json
import sys, os, importlib, math
import matplotlib.pyplot as plt
import cv2
from tqdm.auto import tqdm
from skimage.segmentation import mark_boundaries
import io, imageio
from matplotlib import rc
rc('text',usetex=True)
rc('text.latex', preamble='\\usepackage{color}')
import shap_bpt as shap_bpt
print(shap_bpt.__version__)
1.0
[2]:
image_to_explain = cv2.imread('flamingo4.png', cv2.IMREAD_COLOR)[:,:,::-1].astype(np.uint8)
print(image_to_explain.shape)
(224, 224, 3)
[3]:
%%time
bptree = shap_bpt.build_bpt_from_image(image_to_explain)
CPU times: user 127 ms, sys: 6.76 ms, total: 134 ms
Wall time: 135 ms
[4]:
import matplotlib.colors as mcolors
cmap = [shap_bpt.hex_to_rgb(c[1]) for c in list(mcolors.XKCD_COLORS.items())]
def colorize(nodes, img, i):
is_aa = isinstance(nodes[0], shap_bpt.AxisAlignedSegment)
pxflat_image = image_to_explain.reshape((img.shape[0] * img.shape[1], 3))
colored = np.zeros_like(image_to_explain, dtype=np.float32)
flat_colored = colored.reshape(pxflat_image.shape)
for node in nodes:
if is_aa:
clr = np.mean(np.mean(image_to_explain[node.ymin:node.ymax, node.xmin:node.xmax, :], axis=1), axis=0)/255.0
colored[ node.ymin:node.ymax, node.xmin:node.xmax ] = clr
else:
s,e = node.pixels_interval()
clr = np.mean(pxflat_image[ node.bpt.pixels[s:e] ], axis=0)/255.0
flat_colored[ node.bpt.pixels[s:e], :: ] = clr #np.array(cmap[s % len(cmap)])[0:3]
return colored
def make_segments(nodes, img):
is_aa = isinstance(nodes[0], shap_bpt.AxisAlignedSegment)
flat_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint32)
for i, node in enumerate(nodes):
if is_aa:
flat_img[ node.ymin:node.ymax, node.xmin:node.xmax ] = i
else:
s,e = node.pixels_interval()
flat_img.ravel()[ node.bpt.pixels[s:e] ] = i
return flat_img#.reshape((img.shape[0], img.shape[1]))
[6]:
base_segment = shap_bpt.BaseSegment()
# Build visualization for AA or BPT (keep one)
# root_node, name, title = shap_bpt.AxisAlignedSegment(0, bptree.width, 0, bptree.height, base_segment), 'aa', 'AA Hierarchy'
root_node, name, title = shap_bpt.BPT_Segment(bptree, bptree.N-1, base_segment), 'bpt', 'BPT Hierarchy'
segments = [root_node]
all_nodes = [root_node]
prev_boundaries = None
K=11
frames = []
leaves = np.zeros(K, dtype=int)
fig, ax = plt.subplots(figsize=(3,3))
for ii in range(0,K):
img = colorize(segments, image_to_explain, 0)
img = np.clip(0.2 + img * 1.1, 0, 1)
sgm = make_segments(segments, image_to_explain)
cut_color = (.5, 0, .25, 1)
boundaries = mark_boundaries(np.tile((255,255,255,0), (image_to_explain.shape[0],image_to_explain.shape[1],1)), sgm,
mode='thick', color=cut_color)
ax.set_xticks([]) ; ax.set_yticks([])
ax.set_title(f'{title} ({ii}/{K-1})', fontsize=16)
ax.imshow(img)
if ii==0:
pass
else:
ax.imshow(boundaries)
nshape = (224*224, 4)
boundaries.reshape(nshape)[ np.where(boundaries.reshape(nshape) == cut_color)[0] ] = (0,0,0,1)
if prev_boundaries is not None:
ax.imshow(prev_boundaries)
prev_boundaries = boundaries
new_segments = []
for s in segments:
split = s.split(s, s)
if split is None:
new_segments.append(s)
leaves[ii] += 1
else:
new_segments.extend(split)
all_nodes.extend(split)
segments = new_segments
io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=100)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()
frames.append(img_arr)
ax.clear()
imageio.mimsave(f'sequence_{name}.gif', frames, duration=[1000]+([1000] * (len(frames)-2))+[3000], loop=0)
plt.close()
print('saved.')
saved.