BCa Segmentation Interpretation¶
LATUP-Net None, WDS, BG-NEC-EDE-ENH, BraTS2020¶
FileCopyrightText: Copyright (C) 2023-2024 Ebtihal Alwadee AlwadeeEJ@cardiff.ac.uk, PhD student at Cardiff University
FileCopyrightText: Copyright (C) 2024 Frank C Langbein frank@langbein.org, Cardiff UniversityLicense-Identifier: AGPL-3.0-or-later
Dataset: BraTS2020, fixed crop, >=1% lesions
Architecture: LATUP-Net without attention
Loss: weighted dice score, enet weights
Output channels: BG, NEC, EDE, ENH
We start by setting up the jupyter notebook (autoloading and reload updated modules, etc) and some style options, which may be useful for display.
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
%%html
<style>
.dataframe th { font-size: 11px; }
.dataframe td { font-size: 11px; }
</style>
Dataset Setup¶
This defines the dataset source, BraTS2020, used for training.
The BraTS2020 dataset must be in the data folder specified below. See README.md for how to set this up.
# We assume https://qyber.black/ca/code-bca is cloned into bca in the top folder
# and the BraTS2020 dataset is available in the folders below (see README.md).
# These have been registered as a git sub-modules (but the BraTS dataset are not
# distributed by us).
from bca.bca.dataset import Dataset
# Fixed-crop BraTS2020 dataset
ds = Dataset(os.path.join("data","brats2020","MICCAI_BraTS2020_TrainingData"), # Dataset folder - brought in via a git submodule
# (not public, so replace as needed)
os.path.join("results","brats2020","MICCAI_BraTS2020_TrainingData")) # Results folder
ds.crop((56,184), (56,184), (13,141)) # Fixed centered crop
ds.filter_low_labels("seg",0.01) # Remove patients with non-background segmentation area less than 1%
# Show bounding-box cropped dataset info and slices
print(ds)
ds.browse()
# Dataset: data/brats2020/MICCAI_BraTS2020_TrainingData [344 patients] Cache: results/brats2020/MICCAI_BraTS2020_TrainingData Channels (patient BraTS20_Training_001): flair: (240, 240, 155) (int16) seg: (240, 240, 155) (uint8) t1: (240, 240, 155) (int16) t1ce: (240, 240, 155) (int16) t2: (240, 240, 155) (int16) Crop: [(56, 184), (56, 184), (13, 141)]
interactive(children=(IntSlider(value=172, description='idx', max=344, min=1), IntSlider(value=78, description…
from collections import OrderedDict
from bca.bca.loss import compute_channel_weights
# Specify output channels of network in an OrderedDict.
# Order of keys determines output channel order.
# Each channel is given by a key (name of the channel) and
# with a (list-of-labels, use-in-loss) pair. List of labels
# simply contains the original dataset labels to be used for
# the channel. Use-in-loss is a boolean indicating if we use
# the channel in the weight and consequently loss function.
output_channels = OrderedDict({
"BG": ([0],0),
"NEC": ([1],1),
"EDE": ([2],1),
"ENH": ([4],1)
})
# Specify channels for loss weights in an OrderedDict.
weight_channels = OrderedDict({
"WT": ([1,2,4],1),
"TC": ([1,4],1),
"ET": ([4],1)
})
# ENet weights
weights_enet = compute_channel_weights(ds, "seg", weight_channels, mode="enet", enet_c=1.22, normalise=False)
# Weighted Dice score as loss: adjust loss function weights to
# W0 - \sum_c w_c DSC_c
# where c is the output channel and DSC_c its Dice score. This
# is specific to the above output_channels spec:
# 0.5*WholeTumor + 0.6*TumorCore + 0.7*EnhancingTumor
# = w1*(3-DSC_1-DSC_2-DSC_3) + w2*(2-DSC_1-DSC_3) + w3*(1-DSC_3)
# = w1*3+w2*2+w3*1 - (w1+w2)*DSC_1 - w1*DSC_2 - (w1+w2+w3)*DSC_3
# With Dice loss instead of score (even if we are using WDS, this is equivalent):
# = w1*(DSL_1+DSL_2+DSL_3) + w2*(DSL_1+DSL_3) + w3*DSL_3
# = (w1+w2) * DSL_1 + w1 * DSL_2 + (w1+w2+w3) * DSL_3
# This is without using the background channel, even if it is an output channel.
for we in weights_enet:
print(f"{we} weight: {weights_enet[we][0]}")
# Map enet weights on WT, TC and ET to WDS weights
w1 = weights_enet["WT"][0]
w2 = weights_enet["TC"][0]
w3 = weights_enet["ET"][0]
weight0 = w1*3+w2*2+w3*1
weight1 = w1+w2
weight2 = w1
weight3 = w1+w2+w3
print("WDS loss function weights:")
print(f"weight0: {weight0}")
print(f"weight1: {weight1}")
print(f"weight2: {weight2}")
print(f"weight3: {weight3}")
WT weight: 1.6427562244432643 TC weight: 2.5484483848158628 ET weight: 3.404885782140262 WDS loss function weights: weight0: 13.430051225101781 weight1: 4.1912046092591275 weight2: 1.6427562244432643 weight3: 7.59609039139939
LATUP-Net Model¶
We first setup the model and then train it.
from bca.bca.latupnet import LATUPNet
from bca.bca.trainer import Trainer
from bca.bca.loss import WeightedDiceScore
from bca.bca.metric import sDSC
# Note, the model name determines were model/results are stored,
# so should be different for different models if we preserve them
# Further note, the WeightedDiceScore weights are output_channel
# specified as computed above. This network is without attention.
model = LATUPNet(name="LATUPNet-None_WDS_BG-NEC-EDE-ENH",
attention="None",
loss=WeightedDiceScore(weight0,
[(0.0, 0),
(weight1, 1),
(weight2, 2),
(weight3, 3)]),
metrics=lambda chs=list(output_channels.keys()): [sDSC(channel=n, name="dsc"+ch) for n, ch in enumerate(chs)])
# Plot model architecture (or show text summary if text is True)
model.plot((128,128,128,3,len(output_channels)),text=False) # 3D Spatial dim, number of inputs (flair, t1ce, t2) and number of classes
Segmentation Models: using `tf.keras` framework.
# Parameters determining the dataset sequences
K = 0.8 # 80/20 split; single fold sufficient for interpretation analysis
EPOCHS = 200 # 200 epochs
SEED = 8120341116777169704 # Seed for dataset split
BATCH_SIZE = 1 # Batch size for training
# Generate a list of K keras (train,test)-pair sequences for dataset splits
# We use a 128x128x128 input image with flair, t1ce and t2 channels. The
# output channels have been specified above and just used here.
seqs = ds.sequences(K, (128,128,128),
["flair", "t1ce", "t2"],
["seg+"+"+".join([str(n) for n in output_channels[ch][0]]) for ch in output_channels.keys()],
batch_size=BATCH_SIZE,
pre_proc=Dataset.norm_minmax,
seed=SEED)
# Show test sequence of first fold
seqs[0][1].browse()
# Setup trainer for the model/data
trainer = Trainer(model, epochs=EPOCHS)
interactive(children=(IntSlider(value=35, description='idx', max=69, min=1), IntSlider(value=64, description='…
We are ready to train the models.
For remote execution remote=True should be set in the train run and then the following code can be used to
schedule the tasks. Note, for this the remote schedulers must be defined in cfg.json for the bca package.
# Train the networks
# Specifying loss,loss for the best monitor uses the test loss alone to store best model
results = trainer.train(seqs, jit_compile=True, remote=False)
* Fold 1: done - results/brats2020/MICCAI_BraTS2020_TrainingData/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-None_WDS_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-0.8-0
# Schedule training tasks
from bca.bca.scheduler import schedule, schedule_clean
task_folder="results"
schedule_clean(task_folder=task_folder) # Cleanup failed task (assuming issues have been fixed)
schedule(task_folder=task_folder) # Schedule tasks
All tasks complete.
Evaluation¶
This runs the evaluation of the trained models locally, for any completely trained model not yet evaluated. Only models trained completely are evaluated; everything else is ignored.
# Evaluate model, including standardizing the evaluation as different models may have different outputs
import numpy as np
# Output standardisation function
# Compute performance metrics
from bca.bca.trainer import dsc, hd95, sensitivity, specificity
# Evaluate global and per-sample using standardised output.
def std_eval_BG_NEC_EDE_ENH(P, Y):
# This is slow as it collects overall, per-channel and standardised channel
# performance def std_eval_BG_NEC_EDE_ENH(P, Y):
# Convert prediction and trained output to standardized output.
# We assume output channels are Necrotic, Edema, Enhancing, in
# this order. As regions are combined we threshold the predicted
# masks and clip sums.
P[0] = np.where(P[0] >= 0.5, 1.0, 0.0).astype(np.float32)
return {
'WT': [np.clip(P[0][...,1]+P[0][...,2]+P[0][...,3],0,1), np.clip(Y[0][...,1]+Y[0][...,2]+Y[0][...,3],0.0,1.0)],
'TC': [np.clip(P[0][...,1] +P[0][...,3],0,1), np.clip(Y[0][...,1] +Y[0][...,3],0.0,1.0)],
'ET': [np.clip( P[0][...,3],0,1), np.clip( Y[0][...,3],0.0,1.0)]
}
# Compute performance metrics
from bca.bca.trainer import dsc, hd95, sensitivity, specificity
# Disable shuffle for sequences, so we can map them to specific patients in evaluation/interpretation results
for seq in seqs:
seq[0].disable_shuffle()
seq[1].disable_shuffle()
# Evaluate global and per-sample using standardised output.
# This is slow as it collects overall, per-channel and standardised channel
# performance metrics. Note, adjust std_eval according to output_channels (from
# one of the options above or maybe needs a new function).
trainer.eval(seqs, mode="best", fs=[dsc,hd95,sensitivity,specificity], std_eval=std_eval_BG_NEC_EDE_ENH)
# Save model architecture
trainer.plot_model(seqs, save_only=True)
* Fold 1 Evaluation complete.
Results¶
This presents the resuls of the training/evaluation.
Note that below the "prime" (DSC', val_DSC', etc) values are per-sample metrics (and then averaged, etc), while the others are the metrics used for training (per batch values, and then averaged, etc).
import numpy as np
import pandas as pd
from IPython import display
pd.set_option('display.max_columns', None)
# Show performance results
results = trainer.plot_results(seqs)
| dscBG | val_dscBG | dscEDE | val_dscEDE | dscENH | val_dscENH | dscNEC | val_dscNEC | loss | val_loss | DSC' | val_DSC' | DSC_c0' | val_DSC_c0' | DSC_c1' | val_DSC_c1' | DSC_c2' | val_DSC_c2' | DSC_c3' | val_DSC_c3' | HD95' | val_HD95' | HD95_c0' | val_HD95_c0' | HD95_c1' | val_HD95_c1' | HD95_c2' | val_HD95_c2' | HD95_c3' | val_HD95_c3' | Sensitivity' | val_Sensitivity' | Sensitivity_c0' | val_Sensitivity_c0' | Sensitivity_c1' | val_Sensitivity_c1' | Sensitivity_c2' | val_Sensitivity_c2' | Sensitivity_c3' | val_Sensitivity_c3' | Specificity' | val_Specificity' | Specificity_c0' | val_Specificity_c0' | Specificity_c1' | val_Specificity_c1' | Specificity_c2' | val_Specificity_c2' | Specificity_c3' | val_Specificity_c3' | STD-ET-DSC | val_STD-ET-DSC | STD-ET-HD95 | val_STD-ET-HD95 | STD-ET-Sensitivity | val_STD-ET-Sensitivity | STD-ET-Specificity | val_STD-ET-Specificity | STD-TC-DSC | val_STD-TC-DSC | STD-TC-HD95 | val_STD-TC-HD95 | STD-TC-Sensitivity | val_STD-TC-Sensitivity | STD-TC-Specificity | val_STD-TC-Specificity | STD-WT-DSC | val_STD-WT-DSC | STD-WT-HD95 | val_STD-WT-HD95 | STD-WT-Sensitivity | val_STD-WT-Sensitivity | STD-WT-Specificity | val_STD-WT-Specificity | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Fold 1 | 0.996693 | 0.996037 | 0.831685 | 0.794831 | 0.810207 | 0.735351 | 0.822893 | 0.729323 | 2.958389 | 3.979707 | 0.783193 | 0.740052 | 0.993394 | 0.992743 | 0.700566 | 0.621537 | 0.713448 | 0.682418 | 0.725365 | 0.66351 | 0.03 | 0.03 | 0.000868 | 0.00029 | 0.953077 | 0.932543 | 0.795097 | 0.78635 | 0.947411 | 0.933073 | 0.837678 | 0.797262 | 0.992398 | 0.992276 | 0.749051 | 0.683631 | 0.788551 | 0.759163 | 0.82071 | 0.75398 | 3.478570e-09 | 3.412487e-09 | 2.147606e-09 | 2.122357e-09 | 5.128315e-09 | 4.838586e-09 | 6.650216e-10 | 6.207910e-10 | 5.973334e-09 | 6.068205e-09 | 0.795166 | 0.707139 | 0.01801 | 0.048411 | 0.82997 | 0.781261 | 0.998544 | 0.997669 | 0.90741 | 0.853945 | 0.011282 | 0.038492 | 0.927539 | 0.896944 | 0.997558 | 0.995577 | 0.897099 | 0.887334 | 0.014896 | 0.017584 | 0.907272 | 0.903532 | 0.99414 | 0.992831 |
| Mean | 0.996693 | 0.996037 | 0.831685 | 0.794831 | 0.810207 | 0.735351 | 0.822893 | 0.729323 | 2.958389 | 3.979707 | 0.783193 | 0.740052 | 0.993394 | 0.992743 | 0.700566 | 0.621537 | 0.713448 | 0.682418 | 0.725365 | 0.66351 | 0.03 | 0.03 | 0.000868 | 0.00029 | 0.953077 | 0.932543 | 0.795097 | 0.78635 | 0.947411 | 0.933073 | 0.837678 | 0.797262 | 0.992398 | 0.992276 | 0.749051 | 0.683631 | 0.788551 | 0.759163 | 0.82071 | 0.75398 | 3.478570e-09 | 3.412487e-09 | 2.147606e-09 | 2.122357e-09 | 5.128315e-09 | 4.838586e-09 | 6.650216e-10 | 6.207910e-10 | 5.973334e-09 | 6.068205e-09 | 0.795166 | 0.707139 | 0.01801 | 0.048411 | 0.82997 | 0.781261 | 0.998544 | 0.997669 | 0.90741 | 0.853945 | 0.011282 | 0.038492 | 0.927539 | 0.896944 | 0.997558 | 0.995577 | 0.897099 | 0.887334 | 0.014896 | 0.017584 | 0.907272 | 0.903532 | 0.99414 | 0.992831 |
| Std | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.00000 | 0.00 | 0.00 | 0.000000 | 0.00000 | 0.000000 | 0.000000 | 0.000000 | 0.00000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.00000 | 0.00000 | 0.000000e+00 | 0.000000e+00 | 0.000000e+00 | 0.000000e+00 | 0.000000e+00 | 0.000000e+00 | 0.000000e+00 | 0.000000e+00 | 0.000000e+00 | 0.000000e+00 | 0.000000 | 0.000000 | 0.00000 | 0.000000 | 0.00000 | 0.000000 | 0.000000 | 0.000000 | 0.00000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.00000 | 0.000000 |
# Plot the prediction results (runs tensorflow in a sub-process, so that it can be stopped
# easily; to restart, run cell again; only one can run at the same time, unless there are
# enough GPU resources).
trainer.browse_predict(seqs)
# Stop this before continuing below!
Button(description='Stop', style=ButtonStyle())
interactive(children=(IntSlider(value=1, description='idx', max=275, min=1), IntSlider(value=64, description='…
Interpret model¶
# Get evaluation data to find best/median/worst sample based on loss
# (we only have one fold; seqs not shuffled-see above)
test_std_eval = trainer.get_eval(seqs, mode="best")[0]["test_std_per_sample"]
loss_sample = []
for l in range(0,len(test_std_eval["WT"]["DSC"])):
loss_sample.append(weight0 - weight1*test_std_eval["WT"]["DSC"][l] - weight2*test_std_eval["TC"]["DSC"][l] - weight3*test_std_eval["ET"]["DSC"][l])
loss_sample = np.array(loss_sample)
print(f"Best sample: {np.argmin(loss_sample)}")
print(f"Median sample: {np.argwhere(loss_sample==np.median(loss_sample))[0][0]}")
print(f"Worst sample: {np.argmax(loss_sample)}")
* Fold 1 Best sample: 26 Median sample: 16 Worst sample: 46
# Visualise GradCAM heatmap
# Need to change this according to my model layer names
layers = ['dec1_conv2', 'dec2_conv3', 'upsample_layer_3_conv2', 'dec2_conv1', 'enc3_conv1', 'enc2_conv1']
for im,sl,l in [
(26,60,2), # Attention and no attention, test, best
(16,46,2), # No attention, test, median
(17,80,2), # Attention, test, median
(46,66,2) # Attention and no attention, test, worst
]:
trainer.interpreter(seqs[0][1], index_image=im, index_slice=sl, index_layer=l, vis='HM', layers=layers)
1/1 [==============================] - 20s 20s/step GradCAM for 26-60
1/1 [==============================] - 21s 21s/step GradCAM for 16-46
1/1 [==============================] - 22s 22s/step GradCAM for 17-80
1/1 [==============================] - 20s 20s/step GradCAM for 46-66
# Visualise confusion matrix for test set
trainer.interpreter(seqs[0][1], vis='CM')
Model: results/brats2020/MICCAI_BraTS2020_TrainingData/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-None_WDS_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-0.8-0/best
Model: results/brats2020/MICCAI_BraTS2020_TrainingData/f56_184x56_184x13_141-128_128_128-norm_minmax/LATUPNet-None_WDS_BG-NEC-EDE-ENH-flair_t1ce_t2-seg+0_seg+1_seg+2_seg+4/200-1/8120341116777169704-0.8-0/best