BCa Segmentation¶

LATUP-Net SE, WDS-manual, BG-NEC-EDE-ENH, BraTS2021¶

FileCopyrightText: Copyright (C) 2023-2024 Ebtihal Alwadee AlwadeeEJ@cardiff.ac.uk, PhD student at Cardiff University
FileCopyrightText: Copyright (C) 2023-2024 Frank C Langbein frank@langbein.org, Cardiff University

License-Identifier: AGPL-3.0-or-later

Dataset: BraTS2021, fixed crop, >=1% lesions
Architecture: LATUP-Net with SE attention
Loss: weighted dice score, manual weights
Output channels: BG, NEC, EDE, ENH

We start by setting up the jupyter notebook (autoloading and reload updated modules, etc) and some style options, which may be useful for display.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
In [2]:
%%html
<style>
.dataframe th { font-size: 11px; }
.dataframe td { font-size: 11px; }
</style>

Dataset Setup¶

This defines the dataset source, BraTS2010, used for training.

The BraTS2021 dataset must be in the data folder specified below. See README.md for how to set this up.

In [3]:
# We assume https://qyber.black/ca/code-bca is cloned into bca in the top folder
# and the BraTS2020 dataset is available in the folders below (see README.md).
# These have been registered as a git sub-modules (but the BraTS dataset are not
# distributed by us).
from bca.bca.dataset import Dataset

# Fixed-crop BraTS2021 dataset
ds = Dataset(os.path.join("data","brats2021","Training"),    # Dataset folder - brought in via a git submodule
                                                             #  (not public, so replace as needed)
             os.path.join("results","brats2021","Training")) # Results folder
ds.crop((56,184), (56,184), (13,141)) # Fixed centered crop
ds.filter_low_labels("seg",0.01) # Remove patients with non-background segmentation area less than 1%

# Show bounding-box cropped dataset info and slices
print(ds)
ds.browse()
# Dataset: data/brats2021/Training [1151 patients]
Cache: results/brats2021/Training
Channels (patient BraTS2021_00000):
  flair: (240, 240, 155) (int16)
  seg: (240, 240, 155) (uint8)
  t1: (240, 240, 155) (int16)
  t1ce: (240, 240, 155) (int16)
  t2: (240, 240, 155) (int16)
Crop: [(56, 184), (56, 184), (13, 141)]

