{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3982f572-4151-47df-923d-3cd6794e7070",
   "metadata": {},
   "source": [
    "# Callbacks\n",
    "\n",
    "The image labeller implements a callback system that allows you to run arbitrary code whenever the displayed image changes or when an image has a label assigned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eebbcc16-dd8a-4e9b-9d79-b7cb2d3dbee0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If in a notebook\n",
    "%matplotlib ipympl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f576ccc0-d8e8-4d6c-b09c-5c3634c94bf6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from mpl_image_labeller import image_labeller\n",
    "\n",
    "images = np.random.randn(5, 10, 10)\n",
    "labeller = image_labeller(images, classes=[\"good\", \"bad\", \"blarg\"])\n",
    "\n",
    "\n",
    "def image_changed_callback(index, image):\n",
    "    print(index)\n",
    "    print(image.sum())\n",
    "\n",
    "\n",
    "def label_assigned(index, label):\n",
    "    print(f\"label {label} assigned to image {index}\")\n",
    "\n",
    "\n",
    "labeller.on_image_changed(image_changed_callback)\n",
    "labeller.on_label_assigned(image_changed_callback)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e4c6ae8-c124-4e05-a9d2-6c4d987d9cf7",
   "metadata": {},
   "source": [
    "## Overlaying a mask\n",
    "\n",
    "One potential usage of this is to overlay a mask over the images which changes for each image. If the shape of image is changing then you will also need to adjust the `extent` of the overlay. If this is the case then uncomment the line in the `update_mask` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "352f6e5e-7b0b-44aa-bb19-da2bc6bca654",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "labeller = image_labeller(images, classes=[\"good\", \"bad\", \"blarg\"])\n",
    "\n",
    "mask_threshold = 0.6\n",
    "\n",
    "from numpy.random import default_rng\n",
    "\n",
    "\n",
    "def gen_mask(idx, image):\n",
    "    # here we return a random mask - but you could base this on your data\n",
    "    rng = default_rng(idx)\n",
    "    mask = rng.random(image.shape)\n",
    "    return mask > mask_threshold\n",
    "\n",
    "\n",
    "overlay = labeller.ax.imshow(\n",
    "    gen_mask(0, images[0]), cmap=\"gray\", vmin=0, vmax=mask_threshold, alpha=0.75\n",
    ")\n",
    "cmap = overlay.cmap.copy()\n",
    "cmap.set_over(alpha=0)\n",
    "overlay.set_cmap(cmap)\n",
    "\n",
    "\n",
    "def update_mask(idx, image):\n",
    "    new_mask = gen_mask(idx, image)\n",
    "    overlay.set_data(new_mask)\n",
    "\n",
    "    # if your image is changing shape uncomment the next line\n",
    "    # overlay.set_extent((-0.5, new_mask.shape[1] - 0.5, new_mask.shape[0] - 0.5, -0.5))\n",
    "\n",
    "\n",
    "labeller.on_image_changed(update_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f1e35b8-5b0e-442e-ae91-910d4ff04017",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
