BCa Segmentation¶
LATUP-Net Tests, BraTS2020¶
This is for testing only. Please only commit empty/cleared-output notebooks and do not commit the trained models for this (if you do not change the mode name, the mode directories will have to be deleted or it may be inconsistent).
FileCopyrightText: Copyright (C) 2023-2024 Ebtihal Alwadee AlwadeeEJ@cardiff.ac.uk, PhD student at Cardiff University
FileCopyrightText: Copyright (C) 2023-2024 Frank C Langbein frank@langbein.org, Cardiff UniversityLicense-Identifier: AGPL-3.0-or-later
Dataset: BraTS2020, fixed crop, >=1% lesions
Architecture: LATUP-Net variations for testing
Loss: weighted dice loss or score with various weights
Output channels: various as specified below
We start by setting up the jupyter notebook (autoloading and reload updated modules, etc) and some style options, which may be useful for display.
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
%%html
<style>
.dataframe th { font-size: 11px; }
.dataframe td { font-size: 11px; }
</style>
Dataset Setup¶
This defines the dataset source, BraTS2020, used for training.
The BraTS2020 dataset must be in the data folder specified below. See README.md for how to set this up.
# We assume https://qyber.black/ca/code-bca is cloned into bca in the top folder
# and the BraTS2020 dataset is available in the folders below (see README.md).
# These have been registered as a git sub-modules (but the BraTS dataset are not
# distributed by us).
from bca.bca.dataset import Dataset
# Fixed-crop BraTS2020 dataset
ds = Dataset(os.path.join("data","brats2020","MICCAI_BraTS2020_TrainingData"), # Dataset folder - brought in via a git submodule
# (not public, so replace as needed)
os.path.join("results","brats2020","MICCAI_BraTS2020_TrainingData")) # Results folder
ds.crop((56,184), (56,184), (13,141)) # Fixed centered crop
ds.filter_low_labels("seg",0.01) # Remove patients with non-background segmentation area less than 1%
# Show bounding-box cropped dataset info and slices.
print(ds)
ds.browse()
from collections import OrderedDict
from bca.bca.loss import compute_channel_weights
# Specify output channels of network in an OrderedDict.
# Order of keys determines output channel order.
# Each channel is given by a key (name of the channel) and
# with a (list-of-labels, use-in-loss) pair. List of labels
# simply contains the original dataset labels to be used for
# the channel. Use-in-loss is a boolean indicating if we use
# the channel in the weight and consequently loss function.
output_channels = OrderedDict({
#"BG": ([0],0), # FIXME: try with background channel
"WT": ([1,2,4],1),
"TC": ([1,4],1),
"ET": ([4],1)
})
# FIXME: try with individual channels
# FIXME: try without normalisation; try with "inverse" mode (?); WDS loss weights, etc.
weights = compute_channel_weights(ds, "seg", output_channels, mode="enet", enet_c=1.22, normalise=True)
for ch in output_channels:
print(f"{ch} weight: {weights[ch][0]}, voxel count: {weights[ch][1]} for {output_channels[ch][0]}")
LATUP-Net Model¶
We first setup the model and then train it.
from bca.bca.latupnet import LATUPNet
from bca.bca.trainer import Trainer
from bca.bca.loss import WeightedDiceLoss
from bca.bca.metric import sDSC
# Note, the model name determines were model/results are stored,
# so should be different for different models if we preserve them
model = LATUPNet(name="LATUPNet-TEST",
attention="SE",
loss=WeightedDiceLoss([ (weights[ch][0],n)
for n, ch in enumerate(list(output_channels.keys()))
if output_channels[ch][1] ]),
metrics=lambda chs=list(output_channels.keys()): [sDSC(channel=n, name="dsc"+ch) for n, ch in enumerate(chs)]) # FIXME
# Note, metrics must be a lambda expression to create the
# metrics inside the graph or it may fail
# Note, the WeightedDiceLoss can also accept a list of channels
# to combine by averaging the output channels for a Dice score mask,
# e.g.,
# WeightedDiceLoss([ (0.4,[0,1]), (0.6,[2,3]) ])
# This would combine channels 0,1 and 2,3 (these are output indices,
# not labels) by averaging and compute the Dice score for each average
# and take a 0.4/0.6 weighted Dice loss. This is not supported by
# using the output_channels mechanism above, but can be specified
# separately for a specific output_channels setup, if needed.
#
# The loss can also be WeightedDiceScore that works similarly, but based
# on weights for the dice score instead of the loss.
# Plot model architecture (or show text summary if text is True)
model.plot((128,128,128,3,len(output_channels)),text=False) # 3D Spatial dim, number of inputs (flair, t1ce, t2) and number of classes
# Parameters determining the dataset sequences
K = 0.8 # 80/20 split
EPOCHS = 200 # 200 epochs
SEED = 8120341116777169704 # Seed for split
BATCH_SIZE = 1 # Batch size for training
# Generate a list of K keras (train,test)-pair sequences for dataset splits
# We use a 128x128x128 input image with flair, t1ce and t2 channels. The
# output channels have been specified above and just used here.
seqs = ds.sequences(K, (128,128,128),
["flair", "t1ce", "t2"],
["seg+"+"+".join([str(n) for n in output_channels[ch][0]]) for ch in output_channels.keys()],
batch_size=BATCH_SIZE,
pre_proc=Dataset.norm_minmax,
seed=SEED)
# Show training sequence of first fold
seqs[0][0].browse()
# Setup trainer for the model/data
trainer = Trainer(model, epochs=EPOCHS)
We are ready to train the models.
For remote execution remote=True should be set in the train run and then the following code can be used to
schedule the tasks. Note, for this the remote schedulers must be defined in cfg.json for the bca package.
# Train the networks
# Specifying loss,loss for the best monitor uses the test lost alone to store best model
results = trainer.train(seqs, jit_compile=True, remote=False)
# Schedule training tasks
from bca.bca.scheduler import schedule, schedule_clean
task_folder="results"
schedule_clean(task_folder=task_folder) # Cleanup failed task (assuming issues have been fixed)
schedule(task_folder=task_folder) # Schedule tasks
Evaluation¶
This runs the evaluation of the trained models locally, for any completely trained model not yet evaluated. Only models trained completely are evaluated; everything else is ignored.
# Evaluate model, including standardizing the evaluation as different models may have different outputs
import numpy as np
# Output standardisation functions
def std_eval_BG_WT_TC_ET(P, Y):
# Convert prediction and trained output to standardized output
# Input: P - prediction, Y - ground truth.
# Output: dictionary of mapping names to (prediction, expected) pairs from single prediction, expected pair
# For BraTS2020:
# whole 1,2,4
# necrotic 1
# enhancing 4
# edema 2
# core 1,4
# Here we assuming output channels are BG, WT, TC, ET.
# Note, first [0] is first output of network (in case it has multiple outputs).
# Second set of indices is [SAMPLE,data-axes...,output-channel] where SAMPLE=0 (only
# one, but needed for batch axis in metrics functions).
# FIXME: threshold the prediction mask?
P[0] = np.where(P[0] >= 0.5, 1.0, 0.0).astype(np.float32)
return {
'WT': [P[0][...,1], Y[0][...,1]],
'TC': [P[0][...,2], Y[0][...,2]],
'ET': [P[0][...,3], Y[0][...,3]]
}
def std_eval_WT_TC_ET(P, Y):
# Convert prediction and trained output to standardized output
# Here we assuming output channels are WT, TC, ET.
# FIXME: threshold the prediction mask?
P[0] = np.where(P[0] >= 0.5, 1.0, 0.0).astype(np.float32)
return {
'WT': [P[0][...,0], Y[0][...,0]],
'TC': [P[0][...,1], Y[0][...,1]],
'ET': [P[0][...,2], Y[0][...,2]]
}
def std_eval_BG_NEC_EDE_ENH(P, Y):
# Convert prediction and trained output to standardized output
# Here we assuming output channels are BackGround, NECrotic, EDEma, ENHancing.
# As regions are combined we threshold the predicted masks and clip sums (FIXME: good idea nor not?)
P[0] = np.where(P[0] >= 0.5, 1.0, 0.0).astype(np.float32)
return {
'WT': [np.clip(P[0][...,1]+P[0][...,2]+P[0][...,3],0,1), np.clip(Y[0][...,1]+Y[0][...,2]+Y[0][...,3],0.0,1.0)],
'TC': [np.clip(P[0][...,1] +P[0][...,3],0,1), np.clip(Y[0][...,1] +Y[0][...,3],0.0,1.0)],
'ET': [np.clip( P[0][...,3],0,1), np.clip( Y[0][...,3],0.0,1.0)]
}
def std_eval_NEC_EDE_ENH(P, Y):
# Convert prediction and trained output to standardized output
# Here we assuming output channels are NECrotic, EDEma, ENHancing.
# As regions are combined we threshold the predicted masks and clip sums (FIXME: good idea nor not?)
P[0] = np.where(P[0] >= 0.5, 1.0, 0.0).astype(np.float32)
return {
'WT': [np.clip(P[0][...,0]+P[0][...,1]+P[0][...,2],0,1), np.clip(Y[0][...,0]+Y[0][...,1]+Y[0][...,2],0.0,1.0)],
'TC': [np.clip(P[0][...,0] +P[0][...,2],0,1), np.clip(Y[0][...,0] +Y[0][...,2],0.0,1.0)],
'ET': [np.clip( P[0][...,2],0,1), np.clip( Y[0][...,2],0.0,1.0)]
}
# Compute performance metrics
from bca.bca.trainer import dsc, hd95, sensitivity, specificity
# Evaluate global and per-sample using standardised output.
# This is slow as it collects overall, per-channel and standardised channel
# performance metrics. Note, adjust std_eval according to output_channels (from
# one of the options above or maybe needs a new function).
trainer.eval(seqs, mode="best", fs=[dsc,hd95,sensitivity,specificity], std_eval=std_eval_WT_TC_ET)
# Save model architecture
trainer.plot_model(seqs, save_only=True)
Results¶
This presents the resuls of the training/evaluation.
Note that below the "prime" (DSC', val_DSC', etc) values are per-sample metrics (and then averaged, etc), while the others are the metrics used for training (per batch values, and then averaged, etc). The STD values are the results from the standardised metrics using the standardised outputs with the std_eval function above, making the values better comparable (for the dataset they were used on). They are based on per-sample values.
import numpy as np
import pandas as pd
from IPython import display
pd.set_option('display.max_columns', None)
# Show performance results
results = trainer.plot_results(seqs)
# Plot the prediction results (runs tensorflow in a sub-process, so that it can be stopped
# easily; to restart, run cell again; only one can run at the same time, unless there are
# enough GPU resources).
trainer.browse_predict(seqs)