interactive(children=(IntSlider(value=576, description='idx', max=1151, min=1), IntSlider(value=78, descriptio…

Train LATUP-Net model¶

Output channels and their weights for weighted Dice loss¶

Specifies the output channels and compute the weights based on inverted relative volume, overall normalized.

In [4]:
from collections import OrderedDict
from bca.bca.loss import compute_channel_weights

# Specify output channels of network in an OrderedDict.
# Order of keys determines output channel order.
# Each channel is given by a key (name of the channel) and 
# with a (list-of-labels, use-in-loss) pair. List of labels
# simply contains the original dataset labels to be used for
# the channel. Use-in-loss is a boolean indicating if we use
# the channel in the weight and consequently loss function.
output_channels = OrderedDict({
    "BG": ([0],0),
    "NEC": ([1],1),
    "EDE": ([2],1),
    "ENH": ([4],1)
  })

# Weighted Dice score as loss: adjust loss function weights to
#     W0 - \sum_c w_c DSC_c  
# where c is the output channel and DSC_c its Dice score. This
# is specific to the above output_channels spec:
#     0.5*WholeTumor           + 0.6*TumorCore      + 0.7*EnhancingTumor
#   = w1*(3-DSC_1-DSC_2-DSC_3) + w2*(2-DSC_1-DSC_3) + w3*(1-DSC_3)
#   = w1*3+w2*2+w3*1 - (w1+w2)*DSC_1 - w1*DSC_2 - (w1+w2+w3)*DSC_3 
# With Dice loss instead of score (even if we are using WDS, this is equivalent):
#   = w1*(DSL_1+DSL_2+DSL_3) + w2*(DSL_1+DSL_3) + w3*DSL_3
#   = (w1+w2) * DSL_1 + w1 * DSL_2 + (w1+w2+w3) * DSL_3
# This is without using the background channel, even if it is an output channel.

w1 = 0.5
w2 = 0.6
w3 = 0.7

weight0 = w1*3+w2*2+w3*1
weight1 = w1+w2
weight2 = w1
weight3 = w1+w2+w3

print("Dice score loss function weights:")
print(f"weight0: {weight0}")
print(f"weight1: {weight1}")
print(f"weight2: {weight2}")
print(f"weight3: {weight3}")
Dice score loss function weights:
weight0: 3.4000000000000004
weight1: 1.1
weight2: 0.5
weight3: 1.8

LATUP-Net Model¶

We first setup the model and then train it.

In [5]:
from bca.bca.latupnet import LATUPNet
from bca.bca.trainer import Trainer
from bca.bca.loss import WeightedDiceScore
from bca.bca.metric import sDSC

# Note, the model name determines were model/results are stored,
# so should be different for different models if we preserve them
# Further note, the WeightedDiceScore weights are output_channel specified
# as computed above.
model = LATUPNet(name="LATUPNet-SE_WDSm_BG-NEC-EDE-ENH",
                 attention="SE",
                 loss=WeightedDiceScore(weight0,
                                        [(0.0, 0),
                                         (weight1, 1),
                                         (weight2, 2),
                                         (weight3, 3)]),
                 metrics=lambda chs=list(output_channels.keys()): [sDSC(channel=n, name="dsc"+ch) for n, ch in enumerate(chs)])

# Plot model architecture (or show text summary if text is True)
model.plot((128,128,128,3,len(output_channels)),text=False) # 3D Spatial dim, number of inputs (flair, t1ce, t2) and number of classes
Segmentation Models: using `tf.keras` framework.
No description has been provided for this image
In [6]:
# Parameters determining the dataset sequences
K = 5                      # 5-fold cross validation
EPOCHS = 200               # 200 epochs
SEED = 8120341116777169704 # Seed for dataset 5-fold split
BATCH_SIZE = 1             # Batch size for training

# Generate a list of K keras (train,test)-pair sequences for dataset splits
# We use a 128x128x128 input image with flair, t1ce and t2 channels. The 
# output channels have been specified above and just used here.
seqs = ds.sequences(K, (128,128,128),
                    ["flair", "t1ce", "t2"],
                    ["seg+"+"+".join([str(n) for n in output_channels[ch][0]]) for ch in output_channels.keys()],
                    batch_size=BATCH_SIZE,
                    pre_proc=Dataset.norm_minmax,
                    seed=SEED)

# Show training sequence of first fold
seqs[0][0].browse()

# Setup trainer for the model/data
trainer = Trainer(model, epochs=EPOCHS)
interactive(children=(IntSlider(value=460, description='idx', max=920, min=1), IntSlider(value=64, description…

We are ready to train the models.

For remote execution remote=True should be set in the train run and then the following code can be used to schedule the tasks. Note, for this the remote schedulers must be defined in cfg.json for the bca package.

In [7]:
# Train the networks
#  Specifying loss,loss for the best monitor uses the test loss alone to store best model
results = trainer.train(seqs, jit_compile=True, remote=True)
* Fold 1: done - results/brats2021/Training/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDSm_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-5-0
* Fold 2: done - results/brats2021/Training/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDSm_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-5-1
* Fold 3: done - results/brats2021/Training/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDSm_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-5-2
* Fold 4: done - results/brats2021/Training/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDSm_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-5-3
* Fold 5: done - results/brats2021/Training/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDSm_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-5-4
=> Done: 5; Failed: 0; Training: 0; Start: 0
In [8]:
# Schedule training tasks
from bca.bca.scheduler import schedule, schedule_clean
task_folder="results"
schedule_clean(task_folder=task_folder) # Cleanup failed task (assuming issues have been fixed)
schedule(task_folder=task_folder)       # Schedule tasks
All tasks complete.

Evaluation¶

This runs the evaluation of the trained models locally, for any completely trained model not yet evaluated. Only models trained completely are evaluated; everything else is ignored.

In [9]:
# Evaluate model, including standardizing the evaluation as different models may have different outputs

import numpy as np

# Output standardisation function
def std_eval_BG_NEC_EDE_ENH(P, Y):
  # Convert prediction and trained output to standardized output.
  # We assume output channels are Necrotic, Edema, Enhancing, in 
  # this order. As regions are combined we threshold the predicted
  # masks and clip sums.
  P[0] = np.where(P[0] >= 0.5, 1.0, 0.0).astype(np.float32)
  return {
    'WT': [np.clip(P[0][...,1]+P[0][...,2]+P[0][...,3],0.0,1.0), np.clip(Y[0][...,1]+Y[0][...,2]+Y[0][...,3],0.0,1.0)],
    'TC': [np.clip(P[0][...,1]            +P[0][...,3],0.0,1.0), np.clip(Y[0][...,1]            +Y[0][...,3],0.0,1.0)],
    'ET': [np.clip(                        P[0][...,3],0.0,1.0), np.clip(                        Y[0][...,3],0.0,1.0)]
  }
In [10]:
# Compute performance metrics

from bca.bca.trainer import dsc, hd95, sensitivity, specificity

# Evaluate global and per-sample using standardised output.
# This is slow as it collects overall, per-channel and standardised channel 
# performance metrics. Note, adjust std_eval according to output_channels (from
# one of the options above or maybe needs a new function).
trainer.eval(seqs, mode="best", fs=[dsc,hd95,sensitivity,specificity], std_eval=std_eval_BG_NEC_EDE_ENH)
# Save model architecture
trainer.plot_model(seqs, save_only=True)
* Fold 1
* Fold 2
* Fold 3
* Fold 4
* Fold 5
Evaluation complete.

Results¶

This presents the resuls of the training/evaluation.

Note that below the "prime" (DSC', val_DSC', etc) values are per-sample metrics (and then averaged, etc), while the others are the metrics used for training (per batch values, and then averaged, etc).

In [11]:
import numpy as np
import pandas as pd
from IPython import display

pd.set_option('display.max_columns', None)

# Show performance results
results = trainer.plot_results(seqs)
No description has been provided for this image
dscBG val_dscBG dscEDE val_dscEDE dscENH val_dscENH dscNEC val_dscNEC loss val_loss DSC' val_DSC' DSC_c0' val_DSC_c0' DSC_c1' val_DSC_c1' DSC_c2' val_DSC_c2' DSC_c3' val_DSC_c3' HD95' val_HD95' HD95_c0' val_HD95_c0' HD95_c1' val_HD95_c1' HD95_c2' val_HD95_c2' HD95_c3' val_HD95_c3' Sensitivity' val_Sensitivity' Sensitivity_c0' val_Sensitivity_c0' Sensitivity_c1' val_Sensitivity_c1' Sensitivity_c2' val_Sensitivity_c2' Sensitivity_c3' val_Sensitivity_c3' Specificity' val_Specificity' Specificity_c0' val_Specificity_c0' Specificity_c1' val_Specificity_c1' Specificity_c2' val_Specificity_c2' Specificity_c3' val_Specificity_c3' STD-ET-DSC val_STD-ET-DSC STD-ET-HD95 val_STD-ET-HD95 STD-ET-Sensitivity val_STD-ET-Sensitivity STD-ET-Specificity val_STD-ET-Specificity STD-TC-DSC val_STD-TC-DSC STD-TC-HD95 val_STD-TC-HD95 STD-TC-Sensitivity val_STD-TC-Sensitivity STD-TC-Specificity val_STD-TC-Specificity STD-WT-DSC val_STD-WT-DSC STD-WT-HD95 val_STD-WT-HD95 STD-WT-Sensitivity val_STD-WT-Sensitivity STD-WT-Specificity val_STD-WT-Specificity
Fold 1 0.997067 0.996743 0.851743 0.827488 0.872404 0.865756 0.811303 0.767460 0.599276 0.671598 0.789410 0.772337 0.993161 0.992811 0.676388 0.638176 0.717828 0.691543 0.770262 0.766816 0.03 0.03 0.000587 0.000703 0.976445 0.978856 0.802062 0.803616 0.941690 0.936823 0.854968 0.838048 0.992788 0.992356 0.793066 0.761653 0.752804 0.736081 0.881213 0.862100 2.736592e-09 2.805653e-09 1.655317e-09 2.237184e-09 5.621872e-09 5.399412e-09 7.684794e-10 7.396539e-10 2.900711e-09 2.846254e-09 0.853254 0.841468 0.017371 0.027014 0.840528 0.829494 0.999115 0.998909 0.921492 0.882578 0.013570 0.025531 0.924165 0.882702 0.998668 0.998549 0.908798 0.898704 0.019428 0.025980 0.951813 0.937901 0.993714 0.993375
Fold 2 0.996338 0.996436 0.828626 0.818769 0.844324 0.819963 0.794776 0.759160 0.685998 0.773956 0.774433 0.762601 0.992495 0.992639 0.665475 0.641035 0.703980 0.695564 0.735783 0.721165 0.03 0.03 0.000662 0.000402 0.975362 0.983383 0.800872 0.808382 0.939415 0.945815 0.828227 0.817736 0.992866 0.992789 0.803756 0.791217 0.753867 0.753777 0.762419 0.733163 2.829100e-09 2.859312e-09 1.445461e-09 1.466185e-09 5.023470e-09 4.955461e-09 7.081631e-10 6.884853e-10 4.139295e-09 4.327110e-09 0.827744 0.803967 0.023315 0.028380 0.912226 0.915638 0.997840 0.997928 0.887065 0.854021 0.018644 0.025881 0.959740 0.943976 0.996942 0.996829 0.886366 0.881078 0.026340 0.030253 0.951302 0.949250 0.991785 0.992217
Fold 3 0.996980 0.996724 0.848720 0.832545 0.858340 0.833141 0.803474 0.775143 0.632234 0.716845 0.773902 0.762099 0.992442 0.992076 0.650190 0.638328 0.705440 0.691816 0.747536 0.726178 0.03 0.03 0.000530 0.000932 0.977037 0.976534 0.804038 0.795680 0.939737 0.944533 0.871400 0.860952 0.989330 0.989275 0.789913 0.771148 0.804557 0.792241 0.901800 0.891143 2.441952e-09 2.257222e-09 2.833335e-09 2.360367e-09 4.387782e-09 4.190890e-09 5.452996e-10 5.304022e-10 2.001346e-09 1.947216e-09 0.834146 0.808603 0.019059 0.042413 0.793397 0.774272 0.999175 0.998939 0.907559 0.876248 0.013208 0.032944 0.893558 0.872656 0.998649 0.997326 0.905820 0.901817 0.025586 0.026922 0.905445 0.906377 0.995584 0.994797
Fold 4 0.997049 0.996642 0.857172 0.846042 0.876835 0.844452 0.818587 0.758842 0.582752 0.712325 0.798201 0.776854 0.993388 0.992858 0.691764 0.642417 0.733617 0.727972 0.774034 0.744171 0.03 0.03 0.000579 0.000734 0.977697 0.973883 0.803099 0.799474 0.940628 0.940997 0.864053 0.850201 0.992070 0.991681 0.782247 0.750522 0.813026 0.804967 0.868870 0.853635 3.271533e-09 3.263570e-09 2.644652e-09 2.632091e-09 6.775927e-09 6.727313e-09 6.806372e-10 6.519818e-10 2.984916e-09 3.042914e-09 0.857249 0.827675 0.018310 0.029782 0.853392 0.831718 0.999020 0.998829 0.919890 0.884456 0.013260 0.032526 0.934728 0.912289 0.998249 0.997517 0.912412 0.898801 0.018045 0.024904 0.928501 0.916047 0.994816 0.994130
Fold 5 0.996918 0.996959 0.848211 0.840660 0.863606 0.851793 0.798684 0.782862 0.637115 0.679559 0.790005 0.791187 0.993083 0.993190 0.682012 0.685907 0.718192 0.715672 0.766733 0.769981 0.03 0.03 0.000693 0.000279 0.978153 0.972129 0.801797 0.804683 0.942030 0.935342 0.858433 0.850147 0.992264 0.992364 0.800171 0.781201 0.793005 0.780964 0.848292 0.846060 3.162221e-09 2.976177e-09 2.130569e-09 2.127705e-09 6.413324e-09 6.066608e-09 6.715817e-10 6.775811e-10 3.433433e-09 3.032654e-09 0.842977 0.828669 0.022202 0.030903 0.854939 0.838303 0.998553 0.998552 0.901828 0.896070 0.017404 0.028534 0.923379 0.916730 0.997645 0.997778 0.906932 0.907809 0.020421 0.026826 0.936149 0.937920 0.993947 0.994038
Mean 0.996871 0.996701 0.846894 0.833101 0.863102 0.843021 0.805365 0.768693 0.627475 0.710857 0.785190 0.773016 0.992914 0.992715 0.673166 0.649173 0.715811 0.704513 0.758870 0.745662 0.03 0.03 0.000610 0.000610 0.976939 0.976957 0.802374 0.802367 0.940700 0.940702 0.855416 0.843417 0.991864 0.991693 0.793830 0.771148 0.783452 0.773606 0.852519 0.837220 2.888280e-09 2.832387e-09 2.141867e-09 2.164706e-09 5.644475e-09 5.467937e-09 6.748322e-10 6.576208e-10 3.091940e-09 3.039230e-09 0.843074 0.822076 0.020051 0.031698 0.850896 0.837885 0.998741 0.998632 0.907567 0.878675 0.015217 0.029083 0.927114 0.905671 0.998031 0.997600 0.904066 0.897642 0.021964 0.026977 0.934642 0.929499 0.993969 0.993711
Std 0.000272 0.000169 0.009674 0.009611 0.011408 0.015653 0.008607 0.009293 0.035597 0.036167 0.009523 0.010705 0.000377 0.000366 0.014294 0.018438 0.010717 0.014726 0.014718 0.020104 0.00 0.00 0.000059 0.000237 0.000979 0.003946 0.001094 0.004386 0.001033 0.004115 0.014688 0.014745 0.001302 0.001260 0.007600 0.014279 0.025400 0.025244 0.048281 0.054226 2.992146e-10 3.282831e-10 5.390290e-10 3.876303e-10 8.759849e-10 8.766539e-10 7.308037e-11 6.970915e-11 7.001409e-10 7.600563e-10 0.011129 0.013859 0.002301 0.005515 0.037941 0.045168 0.000501 0.000377 0.012641 0.013892 0.002328 0.003160 0.021312 0.025499 0.000659 0.000567 0.009128 0.008918 0.003360 0.001792 0.017116 0.015794 0.001277 0.000872
In [12]:
# Plot the prediction results (runs tensorflow in a sub-process, so that it can be stopped
# easily; to  restart, run cell again; only one can run at the same time, unless there are
# enough GPU resources).
trainer.browse_predict(seqs) 
Button(description='Stop', style=ButtonStyle())
interactive(children=(IntSlider(value=1, description='idx', max=920, min=1), IntSlider(value=64, description='…