BCa Segmentation¶

LATUP-Net SE, WDS, BG-NEC-EDE-ENH, BraTS2021¶

FileCopyrightText: Copyright (C) 2023-2024 Ebtihal Alwadee AlwadeeEJ@cardiff.ac.uk, PhD student at Cardiff University
FileCopyrightText: Copyright (C) 2023-2024 Frank C Langbein frank@langbein.org, Cardiff University

License-Identifier: AGPL-3.0-or-later

Dataset: BraTS2021, fixed crop, >=1% lesions
Architecture: LATUP-Net with SE attention
Loss: weighted dice score, enet weights
Output channels: BG, NEC, EDE, ENH

We start by setting up the jupyter notebook (autoloading and reload updated modules, etc) and some style options, which may be useful for display.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
In [2]:
%%html
<style>
.dataframe th { font-size: 11px; }
.dataframe td { font-size: 11px; }
</style>

Dataset Setup¶

This defines the dataset source, BraTS2021, used for training.

The BraTS2021 dataset must be in the data folder specified below. See README.md for how to set this up.

In [3]:
# We assume https://qyber.black/ca/code-bca is cloned into bca in the top folder
# and the BraTS2020 dataset is available in the folders below (see README.md).
# These have been registered as a git sub-modules (but the BraTS dataset are not
# distributed by us).
from bca.bca.dataset import Dataset

# Fixed-crop BraTS2020 dataset
ds = Dataset(os.path.join("data","brats2021","Training"),    # Dataset folder - brought in via a git submodule
                                                             #  (not public, so replace as needed)
             os.path.join("results","brats2021","Training")) # Results folder
ds.crop((56,184), (56,184), (13,141)) # Fixed centered crop
ds.filter_low_labels("seg",0.01) # Remove patients with non-background segmentation area less than 1%

# Show bounding-box cropped dataset info and slices
print(ds)
ds.browse()
# Dataset: data/brats2021/Training [1151 patients]
Cache: results/brats2021/Training
Channels (patient BraTS2021_00000):
  flair: (240, 240, 155) (int16)
  seg: (240, 240, 155) (uint8)
  t1: (240, 240, 155) (int16)
  t1ce: (240, 240, 155) (int16)
  t2: (240, 240, 155) (int16)
Crop: [(56, 184), (56, 184), (13, 141)]

