Simple BCa Example¶

FileCopyrightText: Copyright (C) 2022 Ebtihal Alwadee AlwadeeEJ@cardiff.ac.uk, PhD student at Cardiff University
FileCopyrightText: Copyright (C) 2022-2023 Frank C Langbein frank@langbein.org, Cardiff University

License-Identifier: AGPL-3.0-or-later

This is a simple example how to use the BCa framework, for demonstration only. For details also see the BCa API documentation.

We start by setting up the jupyter notebook (autoloading and reload updated modules, etc) and some style options, which may be useful for display.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
In [2]:
%%html
<style>
.dataframe th { font-size: 11px; }
.dataframe td { font-size: 11px; }
</style>

Dataset Setup¶

This defines the dataset source used for training.

In [3]:
# We assume https://qyber.black/ca/code-bca is cloned into bca in the top folder
# This has been registered as a sub-module.
from bca.bca.dataset import Dataset

# Bounding-box cropped dataset; using a small version of the full dataset for efficiency / this is just a demo
ds = Dataset(os.path.join("data","brats2020","MICCAI_BraTS2020_Small"),    # Dataset folder - brought in via a git submodule (not public, so replace as needed for repeats)
             os.path.join("results","brats2020","MICCAI_BraTS2020_Small")) # Results folder
ds.crop_to_bb() # Crop to axis-aligned bounding box
# For fixed crop use
#ds_fi.crop((56,184), (56,184), (13,141))

# Show bounding-box cropped dataset info and slices.
print(ds)
ds.browse()
# Dataset: data/brats2020/MICCAI_BraTS2020_Small [34 patients]
Cache: results/brats2020/MICCAI_BraTS2020_Small
Channels (patient BraTS20_Training_001):
  flair: (240, 240, 155) (int16)
  seg: (240, 240, 155) (uint8)
  t1: (240, 240, 155) (int16)
  t1ce: (240, 240, 155) (int16)
  t2: (240, 240, 155) (int16)
Crop: bounding-box

