Multiomics Cancer Classification#
In this tutorial, we will use a Multi-Omics Graph cOnvolutional NETworks (MOGONET) by Wang et al. (Nature Communication, 2021) [1] pipeline implemented in PyKale
[2] to integrate patient multiomics data for cancer classification.
We will work with multiomics data from BRCA of TCGA [3], which has five subtypes as the labels of classification. Three omics modalities will be used: mRNA expression, DNA methylation, and miRNA expression.
The multimodal approach used in this tutorial involves late fusion, where a cross-omics tensor is constructed for the prediction probability fusion across three omics modalities.
The main tasks of this tutorial are:
Load BRCA dataset.
Define a MOGONET model.
Train and evaluate the MOGONET model on the multiomics data.
Obtain the feature importance and visualize the interpretation of the model.
Step 0: Environment Preparation#
As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings.
To prepare the helper functions and necessary materials, we download them from the GitHub repository.
Package Installation#
The main package required for this tutorial is PyKale
.
PyKale
is an open-source interdisciplinary machine learning library developed at the University of Sheffield, with a focus on applications in biomedical and scientific domains.
Then, we install PyG
(PyTorch Geometric) and related packages.
[Estimated running time] 3 mins
We then hide the warnings messages to get a clear output.
Configuration#
To minimize the footprint of the notebook when specifying configurations, we provide a config.py
file that defines default parameters. These can be customized by supplying a .yaml
configuration file, such as configs/BRCA.yaml
as an example.
First, we load the configuration from configs/BRCA.yaml
.
from config import get_cfg_defaults
cfg = get_cfg_defaults()
cfg.merge_from_file("configs/BRCA.yaml")
In this tutorial, we list the hyperparameters we would like users to play with outside the .yaml
file:
cfg.SOLVER.MAX_EPOCHS_PRETRAIN
: Number of epochs in pre-training stage.cfg.SOLVER.MAX_EPOCHS
: Number of epochs in training stage.cfg.DATASET.NUM_MODALITIES
: Number of modalities in the pipeline.1
: mRNA expression.2
: mRNA expression + DNA methylation.3
: mRNA expression + DNA methylation + miRNA expression.
[NOTE] Because this tutorial aims to demonmstrate PyKale
pipeline, we only set cfg.SOLVER.MAX_EPOCHS_PRETRAIN=100
and cfg.SOLVER.MAX_EPOCHS=500
to reduce the training time.
If users are interested, please increase them to get more accurate predictions.
cfg.SOLVER.MAX_EPOCHS_PRETRAIN = 100
cfg.SOLVER.MAX_EPOCHS = 500
cfg.DATASET.NUM_MODALITIES = 3
Print hyperparameters:
print(cfg)
DATASET:
NAME: TCGA_BRCA
NUM_CLASSES: 5
NUM_MODALITIES: 3
RANDOM_SPLIT: False
ROOT: dataset/
URL: https://github.com/pykale/data/raw/main/multiomics/TCGA_BRCA.zip
MODEL:
EDGE_PER_NODE: 10
EQUAL_WEIGHT: False
GCN_DROPOUT_RATE: 0.5
GCN_HIDDEN_DIM: [400, 400, 200]
GCN_LR: 0.0005
GCN_LR_PRETRAIN: 0.001
VCDN_LR: 0.001
OUTPUT:
OUT_DIR: ./outputs
SOLVER:
MAX_EPOCHS: 500
MAX_EPOCHS_PRETRAIN: 100
SEED: 2023
Step 1: Data Loading and Preparation#
We use the multiomics benchmark BRCA in this tutorial, which have been provided by the authors of MOGONET paper in their repository.
If users are interested in more details regarding data organization, downloading, loading, and pre-processing, please refer to the Data page of the tutorial.
Delete the potential existing data and download new version:
!rm -rf dataset/
To load data, we first define a list the names of data files:
file_names = []
for modality in range(1, cfg.DATASET.NUM_MODALITIES + 1):
file_names.append(f"{modality}_tr.csv")
file_names.append(f"{modality}_lbl_tr.csv")
file_names.append(f"{modality}_te.csv")
file_names.append(f"{modality}_lbl_te.csv")
file_names.append(f"{modality}_feat_name.csv")
Then, we download, load, and pre-process the data by PyKale
.
[Estimated running time] 20s
import torch
from kale.loaddata.multiomics_datasets import SparseMultiomicsDataset
from kale.prepdata.tabular_transform import ToOneHotEncoding, ToTensor
multiomics_data = SparseMultiomicsDataset(
root=cfg.DATASET.ROOT,
raw_file_names=file_names,
num_modalities=cfg.DATASET.NUM_MODALITIES,
num_classes=cfg.DATASET.NUM_CLASSES,
edge_per_node=cfg.MODEL.EDGE_PER_NODE,
url=cfg.DATASET.URL,
random_split=cfg.DATASET.RANDOM_SPLIT,
equal_weight=cfg.MODEL.EQUAL_WEIGHT,
pre_transform=ToTensor(dtype=torch.float),
target_pre_transform=ToOneHotEncoding(dtype=torch.float),
)
Downloading https://github.com/pykale/data/raw/main/multiomics/TCGA_BRCA.zip
Extracting dataset/raw/TCGA_BRCA.zip
Processing...
Done!
Inspect the dataset:
print(multiomics_data)
Dataset info:
number of modalities: 3
number of classes: 5
modality | total samples | num train | num test | num features
-----------------------------------------------------------------
1 | 875 | 612 | 263 | 1000
2 | 875 | 612 | 263 | 1000
3 | 875 | 612 | 263 | 503
-----------------------------------------------------------------
Step 2: Model Definition#
If users are interested in more details regarding the model, please refer to the Helper Function and Model Definition of the tutorial.
To initialize the model, we firstly call MogonetModel
from model.py
.
from model import MogonetModel
mogonet_model = MogonetModel(cfg, dataset=multiomics_data)
Visualize the model architecture:
print(mogonet_model)
Model info:
Unimodal encoder:
(1) MogonetGCN(
(conv1): MogonetGCNConv(1000, 400)
(conv2): MogonetGCNConv(400, 400)
(conv3): MogonetGCNConv(400, 200)
) (2) MogonetGCN(
(conv1): MogonetGCNConv(1000, 400)
(conv2): MogonetGCNConv(400, 400)
(conv3): MogonetGCNConv(400, 200)
) (3) MogonetGCN(
(conv1): MogonetGCNConv(503, 400)
(conv2): MogonetGCNConv(400, 400)
(conv3): MogonetGCNConv(400, 200)
)
Unimodal decoder:
(1) LinearClassifier(
(fc): Linear(in_features=200, out_features=5, bias=True)
) (2) LinearClassifier(
(fc): Linear(in_features=200, out_features=5, bias=True)
) (3) LinearClassifier(
(fc): Linear(in_features=200, out_features=5, bias=True)
)
Multimodal decoder:
VCDN(
(model): Sequential(
(0): Linear(in_features=125, out_features=125, bias=True)
(1): LeakyReLU(negative_slope=0.25)
(2): Linear(in_features=125, out_features=5, bias=True)
)
)
Step 3: Model Training#
Pretrain Unimodal Encoders#
Before training the multiomics model, we first pretrain encoders for each modality independently. This step helps each GCN encoder learn a good representation of its respective modality before integration.
We can define the trainer of pretraining stage by:
import pytorch_lightning as pl
network = mogonet_model.get_model(pretrain=True)
trainer_pretrain = pl.Trainer(
max_epochs=cfg.SOLVER.MAX_EPOCHS_PRETRAIN,
default_root_dir=cfg.OUTPUT.OUT_DIR,
accelerator="auto",
devices="auto",
enable_model_summary=False,
)
We pretrain the model by:
[Estimated running time] 15s for 100 epochs
trainer_pretrain.fit(network)
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Epoch 99: 100%|ββββββββββ| 1/1 [00:00<00:00, 57.69it/s, v_num=0]
`Trainer.fit` stopped: `max_epochs=100` reached.
Epoch 99: 100%|ββββββββββ| 1/1 [00:00<00:00, 19.96it/s, v_num=0]
Train the Multimodal Model#
After pretraining the unimodal pathways, we now train the full MOGONET model by enabling the VCDN. In this stage, all modality-specific encoders and VCDN are trained.
We define the trainer of multimodal training by:
network = mogonet_model.get_model(pretrain=False)
trainer = pl.Trainer(
max_epochs=cfg.SOLVER.MAX_EPOCHS,
default_root_dir=cfg.OUTPUT.OUT_DIR,
accelerator="auto",
devices="auto",
enable_model_summary=False,
log_every_n_steps=1,
)
π‘ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
We start the multimodal training by:
[Estimated running time] 1 min for 500 epochs
trainer.fit(network)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Epoch 499: 100%|ββββββββββ| 1/1 [00:00<00:00, 42.63it/s, v_num=1]
`Trainer.fit` stopped: `max_epochs=500` reached.
Epoch 499: 100%|ββββββββββ| 1/1 [00:00<00:00, 12.46it/s, v_num=1]
Step 4: Evaluation#
Once training is complete, we evaluate the model on the test set using trainer.test()
.
trainer.test(network)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing DataLoader 0: 100%|ββββββββββ| 1/1 [00:00<00:00, 37.12it/s]
βββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββ β Test metric β DataLoader 0 β β‘ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ© β Accuracy β 0.824999988079071 β β F1 macro β 0.75 β β F1 weighted β 0.8050000071525574 β βββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββ
[{'Accuracy': 0.824999988079071,
'F1 weighted': 0.8050000071525574,
'F1 macro': 0.75}]
Step 5: Interpretation Study#
We use kale.interpret
to perform interpretation, where a function that systematically masks input features and observes the effect on performanceβhighlighting which features are most important for classification is provided. Please refer to Interpretation Study page for more details.
Because the interpretation study needs us to mask one feature and observe the performance drop, we firstly define the trainer for the interpretation experiments.
[NOTE] The final results may be different from what they should be because we only train the model for a few epochs to reduce waiting time in this tutorial.
from kale.interpret.model_weights import select_top_features_by_masking
import pytorch_lightning as pl
trainer_biomarker = pl.Trainer(
max_epochs=cfg.SOLVER.MAX_EPOCHS,
accelerator="auto",
devices="auto",
enable_progress_bar=False,
)
Then, we start the experiment.
To supress the verbose messages in the following experiments:
import logging
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
Run the interpretation experiments:
[Estimated running time] Because the following block will train the model for 2,503 times for BRCA dataset, the following block may take about 6 minutes.
f1_key = "F1" if multiomics_data.num_classes == 2 else "F1 macro"
df_featimp_top = select_top_features_by_masking(
trainer=trainer_biomarker,
model=network,
dataset=multiomics_data,
metric=f1_key,
num_top_feats=30,
verbose=False,
)
Print the most important features:
print("{:>4}\t{:<20}\t{:>5}\t{}".format("Rank", "Feature name", "Omics", "Importance"))
for rank, row in enumerate(df_featimp_top.itertuples(index=False), 1):
print(f"{rank:>4}\t{row.feat_name:<20}\t{row.omics:>5}\t{row.imp:.4f}")
Rank Feature name Omics Importance
1 FABP7|2173 0 28.0000
2 WNT6|7475 0 28.0000
3 MIA|8190 0 28.0000
4 CRHR1|1394 0 27.0000
5 SOX10|6663 0 27.0000
6 SERPINB5|5268 0 27.0000
7 GABRP|2568 0 27.0000
8 TMEM207 1 25.0000
9 KRT6B|3854 0 25.0000
10 PTX3|5806 0 25.0000
11 SCRG1|11341 0 24.0000
12 FLJ41941 1 23.0000
13 OR1J4 1 23.0000
14 GPR37L1 1 23.0000
15 SLC7A8|23428 0 22.0000
16 LEMD1|93273 0 20.0000
17 PGBD5|79605 0 19.0000
18 BPI|671 0 19.0000
19 CRIP1|1396 0 19.0000
20 ATP1B3|483 0 19.0000
21 BBOX1|8424 0 18.0000
22 FOXC1|2296 0 17.0000
23 SOX21 1 17.0000
24 C10orf116|10974 0 17.0000
25 hsa-mir-29b-1 2 16.0960
26 CXCL3|2921 0 16.0000
27 FGFBP1|9982 0 16.0000
28 hsa-mir-944 2 14.0840
29 hsa-mir-9-2 2 14.0840
30 hsa-mir-9-1 2 14.0840
References#
[1] Wang, T., Shao, W., Huang, Z., Tang, H., Zhang, J., Ding, Z., & Huang, K. (2021). MOGONET integrates multi-omics data using graph convolutional networks allowing patient classification and biomarker identification. Nature communications, 12(1), 3445.
[2] Lu, H., Liu, X., Zhou, S., Turner, R., Bai, P., Koot, R. E., β¦ & Xu, H. (2022, October). PyKale: Knowledge-aware machine learning from multiple sources in Python. In Proceedings of the 31st ACM International Conference on Information & Knowledge Management (pp. 4274-4278).
[3] Lingle, W., Erickson, B. J., Zuley, M. L., Jarosz, R., Bonaccio, E., Filippini, J., Net, J. M., Levi, L., Morris, E. A., Figler, G. G., Elnajjar, P., Kirk, S., Lee, Y., Giger, M., & Gruszauskas, N. (2016). The Cancer Genome Atlas Breast Invasive Carcinoma Collection (TCGA-BRCA) (Version 3) [Data set]. The Cancer Imaging Archive.