{ "cells": [ { "cell_type": "markdown", "id": "8833c097", "metadata": {}, "source": [ "# Generate the interactive GIF pictures of the recursive splitting of the BPT/AA structure of an input image" ] }, { "cell_type": "code", "execution_count": 1, "id": "a8a8e500-072d-4418-998b-2ef98feb7677", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] } ], "source": [ "import numpy as np\n", "import json\n", "import sys, os, importlib, math\n", "import matplotlib.pyplot as plt\n", "import cv2\n", "from tqdm.auto import tqdm\n", "from skimage.segmentation import mark_boundaries\n", "import io, imageio\n", "\n", "from matplotlib import rc\n", "rc('text',usetex=True)\n", "rc('text.latex', preamble='\\\\usepackage{color}')\n", "\n", "import shap_bpt as shap_bpt\n", "print(shap_bpt.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "id": "c42a1905-c74c-4e2d-b35e-3b9eedcab823", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(224, 224, 3)\n" ] } ], "source": [ "image_to_explain = cv2.imread('flamingo4.png', cv2.IMREAD_COLOR)[:,:,::-1].astype(np.uint8)\n", "print(image_to_explain.shape)" ] }, { "cell_type": "code", "execution_count": 3, "id": "df413b6f-5df7-4005-a061-a2739a468dc8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 127 ms, sys: 6.76 ms, total: 134 ms\n", "Wall time: 135 ms\n" ] } ], "source": [ "%%time\n", "bptree = shap_bpt.build_bpt_from_image(image_to_explain)" ] }, { "cell_type": "code", "execution_count": 4, "id": "54c00643-9276-4468-8ad9-1ef218c575c1", "metadata": {}, "outputs": [], "source": [ "import matplotlib.colors as mcolors\n", "cmap = [shap_bpt.hex_to_rgb(c[1]) for c in list(mcolors.XKCD_COLORS.items())]\n", "\n", "def colorize(nodes, img, i):\n", " is_aa = isinstance(nodes[0], shap_bpt.AxisAlignedSegment)\n", " pxflat_image = image_to_explain.reshape((img.shape[0] * img.shape[1], 3))\n", " colored = np.zeros_like(image_to_explain, dtype=np.float32)\n", " flat_colored = colored.reshape(pxflat_image.shape)\n", " for node in nodes:\n", " if is_aa:\n", " clr = np.mean(np.mean(image_to_explain[node.ymin:node.ymax, node.xmin:node.xmax, :], axis=1), axis=0)/255.0\n", " colored[ node.ymin:node.ymax, node.xmin:node.xmax ] = clr\n", " else:\n", " s,e = node.pixels_interval()\n", " clr = np.mean(pxflat_image[ node.bpt.pixels[s:e] ], axis=0)/255.0\n", " flat_colored[ node.bpt.pixels[s:e], :: ] = clr #np.array(cmap[s % len(cmap)])[0:3]\n", " return colored\n", "\n", "def make_segments(nodes, img):\n", " is_aa = isinstance(nodes[0], shap_bpt.AxisAlignedSegment)\n", " flat_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint32)\n", " for i, node in enumerate(nodes):\n", " if is_aa:\n", " flat_img[ node.ymin:node.ymax, node.xmin:node.xmax ] = i\n", " else:\n", " s,e = node.pixels_interval()\n", " flat_img.ravel()[ node.bpt.pixels[s:e] ] = i\n", " return flat_img#.reshape((img.shape[0], img.shape[1]))" ] }, { "cell_type": "code", "execution_count": 6, "id": "1843dc88-109b-486a-9099-96c7159fc5ba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "saved.\n" ] } ], "source": [ "base_segment = shap_bpt.BaseSegment()\n", "\n", "# Build visualization for AA or BPT (keep one)\n", "# root_node, name, title = shap_bpt.AxisAlignedSegment(0, bptree.width, 0, bptree.height, base_segment), 'aa', 'AA Hierarchy'\n", "root_node, name, title = shap_bpt.BPT_Segment(bptree, bptree.N-1, base_segment), 'bpt', 'BPT Hierarchy'\n", "\n", "segments = [root_node]\n", "all_nodes = [root_node]\n", "prev_boundaries = None\n", "\n", "K=11\n", "frames = []\n", "leaves = np.zeros(K, dtype=int)\n", "fig, ax = plt.subplots(figsize=(3,3))\n", "for ii in range(0,K):\n", " img = colorize(segments, image_to_explain, 0)\n", " img = np.clip(0.2 + img * 1.1, 0, 1)\n", " sgm = make_segments(segments, image_to_explain)\n", " cut_color = (.5, 0, .25, 1)\n", " boundaries = mark_boundaries(np.tile((255,255,255,0), (image_to_explain.shape[0],image_to_explain.shape[1],1)), sgm, \n", " mode='thick', color=cut_color)\n", " ax.set_xticks([]) ; ax.set_yticks([])\n", " ax.set_title(f'{title} ({ii}/{K-1})', fontsize=16)\n", " ax.imshow(img)\n", " if ii==0:\n", " pass\n", " else:\n", " ax.imshow(boundaries)\n", " nshape = (224*224, 4)\n", " boundaries.reshape(nshape)[ np.where(boundaries.reshape(nshape) == cut_color)[0] ] = (0,0,0,1)\n", " if prev_boundaries is not None:\n", " ax.imshow(prev_boundaries)\n", "\n", " prev_boundaries = boundaries\n", "\n", " new_segments = []\n", " for s in segments:\n", " split = s.split(s, s)\n", " if split is None:\n", " new_segments.append(s)\n", " leaves[ii] += 1\n", " else:\n", " new_segments.extend(split)\n", " all_nodes.extend(split)\n", "\n", " segments = new_segments\n", "\n", " io_buf = io.BytesIO()\n", " fig.savefig(io_buf, format='raw', dpi=100)\n", " io_buf.seek(0)\n", " img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),\n", " newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))\n", " io_buf.close()\n", " frames.append(img_arr)\n", " ax.clear()\n", "\n", "imageio.mimsave(f'sequence_{name}.gif', frames, duration=[1000]+([1000] * (len(frames)-2))+[3000], loop=0)\n", "plt.close()\n", "print('saved.')" ] } ], "metadata": { "kernelspec": { "display_name": "metal11", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 5 }