interactive(children=(IntSlider(value=576, description='idx', max=1151, min=1), IntSlider(value=78, descriptio…

Train LATUP-Net model¶

Output channels and their weights for weighted Dice loss¶

Specifies the output channels and compute the weights based on inverted relative volume, overall normalized.

In [4]:
from collections import OrderedDict
from bca.bca.loss import compute_channel_weights

# Specify output channels of network in an OrderedDict.
# Order of keys determines output channel order.
# Each channel is given by a key (name of the channel) and 
# with a (list-of-labels, use-in-loss) pair. List of labels
# simply contains the original dataset labels to be used for
# the channel. Use-in-loss is a boolean indicating if we use
# the channel in the weight and consequently loss function.
output_channels = OrderedDict({
    "BG": ([0],0),
    "NEC": ([1],1),
    "EDE": ([2],1),
    "ENH": ([4],1)
  })

# Specify channels for loss weights in an OrderedDict.
weight_channels = OrderedDict({
    "WT": ([1,2,4],1),
    "TC": ([1,4],1),
    "ET": ([4],1)
  })
# ENet weights, not normalised
weights_enet = compute_channel_weights(ds, "seg", weight_channels, mode="enet", enet_c=1.22, normalise=False)

# Weighted Dice score as loss: adjust loss function weights to
#     W0 - \sum_c w_c DSC_c  
# where c is the output channel and DSC_c its Dice score. This
# is specific to the above output_channels spec:
#     0.5*WholeTumor           + 0.6*TumorCore      + 0.7*EnhancingTumor
#   = w1*(3-DSC_1-DSC_2-DSC_3) + w2*(2-DSC_1-DSC_3) + w3*(1-DSC_3)
#   = w1*3+w2*2+w3*1 - (w1+w2)*DSC_1 - w1*DSC_2 - (w1+w2+w3)*DSC_3 
# This is without using the background channel.
for we in weights_enet:
  print(f"{we} weight: {weights_enet[we][0]}")

# Map enet weights on WT, TC and ET to WDS weights
w1 = weights_enet["WT"][0]
w2 = weights_enet["TC"][0]
w3 = weights_enet["ET"][0]

weight0 = w1*3+w2*2+w3*1
weight1 = w1+w2
weight2 = w1
weight3 = w1+w2+w3

print("Dice score WDS loss function weights:")
print(f"weight0: {weight0}")
print(f"weight1: {weight1}")
print(f"weight2: {weight2}")
print(f"weight3: {weight3}")
WT weight: 1.631778162179294
TC weight: 2.6674780079425817
ET weight: 3.2558520361583145
Dice score WDS loss function weights:
weight0: 13.48614253858136
weight1: 4.299256170121875
weight2: 1.631778162179294
weight3: 7.55510820628019

LATUP-Net Model¶

We first setup the model and then train it.

In [5]:
from bca.bca.latupnet import LATUPNet
from bca.bca.trainer import Trainer
from bca.bca.loss import WeightedDiceScore
from bca.bca.metric import sDSC

# Note, the model name determines were model/results are stored,
# so should be different for different models if we preserve them
# Further note, the WeightedDiceScore weights are output_channel specified
# as computed above.
model = LATUPNet(name="LATUPNet-SE_WDS_BG-NEC-EDE-ENH",
                 attention="SE",
                 loss=WeightedDiceScore(weight0,
                                        [(0.0, 0),
                                         (weight1, 1),
                                         (weight2, 2),
                                         (weight3, 3)]),
                 metrics=lambda chs=list(output_channels.keys()): [sDSC(channel=n, name="dsc"+ch) for n, ch in enumerate(chs)])

# Plot model architecture (or show text summary if text is True)
model.plot((128,128,128,3,len(output_channels)),text=False) # 3D Spatial dim, number of inputs (flair, t1ce, t2) and number of classes
Segmentation Models: using `tf.keras` framework.
No description has been provided for this image
In [6]:
# Parameters determining the dataset sequences
K = 5                      # 5-fold cross validation
EPOCHS = 200               # 200 epochs
SEED = 8120341116777169704 # Seed for dataset 5-fold split
BATCH_SIZE = 1             # Batch size for training

# Generate a list of K keras (train,test)-pair sequences for dataset splits
# We use a 128x128x128 input image with flair, t1ce and t2 channels. The 
# output channels have been specified above and just used here.
seqs = ds.sequences(K, (128,128,128),
                    ["flair", "t1ce", "t2"],
                    ["seg+"+"+".join([str(n) for n in output_channels[ch][0]]) for ch in output_channels.keys()],
                    batch_size=BATCH_SIZE,
                    pre_proc=Dataset.norm_minmax,
                    seed=SEED)

# Show training sequence of first fold
seqs[0][0].browse()

# Setup trainer for the model/data
trainer = Trainer(model, epochs=EPOCHS)
interactive(children=(IntSlider(value=460, description='idx', max=920, min=1), IntSlider(value=64, description…

We are ready to train the models.

For remote execution remote=True should be set in the train run and then the following code can be used to schedule the tasks. Note, for this the remote schedulers must be defined in cfg.json for the bca package.

In [7]:
# Train the networks
#  Specifying loss,loss for the best monitor uses the test loss alone to store best model
results = trainer.train(seqs, jit_compile=True, remote=True)
* Fold 1: done - results/brats2021/Training/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDS_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-5-0
* Fold 2: done - results/brats2021/Training/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDS_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-5-1
* Fold 3: done - results/brats2021/Training/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDS_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-5-2
* Fold 4: done - results/brats2021/Training/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDS_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-5-3
* Fold 5: done - results/brats2021/Training/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-SE_WDS_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-5-4
=> Done: 5; Failed: 0; Training: 0; Start: 0
In [8]:
# Schedule training tasks
from bca.bca.scheduler import schedule, schedule_clean
task_folder="results"
schedule_clean(task_folder=task_folder) # Cleanup failed task (assuming issues have been fixed)
schedule(task_folder=task_folder)       # Schedule tasks
All tasks complete.

Evaluation¶

This runs the evaluation of the trained models locally, for any completely trained model not yet evaluated. Only models trained completely are evaluated; everything else is ignored.

In [9]:
# Evaluate model, including standardizing the evaluation as different models may have different outputs

import numpy as np

# Output standardisation function
def std_eval_BG_NEC_EDE_ENH(P, Y):
  # Convert prediction and trained output to standardized output.
  # We assume output channels are Necrotic, Edema, Enhancing, in 
  # this order. As regions are combined we threshold the predicted
  # masks and clip sums.
  P[0] = np.where(P[0] >= 0.5, 1.0, 0.0).astype(np.float32)
  return {
    'WT': [np.clip(P[0][...,1]+P[0][...,2]+P[0][...,3],0.0,1.0), np.clip(Y[0][...,1]+Y[0][...,2]+Y[0][...,3],0.0,1.0)],
    'TC': [np.clip(P[0][...,1]            +P[0][...,3],0.0,1.0), np.clip(Y[0][...,1]            +Y[0][...,3],0.0,1.0)],
    'ET': [np.clip(                        P[0][...,3],0.0,1.0), np.clip(                        Y[0][...,3],0.0,1.0)]
  }
In [10]:
# Compute performance metrics

from bca.bca.trainer import dsc, hd95, sensitivity, specificity

# Evaluate global and per-sample using standardised output.
# This is slow as it collects overall, per-channel and standardised channel 
# performance metrics. Note, adjust std_eval according to output_channels (from
# one of the options above or maybe needs a new function).
trainer.eval(seqs, mode="best", fs=[dsc,hd95,sensitivity,specificity], std_eval=std_eval_BG_NEC_EDE_ENH)
# Save model architecture
trainer.plot_model(seqs, save_only=True)
* Fold 1
* Fold 2
* Fold 3
* Fold 4
* Fold 5
Evaluation complete.

Results¶

This presents the resuls of the training/evaluation.

Note that below the "prime" (DSC', val_DSC', etc) values are per-sample metrics (and then averaged, etc), while the others are the metrics used for training (per batch values, and then averaged, etc).

In [11]:
import numpy as np
import pandas as pd
from IPython import display

pd.set_option('display.max_columns', None)

# Show performance results
results = trainer.plot_results(seqs)
No description has been provided for this image
dscBG val_dscBG dscEDE val_dscEDE dscENH val_dscENH dscNEC val_dscNEC loss val_loss DSC' val_DSC' DSC_c0' val_DSC_c0' DSC_c1' val_DSC_c1' DSC_c2' val_DSC_c2' DSC_c3' val_DSC_c3' HD95' val_HD95' HD95_c0' val_HD95_c0' HD95_c1' val_HD95_c1' HD95_c2' val_HD95_c2' HD95_c3' val_HD95_c3' Sensitivity' val_Sensitivity' Sensitivity_c0' val_Sensitivity_c0' Sensitivity_c1' val_Sensitivity_c1' Sensitivity_c2' val_Sensitivity_c2' Sensitivity_c3' val_Sensitivity_c3' Specificity' val_Specificity' Specificity_c0' val_Specificity_c0' Specificity_c1' val_Specificity_c1' Specificity_c2' val_Specificity_c2' Specificity_c3' val_Specificity_c3' STD-ET-DSC val_STD-ET-DSC STD-ET-HD95 val_STD-ET-HD95 STD-ET-Sensitivity val_STD-ET-Sensitivity STD-ET-Specificity val_STD-ET-Specificity STD-TC-DSC val_STD-TC-DSC STD-TC-HD95 val_STD-TC-HD95 STD-TC-Sensitivity val_STD-TC-Sensitivity STD-TC-Specificity val_STD-TC-Specificity STD-WT-DSC val_STD-WT-DSC STD-WT-HD95 val_STD-WT-HD95 STD-WT-Sensitivity val_STD-WT-Sensitivity STD-WT-Specificity val_STD-WT-Specificity
Fold 1 0.996743 0.996314 0.849419 0.819485 0.872379 0.865779 0.818300 0.771504 2.281749 2.581642 0.817946 0.798754 0.993720 0.993276 0.730677 0.687577 0.747165 0.720370 0.800221 0.793793 0.03 0.03 0.000587 0.000703 0.976445 0.978856 0.802062 0.803616 0.941690 0.936823 0.877083 0.856179 0.991657 0.990982 0.811333 0.766920 0.844849 0.828312 0.860493 0.838500 5.264700e-09 5.326098e-09 3.525680e-09 4.740193e-09 1.174866e-08 1.094228e-08 6.464572e-10 6.265988e-10 5.138088e-09 4.995343e-09 0.854141 0.846378 0.024546 0.032558 0.866339 0.859896 0.998870 0.998623 0.919358 0.884816 0.015765 0.026667 0.928446 0.895131 0.998308 0.998225 0.901473 0.887413 0.026605 0.039057 0.886001 0.866613 0.995970 0.995740
Fold 2 0.997172 0.997050 0.862925 0.850935 0.882923 0.861064 0.823485 0.778158 2.162760 2.542337 0.810994 0.798136 0.993766 0.993715 0.700406 0.672812 0.749052 0.739309 0.800753 0.786707 0.03 0.03 0.000662 0.000402 0.975362 0.983383 0.800872 0.808382 0.939415 0.945815 0.876678 0.862424 0.991079 0.991124 0.780312 0.755498 0.856986 0.853255 0.878336 0.849818 4.098989e-09 4.149007e-09 3.706054e-09 3.627890e-09 7.963913e-09 8.087300e-09 6.065125e-10 5.918014e-10 4.119465e-09 4.289030e-09 0.866001 0.839730 0.014517 0.024585 0.863464 0.845884 0.999120 0.999003 0.927899 0.896508 0.008763 0.015563 0.939755 0.929854 0.998463 0.998059 0.913998 0.906480 0.021014 0.019704 0.896314 0.892900 0.996744 0.996584
Fold 3 0.997242 0.996998 0.864705 0.846983 0.879148 0.848919 0.826403 0.782631 2.157285 2.602767 0.823023 0.803331 0.994194 0.993778 0.735771 0.704371 0.756621 0.738391 0.805504 0.776786 0.03 0.03 0.000530 0.000932 0.977037 0.976534 0.804038 0.795680 0.939737 0.944533 0.882898 0.869369 0.992937 0.992700 0.840676 0.808800 0.824868 0.811339 0.873112 0.864638 4.293809e-09 3.939135e-09 2.780770e-09 2.153821e-09 8.908874e-09 8.455401e-09 8.092824e-10 7.767591e-10 4.676171e-09 4.370583e-09 0.858699 0.827428 0.020446 0.049230 0.860204 0.829442 0.998876 0.998627 0.925564 0.892831 0.013544 0.032084 0.929095 0.903257 0.998686 0.997444 0.913334 0.908136 0.026483 0.037895 0.920391 0.917348 0.995643 0.994975
Fold 4 0.997315 0.996808 0.865449 0.851278 0.884234 0.852788 0.831358 0.761958 2.112667 2.671732 0.831074 0.805827 0.994827 0.994177 0.740093 0.680961 0.770382 0.759840 0.818994 0.788330 0.03 0.03 0.000579 0.000734 0.977697 0.973883 0.803099 0.799474 0.940628 0.940997 0.867267 0.849648 0.995335 0.994869 0.791420 0.743529 0.797834 0.788140 0.884478 0.872053 5.262196e-09 5.268123e-09 1.867426e-09 1.784759e-09 1.289316e-08 1.308059e-08 1.095642e-09 1.029457e-09 5.192544e-09 5.177652e-09 0.864919 0.833736 0.015145 0.022088 0.862474 0.835546 0.999014 0.998876 0.926008 0.887283 0.009367 0.022343 0.948314 0.920789 0.998299 0.997428 0.916571 0.902230 0.018863 0.027026 0.954955 0.939250 0.994336 0.993493
Fold 5 0.997080 0.997064 0.856938 0.849109 0.880667 0.869730 0.826782 0.793358 2.178144 2.417244 0.811989 0.807012 0.993374 0.993367 0.724723 0.709835 0.732526 0.727662 0.797332 0.797185 0.03 0.03 0.000693 0.000279 0.978153 0.972129 0.801797 0.804683 0.942030 0.935342 0.893016 0.879509 0.990295 0.990185 0.853770 0.816641 0.839019 0.826729 0.888979 0.884481 4.192014e-09 3.725751e-09 3.765894e-09 3.659198e-09 8.244777e-09 7.217203e-09 5.707433e-10 5.621231e-10 4.186569e-09 3.464878e-09 0.860669 0.849118 0.015087 0.024606 0.847782 0.835373 0.999123 0.999138 0.926118 0.915746 0.008891 0.025489 0.915008 0.905287 0.999039 0.999009 0.911308 0.910587 0.017560 0.027891 0.892599 0.893579 0.996672 0.996773
Mean 0.997110 0.996847 0.859887 0.843558 0.879870 0.859656 0.825265 0.777522 2.178521 2.563145 0.819005 0.802612 0.993976 0.993662 0.726334 0.691111 0.751149 0.737114 0.804561 0.788560 0.03 0.03 0.000610 0.000610 0.976939 0.976957 0.802374 0.802367 0.940700 0.940702 0.879388 0.863426 0.992261 0.991972 0.815502 0.778277 0.832711 0.821555 0.877080 0.861898 4.622342e-09 4.481623e-09 3.129165e-09 3.193172e-09 9.951876e-09 9.556555e-09 7.457275e-10 7.173478e-10 4.662567e-09 4.459497e-09 0.860886 0.839278 0.017948 0.030613 0.860053 0.841228 0.999001 0.998854 0.924989 0.895437 0.011266 0.024429 0.932124 0.910864 0.998559 0.998033 0.911337 0.902969 0.022105 0.030315 0.910052 0.901938 0.995873 0.995513
Std 0.000199 0.000282 0.006028 0.012133 0.004138 0.007790 0.004299 0.010543 0.056022 0.084170 0.007434 0.003609 0.000499 0.000322 0.013943 0.013977 0.012385 0.013361 0.007678 0.006984 0.00 0.00 0.000059 0.000237 0.000979 0.003946 0.001094 0.004386 0.001033 0.004115 0.008457 0.010367 0.001763 0.001662 0.028049 0.029185 0.020270 0.021439 0.009887 0.016211 5.270768e-10 6.794098e-10 7.223313e-10 1.082895e-09 1.991662e-09 2.153271e-09 1.930463e-10 1.726463e-10 4.535888e-10 6.048319e-10 0.004307 0.007982 0.003939 0.009955 0.006444 0.010730 0.000111 0.000204 0.002926 0.010951 0.002861 0.005435 0.011278 0.012615 0.000278 0.000584 0.005211 0.008241 0.003789 0.007255 0.025278 0.024613 0.000874 0.001197
In [12]:
# Plot the prediction results (runs tensorflow in a sub-process, so that it can be stopped
# easily; to  restart, run cell again; only one can run at the same time, unless there are
# enough GPU resources).
trainer.browse_predict(seqs) 
Button(description='Stop', style=ButtonStyle())
interactive(children=(IntSlider(value=1, description='idx', max=920, min=1), IntSlider(value=64, description='…