interactive(children=(IntSlider(value=17, description='idx', max=34, min=1), IntSlider(value=78, description='…

Setup Model¶

We use a standard 3D UNet defined in the unet package. This is a class generating the model, not the model itself. Use it as example of how to define your own model classes.

In [4]:
from bca.bca.unet import UNet3D
from bca.bca.trainer import Trainer, dsc_loss, dsc, iou
import tensorflow.keras as keras

model =  UNet3D(name="UNet3D_dice",
                enc=[{"filters": 16},
                     {"filters": 32},
                     {"filters": 64},
                     {"filters":128, "kernel_regularizer":keras.regularizers.l2(0.02)},
                     {"filters":256, "kernel_regularizer":keras.regularizers.l2(0.02), "max_pooling":None}],
                dec=[{"filters":128},
                     {"filters": 64},
                     {"filters": 32, "kernel_regularizer":keras.regularizers.l2(0.02)},
                     {"filters": 16, "kernel_regularizer":keras.regularizers.l2(0.02)}],
                loss=dsc_loss,
                metrics=[dsc,iou])
# Plot model structure for single output map; we need to fix an input shape to be able to do this.
# First part is shape of input, last entry in tuple is number of classes.
model.plot((128,128,128,4,3))
Segmentation Models: using `tf.keras` framework.
No description has been provided for this image

Training¶

We use a 5-fold cross-validation to explore model performance for specific input/output combination. The generated sequence data (cropped, scaled to 128x128x128, normalised to $[0,1]$) is not stored in the repo and will be recreated if the notebook is executed.

In [5]:
# Parameters determining the dataset sequences
K = 5                      # 5-fold cross validation
EPOCHS = 100               # 100 epochs
SEED = 8120341116777169704 # Seed for dataset 5-fold split
BATCH_SIZE = 4             # Batch size for training

# Generate a list of K keras (train,test)-pair sequences for dataset splits
# seg+1+2+4 combined labels 1,2,4 into a single 0/1 bitmap;
# seg=1=2 would generate a map with values 1 and 2 (and 0 for background)
seqs = ds.sequences(K, (128,128,128), ["flair"], ["seg+1+2+4"],
                    batch_size=BATCH_SIZE, pre_proc=Dataset.norm_minmax, seed=SEED)

# Setup trainer for the model/data
trainer = Trainer(model, epochs=EPOCHS)

# Train the networks
results = trainer.train(seqs, jit_compile=True, remote=False)
* Fold 1: done - results/brats2020/MICCAI_BraTS2020_Small/bb-128_128_128-norm_minmax/UNet3D_dice-flair-seg+1+2+4/100-4/8120341116777169704-5-0
* Fold 2: done - results/brats2020/MICCAI_BraTS2020_Small/bb-128_128_128-norm_minmax/UNet3D_dice-flair-seg+1+2+4/100-4/8120341116777169704-5-1
* Fold 3: done - results/brats2020/MICCAI_BraTS2020_Small/bb-128_128_128-norm_minmax/UNet3D_dice-flair-seg+1+2+4/100-4/8120341116777169704-5-2
* Fold 4: done - results/brats2020/MICCAI_BraTS2020_Small/bb-128_128_128-norm_minmax/UNet3D_dice-flair-seg+1+2+4/100-4/8120341116777169704-5-3
* Fold 5: done - results/brats2020/MICCAI_BraTS2020_Small/bb-128_128_128-norm_minmax/UNet3D_dice-flair-seg+1+2+4/100-4/8120341116777169704-5-4

For remote execution remote=True should be set in the train run and then the following code can be used to schedule the tasks. Note, for this the remote schedulers must be defined in cfg.json for the bca package.

In [6]:
#from bca.bca.scheduler import schedule, schedule_clean
#
#task_folder="results"
#schedule_clean(task_folder=task_folder) # Cleanup failed task (assuming issues hav been fixed)
#schedule(task_folder=task_folder)       # Schedule tasks

Evaluation¶

This runs the evaluation of the trained models locally, for any completely trained model not yet evaluated. Only models trained completely are evaluated; everything else is ignored.

In [7]:
# Evaluate model, including standardising the evaluation as different models may have different outputs

import numpy as np

def std_eval(P, Y):
  # Convert prediction and trained output to standardised output
  # Output: dictionary of mapping names to (prediction, expected) pairs from single prediction, expected pair
  # For BraTS2020:
  #   whole     1,2,4
  #   necrotic  1
  #   enhancing 4
  #   edema     2
  #   core      1,4
  # Nothing to do here, as P, Y is seg+1+2+4. Just for demo, as these should be the same as the direct output
  # evaluation of the network. Note that P and Y are lists of outputs, even if the network only has one output,
  # so we return the first list element (which in this case is a single sample).
  # Also note, we apply a threshold to the prediction for the standard measure to obtain a bitmap and expand
  # the dimensions to have the rquired sample and channel index
  p = np.where(P[0] >= 0.5, 1.0, 0.0).astype(np.float32)
  return { 'whole': [p, Y[0]] }

trainer.eval(seqs, std_eval=std_eval)
trainer.plot_model(seqs, save_only=True)
* Fold 1
* Fold 2
* Fold 3
* Fold 4
* Fold 5
Evaluation complete.

Results¶

This presents the resuls of the training/evaluation.

Note that below the "prime" (DSC', val_DSC', etc) values are per-sample metrics (and then averaged, etc), while the others are the metrics used for training (per batch values, and then averaged, etc).

In [8]:
import numpy as np
import pandas as pd
from IPython import display

pd.set_option('display.max_columns', None)

results = trainer.plot_results(seqs)
No description has been provided for this image
DSC val_DSC IoU val_IoU loss val_loss DSC' val_DSC' IoU' val_IoU' STD-whole-DSC val_STD-whole-DSC STD-whole-IoU val_STD-whole-IoU
Fold 1 0.358000 0.336694 0.219870 0.202495 15.421636 15.440179 0.352439 0.339223 0.226080 0.215503 0.355217 0.342502 0.228415 0.218075
Fold 2 0.519200 0.505013 0.356384 0.341096 14.569013 14.577881 0.500532 0.508442 0.351002 0.353505 0.503951 0.513441 0.354301 0.358231
Fold 3 0.272825 0.228741 0.159403 0.129298 15.085237 15.128102 0.279333 0.229827 0.170907 0.136390 0.280663 0.230238 0.171931 0.136675
Fold 4 0.503390 0.408319 0.340219 0.264984 14.917871 14.997222 0.492676 0.422088 0.341240 0.287731 0.498517 0.428275 0.346619 0.293706
Fold 5 0.145744 0.221021 0.079355 0.124717 15.223004 15.135514 0.147070 0.234069 0.081401 0.134032 0.147921 0.235105 0.081922 0.134710
Mean 0.359832 0.339957 0.231046 0.212518 15.043352 15.055780 0.354410 0.346730 0.234126 0.225432 0.357254 0.349912 0.236638 0.228280
Std 0.141007 0.108113 0.105732 0.082486 0.289092 0.279702 0.133491 0.107931 0.102489 0.085633 0.135030 0.109945 0.104050 0.087657
In [9]:
# Plot the prediction results (runs tensorflow in a sub-process, so that it can be stopped
# easily; to  restart, run cell again; only one can run at the same time, unless there are
# enough GPU resources).
trainer.browse_predict(seqs) 
Button(description='Stop', style=ButtonStyle())
interactive(children=(IntSlider(value=1, description='idx', max=27, min=1), IntSlider(value=64, description='s…