Simple BCa Example¶
FileCopyrightText: Copyright (C) 2022 Ebtihal Alwadee AlwadeeEJ@cardiff.ac.uk, PhD student at Cardiff University
FileCopyrightText: Copyright (C) 2022-2023 Frank C Langbein frank@langbein.org, Cardiff UniversityLicense-Identifier: AGPL-3.0-or-later
This is a simple example how to use the BCa framework, for demonstration only. For details also see the BCa API documentation.
We start by setting up the jupyter notebook (autoloading and reload updated modules, etc) and some style options, which may be useful for display.
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
%%html
<style>
.dataframe th { font-size: 11px; }
.dataframe td { font-size: 11px; }
</style>
Dataset Setup¶
This defines the dataset source used for training.
# We assume https://qyber.black/ca/code-bca is cloned into bca in the top folder
# This has been registered as a sub-module.
from bca.bca.dataset import Dataset
# Bounding-box cropped dataset; using a small version of the full dataset for efficiency / this is just a demo
ds = Dataset(os.path.join("data","brats2020","MICCAI_BraTS2020_Small"), # Dataset folder - brought in via a git submodule (not public, so replace as needed for repeats)
os.path.join("results","brats2020","MICCAI_BraTS2020_Small")) # Results folder
ds.crop_to_bb() # Crop to axis-aligned bounding box
# For fixed crop use
#ds_fi.crop((56,184), (56,184), (13,141))
# Show bounding-box cropped dataset info and slices.
print(ds)
ds.browse()
# Dataset: data/brats2020/MICCAI_BraTS2020_Small [34 patients] Cache: results/brats2020/MICCAI_BraTS2020_Small Channels (patient BraTS20_Training_001): flair: (240, 240, 155) (int16) seg: (240, 240, 155) (uint8) t1: (240, 240, 155) (int16) t1ce: (240, 240, 155) (int16) t2: (240, 240, 155) (int16) Crop: bounding-box
interactive(children=(IntSlider(value=17, description='idx', max=34, min=1), IntSlider(value=78, description='…
Setup Model¶
We use a standard 3D UNet defined in the unet package. This is a class generating the model, not the model itself. Use it as example of how to define your own model classes.
from bca.bca.unet import UNet3D
from bca.bca.trainer import Trainer, dsc_loss, dsc, iou
import tensorflow.keras as keras
model = UNet3D(name="UNet3D_dice",
enc=[{"filters": 16},
{"filters": 32},
{"filters": 64},
{"filters":128, "kernel_regularizer":keras.regularizers.l2(0.02)},
{"filters":256, "kernel_regularizer":keras.regularizers.l2(0.02), "max_pooling":None}],
dec=[{"filters":128},
{"filters": 64},
{"filters": 32, "kernel_regularizer":keras.regularizers.l2(0.02)},
{"filters": 16, "kernel_regularizer":keras.regularizers.l2(0.02)}],
loss=dsc_loss,
metrics=[dsc,iou])
# Plot model structure for single output map; we need to fix an input shape to be able to do this.
# First part is shape of input, last entry in tuple is number of classes.
model.plot((128,128,128,4,3))
Segmentation Models: using `tf.keras` framework.
Training¶
We use a 5-fold cross-validation to explore model performance for specific input/output combination. The generated sequence data (cropped, scaled to 128x128x128, normalised to $[0,1]$) is not stored in the repo and will be recreated if the notebook is executed.
# Parameters determining the dataset sequences
K = 5 # 5-fold cross validation
EPOCHS = 100 # 100 epochs
SEED = 8120341116777169704 # Seed for dataset 5-fold split
BATCH_SIZE = 4 # Batch size for training
# Generate a list of K keras (train,test)-pair sequences for dataset splits
# seg+1+2+4 combined labels 1,2,4 into a single 0/1 bitmap;
# seg=1=2 would generate a map with values 1 and 2 (and 0 for background)
seqs = ds.sequences(K, (128,128,128), ["flair"], ["seg+1+2+4"],
batch_size=BATCH_SIZE, pre_proc=Dataset.norm_minmax, seed=SEED)
# Setup trainer for the model/data
trainer = Trainer(model, epochs=EPOCHS)
# Train the networks
results = trainer.train(seqs, jit_compile=True, remote=False)
* Fold 1: done - results/brats2020/MICCAI_BraTS2020_Small/bb-128_128_128-norm_minmax/UNet3D_dice-flair-seg+1+2+4/100-4/8120341116777169704-5-0 * Fold 2: done - results/brats2020/MICCAI_BraTS2020_Small/bb-128_128_128-norm_minmax/UNet3D_dice-flair-seg+1+2+4/100-4/8120341116777169704-5-1 * Fold 3: done - results/brats2020/MICCAI_BraTS2020_Small/bb-128_128_128-norm_minmax/UNet3D_dice-flair-seg+1+2+4/100-4/8120341116777169704-5-2 * Fold 4: done - results/brats2020/MICCAI_BraTS2020_Small/bb-128_128_128-norm_minmax/UNet3D_dice-flair-seg+1+2+4/100-4/8120341116777169704-5-3 * Fold 5: done - results/brats2020/MICCAI_BraTS2020_Small/bb-128_128_128-norm_minmax/UNet3D_dice-flair-seg+1+2+4/100-4/8120341116777169704-5-4
For remote execution remote=True should be set in the train run and then the following code can be used to
schedule the tasks. Note, for this the remote schedulers must be defined in cfg.json for the bca package.
#from bca.bca.scheduler import schedule, schedule_clean
#
#task_folder="results"
#schedule_clean(task_folder=task_folder) # Cleanup failed task (assuming issues hav been fixed)
#schedule(task_folder=task_folder) # Schedule tasks
Evaluation¶
This runs the evaluation of the trained models locally, for any completely trained model not yet evaluated. Only models trained completely are evaluated; everything else is ignored.
# Evaluate model, including standardising the evaluation as different models may have different outputs
import numpy as np
def std_eval(P, Y):
# Convert prediction and trained output to standardised output
# Output: dictionary of mapping names to (prediction, expected) pairs from single prediction, expected pair
# For BraTS2020:
# whole 1,2,4
# necrotic 1
# enhancing 4
# edema 2
# core 1,4
# Nothing to do here, as P, Y is seg+1+2+4. Just for demo, as these should be the same as the direct output
# evaluation of the network. Note that P and Y are lists of outputs, even if the network only has one output,
# so we return the first list element (which in this case is a single sample).
# Also note, we apply a threshold to the prediction for the standard measure to obtain a bitmap and expand
# the dimensions to have the rquired sample and channel index
p = np.where(P[0] >= 0.5, 1.0, 0.0).astype(np.float32)
return { 'whole': [p, Y[0]] }
trainer.eval(seqs, std_eval=std_eval)
trainer.plot_model(seqs, save_only=True)
* Fold 1 * Fold 2 * Fold 3 * Fold 4 * Fold 5 Evaluation complete.
Results¶
This presents the resuls of the training/evaluation.
Note that below the "prime" (DSC', val_DSC', etc) values are per-sample metrics (and then averaged, etc), while the others are the metrics used for training (per batch values, and then averaged, etc).
import numpy as np
import pandas as pd
from IPython import display
pd.set_option('display.max_columns', None)
results = trainer.plot_results(seqs)
| DSC | val_DSC | IoU | val_IoU | loss | val_loss | DSC' | val_DSC' | IoU' | val_IoU' | STD-whole-DSC | val_STD-whole-DSC | STD-whole-IoU | val_STD-whole-IoU | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Fold 1 | 0.358000 | 0.336694 | 0.219870 | 0.202495 | 15.421636 | 15.440179 | 0.352439 | 0.339223 | 0.226080 | 0.215503 | 0.355217 | 0.342502 | 0.228415 | 0.218075 |
| Fold 2 | 0.519200 | 0.505013 | 0.356384 | 0.341096 | 14.569013 | 14.577881 | 0.500532 | 0.508442 | 0.351002 | 0.353505 | 0.503951 | 0.513441 | 0.354301 | 0.358231 |
| Fold 3 | 0.272825 | 0.228741 | 0.159403 | 0.129298 | 15.085237 | 15.128102 | 0.279333 | 0.229827 | 0.170907 | 0.136390 | 0.280663 | 0.230238 | 0.171931 | 0.136675 |
| Fold 4 | 0.503390 | 0.408319 | 0.340219 | 0.264984 | 14.917871 | 14.997222 | 0.492676 | 0.422088 | 0.341240 | 0.287731 | 0.498517 | 0.428275 | 0.346619 | 0.293706 |
| Fold 5 | 0.145744 | 0.221021 | 0.079355 | 0.124717 | 15.223004 | 15.135514 | 0.147070 | 0.234069 | 0.081401 | 0.134032 | 0.147921 | 0.235105 | 0.081922 | 0.134710 |
| Mean | 0.359832 | 0.339957 | 0.231046 | 0.212518 | 15.043352 | 15.055780 | 0.354410 | 0.346730 | 0.234126 | 0.225432 | 0.357254 | 0.349912 | 0.236638 | 0.228280 |
| Std | 0.141007 | 0.108113 | 0.105732 | 0.082486 | 0.289092 | 0.279702 | 0.133491 | 0.107931 | 0.102489 | 0.085633 | 0.135030 | 0.109945 | 0.104050 | 0.087657 |
# Plot the prediction results (runs tensorflow in a sub-process, so that it can be stopped
# easily; to restart, run cell again; only one can run at the same time, unless there are
# enough GPU resources).
trainer.browse_predict(seqs)
Button(description='Stop', style=ButtonStyle())
interactive(children=(IntSlider(value=1, description='idx', max=27, min=1), IntSlider(value=64, description='s…