BCa Segmentation¶
LATUP-Net SE, WDL, BG-WT-TC-ET, BraTS2020¶
FileCopyrightText: Copyright (C) 2023-2024 Ebtihal Alwadee AlwadeeEJ@cardiff.ac.uk, PhD student at Cardiff University
FileCopyrightText: Copyright (C) 2023-2024 Frank C Langbein frank@langbein.org, Cardiff UniversityLicense-Identifier: AGPL-3.0-or-later
Dataset: BraTS2020, fixed crop, >=1% lesions
Architecture: LATUP-Net with SE attention
Loss: weighted dice loss, enet weights
Output channels: BG, WT, TC, ET
We start by setting up the jupyter notebook (autoloading and reload updated modules, etc) and some style options, which may be useful for display.
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
%%html
<style>
.dataframe th { font-size: 11px; }
.dataframe td { font-size: 11px; }
</style>
Dataset Setup¶
This defines the dataset source, BraTS2020, used for training.
The BraTS2020 dataset must be in the data folder specified below. See README.md for how to set this up.
# We assume https://qyber.black/ca/code-bca is cloned into bca in the top folder
# and the BraTS2020 dataset is available in the folders below (see README.md).
# These have been registered as a git sub-modules (but the BraTS dataset are not
# distributed by us).
from bca.bca.dataset import Dataset
# Fixed-crop BraTS2020 dataset
ds = Dataset(os.path.join("data","brats2020","MICCAI_BraTS2020_TrainingData"), # Dataset folder - brought in via a git submodule
# (not public, so replace as needed)
os.path.join("results","brats2020","MICCAI_BraTS2020_TrainingData")) # Results folder
ds.crop((56,184), (56,184), (13,141)) # Fixed centered crop
ds.filter_low_labels("seg",0.01) # Remove patients with non-background segmentation area less than 1%
# Show bounding-box cropped dataset info and slices.
print(ds)
ds.browse()
# Dataset: data/brats2020/MICCAI_BraTS2020_TrainingData [344 patients] Cache: results/brats2020/MICCAI_BraTS2020_TrainingData Channels (patient BraTS20_Training_001): flair: (240, 240, 155) (int16) seg: (240, 240, 155) (uint8) t1: (240, 240, 155) (int16) t1ce: (240, 240, 155) (int16) t2: (240, 240, 155) (int16) Crop: [(56, 184), (56, 184), (13, 141)]
interactive(children=(IntSlider(value=172, description='idx', max=344, min=1), IntSlider(value=78, description…
from collections import OrderedDict
from bca.bca.loss import compute_channel_weights
# Specify output channels of network in an OrderedDict.
# Order of keys determines output channel order.
# Each channel is given by a key (name of the channel) and
# with a (list-of-labels, use-in-loss) pair.
output_channels = OrderedDict({
"BG": ([0],0),
"WT": ([1,2,4],1),
"TC": ([1,4],1),
"ET": ([4],1)
})
# ENet weights
weights = compute_channel_weights(ds, "seg", output_channels, mode="enet", enet_c=1.22, normalise=True)
for ch in output_channels:
print(f"{ch} weight: {weights[ch][0]}, voxel count: {weights[ch][1]} for {output_channels[ch][0]}")
BG weight: 0.0, voxel count: 0 for [0] WT weight: 0.21626338547830623, voxel count: 34921136 for [1, 2, 4] TC weight: 0.33549474183473676, voxel count: 14718903 for [1, 4] ET weight: 0.4482418726869569, voxel count: 6857408 for [4]
LATUP-Net Model¶
We first setup the model and then train it.
from bca.bca.latupnet import LATUPNet
from bca.bca.trainer import Trainer
from bca.bca.loss import WeightedDiceLoss
from bca.bca.metric import sDSC
# Note, the model name determines were model/results are stored,
# so should be different for different models if we preserve them
model = LATUPNet(name="LATUPNet-SE_WDL_BG-WT-TC-ET",
attention="SE",
loss=WeightedDiceLoss([ (weights[ch][0],n)
for n, ch in enumerate(list(output_channels.keys()))
if output_channels[ch][1] ]),
metrics=lambda chs=list(output_channels.keys()): [sDSC(channel=n, name="dsc"+ch) for n, ch in enumerate(chs)]) # FIXME
# Plot model architecture (or show text summary if text is True)
model.plot((128,128,128,3,len(output_channels)),text=False) # 3D Spatial dim, number of inputs (flair, t1ce, t2) and number of classes
Segmentation Models: using `tf.keras` framework.
# Parameters determining the dataset sequences
K = 5 # 5-fold cross validation
EPOCHS = 200 # 200 epochs
SEED = 8120341116777169704 # Seed for dataset 5-fold split
BATCH_SIZE = 1 # Batch size for training
# Generate a list of K keras (train,test)-pair sequences for dataset splits
# We use a 128x128x128 input image with flair, t1ce and t2 channels. The
# output channels have been specified above and just used here.
seqs = ds.sequences(K, (128,128,128),
["flair", "t1ce", "t2"],
["seg+"+"+".join([str(n) for n in output_channels[ch][0]]) for ch in output_channels.keys()],
batch_size=BATCH_SIZE,
pre_proc=Dataset.norm_minmax,
seed=SEED)
# Show training sequence of first fold
seqs[0][0].browse()
# Setup trainer for the model/data
trainer = Trainer(model, epochs=EPOCHS)
interactive(children=(IntSlider(value=138, description='idx', max=275, min=1), IntSlider(value=64, description…
We are ready to train the models.
For remote execution remote=True should be set in the train run and then the following code can be used to
schedule the tasks. Note, for this the remote schedulers must be defined in cfg.json for the bca package.
# Train the networks
# Specifying loss,loss for the best monitor uses the test loss alone to store best model
results = trainer.train(seqs, jit_compile=True, remote=True)
* Fold 1: done - results/brats2020/MICCAI_BraTS2020_TrainingData/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDL_BG-WT-TC-ET-flair_t1ce_t2-seg+0_seg+1+2+4_seg+1+4_seg+4/200-1/8120341116777169704-5-0 * Fold 2: done - results/brats2020/MICCAI_BraTS2020_TrainingData/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDL_BG-WT-TC-ET-flair_t1ce_t2-seg+0_seg+1+2+4_seg+1+4_seg+4/200-1/8120341116777169704-5-1 * Fold 3: done - results/brats2020/MICCAI_BraTS2020_TrainingData/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDL_BG-WT-TC-ET-flair_t1ce_t2-seg+0_seg+1+2+4_seg+1+4_seg+4/200-1/8120341116777169704-5-2 * Fold 4: done - results/brats2020/MICCAI_BraTS2020_TrainingData/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDL_BG-WT-TC-ET-flair_t1ce_t2-seg+0_seg+1+2+4_seg+1+4_seg+4/200-1/8120341116777169704-5-3 * Fold 5: done - results/brats2020/MICCAI_BraTS2020_TrainingData/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDL_BG-WT-TC-ET-flair_t1ce_t2-seg+0_seg+1+2+4_seg+1+4_seg+4/200-1/8120341116777169704-5-4 => Done: 5; Failed: 0; Training: 0; Start: 0
# Schedule training tasks
from bca.bca.scheduler import schedule, schedule_clean
task_folder="results"
schedule_clean(task_folder=task_folder) # Cleanup failed task (assuming issues have been fixed)
schedule(task_folder=task_folder) # Schedule task
All tasks complete.
Evaluation¶
This runs the evaluation of the trained models locally, for any completely trained model not yet evaluated. Only models trained completely are evaluated; everything else is ignored.
# Evaluate model, including standardizing the evaluation as different models may have different outputs
import numpy as np
# Output standardisation function
def std_eval_BG_WT_TC_ET(P, Y):
# Convert prediction and trained output to standardized output
# Here we assuming output channels are WT, TC, ET.
P[0] = np.where(P[0] >= 0.5, 1.0, 0.0).astype(np.float32)
return {
'WT': [P[0][...,1], Y[0][...,1]],
'TC': [P[0][...,2], Y[0][...,2]],
'ET': [P[0][...,3], Y[0][...,3]]
}
# Compute performance metrics
from bca.bca.trainer import dsc, hd95, sensitivity, specificity
# Evaluate global and per-sample using standardised output.
# This is slow as it collects overall, per-channel and standardised channel
# performance metrics. Note, adjust std_eval according to output_channels (from
# one of the options above or maybe needs a new function).
trainer.eval(seqs, mode="best", fs=[dsc,hd95,sensitivity,specificity], std_eval=std_eval_BG_WT_TC_ET)
# Save model architecture
trainer.plot_model(seqs, save_only=True)
* Fold 1 * Fold 2 * Fold 3 * Fold 4 * Fold 5 Evaluation complete.
Results¶
This presents the resuls of the training/evaluation.
Note that below the "prime" (DSC', val_DSC', etc) values are per-sample metrics (and then averaged, etc), while the others are the metrics used for training (per batch values, and then averaged, etc). The STD values are the results from the standardised metrics using the standardised outputs with the std_eval function above, making the values better comparable (for the dataset they were used on). They are based on per-sample values.
import numpy as np
import pandas as pd
from IPython import display
pd.set_option('display.max_columns', None)
# Show performance results
results = trainer.plot_results(seqs)
| dscBG | val_dscBG | dscET | val_dscET | dscTC | val_dscTC | dscWT | val_dscWT | loss | val_loss | DSC' | val_DSC' | DSC_c0' | val_DSC_c0' | DSC_c1' | val_DSC_c1' | DSC_c2' | val_DSC_c2' | DSC_c3' | val_DSC_c3' | HD95' | val_HD95' | HD95_c0' | val_HD95_c0' | HD95_c1' | val_HD95_c1' | HD95_c2' | val_HD95_c2' | HD95_c3' | val_HD95_c3' | Sensitivity' | val_Sensitivity' | Sensitivity_c0' | val_Sensitivity_c0' | Sensitivity_c1' | val_Sensitivity_c1' | Sensitivity_c2' | val_Sensitivity_c2' | Sensitivity_c3' | val_Sensitivity_c3' | Specificity' | val_Specificity' | Specificity_c0' | val_Specificity_c0' | Specificity_c1' | val_Specificity_c1' | Specificity_c2' | val_Specificity_c2' | Specificity_c3' | val_Specificity_c3' | STD-ET-DSC | val_STD-ET-DSC | STD-ET-HD95 | val_STD-ET-HD95 | STD-ET-Sensitivity | val_STD-ET-Sensitivity | STD-ET-Specificity | val_STD-ET-Specificity | STD-TC-DSC | val_STD-TC-DSC | STD-TC-HD95 | val_STD-TC-HD95 | STD-TC-Sensitivity | val_STD-TC-Sensitivity | STD-TC-Specificity | val_STD-TC-Specificity | STD-WT-DSC | val_STD-WT-DSC | STD-WT-HD95 | val_STD-WT-HD95 | STD-WT-Sensitivity | val_STD-WT-Sensitivity | STD-WT-Specificity | val_STD-WT-Specificity | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Fold 1 | 0.995669 | 0.995592 | 0.727890 | 0.644266 | 0.698643 | 0.692800 | 0.680890 | 0.700147 | 0.330149 | 0.365428 | 0.633566 | 0.621650 | 0.987115 | 0.986964 | 0.545088 | 0.555856 | 0.497512 | 0.491605 | 0.504550 | 0.452175 | 0.03 | 3.000000e-02 | 0.000693 | 0.000988 | 0.788845 | 0.793519 | 0.926406 | 0.921075 | 0.944972 | 0.942603 | 0.647979 | 0.652017 | 0.979180 | 0.979106 | 0.535131 | 0.552489 | 0.478360 | 0.478304 | 0.599243 | 0.598171 | 2.051961e-09 | 1.606637e-09 | 5.046229e-09 | 3.172641e-09 | 3.332029e-10 | 3.278340e-10 | 9.667223e-10 | 9.652201e-10 | 1.862028e-09 | 1.960536e-09 | 0.753369 | 0.690347 | 0.028744 | 0.036594 | 0.811871 | 0.765282 | 0.997522 | 0.997813 | 0.490425 | 0.497893 | 0.048686 | 0.059895 | 0.823586 | 0.806006 | 0.987405 | 0.987219 | 0.637306 | 0.650561 | 0.077158 | 0.073205 | 0.805219 | 0.812999 | 0.976988 | 0.976508 |
| Fold 2 | 0.995031 | 0.994365 | 0.736259 | 0.653646 | 0.686533 | 0.659010 | 0.636079 | 0.644860 | 0.340067 | 0.384433 | 0.628637 | 0.617805 | 0.987368 | 0.986672 | 0.518836 | 0.527668 | 0.493432 | 0.485678 | 0.514913 | 0.471202 | 0.03 | 3.000000e-02 | 0.000663 | 0.001109 | 0.791011 | 0.784886 | 0.930938 | 0.903010 | 0.949914 | 0.921821 | 0.657464 | 0.650079 | 0.978942 | 0.978783 | 0.507215 | 0.506383 | 0.491853 | 0.469393 | 0.651847 | 0.645756 | 2.097967e-09 | 1.764390e-09 | 5.634752e-09 | 4.433572e-09 | 3.432406e-10 | 3.541310e-10 | 8.926030e-10 | 8.220422e-10 | 1.521461e-09 | 1.447905e-09 | 0.753741 | 0.663000 | 0.028536 | 0.057720 | 0.738971 | 0.653020 | 0.998522 | 0.998105 | 0.464268 | 0.456880 | 0.053147 | 0.094069 | 0.772165 | 0.758935 | 0.988247 | 0.986513 | 0.598298 | 0.606395 | 0.084547 | 0.104652 | 0.756469 | 0.766663 | 0.975446 | 0.973598 |
| Fold 3 | 0.025347 | 0.025446 | 0.736394 | 0.756288 | 0.005197 | 0.004234 | 0.788923 | 0.749593 | 0.522933 | 0.522844 | 0.301525 | 0.298589 | 0.025023 | 0.025120 | 0.633551 | 0.599809 | 0.005046 | 0.004115 | 0.542478 | 0.565310 | 0.03 | 3.000000e-02 | 0.000868 | 0.000290 | 0.791397 | 0.783350 | 0.921266 | 0.941559 | 0.941298 | 0.957086 | 0.353144 | 0.341026 | 0.012676 | 0.012725 | 0.619844 | 0.582340 | 0.085506 | 0.080059 | 0.694547 | 0.688981 | 5.894288e-09 | 5.585175e-09 | 2.141827e-08 | 2.019783e-08 | 4.160438e-10 | 4.061326e-10 | 5.105731e-12 | 5.090930e-12 | 1.737759e-09 | 1.731661e-09 | 0.728456 | 0.755516 | 0.031860 | 0.044724 | 0.815017 | 0.846326 | 0.997151 | 0.997492 | 0.004366 | 0.003594 | 0.925113 | 0.945232 | 0.002290 | 0.001866 | 0.544967 | 0.534979 | 0.743473 | 0.700638 | 0.038728 | 0.059432 | 0.956511 | 0.945238 | 0.980205 | 0.980673 |
| Fold 4 | 0.995885 | 0.995863 | 0.724875 | 0.730182 | 0.706891 | 0.644624 | 0.679283 | 0.646075 | 0.327905 | 0.353598 | 0.653286 | 0.627810 | 0.989190 | 0.989118 | 0.570890 | 0.527832 | 0.527647 | 0.470438 | 0.525417 | 0.523851 | 0.03 | 3.000000e-02 | 0.000868 | 0.000290 | 0.780621 | 0.826297 | 0.918945 | 0.950810 | 0.938177 | 0.968308 | 0.660921 | 0.632250 | 0.983182 | 0.981768 | 0.539962 | 0.525002 | 0.484694 | 0.448956 | 0.635844 | 0.573274 | 1.672320e-09 | 2.353631e-09 | 2.822241e-09 | 5.579681e-09 | 4.091735e-10 | 3.746585e-10 | 1.273451e-09 | 1.220659e-09 | 2.184386e-09 | 2.241978e-09 | 0.753021 | 0.739990 | 0.023493 | 0.036075 | 0.788955 | 0.821081 | 0.997995 | 0.997761 | 0.497175 | 0.423033 | 0.051051 | 0.062215 | 0.862810 | 0.775217 | 0.986984 | 0.988266 | 0.642489 | 0.607614 | 0.084408 | 0.131548 | 0.816958 | 0.755805 | 0.976226 | 0.980032 |
| Fold 5 | 0.995945 | 0.994798 | 0.627692 | 0.633902 | 0.716158 | 0.696756 | 0.715793 | 0.704773 | 0.360189 | 0.366297 | 0.630206 | 0.632980 | 0.988066 | 0.987029 | 0.577642 | 0.587442 | 0.516710 | 0.511620 | 0.438405 | 0.445828 | 0.03 | 3.000000e-02 | 0.000669 | 0.001090 | 0.797013 | 0.760435 | 0.929113 | 0.910006 | 0.948121 | 0.930530 | 0.636356 | 0.624267 | 0.980945 | 0.981257 | 0.581402 | 0.565254 | 0.488110 | 0.484591 | 0.494967 | 0.465967 | 1.844750e-09 | 1.562606e-09 | 3.605955e-09 | 2.675088e-09 | 3.659457e-10 | 3.826988e-10 | 1.178206e-09 | 1.135022e-09 | 2.228948e-09 | 2.057526e-09 | 0.658653 | 0.635264 | 0.039802 | 0.053960 | 0.873874 | 0.854362 | 0.996457 | 0.995637 | 0.533194 | 0.517545 | 0.057547 | 0.077982 | 0.890215 | 0.851809 | 0.987967 | 0.985329 | 0.676796 | 0.666330 | 0.070113 | 0.071571 | 0.821131 | 0.826828 | 0.979431 | 0.975578 |
| Mean | 0.801576 | 0.801213 | 0.710622 | 0.683657 | 0.562684 | 0.539485 | 0.700194 | 0.689090 | 0.376248 | 0.398520 | 0.569444 | 0.559767 | 0.795352 | 0.794981 | 0.569201 | 0.559722 | 0.408069 | 0.392691 | 0.505153 | 0.491673 | 0.03 | 3.000000e-02 | 0.000752 | 0.000753 | 0.789777 | 0.789698 | 0.925334 | 0.925292 | 0.944496 | 0.944070 | 0.591173 | 0.579928 | 0.786985 | 0.786728 | 0.556711 | 0.546293 | 0.405705 | 0.392261 | 0.615290 | 0.594430 | 2.712257e-09 | 2.574488e-09 | 7.705489e-09 | 7.211763e-09 | 3.735213e-10 | 3.690910e-10 | 8.632177e-10 | 8.296069e-10 | 1.906916e-09 | 1.887921e-09 | 0.729448 | 0.696823 | 0.030487 | 0.045815 | 0.805738 | 0.788014 | 0.997529 | 0.997362 | 0.397886 | 0.379789 | 0.227109 | 0.247879 | 0.670213 | 0.638767 | 0.899114 | 0.896461 | 0.659672 | 0.646308 | 0.070991 | 0.088082 | 0.831258 | 0.821507 | 0.977659 | 0.977278 |
| Std | 0.388114 | 0.387884 | 0.041714 | 0.049735 | 0.278914 | 0.268355 | 0.051063 | 0.039586 | 0.074223 | 0.062938 | 0.134252 | 0.130692 | 0.385165 | 0.384931 | 0.038297 | 0.029782 | 0.201899 | 0.194736 | 0.035650 | 0.045910 | 0.00 | 1.551584e-18 | 0.000095 | 0.000381 | 0.005314 | 0.021325 | 0.004565 | 0.018241 | 0.004310 | 0.016955 | 0.119319 | 0.119913 | 0.387157 | 0.387003 | 0.039475 | 0.027362 | 0.160161 | 0.156564 | 0.067492 | 0.075549 | 1.598293e-09 | 1.531660e-09 | 6.928957e-09 | 6.571391e-09 | 3.370202e-11 | 2.651110e-11 | 4.506533e-10 | 4.346205e-10 | 2.682901e-10 | 2.747459e-10 | 0.036691 | 0.045352 | 0.005374 | 0.008822 | 0.043602 | 0.074345 | 0.000706 | 0.000884 | 0.197987 | 0.190925 | 0.349014 | 0.348893 | 0.336313 | 0.320018 | 0.177074 | 0.180744 | 0.048741 | 0.035947 | 0.016991 | 0.026376 | 0.066734 | 0.067431 | 0.001845 | 0.002688 |
# Plot the prediction results (runs tensorflow in a sub-process, so that it can be stopped
# easily; to restart, run cell again; only one can run at the same time, unless there are
# enough GPU resources).
trainer.browse_predict(seqs)
Button(description='Stop', style=ButtonStyle())
interactive(children=(IntSlider(value=1, description='idx', max=275, min=1), IntSlider(value=64, description='…