diff --git a/docker/ReadME.md b/docker/ReadME.md new file mode 100644 index 0000000..1b66efc --- /dev/null +++ b/docker/ReadME.md @@ -0,0 +1,185 @@ +# Docker Submission Instructions + +*Useful tip: view the page outline / table of contents by clicking the icon shown in image below* + +![width=200](media/show-page-outline.png) + +### Important Notes +* Please contact the organizers using [email](mailto:mkfmelbatel@connect.ust.hk) if you have any questions regarding this step of the challenge. We understand that this may be new to some participants and we would like to help resolve any issues faced. + +* These instructions have been created referring to the instructions from [Syn-ISS 2023](https://www.synapse.org/#!Synapse:syn50908388/wiki/621840) and [SimCol-to-3D 2022](https://www.synapse.org/#!Synapse:syn28548633/wiki/617244) challenge. + +* In order to be considered as a valid submission for the TriALS challenge the participant is required to submit both the Docker image and their writeup. + +* To upload any files to Synapse you must be a certified user. See this link for details on [how to get certified](https://help.synapse.org/docs/Synapse-User-Account-Types.2007072795.html#SynapseUserAccountTypes-CertifiedUser). + +### Overview +This document covers the details related to creating the docker images for submitting to the TriALS sub-challenge. + +### What is the Docker image supposed to do +The docker image should read the test data identifiers from a CSV file and look for the corresponding images in the input directory specified when running the docker container. The docker container should write the predicted masks as image files to the specified output directory as well. + +The docker templates already provide a Dockerfile to do this. Below are instructions about using the provided docker templates: (1) where the participant's should place their code, and (2) how they can generate a docker image for submission. The same instructions apply for task 1 and task 2. + +## Editing the Docker templates + +### Download the Docker templates +The docker templates can be obtained in one of the following ways: +1. Downloading `.zip` archives of the templates from the [latest release](https://github.com/xmed-lab/TriALS/releases/latest). +2. Downloading the entire GitHub repository locally. Please follow [this URL](https://github.com/xmed-lab/TriALS.git) to get the repository files. +3. Cloning the GitHub respository using a Git client. + +### How does the code flow +Before editing the docker template here is information about how the Docker template executes. +1. The `Dockerfile` is set up to launch the `run.sh` file as the entry point. Any other entrypoint will be ignored. +2. The `run.sh` script expects three inputs: + * a path to csv file containing list of test image identifiers, + * a path to a folder where the test image files can be found, and + * a path to a folder where the image files containing the predicted masks will be written to. +3. The `run.sh` calls the Python script `main.py` passing along these three input parameters to it. +4. The `main.py` script imports the segmentation model `MySegmentation` from the `Model.py` file. +6. The `Model.py` file contains the segmentation functionality to take an input image and produce a segmentation mask for the current task. + +**Note:** We provide a sample model trained under `nnUNet_results` (located at `template/src/nnUNet_results`). You can download it from [this link](https://drive.google.com/drive/folders/1G53ttrukdTpdQLIgsW55VZbb_1adoD8g?usp=sharing) and place it in `template/src/nnUNet_results`. + +``` +├── nnUNet_results/ +│ ├── Dataset102_TriALS/ +│ │ ├── nnUNetTrainer__nnUNetPlans__3d_fullres +│ │ │ ├── ... +``` + +Now, let us edit the Docker template files. + +### Update the Dockerfile (optional) +Please update the `Dockerfile` to specify any base image that your code needs like PyTorch, Tensorflow, NVidia. +This done by adding a [`FROM` instruction](https://docs.docker.com/engine/reference/builder/#from). +For example, this is how the docker container can be instructed to use the tensorflow base image. +```Docker +FROM tensorflow/tensorflow +``` +A catalog of base images can be found in the [Docker Hub](https://hub.docker.com/search?image_filter=official&q=&type=image). + +### Insert segmentation code in Model.py +Please insert your model related code in the Python script `Model.py` within the `MySegmentation` class. +Commented blocks specifying the region in the code flow of the `segment()` function are provided in the template. Here is an example of such a comment block: + + +## Creating the Docker images +Now that the docker template has been updated to include your model related changes in it, the following instructions will guide you in creating a docker image that you need to submit to the challenge. + +### Have you done the following +* Set up Docker on your machine. Please refer to the Docker Guide for instructions on how to [get started with Docker](https://docs.docker.com/get-docker/). +* Downloaded the Docker template. These are located under the `docker/template` folder. +* Updated the files in the Docker template following the [instructions](#editing-the-docker-templates). + +Please finish the above listed tasks before proceeding further. + +### Building the Docker image + +Build the docker image by following these steps: +* Open a command line tool. +* Browse to the directory where the `Dockerfile` is located. +* Run the following command to build the image (please check that you have included the `.` at the end of the command). +```Docker +$ docker build -t . +``` +where,
+`image-name` is the name to be given to the docker image created. + +The docker image must be named using the following format. +``` +trials--: +``` +where,
+`task-name` is either "task1" or "task2",
+`team-name` is the team abbreviation that was provided during registration,
+`version` is the version number of the image starting from `v1` and increasing as you submit newer versions of the image.
+ +Note: the highest version number tagged image will be used for the final evaluation of your model. + +As an example, a team named "medhacker" submitting a second version of their Docker image for the binary segmentation task "task1" must name their Docker image as `trials-task1-medhacker:v2`. + + +### Testing Docker image +It is recommended that you verify the docker image built to ensure it works as intended before submitting to the challenge. +Sample volume are available to do this with proper file and folder names that the organizers will use for evaluating the submissions using the test dataset. +The sample images are located in the Docker template folders within a subfolder named `sample-test-data`. +You can test by running the docker image using the following command in a command line tool. +```bash +$ docker run -it --rm -v ":/data" /data/test.csv /data/inputs /data/predictions +``` +where,
+`path-to-sample-test-data` is the location of the sample test data folder on the machine that is being used to test the Docker image,
+`image-name` is the name of the Docker image being tested. + +## Submitting the Docker images + +At any point in the following steps if more information related to Synapse is needed then refer to the [Docker Registry documentation page](https://help.synapse.org/docs/Synapse-Docker-Registry.2011037752.html). + +### Create a Synapse project +To submit files to a challenge on Synapse you need to create a Synapse project first. The project must be named using the challenge name and team names as shown below. +``` +TriALS-MedHacker +``` +The Synapse documentation can be referred to [create a project](https://help.synapse.org/docs/Setting-Up-a-Project.2055471258.html#SettingUpaProject-CreatingaProject). + +Please add the team named [`TriALS 2024 Admin`](https://www.synapse.org/Team:3491688) to the project and give them "Download" permissions. Follow the documentation on how to [share a project](https://help.synapse.org/docs/Sharing-Settings,-Permissions,-and-Conditions-for-Use.2024276030.html#SharingSettings,Permissions,andConditionsforUse-EditSharingSettingsonaProject). + +### Login to Synapse in Docker +* Type the following in a command line tool to login to synapse using docker. +```bash +$ docker login -u docker.synapse.org +``` +* Enter your synapse account password when prompted. + +### Tag the Docker image +This step requires your new project's Synapse ID. This can be found by looking at the web URL for the project page. For example, the Synapse ID of the project at the URL https://www.synapse.org/#!Synapse:syn150935 is `syn150935`. + +Type the following in a command line tool to tag the docker image before uploading to Synapse. +```bash +$ docker tag docker.synapse.org// +``` +where,
+`image-name` is the name of the Docker image being prepared for submission to the challenge,
+`synapse-project-ID` is the Synapse ID of your project that is being used to submit to the TriALS challenge, e.g., syn150935. + +### Push the Docker image to Synapse +Type the following in a command line tool to push the tagged local Docker image so that it appears in your Synapse project. +```bash +$ docker push docker.synapse.org// +``` +where,
+`synapse-project-ID` is the Synapse ID of your project that is being used to submit to the TriALS challenge, e.g., syn150935,
+`image-name` is the name of the Docker image being prepared for submission to the challenge. + +> [!IMPORTANT] +> This command will fail if your Synapse user account is not a certified user account. See this link for details on [how to get certified](https://help.synapse.org/docs/Synapse-User-Account-Types.2007072795.html#SynapseUserAccountTypes-CertifiedUser). + +### Verify Docker image on Synapse +The Docker images for a project appear under the Docker navigation tab of the project. See the example image below. + +![](media/project-docker-registry-view.png) + +### Submit Docker image to challenge +* Under the Docker tab of your Synapse project click the Docker image that you want to submit to the challenge. + +* Click the `Docker Repository Tools` button and select `Submit Docker Repository to Challenge` in the menu. See reference image below. + +![](media/docker-image-submit-to-challenge.png) + +* Select the version that you want to submit. + +* On the next page, select the challenge task that you want to submit the Docker image to. See image below. + +
+ warmup.png +
+ + +* Then, select the option: `I am submitting as an individual`.
+Ignore the team submission option even though you are part of a team. The organizers have the information about the team through the email registration process. + +* You will receive a confirmation email once the docker submission has been validated by the organizers. + +Thank you! \ No newline at end of file diff --git a/docker/media/docker-image-select-evaluation-queue.png b/docker/media/docker-image-select-evaluation-queue.png new file mode 100644 index 0000000..38d1838 Binary files /dev/null and b/docker/media/docker-image-select-evaluation-queue.png differ diff --git a/docker/media/docker-image-submit-to-challenge.png b/docker/media/docker-image-submit-to-challenge.png new file mode 100644 index 0000000..8bc2bad Binary files /dev/null and b/docker/media/docker-image-submit-to-challenge.png differ diff --git a/docker/media/project-docker-registry-view.png b/docker/media/project-docker-registry-view.png new file mode 100644 index 0000000..a54a918 Binary files /dev/null and b/docker/media/project-docker-registry-view.png differ diff --git a/docker/media/show-page-outline.png b/docker/media/show-page-outline.png new file mode 100644 index 0000000..94a24c1 Binary files /dev/null and b/docker/media/show-page-outline.png differ diff --git a/docker/template/Dockerfile b/docker/template/Dockerfile new file mode 100644 index 0000000..69a8b9e --- /dev/null +++ b/docker/template/Dockerfile @@ -0,0 +1,18 @@ +# Start with Python base image +# ATTENTION: Modify the image below to your need - PyTorch, Tensorflow, NVidia, etc. +FROM pytorch/pytorch:2.2.0-cuda11.8-cudnn8-devel + +# Setting a working directory explicitly +WORKDIR /user + +# copy the contents of source files to the docker container +COPY ./src /user + +# install Python dependencies that are specified in the requirements.txt file +# ATTENTION: Make sure your code dependencies are listed in the requirements.txt file +#RUN pip install -r requirements.txt + +RUN python -m pip install --no-cache-dir /user/ + +# launching the code when the docker container in this image is run +ENTRYPOINT [ "/bin/bash", "/user/run.sh" ] \ No newline at end of file diff --git a/docker/template/sample-test-data/test.csv b/docker/template/sample-test-data/test.csv new file mode 100644 index 0000000..fbe9d54 --- /dev/null +++ b/docker/template/sample-test-data/test.csv @@ -0,0 +1,4 @@ +venous_0 +venous_11 +venous_12 +venous_13 \ No newline at end of file diff --git a/docker/template/src/Model.py b/docker/template/src/Model.py new file mode 100644 index 0000000..a13524a --- /dev/null +++ b/docker/template/src/Model.py @@ -0,0 +1,32 @@ +import torch +import os +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +# predict a numpy array + +class MySegmentation: + def __init__(self, task="Dataset102_TriALS",nnunet_model_dir='nnUNet_results', + model_name='nnUNetTrainer__nnUNetPlans__3d_fullres', + folds=(0, 1, 2, 3, 4) + ): + # network parameters + self.predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=True, + perform_everything_on_device=True, + device=torch.device('cuda', 0), + verbose=True, + verbose_preprocessing=False, + allow_tqdm=True + ) + self.predictor.initialize_from_trained_model_folder( + os.path.join(nnunet_model_dir, + f'{task}/{model_name}'), + use_folds=folds, + checkpoint_name='checkpoint_final.pth', + ) + + def process_image(self, image_np, properties): + ret = self.predictor.predict_single_npy_array( + image_np, properties, None, None, False) + return ret diff --git a/docker/template/src/main.py b/docker/template/src/main.py new file mode 100644 index 0000000..5b1b0a3 --- /dev/null +++ b/docker/template/src/main.py @@ -0,0 +1,63 @@ +import sys +import os +from os.path import join as osjoin +import csv +from Model import MySegmentation +from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + +if len(sys.argv) != 4: + raise (RuntimeError(f"Expected 3 arguments, was provided {len(sys.argv) - 1} argument(s).")) + +test_csv_path = sys.argv[1] +input_dir_path = sys.argv[2] +output_dir_path = sys.argv[3] + +print("=" * 30) +print("Running segmentation:") +print(f" For IDs listed in {test_csv_path}") +print(f" Using images under {input_dir_path}") +print(f" Storing predictions under {output_dir_path}") +print("=" * 30) + +# check csv file +if not os.path.exists(test_csv_path): + raise (FileNotFoundError(f"Could not find csv file: {test_csv_path}")) + +# check folders +if not os.path.exists(input_dir_path): + raise (NotADirectoryError(f"Could not find directory: {input_dir_path}")) + +if not os.path.exists(output_dir_path): + os.makedirs(output_dir_path) + +# read csv file containing file identifiers +# csv file contains a single column specifying the identifiers for the images +# such that the input image filename can be constructed as venous_.nii.gz +with open(test_csv_path, "r") as csvfile: + reader_obj = csv.reader(csvfile) + orders = list(reader_obj) + +model = MySegmentation() + +row_counter = 0 +for row in orders: + input_image_path = osjoin(input_dir_path, f"{row[0]}_0000.nii.gz") + + if not os.path.exists(input_image_path): + FileNotFoundError(f"Could not find input image at: {input_image_path}") + + #read the input volume + image_np, properties = SimpleITKIO().read_images([input_image_path]) + + print(f"Segmenting image {row_counter:03d}: {row[0]}_0000.nii.gz") + + #segment the volume + pred_labels = model.process_image( + image_np, properties) + + #write the segmentation volume + SimpleITKIO().write_seg(pred_labels, os.path.join(output_dir_path, f"{row[0]}.nii.gz"), properties) + + print("Done.") + + row_counter += 1 diff --git a/docker/template/src/nnUNet_results/Dataset102_TriALS/.gitignore b/docker/template/src/nnUNet_results/Dataset102_TriALS/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/__init__.py b/docker/template/src/nnunetv2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/batch_running/__init__.py b/docker/template/src/nnunetv2/batch_running/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/batch_running/benchmarking/__init__.py b/docker/template/src/nnunetv2/batch_running/benchmarking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py b/docker/template/src/nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py new file mode 100644 index 0000000..ca37206 --- /dev/null +++ b/docker/template/src/nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py @@ -0,0 +1,41 @@ +if __name__ == '__main__': + """ + This code probably only works within the DKFZ infrastructure (using LSF). You will need to adapt it to your scheduler! + """ + gpu_models = [#'NVIDIAA100_PCIE_40GB', 'NVIDIAGeForceRTX2080Ti', 'NVIDIATITANRTX', 'TeslaV100_SXM2_32GB', + 'NVIDIAA100_SXM4_40GB']#, 'TeslaV100_PCIE_32GB'] + datasets = [2, 3, 4, 5] + trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading'] + plans = ['nnUNetPlans'] + configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x'] + num_gpus = 1 + + benchmark_configurations = {d: configs for d in datasets} + + exclude_hosts = "-R \"select[hname!='e230-dgxa100-1']'\"" + resources = "-R \"tensorcore\"" + queue = "-q gpu" + preamble = "-L /bin/bash \"source ~/load_env_torch210.sh && " + train_command = 'nnUNet_compile=False nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_benchmark nnUNetv2_train' + + folds = (0, ) + + use_these_modules = { + tr: plans for tr in trainers + } + + additional_arguments = f' -num_gpus {num_gpus}' # '' + + output_file = "/home/isensee/deleteme.txt" + with open(output_file, 'w') as f: + for g in gpu_models: + gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmodel={g}" + for tr in use_these_modules.keys(): + for p in use_these_modules[tr]: + for dataset in benchmark_configurations.keys(): + for config in benchmark_configurations[dataset]: + for fl in folds: + command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' + if additional_arguments is not None and len(additional_arguments) > 0: + command += f' {additional_arguments}' + f.write(f'{command}\"\n') \ No newline at end of file diff --git a/docker/template/src/nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py b/docker/template/src/nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py new file mode 100644 index 0000000..d966321 --- /dev/null +++ b/docker/template/src/nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py @@ -0,0 +1,70 @@ +from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.paths import nnUNet_results +from nnunetv2.utilities.file_path_utilities import get_output_folder + +if __name__ == '__main__': + trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading'] + datasets = [2, 3, 4, 5] + plans = ['nnUNetPlans'] + configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x'] + output_file = join(nnUNet_results, 'benchmark_results.csv') + + torch_version = '2.1.0.dev20230330'#"2.0.0"#"2.1.0.dev20230328" #"1.11.0a0+gitbc2c6ed" # + cudnn_version = 8700 # 8302 # + num_gpus = 1 + + unique_gpus = set() + + # collect results in the most janky way possible. Amazing coding skills! + all_results = {} + for tr in trainers: + all_results[tr] = {} + for p in plans: + all_results[tr][p] = {} + for c in configs: + all_results[tr][p][c] = {} + for d in datasets: + dataset_name = maybe_convert_to_dataset_name(d) + output_folder = get_output_folder(dataset_name, tr, p, c, fold=0) + expected_benchmark_file = join(output_folder, 'benchmark_result.json') + all_results[tr][p][c][d] = {} + if isfile(expected_benchmark_file): + # filter results for what we want + results = [i for i in load_json(expected_benchmark_file).values() + if i['num_gpus'] == num_gpus and i['cudnn_version'] == cudnn_version and + i['torch_version'] == torch_version] + for r in results: + all_results[tr][p][c][d][r['gpu_name']] = r + unique_gpus.add(r['gpu_name']) + + # haha. Fuck this. Collect GPUs in the code above. + # unique_gpus = np.unique([i["gpu_name"] for tr in trainers for p in plans for c in configs for d in datasets for i in all_results[tr][p][c][d]]) + + unique_gpus = list(unique_gpus) + unique_gpus.sort() + + with open(output_file, 'w') as f: + f.write('Dataset,Trainer,Plans,Config') + for g in unique_gpus: + f.write(f",{g}") + f.write("\n") + for d in datasets: + for tr in trainers: + for p in plans: + for c in configs: + gpu_results = [] + for g in unique_gpus: + if g in all_results[tr][p][c][d].keys(): + gpu_results.append(round(all_results[tr][p][c][d][g]["fastest_epoch"], ndigits=2)) + else: + gpu_results.append("MISSING") + # skip if all are missing + if all([i == 'MISSING' for i in gpu_results]): + continue + f.write(f"{d},{tr},{p},{c}") + for g in gpu_results: + f.write(f",{g}") + f.write("\n") + f.write("\n") + diff --git a/docker/template/src/nnunetv2/batch_running/collect_results_custom_Decathlon.py b/docker/template/src/nnunetv2/batch_running/collect_results_custom_Decathlon.py new file mode 100644 index 0000000..b670661 --- /dev/null +++ b/docker/template/src/nnunetv2/batch_running/collect_results_custom_Decathlon.py @@ -0,0 +1,114 @@ +from typing import Tuple + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.evaluation.evaluate_predictions import load_summary_json +from nnunetv2.paths import nnUNet_results +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id +from nnunetv2.utilities.file_path_utilities import get_output_folder + + +def collect_results(trainers: dict, datasets: List, output_file: str, + configurations=("2d", "3d_fullres", "3d_lowres", "3d_cascade_fullres"), + folds=tuple(np.arange(5))): + results_dirs = (nnUNet_results,) + datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets] + with open(output_file, 'w') as f: + for i, d in zip(datasets, datasets_names): + for c in configurations: + for module in trainers.keys(): + for plans in trainers[module]: + for r in results_dirs: + expected_output_folder = get_output_folder(d, module, plans, c) + if isdir(expected_output_folder): + results_folds = [] + f.write(f"{d},{c},{module},{plans},{r}") + for fl in folds: + expected_output_folder_fold = get_output_folder(d, module, plans, c, fl) + expected_summary_file = join(expected_output_folder_fold, "validation", + "summary.json") + if not isfile(expected_summary_file): + print('expected output file not found:', expected_summary_file) + f.write(",") + results_folds.append(np.nan) + else: + foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][ + 'Dice'] + results_folds.append(foreground_mean) + f.write(f",{foreground_mean:02.4f}") + f.write(f",{np.nanmean(results_folds):02.4f}\n") + + +def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers): + txt = np.loadtxt(input_file, dtype=str, delimiter=',') + num_folds = txt.shape[1] - 6 + valid_configs = {} + for d in datasets: + if isinstance(d, int): + d = maybe_convert_to_dataset_name(d) + configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d]) + valid_configs[d] = [i for i in configs_in_txt if i in configs] + assert max(folds) < num_folds + + with open(output_file, 'w') as f: + f.write("name") + for d in valid_configs.keys(): + for c in valid_configs[d]: + f.write(",%d_%s" % (convert_dataset_name_to_id(d), c[:4])) + f.write(',mean\n') + valid_entries = txt[:, 4] == nnUNet_results + for t in trainers.keys(): + trainer_locs = valid_entries & (txt[:, 2] == t) + for pl in trainers[t]: + f.write(f"{t}__{pl}") + trainer_plan_locs = trainer_locs & (txt[:, 3] == pl) + r = [] + for d in valid_configs.keys(): + trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d) + for v in valid_configs[d]: + trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v) + if np.any(trainer_plan_d_config_locs): + # we cannot have more than one row + assert np.sum(trainer_plan_d_config_locs) == 1 + + # now check that we have all folds + selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]] + + fold_results = selected_row[[i + 5 for i in folds]] + + if '' in fold_results: + print('missing fold in', t, pl, d, v) + f.write(",nan") + r.append(np.nan) + else: + mean_dice = np.mean([float(i) for i in fold_results]) + f.write(f",{mean_dice:02.4f}") + r.append(mean_dice) + else: + print('missing:', t, pl, d, v) + f.write(",nan") + r.append(np.nan) + f.write(f",{np.mean(r):02.4f}\n") + + +if __name__ == '__main__': + use_these_trainers = { + 'nnUNetTrainer': ('nnUNetPlans',), + 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',), + 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), + } + all_results_file= join(nnUNet_results, 'customDecResults.csv') + datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82] + collect_results(use_these_trainers, datasets, all_results_file) + + folds = (0, 1, 2, 3, 4) + configs = ("3d_fullres", "3d_lowres") + output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv') + summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) + + folds = (0, ) + configs = ("3d_fullres", "3d_lowres") + output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv') + summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) + diff --git a/docker/template/src/nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py b/docker/template/src/nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py new file mode 100644 index 0000000..2795d3d --- /dev/null +++ b/docker/template/src/nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py @@ -0,0 +1,18 @@ +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.batch_running.collect_results_custom_Decathlon import collect_results, summarize +from nnunetv2.paths import nnUNet_results + +if __name__ == '__main__': + use_these_trainers = { + 'nnUNetTrainer': ('nnUNetPlans', ), + } + all_results_file = join(nnUNet_results, 'hrnet_results.csv') + datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82] + collect_results(use_these_trainers, datasets, all_results_file) + + folds = (0, ) + configs = ('2d', ) + output_file = join(nnUNet_results, 'hrnet_results_summary_fold0.csv') + summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) + diff --git a/docker/template/src/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py b/docker/template/src/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py new file mode 100644 index 0000000..0a75fbd --- /dev/null +++ b/docker/template/src/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py @@ -0,0 +1,86 @@ +from copy import deepcopy +import numpy as np + + +def merge(dict1, dict2): + keys = np.unique(list(dict1.keys()) + list(dict2.keys())) + keys = np.unique(keys) + res = {} + for k in keys: + all_configs = [] + if dict1.get(k) is not None: + all_configs += list(dict1[k]) + if dict2.get(k) is not None: + all_configs += list(dict2[k]) + if len(all_configs) > 0: + res[k] = tuple(np.unique(all_configs)) + return res + + +if __name__ == "__main__": + # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of + # datasets for evaluation and future development + configurations_all = { + 2: ("3d_fullres", "2d"), + 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 4: ("2d", "3d_fullres"), + 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 20: ("2d", "3d_fullres"), + 24: ("2d", "3d_fullres"), + 27: ("2d", "3d_fullres"), + 38: ("2d", "3d_fullres"), + 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 82: ("2d", "3d_fullres"), + # 83: ("2d", "3d_fullres"), + } + + configurations_3d_fr_only = { + i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i] + } + + configurations_3d_c_only = { + i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i] + } + + configurations_3d_lr_only = { + i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i] + } + + configurations_2d_only = { + i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i] + } + + num_gpus = 1 + exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\" -R \"select[hname!='e230-dgx1-1']\" -R \"select[hname!='e230-dgxa100-1']\" -R \"select[hname!='e230-dgxa100-2']\" -R \"select[hname!='e230-dgxa100-3']\" -R \"select[hname!='e230-dgxa100-4']\"" + resources = "-R \"tensorcore\"" + gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=33G" + queue = "-q gpu-lowprio" + preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && " + train_command = 'nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_release nnUNetv2_train' + + folds = (0, ) + # use_this = configurations_2d_only + use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only) + # use_this = merge(use_this, configurations_3d_c_only) + + use_these_modules = { + 'nnUNetTrainer': ('nnUNetPlans',), + 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',), + # 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), + } + + additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' + + output_file = "/home/isensee/deleteme.txt" + with open(output_file, 'w') as f: + for tr in use_these_modules.keys(): + for p in use_these_modules[tr]: + for dataset in use_this.keys(): + for config in use_this[dataset]: + for fl in folds: + command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' + if additional_arguments is not None and len(additional_arguments) > 0: + command += f' {additional_arguments}' + f.write(f'{command}\"\n') + diff --git a/docker/template/src/nnunetv2/batch_running/release_trainings/__init__.py b/docker/template/src/nnunetv2/batch_running/release_trainings/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py b/docker/template/src/nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py b/docker/template/src/nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py new file mode 100644 index 0000000..828c396 --- /dev/null +++ b/docker/template/src/nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py @@ -0,0 +1,113 @@ +from typing import Tuple + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.evaluation.evaluate_predictions import load_summary_json +from nnunetv2.paths import nnUNet_results +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id +from nnunetv2.utilities.file_path_utilities import get_output_folder + + +def collect_results(trainers: dict, datasets: List, output_file: str, + configurations=("2d", "3d_fullres", "3d_lowres", "3d_cascade_fullres"), + folds=tuple(np.arange(5))): + results_dirs = (nnUNet_results,) + datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets] + with open(output_file, 'w') as f: + for i, d in zip(datasets, datasets_names): + for c in configurations: + for module in trainers.keys(): + for plans in trainers[module]: + for r in results_dirs: + expected_output_folder = get_output_folder(d, module, plans, c) + if isdir(expected_output_folder): + results_folds = [] + f.write(f"{d},{c},{module},{plans},{r}") + for fl in folds: + expected_output_folder_fold = get_output_folder(d, module, plans, c, fl) + expected_summary_file = join(expected_output_folder_fold, "validation", + "summary.json") + if not isfile(expected_summary_file): + print('expected output file not found:', expected_summary_file) + f.write(",") + results_folds.append(np.nan) + else: + foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][ + 'Dice'] + results_folds.append(foreground_mean) + f.write(f",{foreground_mean:02.4f}") + f.write(f",{np.nanmean(results_folds):02.4f}\n") + + +def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers): + txt = np.loadtxt(input_file, dtype=str, delimiter=',') + num_folds = txt.shape[1] - 6 + valid_configs = {} + for d in datasets: + if isinstance(d, int): + d = maybe_convert_to_dataset_name(d) + configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d]) + valid_configs[d] = [i for i in configs_in_txt if i in configs] + assert max(folds) < num_folds + + with open(output_file, 'w') as f: + f.write("name") + for d in valid_configs.keys(): + for c in valid_configs[d]: + f.write(",%d_%s" % (convert_dataset_name_to_id(d), c[:4])) + f.write(',mean\n') + valid_entries = txt[:, 4] == nnUNet_results + for t in trainers.keys(): + trainer_locs = valid_entries & (txt[:, 2] == t) + for pl in trainers[t]: + f.write(f"{t}__{pl}") + trainer_plan_locs = trainer_locs & (txt[:, 3] == pl) + r = [] + for d in valid_configs.keys(): + trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d) + for v in valid_configs[d]: + trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v) + if np.any(trainer_plan_d_config_locs): + # we cannot have more than one row + assert np.sum(trainer_plan_d_config_locs) == 1 + + # now check that we have all folds + selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]] + + fold_results = selected_row[[i + 5 for i in folds]] + + if '' in fold_results: + print('missing fold in', t, pl, d, v) + f.write(",nan") + r.append(np.nan) + else: + mean_dice = np.mean([float(i) for i in fold_results]) + f.write(f",{mean_dice:02.4f}") + r.append(mean_dice) + else: + print('missing:', t, pl, d, v) + f.write(",nan") + r.append(np.nan) + f.write(f",{np.mean(r):02.4f}\n") + + +if __name__ == '__main__': + use_these_trainers = { + 'nnUNetTrainer': ('nnUNetPlans',), + 'nnUNetTrainer_v1loss': ('nnUNetPlans',), + } + all_results_file = join(nnUNet_results, 'customDecResults.csv') + datasets = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 24, 27, 35, 38, 48, 55, 64, 82] + collect_results(use_these_trainers, datasets, all_results_file) + + folds = (0, 1, 2, 3, 4) + configs = ("3d_fullres", "3d_lowres") + output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv') + summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) + + folds = (0, ) + configs = ("3d_fullres", "3d_lowres") + output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv') + summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) + diff --git a/docker/template/src/nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py b/docker/template/src/nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py new file mode 100644 index 0000000..7c5934f --- /dev/null +++ b/docker/template/src/nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py @@ -0,0 +1,93 @@ +from copy import deepcopy +import numpy as np + + +def merge(dict1, dict2): + keys = np.unique(list(dict1.keys()) + list(dict2.keys())) + keys = np.unique(keys) + res = {} + for k in keys: + all_configs = [] + if dict1.get(k) is not None: + all_configs += list(dict1[k]) + if dict2.get(k) is not None: + all_configs += list(dict2[k]) + if len(all_configs) > 0: + res[k] = tuple(np.unique(all_configs)) + return res + + +if __name__ == "__main__": + # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of + # datasets for evaluation and future development + configurations_all = { + # 1: ("3d_fullres", "2d"), + 2: ("3d_fullres", "2d"), + # 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 4: ("2d", "3d_fullres"), + 5: ("2d", "3d_fullres"), + # 6: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 7: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 8: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 9: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 10: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 20: ("2d", "3d_fullres"), + 24: ("2d", "3d_fullres"), + 27: ("2d", "3d_fullres"), + 35: ("2d", "3d_fullres"), + 38: ("2d", "3d_fullres"), + # 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 82: ("2d", "3d_fullres"), + # 83: ("2d", "3d_fullres"), + } + + configurations_3d_fr_only = { + i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i] + } + + configurations_3d_c_only = { + i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i] + } + + configurations_3d_lr_only = { + i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i] + } + + configurations_2d_only = { + i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i] + } + + num_gpus = 1 + exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"" + resources = "-R \"tensorcore\"" + gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=1G" + queue = "-q gpu-lowprio" + preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && " + train_command = 'nnUNet_keep_files_open=True nnUNet_results=/dkfz/cluster/gpu/data/OE0441/isensee/nnUNet_results_remake_release_normfix nnUNetv2_train' + + folds = (0, 1, 2, 3, 4) + # use_this = configurations_2d_only + # use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only) + # use_this = merge(use_this, configurations_3d_c_only) + use_this = configurations_all + + use_these_modules = { + 'nnUNetTrainer': ('nnUNetPlans',), + } + + additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' + + output_file = "/home/isensee/deleteme.txt" + with open(output_file, 'w') as f: + for tr in use_these_modules.keys(): + for p in use_these_modules[tr]: + for dataset in use_this.keys(): + for config in use_this[dataset]: + for fl in folds: + command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' + if additional_arguments is not None and len(additional_arguments) > 0: + command += f' {additional_arguments}' + f.write(f'{command}\"\n') + diff --git a/docker/template/src/nnunetv2/configuration.py b/docker/template/src/nnunetv2/configuration.py new file mode 100644 index 0000000..cdc8cb6 --- /dev/null +++ b/docker/template/src/nnunetv2/configuration.py @@ -0,0 +1,10 @@ +import os + +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA + +default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc']) + +ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low +# resolution axis must be 3x as large as the next largest spacing) + +default_n_proc_DA = get_allowed_n_proc_DA() diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset027_ACDC.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset027_ACDC.py new file mode 100644 index 0000000..569ff6f --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset027_ACDC.py @@ -0,0 +1,87 @@ +import os +import shutil +from pathlib import Path + +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw + + +def make_out_dirs(dataset_id: int, task_name="ACDC"): + dataset_name = f"Dataset{dataset_id:03d}_{task_name}" + + out_dir = Path(nnUNet_raw.replace('"', "")) / dataset_name + out_train_dir = out_dir / "imagesTr" + out_labels_dir = out_dir / "labelsTr" + out_test_dir = out_dir / "imagesTs" + + os.makedirs(out_dir, exist_ok=True) + os.makedirs(out_train_dir, exist_ok=True) + os.makedirs(out_labels_dir, exist_ok=True) + os.makedirs(out_test_dir, exist_ok=True) + + return out_dir, out_train_dir, out_labels_dir, out_test_dir + + +def copy_files(src_data_folder: Path, train_dir: Path, labels_dir: Path, test_dir: Path): + """Copy files from the ACDC dataset to the nnUNet dataset folder. Returns the number of training cases.""" + patients_train = sorted([f for f in (src_data_folder / "training").iterdir() if f.is_dir()]) + patients_test = sorted([f for f in (src_data_folder / "testing").iterdir() if f.is_dir()]) + + num_training_cases = 0 + # Copy training files and corresponding labels. + for patient_dir in patients_train: + for file in patient_dir.iterdir(): + if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name: + # The stem is 'patient.nii', and the suffix is '.gz'. + # We split the stem and append _0000 to the patient part. + shutil.copy(file, train_dir / f"{file.stem.split('.')[0]}_0000.nii.gz") + num_training_cases += 1 + elif file.suffix == ".gz" and "_gt" in file.name: + shutil.copy(file, labels_dir / file.name.replace("_gt", "")) + + # Copy test files. + for patient_dir in patients_test: + for file in patient_dir.iterdir(): + if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name: + shutil.copy(file, test_dir / f"{file.stem.split('.')[0]}_0000.nii.gz") + + return num_training_cases + + +def convert_acdc(src_data_folder: str, dataset_id=27): + out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id) + num_training_cases = copy_files(Path(src_data_folder), train_dir, labels_dir, test_dir) + + generate_dataset_json( + str(out_dir), + channel_names={ + 0: "cineMRI", + }, + labels={ + "background": 0, + "RV": 1, + "MLV": 2, + "LVC": 3, + }, + file_ending=".nii.gz", + num_training_cases=num_training_cases, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--input_folder", + type=str, + help="The downloaded ACDC dataset dir. Should contain extracted 'training' and 'testing' folders.", + ) + parser.add_argument( + "-d", "--dataset_id", required=False, type=int, default=27, help="nnU-Net Dataset ID, default: 27" + ) + args = parser.parse_args() + print("Converting...") + convert_acdc(args.input_folder, args.dataset_id) + print("Done!") diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py new file mode 100644 index 0000000..eca22d0 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py @@ -0,0 +1,85 @@ +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +import tifffile +from batchgenerators.utilities.file_and_folder_operations import * +import shutil + + +if __name__ == '__main__': + """ + This is going to be my test dataset for working with tif as input and output images + + All we do here is copy the files and rename them. Not file conversions take place + """ + dataset_name = 'Dataset073_Fluo_C3DH_A549_SIM' + + imagestr = join(nnUNet_raw, dataset_name, 'imagesTr') + imagests = join(nnUNet_raw, dataset_name, 'imagesTs') + labelstr = join(nnUNet_raw, dataset_name, 'labelsTr') + maybe_mkdir_p(imagestr) + maybe_mkdir_p(imagests) + maybe_mkdir_p(labelstr) + + # we extract the downloaded train and test datasets to two separate folders and name them Fluo-C3DH-A549-SIM_train + # and Fluo-C3DH-A549-SIM_test + train_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_train' + test_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_test' + + # with the old nnU-Net we had to convert all the files to nifti. This is no longer required. We can just copy the + # tif files + + # tif is broken when it comes to spacing. No standards. Grr. So when we use tif nnU-Net expects a separate file + # that specifies the spacing. This file needs to exist for EVERY training/test case to allow for different spacings + # between files. Important! The spacing must align with the axes. + # Here when we do print(tifffile.imread('IMAGE').shape) we get (29, 300, 350). The low resolution axis is the first. + # The spacing on the website is griven in the wrong axis order. Great. + spacing = (1, 0.126, 0.126) + + # train set + for seq in ['01', '02']: + images_dir = join(train_source, seq) + seg_dir = join(train_source, seq + '_GT', 'SEG') + # if we were to be super clean we would go by IDs but here we just trust the files are sorted the correct way. + # Simpler filenames in the cell tracking challenge would be soooo nice. + images = subfiles(images_dir, suffix='.tif', sort=True, join=False) + segs = subfiles(seg_dir, suffix='.tif', sort=True, join=False) + for i, (im, se) in enumerate(zip(images, segs)): + target_name = f'{seq}_image_{i:03d}' + # we still need the '_0000' suffix for images! Otherwise we would not be able to support multiple input + # channels distributed over separate files + shutil.copy(join(images_dir, im), join(imagestr, target_name + '_0000.tif')) + # spacing file! + save_json({'spacing': spacing}, join(imagestr, target_name + '.json')) + shutil.copy(join(seg_dir, se), join(labelstr, target_name + '.tif')) + # spacing file! + save_json({'spacing': spacing}, join(labelstr, target_name + '.json')) + + # test set, same a strain just without the segmentations + for seq in ['01', '02']: + images_dir = join(test_source, seq) + images = subfiles(images_dir, suffix='.tif', sort=True, join=False) + for i, im in enumerate(images): + target_name = f'{seq}_image_{i:03d}' + shutil.copy(join(images_dir, im), join(imagests, target_name + '_0000.tif')) + # spacing file! + save_json({'spacing': spacing}, join(imagests, target_name + '.json')) + + # now we generate the dataset json + generate_dataset_json( + join(nnUNet_raw, dataset_name), + {0: 'fluorescence_microscopy'}, + {'background': 0, 'cell': 1}, + 60, + '.tif' + ) + + # custom split to ensure we are stratifying properly. This dataset only has 2 folds + caseids = [i[:-4] for i in subfiles(labelstr, suffix='.tif', join=False)] + splits = [] + splits.append( + {'train': [i for i in caseids if i.startswith('01_')], 'val': [i for i in caseids if i.startswith('02_')]} + ) + splits.append( + {'train': [i for i in caseids if i.startswith('02_')], 'val': [i for i in caseids if i.startswith('01_')]} + ) + save_json(splits, join(nnUNet_preprocessed, dataset_name, 'splits_final.json')) \ No newline at end of file diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset114_MNMs.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset114_MNMs.py new file mode 100644 index 0000000..20eecd6 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset114_MNMs.py @@ -0,0 +1,198 @@ +import csv +import os +import random +from pathlib import Path + +import nibabel as nib +from batchgenerators.utilities.file_and_folder_operations import load_json, save_json + +from nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_preprocessed + + +def read_csv(csv_file: str): + patient_info = {} + + with open(csv_file) as csvfile: + reader = csv.reader(csvfile) + headers = next(reader) + patient_index = headers.index("External code") + ed_index = headers.index("ED") + es_index = headers.index("ES") + vendor_index = headers.index("Vendor") + + for row in reader: + patient_info[row[patient_index]] = { + "ed": int(row[ed_index]), + "es": int(row[es_index]), + "vendor": row[vendor_index], + } + + return patient_info + + +# ------------------------------------------------------------------------------ +# Conversion to nnUNet format +# ------------------------------------------------------------------------------ +def convert_mnms(src_data_folder: Path, csv_file_name: str, dataset_id: int): + out_dir, out_train_dir, out_labels_dir, out_test_dir = make_out_dirs(dataset_id, task_name="MNMs") + patients_train = [f for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()] + patients_test = [f for f in (src_data_folder / "Testing").iterdir() if f.is_dir()] + + patient_info = read_csv(str(src_data_folder / csv_file_name)) + + save_cardiac_phases(patients_train, patient_info, out_train_dir, out_labels_dir) + save_cardiac_phases(patients_test, patient_info, out_test_dir) + + # There are non-orthonormal direction cosines in the test and validation data. + # Not sure if the data should be fixed, or we should skip the problematic data. + # patients_val = [f for f in (src_data_folder / "Validation").iterdir() if f.is_dir()] + # save_cardiac_phases(patients_val, patient_info, out_train_dir, out_labels_dir) + + generate_dataset_json( + str(out_dir), + channel_names={ + 0: "cineMRI", + }, + labels={"background": 0, "LVBP": 1, "LVM": 2, "RV": 3}, + file_ending=".nii.gz", + num_training_cases=len(patients_train) * 2, # 2 since we have ED and ES for each patient + ) + + +def save_cardiac_phases( + patients: list[Path], patient_info: dict[str, dict[str, int]], out_dir: Path, labels_dir: Path = None +): + for patient in patients: + print(f"Processing patient: {patient.name}") + + image = nib.load(patient / f"{patient.name}_sa.nii.gz") + ed_frame = patient_info[patient.name]["ed"] + es_frame = patient_info[patient.name]["es"] + + save_extracted_nifti_slice(image, ed_frame=ed_frame, es_frame=es_frame, out_dir=out_dir, patient=patient) + + if labels_dir: + label = nib.load(patient / f"{patient.name}_sa_gt.nii.gz") + save_extracted_nifti_slice(label, ed_frame=ed_frame, es_frame=es_frame, out_dir=labels_dir, patient=patient) + + +def save_extracted_nifti_slice(image, ed_frame: int, es_frame: int, out_dir: Path, patient: Path): + # Save only extracted diastole and systole slices from the 4D H x W x D x time volume. + image_ed = nib.Nifti1Image(image.dataobj[..., ed_frame], image.affine) + image_es = nib.Nifti1Image(image.dataobj[..., es_frame], image.affine) + + # Labels do not have modality identifiers. Labels always end with 'gt'. + suffix = ".nii.gz" if image.get_filename().endswith("_gt.nii.gz") else "_0000.nii.gz" + + nib.save(image_ed, str(out_dir / f"{patient.name}_frame{ed_frame:02d}{suffix}")) + nib.save(image_es, str(out_dir / f"{patient.name}_frame{es_frame:02d}{suffix}")) + + +# ------------------------------------------------------------------------------ +# Create custom splits +# ------------------------------------------------------------------------------ +def create_custom_splits(src_data_folder: Path, csv_file: str, dataset_id: int, num_val_patients: int = 25): + existing_splits = os.path.join(nnUNet_preprocessed, f"Dataset{dataset_id}_MNMs", "splits_final.json") + splits = load_json(existing_splits) + + patients_train = [f.name for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()] + # Filter out any patients not in the training set + patient_info = { + patient: data + for patient, data in read_csv(str(src_data_folder / csv_file)).items() + if patient in patients_train + } + + # Get train and validation patients for both vendors + patients_a = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "A"] + patients_b = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "B"] + train_a, val_a = get_vendor_split(patients_a, num_val_patients) + train_b, val_b = get_vendor_split(patients_b, num_val_patients) + + # Build filenames from corresponding patient frames + train_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_a for frame in ["es", "ed"]] + train_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_b for frame in ["es", "ed"]] + train_a_mix_1, train_a_mix_2 = train_a[: len(train_a) // 2], train_a[len(train_a) // 2 :] + train_b_mix_1, train_b_mix_2 = train_b[: len(train_b) // 2], train_b[len(train_b) // 2 :] + val_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_a for frame in ["es", "ed"]] + val_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_b for frame in ["es", "ed"]] + + for train_set in [train_a, train_b, train_a_mix_1 + train_b_mix_1, train_a_mix_2 + train_b_mix_2]: + # For each train set, we evaluate on A, B and (A + B) respectively + # See table 3 from the original paper for more details. + splits.append({"train": train_set, "val": val_a}) + splits.append({"train": train_set, "val": val_b}) + splits.append({"train": train_set, "val": val_a + val_b}) + + save_json(splits, existing_splits) + + +def get_vendor_split(patients: list[str], num_val_patients: int): + random.shuffle(patients) + total_patients = len(patients) + num_training_patients = total_patients - num_val_patients + return patients[:num_training_patients], patients[num_training_patients:] + + +if __name__ == "__main__": + import argparse + + class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): + pass + + parser = argparse.ArgumentParser(add_help=False, formatter_class=RawTextArgumentDefaultsHelpFormatter) + parser.add_argument( + "-h", + "--help", + action="help", + default=argparse.SUPPRESS, + help="MNMs conversion utility helper. This script can be used to convert MNMs data into the expected nnUNet " + "format. It can also be used to create additional custom splits, for explicitly training on combinations " + "of vendors A and B (see `--custom-splits`).\n" + "If you wish to generate the custom splits, run the following pipeline:\n\n" + "(1) Run `Dataset114_MNMs -i \n" + "(2) Run `nnUNetv2_plan_and_preprocess -d 114 --verify_dataset_integrity`\n" + "(3) Start training, but stop after initial splits are created: `nnUNetv2_train 114 2d 0`\n" + "(4) Re-run `Dataset114_MNMs`, with `-s True`.\n" + "(5) Re-run training.\n", + ) + parser.add_argument( + "-i", + "--input_folder", + type=str, + default="./data/M&Ms/OpenDataset/", + help="The downloaded MNMs dataset dir. Should contain a csv file, as well as Training, Validation and Testing " + "folders.", + ) + parser.add_argument( + "-c", + "--csv_file_name", + type=str, + default="211230_M&Ms_Dataset_information_diagnosis_opendataset.csv", + help="The csv file containing the dataset information.", + ), + parser.add_argument("-d", "--dataset_id", type=int, default=114, help="nnUNet Dataset ID.") + parser.add_argument( + "-s", + "--custom_splits", + type=bool, + default=False, + help="Whether to append custom splits for training and testing on different vendors. If True, will create " + "splits for training on patients from vendors A, B or a mix of A and B. Splits are tested on a hold-out " + "validation sets of patients from A, B or A and B combined. See section 2.4 and table 3 from " + "https://arxiv.org/abs/2011.07592 for more info.", + ) + + args = parser.parse_args() + args.input_folder = Path(args.input_folder) + + if args.custom_splits: + print("Appending custom splits...") + create_custom_splits(args.input_folder, args.csv_file_name, args.dataset_id) + else: + print("Converting...") + convert_mnms(args.input_folder, args.csv_file_name, args.dataset_id) + + print("Done!") diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset115_EMIDEC.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset115_EMIDEC.py new file mode 100644 index 0000000..e307e14 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset115_EMIDEC.py @@ -0,0 +1,61 @@ +import shutil +from pathlib import Path + +from nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json + + +def copy_files(src_data_dir: Path, src_test_dir: Path, train_dir: Path, labels_dir: Path, test_dir: Path): + """Copy files from the EMIDEC dataset to the nnUNet dataset folder. Returns the number of training cases.""" + patients_train = sorted([f for f in src_data_dir.iterdir() if f.is_dir()]) + patients_test = sorted([f for f in src_test_dir.iterdir() if f.is_dir()]) + + # Copy training files and corresponding labels. + for patient in patients_train: + train_file = patient / "Images" / f"{patient.name}.nii.gz" + label_file = patient / "Contours" / f"{patient.name}.nii.gz" + shutil.copy(train_file, train_dir / f"{train_file.stem.split('.')[0]}_0000.nii.gz") + shutil.copy(label_file, labels_dir) + + # Copy test files. + for patient in patients_test: + test_file = patient / "Images" / f"{patient.name}.nii.gz" + shutil.copy(test_file, test_dir / f"{test_file.stem.split('.')[0]}_0000.nii.gz") + + return len(patients_train) + + +def convert_emidec(src_data_dir: str, src_test_dir: str, dataset_id=27): + out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id, task_name="EMIDEC") + num_training_cases = copy_files(Path(src_data_dir), Path(src_test_dir), train_dir, labels_dir, test_dir) + + generate_dataset_json( + str(out_dir), + channel_names={ + 0: "cineMRI", + }, + labels={ + "background": 0, + "cavity": 1, + "normal_myocardium": 2, + "myocardial_infarction": 3, + "no_reflow": 4, + }, + file_ending=".nii.gz", + num_training_cases=num_training_cases, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input_dir", type=str, help="The EMIDEC dataset directory.") + parser.add_argument("-t", "--test_dir", type=str, help="The EMIDEC test set directory.") + parser.add_argument( + "-d", "--dataset_id", required=False, type=int, default=115, help="nnU-Net Dataset ID, default: 115" + ) + args = parser.parse_args() + print("Converting...") + convert_emidec(args.input_dir, args.test_dir, args.dataset_id) + print("Done!") diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py new file mode 100644 index 0000000..90dcc6c --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py @@ -0,0 +1,87 @@ +import multiprocessing +import shutil +from multiprocessing import Pool + +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw +from skimage import io +from acvl_utils.morphology.morphology_helper import generic_filter_components +from scipy.ndimage import binary_fill_holes + + +def load_and_covnert_case(input_image: str, input_seg: str, output_image: str, output_seg: str, + min_component_size: int = 50): + seg = io.imread(input_seg) + seg[seg == 255] = 1 + image = io.imread(input_image) + image = image.sum(2) + mask = image == (3 * 255) + # the dataset has large white areas in which road segmentations can exist but no image information is available. + # Remove the road label in these areas + mask = generic_filter_components(mask, filter_fn=lambda ids, sizes: [i for j, i in enumerate(ids) if + sizes[j] > min_component_size]) + mask = binary_fill_holes(mask) + seg[mask] = 0 + io.imsave(output_seg, seg, check_contrast=False) + shutil.copy(input_image, output_image) + + +if __name__ == "__main__": + # extracted archive from https://www.kaggle.com/datasets/insaff/massachusetts-roads-dataset?resource=download + source = '/media/fabian/data/raw_datasets/Massachussetts_road_seg/road_segmentation_ideal' + + dataset_name = 'Dataset120_RoadSegmentation' + + imagestr = join(nnUNet_raw, dataset_name, 'imagesTr') + imagests = join(nnUNet_raw, dataset_name, 'imagesTs') + labelstr = join(nnUNet_raw, dataset_name, 'labelsTr') + labelsts = join(nnUNet_raw, dataset_name, 'labelsTs') + maybe_mkdir_p(imagestr) + maybe_mkdir_p(imagests) + maybe_mkdir_p(labelstr) + maybe_mkdir_p(labelsts) + + train_source = join(source, 'training') + test_source = join(source, 'testing') + + with multiprocessing.get_context("spawn").Pool(8) as p: + + # not all training images have a segmentation + valid_ids = subfiles(join(train_source, 'output'), join=False, suffix='png') + num_train = len(valid_ids) + r = [] + for v in valid_ids: + r.append( + p.starmap_async( + load_and_covnert_case, + (( + join(train_source, 'input', v), + join(train_source, 'output', v), + join(imagestr, v[:-4] + '_0000.png'), + join(labelstr, v), + 50 + ),) + ) + ) + + # test set + valid_ids = subfiles(join(test_source, 'output'), join=False, suffix='png') + for v in valid_ids: + r.append( + p.starmap_async( + load_and_covnert_case, + (( + join(test_source, 'input', v), + join(test_source, 'output', v), + join(imagests, v[:-4] + '_0000.png'), + join(labelsts, v), + 50 + ),) + ) + ) + _ = [i.get() for i in r] + + generate_dataset_json(join(nnUNet_raw, dataset_name), {0: 'R', 1: 'G', 2: 'B'}, {'background': 0, 'road': 1}, + num_train, '.png', dataset_name=dataset_name) diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset137_BraTS21.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset137_BraTS21.py new file mode 100644 index 0000000..b4817d2 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset137_BraTS21.py @@ -0,0 +1,98 @@ +import multiprocessing +import shutil +from multiprocessing import Pool + +import SimpleITK as sitk +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw + + +def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None: + # use this for segmentation only!!! + # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3 + img = sitk.ReadImage(in_file) + img_npy = sitk.GetArrayFromImage(img) + + uniques = np.unique(img_npy) + for u in uniques: + if u not in [0, 1, 2, 4]: + raise RuntimeError('unexpected label') + + seg_new = np.zeros_like(img_npy) + seg_new[img_npy == 4] = 3 + seg_new[img_npy == 2] = 1 + seg_new[img_npy == 1] = 2 + img_corr = sitk.GetImageFromArray(seg_new) + img_corr.CopyInformation(img) + sitk.WriteImage(img_corr, out_file) + + +def convert_labels_back_to_BraTS(seg: np.ndarray): + new_seg = np.zeros_like(seg) + new_seg[seg == 1] = 2 + new_seg[seg == 3] = 4 + new_seg[seg == 2] = 1 + return new_seg + + +def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder): + a = sitk.ReadImage(join(input_folder, filename)) + b = sitk.GetArrayFromImage(a) + c = convert_labels_back_to_BraTS(b) + d = sitk.GetImageFromArray(c) + d.CopyInformation(a) + sitk.WriteImage(d, join(output_folder, filename)) + + +def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str, num_processes: int = 12): + """ + reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the + """ + maybe_mkdir_p(output_folder) + nii = subfiles(input_folder, suffix='.nii.gz', join=False) + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii))) + + +if __name__ == '__main__': + brats_data_dir = '/home/isensee/drives/E132-Rohdaten/BraTS_2021/training' + + task_id = 137 + task_name = "BraTS2021" + + foldername = "Dataset%03.0d_%s" % (task_id, task_name) + + # setting up nnU-Net folders + out_base = join(nnUNet_raw, foldername) + imagestr = join(out_base, "imagesTr") + labelstr = join(out_base, "labelsTr") + maybe_mkdir_p(imagestr) + maybe_mkdir_p(labelstr) + + case_ids = subdirs(brats_data_dir, prefix='BraTS', join=False) + + for c in case_ids: + shutil.copy(join(brats_data_dir, c, c + "_t1.nii.gz"), join(imagestr, c + '_0000.nii.gz')) + shutil.copy(join(brats_data_dir, c, c + "_t1ce.nii.gz"), join(imagestr, c + '_0001.nii.gz')) + shutil.copy(join(brats_data_dir, c, c + "_t2.nii.gz"), join(imagestr, c + '_0002.nii.gz')) + shutil.copy(join(brats_data_dir, c, c + "_flair.nii.gz"), join(imagestr, c + '_0003.nii.gz')) + + copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, c, c + "_seg.nii.gz"), + join(labelstr, c + '.nii.gz')) + + generate_dataset_json(out_base, + channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'}, + labels={ + 'background': 0, + 'whole tumor': (1, 2, 3), + 'tumor core': (2, 3), + 'enhancing tumor': (3, ) + }, + num_training_cases=len(case_ids), + file_ending='.nii.gz', + regions_class_order=(1, 2, 3), + license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863', + reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863', + dataset_release='1.0') diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py new file mode 100644 index 0000000..1f33cd7 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py @@ -0,0 +1,70 @@ +from batchgenerators.utilities.file_and_folder_operations import * +import shutil +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw + + +def convert_amos_task1(amos_base_dir: str, nnunet_dataset_id: int = 218): + """ + AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into + the train set. Having a 5-fold cross-validation is superior to a single train:val split + """ + task_name = "AMOS2022_postChallenge_task1" + + foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) + + # setting up nnU-Net folders + out_base = join(nnUNet_raw, foldername) + imagestr = join(out_base, "imagesTr") + imagests = join(out_base, "imagesTs") + labelstr = join(out_base, "labelsTr") + maybe_mkdir_p(imagestr) + maybe_mkdir_p(imagests) + maybe_mkdir_p(labelstr) + + dataset_json_source = load_json(join(amos_base_dir, 'dataset.json')) + + training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']] + tr_ctr = 0 + for tr in training_identifiers: + if int(tr.split("_")[-1]) <= 410: # these are the CT images + tr_ctr += 1 + shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) + shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz')) + + test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']] + for ts in test_identifiers: + if int(ts.split("_")[-1]) <= 500: # these are the CT images + shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz')) + + val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']] + for vl in val_identifiers: + if int(vl.split("_")[-1]) <= 409: # these are the CT images + tr_ctr += 1 + shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz')) + shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz')) + + generate_dataset_json(out_base, {0: "CT"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()}, + num_training_cases=tr_ctr, file_ending='.nii.gz', + dataset_name=task_name, reference='https://amos22.grand-challenge.org/', + release='https://zenodo.org/record/7262581', + overwrite_image_reader_writer='NibabelIOWithReorient', + description="This is the dataset as released AFTER the challenge event. It has the " + "validation set gt in it! We just use the validation images as additional " + "training cases because AMOS doesn't specify how they should be used. nnU-Net's" + " 5-fold CV is better than some random train:val split.") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('input_folder', type=str, + help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. " + "Use this link: https://zenodo.org/record/7262581." + "You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!") + parser.add_argument('-d', required=False, type=int, default=218, help='nnU-Net Dataset ID, default: 218') + args = parser.parse_args() + amos_base = args.input_folder + convert_amos_task1(amos_base, args.d) + + diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py new file mode 100644 index 0000000..9a5e2c6 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py @@ -0,0 +1,65 @@ +from batchgenerators.utilities.file_and_folder_operations import * +import shutil +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw + + +def convert_amos_task2(amos_base_dir: str, nnunet_dataset_id: int = 219): + """ + AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into + the train set. Having a 5-fold cross-validation is superior to a single train:val split + """ + task_name = "AMOS2022_postChallenge_task2" + + foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) + + # setting up nnU-Net folders + out_base = join(nnUNet_raw, foldername) + imagestr = join(out_base, "imagesTr") + imagests = join(out_base, "imagesTs") + labelstr = join(out_base, "labelsTr") + maybe_mkdir_p(imagestr) + maybe_mkdir_p(imagests) + maybe_mkdir_p(labelstr) + + dataset_json_source = load_json(join(amos_base_dir, 'dataset.json')) + + training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']] + for tr in training_identifiers: + shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) + shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz')) + + test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']] + for ts in test_identifiers: + shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz')) + + val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']] + for vl in val_identifiers: + shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz')) + shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz')) + + generate_dataset_json(out_base, {0: "either_CT_or_MR"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()}, + num_training_cases=len(training_identifiers) + len(val_identifiers), file_ending='.nii.gz', + dataset_name=task_name, reference='https://amos22.grand-challenge.org/', + release='https://zenodo.org/record/7262581', + overwrite_image_reader_writer='NibabelIOWithReorient', + description="This is the dataset as released AFTER the challenge event. It has the " + "validation set gt in it! We just use the validation images as additional " + "training cases because AMOS doesn't specify how they should be used. nnU-Net's" + " 5-fold CV is better than some random train:val split.") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('input_folder', type=str, + help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. " + "Use this link: https://zenodo.org/record/7262581." + "You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!") + parser.add_argument('-d', required=False, type=int, default=219, help='nnU-Net Dataset ID, default: 219') + args = parser.parse_args() + amos_base = args.input_folder + convert_amos_task2(amos_base, args.d) + + # /home/isensee/Downloads/amos22/amos22/ + diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py new file mode 100644 index 0000000..20a794c --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py @@ -0,0 +1,50 @@ +from batchgenerators.utilities.file_and_folder_operations import * +import shutil +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw + + +def convert_kits2023(kits_base_dir: str, nnunet_dataset_id: int = 220): + task_name = "KiTS2023" + + foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) + + # setting up nnU-Net folders + out_base = join(nnUNet_raw, foldername) + imagestr = join(out_base, "imagesTr") + labelstr = join(out_base, "labelsTr") + maybe_mkdir_p(imagestr) + maybe_mkdir_p(labelstr) + + cases = subdirs(kits_base_dir, prefix='case_', join=False) + for tr in cases: + shutil.copy(join(kits_base_dir, tr, 'imaging.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) + shutil.copy(join(kits_base_dir, tr, 'segmentation.nii.gz'), join(labelstr, f'{tr}.nii.gz')) + + generate_dataset_json(out_base, {0: "CT"}, + labels={ + "background": 0, + "kidney": (1, 2, 3), + "masses": (2, 3), + "tumor": 2 + }, + regions_class_order=(1, 3, 2), + num_training_cases=len(cases), file_ending='.nii.gz', + dataset_name=task_name, reference='none', + release='prerelease', + overwrite_image_reader_writer='NibabelIOWithReorient', + description="KiTS2023") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('input_folder', type=str, + help="The downloaded and extracted KiTS2023 dataset (must have case_XXXXX subfolders)") + parser.add_argument('-d', required=False, type=int, default=220, help='nnU-Net Dataset ID, default: 220') + args = parser.parse_args() + amos_base = args.input_folder + convert_kits2023(amos_base, args.d) + + # /media/isensee/raw_data/raw_datasets/kits23/dataset + diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset221_AutoPETII_2023.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset221_AutoPETII_2023.py new file mode 100644 index 0000000..56ef16e --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset221_AutoPETII_2023.py @@ -0,0 +1,70 @@ +from batchgenerators.utilities.file_and_folder_operations import * +import shutil +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed + + +def convert_autopet(autopet_base_dir:str = '/media/isensee/My Book1/AutoPET/nifti/FDG-PET-CT-Lesions', + nnunet_dataset_id: int = 221): + task_name = "AutoPETII_2023" + + foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) + + # setting up nnU-Net folders + out_base = join(nnUNet_raw, foldername) + imagestr = join(out_base, "imagesTr") + labelstr = join(out_base, "labelsTr") + maybe_mkdir_p(imagestr) + maybe_mkdir_p(labelstr) + + patients = subdirs(autopet_base_dir, prefix='PETCT', join=False) + n = 0 + identifiers = [] + for pat in patients: + patient_acquisitions = subdirs(join(autopet_base_dir, pat), join=False) + for pa in patient_acquisitions: + n += 1 + identifier = f"{pat}_{pa}" + identifiers.append(identifier) + if not isfile(join(imagestr, f'{identifier}_0000.nii.gz')): + shutil.copy(join(autopet_base_dir, pat, pa, 'CTres.nii.gz'), join(imagestr, f'{identifier}_0000.nii.gz')) + if not isfile(join(imagestr, f'{identifier}_0001.nii.gz')): + shutil.copy(join(autopet_base_dir, pat, pa, 'SUV.nii.gz'), join(imagestr, f'{identifier}_0001.nii.gz')) + if not isfile(join(imagestr, f'{identifier}.nii.gz')): + shutil.copy(join(autopet_base_dir, pat, pa, 'SEG.nii.gz'), join(labelstr, f'{identifier}.nii.gz')) + + generate_dataset_json(out_base, {0: "CT", 1:"CT"}, + labels={ + "background": 0, + "tumor": 1 + }, + num_training_cases=n, file_ending='.nii.gz', + dataset_name=task_name, reference='https://autopet-ii.grand-challenge.org/', + release='release', + # overwrite_image_reader_writer='NibabelIOWithReorient', + description=task_name) + + # manual split + splits = [] + for fold in range(5): + val_patients = patients[fold :: 5] + splits.append( + { + 'train': [i for i in identifiers if not any([i.startswith(v) for v in val_patients])], + 'val': [i for i in identifiers if any([i.startswith(v) for v in val_patients])], + } + ) + pp_out_dir = join(nnUNet_preprocessed, foldername) + maybe_mkdir_p(pp_out_dir) + save_json(splits, join(pp_out_dir, 'splits_final.json'), sort_keys=False) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('input_folder', type=str, + help="The downloaded and extracted autopet dataset (must have PETCT_XXX subfolders)") + parser.add_argument('-d', required=False, type=int, default=221, help='nnU-Net Dataset ID, default: 221') + args = parser.parse_args() + amos_base = args.input_folder + convert_autopet(amos_base, args.d) diff --git a/docker/template/src/nnunetv2/dataset_conversion/Dataset988_dummyDataset4.py b/docker/template/src/nnunetv2/dataset_conversion/Dataset988_dummyDataset4.py new file mode 100644 index 0000000..80b295d --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/Dataset988_dummyDataset4.py @@ -0,0 +1,32 @@ +import os + +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + +if __name__ == '__main__': + # creates a dummy dataset where there are no files in imagestr and labelstr + source_dataset = 'Dataset004_Hippocampus' + + target_dataset = 'Dataset987_dummyDataset4' + target_dataset_dir = join(nnUNet_raw, target_dataset) + maybe_mkdir_p(target_dataset_dir) + + dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, source_dataset)) + + # the returned dataset will have absolute paths. We should use relative paths so that you can freely copy + # datasets around between systems. As long as the source dataset is there it will continue working even if + # nnUNet_raw is in different locations + + # paths must be relative to target_dataset_dir!!! + for k in dataset.keys(): + dataset[k]['label'] = os.path.relpath(dataset[k]['label'], target_dataset_dir) + dataset[k]['images'] = [os.path.relpath(i, target_dataset_dir) for i in dataset[k]['images']] + + # load old dataset.json + dataset_json = load_json(join(nnUNet_raw, source_dataset, 'dataset.json')) + dataset_json['dataset'] = dataset + + # save + save_json(dataset_json, join(target_dataset_dir, 'dataset.json'), sort_keys=False) diff --git a/docker/template/src/nnunetv2/dataset_conversion/__init__.py b/docker/template/src/nnunetv2/dataset_conversion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/dataset_conversion/convert_MSD_dataset.py b/docker/template/src/nnunetv2/dataset_conversion/convert_MSD_dataset.py new file mode 100644 index 0000000..40dddc1 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/convert_MSD_dataset.py @@ -0,0 +1,133 @@ +import argparse +import multiprocessing +import shutil +from multiprocessing import Pool +from typing import Optional +import SimpleITK as sitk +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.dataset_name_id_conversion import find_candidate_datasets +from nnunetv2.configuration import default_num_processes +import numpy as np + + +def split_4d_nifti(filename, output_folder): + img_itk = sitk.ReadImage(filename) + dim = img_itk.GetDimension() + file_base = os.path.basename(filename) + if dim == 3: + shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz")) + return + elif dim != 4: + raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename)) + else: + img_npy = sitk.GetArrayFromImage(img_itk) + spacing = img_itk.GetSpacing() + origin = img_itk.GetOrigin() + direction = np.array(img_itk.GetDirection()).reshape(4,4) + # now modify these to remove the fourth dimension + spacing = tuple(list(spacing[:-1])) + origin = tuple(list(origin[:-1])) + direction = tuple(direction[:-1, :-1].reshape(-1)) + for i, t in enumerate(range(img_npy.shape[0])): + img = img_npy[t] + img_itk_new = sitk.GetImageFromArray(img) + img_itk_new.SetSpacing(spacing) + img_itk_new.SetOrigin(origin) + img_itk_new.SetDirection(direction) + sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i)) + + +def convert_msd_dataset(source_folder: str, overwrite_target_id: Optional[int] = None, + num_processes: int = default_num_processes) -> None: + if source_folder.endswith('/') or source_folder.endswith('\\'): + source_folder = source_folder[:-1] + + labelsTr = join(source_folder, 'labelsTr') + imagesTs = join(source_folder, 'imagesTs') + imagesTr = join(source_folder, 'imagesTr') + assert isdir(labelsTr), f"labelsTr subfolder missing in source folder" + assert isdir(imagesTs), f"imagesTs subfolder missing in source folder" + assert isdir(imagesTr), f"imagesTr subfolder missing in source folder" + dataset_json = join(source_folder, 'dataset.json') + assert isfile(dataset_json), f"dataset.json missing in source_folder" + + # infer source dataset id and name + task, dataset_name = os.path.basename(source_folder).split('_') + task_id = int(task[4:]) + + # check if target dataset id is taken + target_id = task_id if overwrite_target_id is None else overwrite_target_id + existing_datasets = find_candidate_datasets(target_id) + assert len(existing_datasets) == 0, f"Target dataset id {target_id} is already taken, please consider changing " \ + f"it using overwrite_target_id. Conflicting dataset: {existing_datasets} (check nnUNet_results, nnUNet_preprocessed and nnUNet_raw!)" + + target_dataset_name = f"Dataset{target_id:03d}_{dataset_name}" + target_folder = join(nnUNet_raw, target_dataset_name) + target_imagesTr = join(target_folder, 'imagesTr') + target_imagesTs = join(target_folder, 'imagesTs') + target_labelsTr = join(target_folder, 'labelsTr') + maybe_mkdir_p(target_imagesTr) + maybe_mkdir_p(target_imagesTs) + maybe_mkdir_p(target_labelsTr) + + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + results = [] + + # convert 4d train images + source_images = [i for i in subfiles(imagesTr, suffix='.nii.gz', join=False) if + not i.startswith('.') and not i.startswith('_')] + source_images = [join(imagesTr, i) for i in source_images] + + results.append( + p.starmap_async( + split_4d_nifti, zip(source_images, [target_imagesTr] * len(source_images)) + ) + ) + + # convert 4d test images + source_images = [i for i in subfiles(imagesTs, suffix='.nii.gz', join=False) if + not i.startswith('.') and not i.startswith('_')] + source_images = [join(imagesTs, i) for i in source_images] + + results.append( + p.starmap_async( + split_4d_nifti, zip(source_images, [target_imagesTs] * len(source_images)) + ) + ) + + # copy segmentations + source_images = [i for i in subfiles(labelsTr, suffix='.nii.gz', join=False) if + not i.startswith('.') and not i.startswith('_')] + for s in source_images: + shutil.copy(join(labelsTr, s), join(target_labelsTr, s)) + + [i.get() for i in results] + + dataset_json = load_json(dataset_json) + dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()} + dataset_json['file_ending'] = ".nii.gz" + dataset_json["channel_names"] = dataset_json["modality"] + del dataset_json["modality"] + del dataset_json["training"] + del dataset_json["test"] + save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False) + + +def entry_point(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', type=str, required=True, + help='Downloaded and extracted MSD dataset folder. CANNOT be nnUNetv1 dataset! Example: ' + '/home/fabian/Downloads/Task05_Prostate') + parser.add_argument('-overwrite_id', type=int, required=False, default=None, + help='Overwrite the dataset id. If not set we use the id of the MSD task (inferred from ' + 'folder name). Only use this if you already have an equivalently numbered dataset!') + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f'Number of processes used. Default: {default_num_processes}') + args = parser.parse_args() + convert_msd_dataset(args.i, args.overwrite_id, args.np) + + +if __name__ == '__main__': + entry_point() + # convert_msd_dataset('/home/fabian/Downloads/Task05_Prostate', overwrite_target_id=201) diff --git a/docker/template/src/nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py b/docker/template/src/nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py new file mode 100644 index 0000000..fb77533 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py @@ -0,0 +1,53 @@ +import shutil +from copy import deepcopy + +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, isdir, load_json, save_json +from nnunetv2.paths import nnUNet_raw + + +def convert(source_folder, target_dataset_name): + """ + remember that old tasks were called TaskXXX_YYY and new ones are called DatasetXXX_YYY + source_folder + """ + if isdir(join(nnUNet_raw, target_dataset_name)): + raise RuntimeError(f'Target dataset name {target_dataset_name} already exists. Aborting... ' + f'(we might break something). If you are sure you want to proceed, please manually ' + f'delete {join(nnUNet_raw, target_dataset_name)}') + maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) + shutil.copytree(join(source_folder, 'imagesTr'), join(nnUNet_raw, target_dataset_name, 'imagesTr')) + shutil.copytree(join(source_folder, 'labelsTr'), join(nnUNet_raw, target_dataset_name, 'labelsTr')) + if isdir(join(source_folder, 'imagesTs')): + shutil.copytree(join(source_folder, 'imagesTs'), join(nnUNet_raw, target_dataset_name, 'imagesTs')) + if isdir(join(source_folder, 'labelsTs')): + shutil.copytree(join(source_folder, 'labelsTs'), join(nnUNet_raw, target_dataset_name, 'labelsTs')) + if isdir(join(source_folder, 'imagesVal')): + shutil.copytree(join(source_folder, 'imagesVal'), join(nnUNet_raw, target_dataset_name, 'imagesVal')) + if isdir(join(source_folder, 'labelsVal')): + shutil.copytree(join(source_folder, 'labelsVal'), join(nnUNet_raw, target_dataset_name, 'labelsVal')) + shutil.copy(join(source_folder, 'dataset.json'), join(nnUNet_raw, target_dataset_name)) + + dataset_json = load_json(join(nnUNet_raw, target_dataset_name, 'dataset.json')) + del dataset_json['tensorImageSize'] + del dataset_json['numTest'] + del dataset_json['training'] + del dataset_json['test'] + dataset_json['channel_names'] = deepcopy(dataset_json['modality']) + del dataset_json['modality'] + + dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()} + dataset_json['file_ending'] = ".nii.gz" + save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False) + + +def convert_entry_point(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("input_folder", type=str, + help='Raw old nnUNet dataset. This must be the folder with imagesTr,labelsTr etc subfolders! ' + 'Please provide the PATH to the old Task, not just the task name. nnU-Net V2 does not ' + 'know where v1 tasks are.') + parser.add_argument("output_dataset_name", type=str, + help='New dataset NAME (not path!). Must follow the DatasetXXX_NAME convention!') + args = parser.parse_args() + convert(args.input_folder, args.output_dataset_name) diff --git a/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py b/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py new file mode 100644 index 0000000..e68c6a6 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py @@ -0,0 +1,74 @@ +import SimpleITK as sitk +import shutil + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json, nifti_files + +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.label_handling.label_handling import LabelManager + + +def sparsify_segmentation(seg: np.ndarray, label_manager: LabelManager, percent_of_slices: float) -> np.ndarray: + assert label_manager.has_ignore_label, "This preprocessor only works with datasets that have an ignore label!" + seg_new = np.ones_like(seg) * label_manager.ignore_label + x, y, z = seg.shape + # x + num_slices = max(1, round(x * percent_of_slices)) + selected_slices = np.random.choice(x, num_slices, replace=False) + seg_new[selected_slices] = seg[selected_slices] + # y + num_slices = max(1, round(y * percent_of_slices)) + selected_slices = np.random.choice(y, num_slices, replace=False) + seg_new[:, selected_slices] = seg[:, selected_slices] + # z + num_slices = max(1, round(z * percent_of_slices)) + selected_slices = np.random.choice(z, num_slices, replace=False) + seg_new[:, :, selected_slices] = seg[:, :, selected_slices] + return seg_new + + +if __name__ == '__main__': + dataset_name = 'IntegrationTest_Hippocampus_regions_ignore' + dataset_id = 996 + dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" + + try: + existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) + if existing_dataset_name != dataset_name: + raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " + f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " + f"nnUNet_results!") + except RuntimeError: + pass + + if isdir(join(nnUNet_raw, dataset_name)): + shutil.rmtree(join(nnUNet_raw, dataset_name)) + + source_dataset = maybe_convert_to_dataset_name(4) + shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) + + # additionally optimize entire hippocampus region, remove Posterior + dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) + dj['labels'] = { + 'background': 0, + 'hippocampus': (1, 2), + 'anterior': 1, + 'ignore': 3 + } + dj['regions_class_order'] = (2, 1) + save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) + + # now add ignore label to segmentation images + np.random.seed(1234) + lm = LabelManager(label_dict=dj['labels'], regions_class_order=dj.get('regions_class_order')) + + segs = nifti_files(join(nnUNet_raw, dataset_name, 'labelsTr')) + for s in segs: + seg_itk = sitk.ReadImage(s) + seg_npy = sitk.GetArrayFromImage(seg_itk) + seg_npy = sparsify_segmentation(seg_npy, lm, 0.1 / 3) + seg_itk_new = sitk.GetImageFromArray(seg_npy) + seg_itk_new.CopyInformation(seg_itk) + sitk.WriteImage(seg_itk_new, s) + diff --git a/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py b/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py new file mode 100644 index 0000000..b40c534 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py @@ -0,0 +1,37 @@ +import shutil + +from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json + +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.paths import nnUNet_raw + +if __name__ == '__main__': + dataset_name = 'IntegrationTest_Hippocampus_regions' + dataset_id = 997 + dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" + + try: + existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) + if existing_dataset_name != dataset_name: + raise FileExistsError( + f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " + f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " + f"nnUNet_results!") + except RuntimeError: + pass + + if isdir(join(nnUNet_raw, dataset_name)): + shutil.rmtree(join(nnUNet_raw, dataset_name)) + + source_dataset = maybe_convert_to_dataset_name(4) + shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) + + # additionally optimize entire hippocampus region, remove Posterior + dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) + dj['labels'] = { + 'background': 0, + 'hippocampus': (1, 2), + 'anterior': 1 + } + dj['regions_class_order'] = (2, 1) + save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) diff --git a/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py b/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py new file mode 100644 index 0000000..1781a27 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py @@ -0,0 +1,33 @@ +import shutil + +from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json + +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.paths import nnUNet_raw + + +if __name__ == '__main__': + dataset_name = 'IntegrationTest_Hippocampus_ignore' + dataset_id = 998 + dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" + + try: + existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) + if existing_dataset_name != dataset_name: + raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " + f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " + f"nnUNet_results!") + except RuntimeError: + pass + + if isdir(join(nnUNet_raw, dataset_name)): + shutil.rmtree(join(nnUNet_raw, dataset_name)) + + source_dataset = maybe_convert_to_dataset_name(4) + shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) + + # set class 2 to ignore label + dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) + dj['labels']['ignore'] = 2 + del dj['labels']['Posterior'] + save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) diff --git a/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py b/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py new file mode 100644 index 0000000..33075da --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py @@ -0,0 +1,27 @@ +import shutil + +from batchgenerators.utilities.file_and_folder_operations import isdir, join + +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.paths import nnUNet_raw + + +if __name__ == '__main__': + dataset_name = 'IntegrationTest_Hippocampus' + dataset_id = 999 + dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" + + try: + existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) + if existing_dataset_name != dataset_name: + raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " + f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " + f"nnUNet_results!") + except RuntimeError: + pass + + if isdir(join(nnUNet_raw, dataset_name)): + shutil.rmtree(join(nnUNet_raw, dataset_name)) + + source_dataset = maybe_convert_to_dataset_name(4) + shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) diff --git a/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py b/docker/template/src/nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/dataset_conversion/generate_dataset_json.py b/docker/template/src/nnunetv2/dataset_conversion/generate_dataset_json.py new file mode 100644 index 0000000..429fa05 --- /dev/null +++ b/docker/template/src/nnunetv2/dataset_conversion/generate_dataset_json.py @@ -0,0 +1,103 @@ +from typing import Tuple + +from batchgenerators.utilities.file_and_folder_operations import save_json, join + + +def generate_dataset_json(output_folder: str, + channel_names: dict, + labels: dict, + num_training_cases: int, + file_ending: str, + regions_class_order: Tuple[int, ...] = None, + dataset_name: str = None, reference: str = None, release: str = None, license: str = None, + description: str = None, + overwrite_image_reader_writer: str = None, **kwargs): + """ + Generates a dataset.json file in the output folder + + channel_names: + Channel names must map the index to the name of the channel, example: + { + 0: 'T1', + 1: 'CT' + } + Note that the channel names may influence the normalization scheme!! Learn more in the documentation. + + labels: + This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not. + Example regular labels: + { + 'background': 0, + 'left atrium': 1, + 'some other label': 2 + } + Example region-based training: + { + 'background': 0, + 'whole tumor': (1, 2, 3), + 'tumor core': (2, 3), + 'enhancing tumor': 3 + } + + Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background! + + num_training_cases: is used to double check all cases are there! + + file_ending: needed for finding the files correctly. IMPORTANT! File endings must match between images and + segmentations! + + dataset_name, reference, release, license, description: self-explanatory and not used by nnU-Net. Just for + completeness and as a reminder that these would be great! + + overwrite_image_reader_writer: If you need a special IO class for your dataset you can derive it from + BaseReaderWriter, place it into nnunet.imageio and reference it here by name + + kwargs: whatever you put here will be placed in the dataset.json as well + + """ + has_regions: bool = any([isinstance(i, (tuple, list)) and len(i) > 1 for i in labels.values()]) + if has_regions: + assert regions_class_order is not None, f"You have defined regions but regions_class_order is not set. " \ + f"You need that." + # channel names need strings as keys + keys = list(channel_names.keys()) + for k in keys: + if not isinstance(k, str): + channel_names[str(k)] = channel_names[k] + del channel_names[k] + + # labels need ints as values + for l in labels.keys(): + value = labels[l] + if isinstance(value, (tuple, list)): + value = tuple([int(i) for i in value]) + labels[l] = value + else: + labels[l] = int(labels[l]) + + dataset_json = { + 'channel_names': channel_names, # previously this was called 'modality'. I didn't like this so this is + # channel_names now. Live with it. + 'labels': labels, + 'numTraining': num_training_cases, + 'file_ending': file_ending, + } + + if dataset_name is not None: + dataset_json['name'] = dataset_name + if reference is not None: + dataset_json['reference'] = reference + if release is not None: + dataset_json['release'] = release + if license is not None: + dataset_json['licence'] = license + if description is not None: + dataset_json['description'] = description + if overwrite_image_reader_writer is not None: + dataset_json['overwrite_image_reader_writer'] = overwrite_image_reader_writer + if regions_class_order is not None: + dataset_json['regions_class_order'] = regions_class_order + + dataset_json.update(kwargs) + + save_json(dataset_json, join(output_folder, 'dataset.json'), sort_keys=False) diff --git a/docker/template/src/nnunetv2/ensembling/__init__.py b/docker/template/src/nnunetv2/ensembling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/ensembling/ensemble.py b/docker/template/src/nnunetv2/ensembling/ensemble.py new file mode 100644 index 0000000..d4a9be4 --- /dev/null +++ b/docker/template/src/nnunetv2/ensembling/ensemble.py @@ -0,0 +1,206 @@ +import argparse +import multiprocessing +import shutil +from copy import deepcopy +from multiprocessing import Pool +from typing import List, Union, Tuple + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import load_json, join, subfiles, \ + maybe_mkdir_p, isdir, save_pickle, load_pickle, isfile +from nnunetv2.configuration import default_num_processes +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.utilities.label_handling.label_handling import LabelManager +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + +def average_probabilities(list_of_files: List[str]) -> np.ndarray: + assert len(list_of_files), 'At least one file must be given in list_of_files' + avg = None + for f in list_of_files: + if avg is None: + avg = np.load(f)['probabilities'] + # maybe increase precision to prevent rounding errors + if avg.dtype != np.float32: + avg = avg.astype(np.float32) + else: + avg += np.load(f)['probabilities'] + avg /= len(list_of_files) + return avg + + +def merge_files(list_of_files, + output_filename_truncated: str, + output_file_ending: str, + image_reader_writer: BaseReaderWriter, + label_manager: LabelManager, + save_probabilities: bool = False): + # load the pkl file associated with the first file in list_of_files + properties = load_pickle(list_of_files[0][:-4] + '.pkl') + # load and average predictions + probabilities = average_probabilities(list_of_files) + segmentation = label_manager.convert_logits_to_segmentation(probabilities) + image_reader_writer.write_seg(segmentation, output_filename_truncated + output_file_ending, properties) + if save_probabilities: + np.savez_compressed(output_filename_truncated + '.npz', probabilities=probabilities) + save_pickle(probabilities, output_filename_truncated + '.pkl') + + +def ensemble_folders(list_of_input_folders: List[str], + output_folder: str, + save_merged_probabilities: bool = False, + num_processes: int = default_num_processes, + dataset_json_file_or_dict: str = None, + plans_json_file_or_dict: str = None): + """we need too much shit for this function. Problem is that we now have to support region-based training plus + multiple input/output formats so there isn't really a way around this. + + If plans and dataset json are not specified, we assume each of the folders has a corresponding plans.json + and/or dataset.json in it. These are usually copied into those folders by nnU-Net during prediction. + We just pick the dataset.json and plans.json from the first of the folders and we DONT check whether the 5 + folders contain the same plans etc! This can be a feature if results from different datasets are to be merged (only + works if label dict in dataset.json is the same between these datasets!!!)""" + if dataset_json_file_or_dict is not None: + if isinstance(dataset_json_file_or_dict, str): + dataset_json = load_json(dataset_json_file_or_dict) + else: + dataset_json = dataset_json_file_or_dict + else: + dataset_json = load_json(join(list_of_input_folders[0], 'dataset.json')) + + if plans_json_file_or_dict is not None: + if isinstance(plans_json_file_or_dict, str): + plans = load_json(plans_json_file_or_dict) + else: + plans = plans_json_file_or_dict + else: + plans = load_json(join(list_of_input_folders[0], 'plans.json')) + + plans_manager = PlansManager(plans) + + # now collect the files in each of the folders and enforce that all files are present in all folders + files_per_folder = [set(subfiles(i, suffix='.npz', join=False)) for i in list_of_input_folders] + # first build a set with all files + s = deepcopy(files_per_folder[0]) + for f in files_per_folder[1:]: + s.update(f) + for f in files_per_folder: + assert len(s.difference(f)) == 0, "Not all folders contain the same files for ensembling. Please only " \ + "provide folders that contain the predictions" + lists_of_lists_of_files = [[join(fl, fi) for fl in list_of_input_folders] for fi in s] + output_files_truncated = [join(output_folder, fi[:-4]) for fi in s] + + image_reader_writer = plans_manager.image_reader_writer_class() + label_manager = plans_manager.get_label_manager(dataset_json) + + maybe_mkdir_p(output_folder) + shutil.copy(join(list_of_input_folders[0], 'dataset.json'), output_folder) + + with multiprocessing.get_context("spawn").Pool(num_processes) as pool: + num_preds = len(s) + _ = pool.starmap( + merge_files, + zip( + lists_of_lists_of_files, + output_files_truncated, + [dataset_json['file_ending']] * num_preds, + [image_reader_writer] * num_preds, + [label_manager] * num_preds, + [save_merged_probabilities] * num_preds + ) + ) + + +def entry_point_ensemble_folders(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', nargs='+', type=str, required=True, + help='list of input folders') + parser.add_argument('-o', type=str, required=True, help='output folder') + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f"Numbers of processes used for ensembling. Default: {default_num_processes}") + parser.add_argument('--save_npz', action='store_true', required=False, help='Set this flag to store output ' + 'probabilities in separate .npz files') + + args = parser.parse_args() + ensemble_folders(args.i, args.o, args.save_npz, args.np) + + +def ensemble_crossvalidations(list_of_trained_model_folders: List[str], + output_folder: str, + folds: Union[Tuple[int, ...], List[int]] = (0, 1, 2, 3, 4), + num_processes: int = default_num_processes, + overwrite: bool = True) -> None: + """ + Feature: different configurations can now have different splits + """ + dataset_json = load_json(join(list_of_trained_model_folders[0], 'dataset.json')) + plans_manager = PlansManager(join(list_of_trained_model_folders[0], 'plans.json')) + + # first collect all unique filenames + files_per_folder = {} + unique_filenames = set() + for tr in list_of_trained_model_folders: + files_per_folder[tr] = {} + for f in folds: + if not isdir(join(tr, f'fold_{f}', 'validation')): + raise RuntimeError(f'Expected model output directory does not exist. You must train all requested ' + f'folds of the specified model.\nModel: {tr}\nFold: {f}') + files_here = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False) + if len(files_here) == 0: + raise RuntimeError(f"No .npz files found in folder {join(tr, f'fold_{f}', 'validation')}. Rerun your " + f"validation with the --npz flag. Use nnUNetv2_train [...] --val --npz.") + files_per_folder[tr][f] = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False) + unique_filenames.update(files_per_folder[tr][f]) + + # verify that all trained_model_folders have all predictions + ok = True + for tr, fi in files_per_folder.items(): + all_files_here = set() + for f in folds: + all_files_here.update(fi[f]) + diff = unique_filenames.difference(all_files_here) + if len(diff) > 0: + ok = False + print(f'model {tr} does not seem to contain all predictions. Missing: {diff}') + if not ok: + raise RuntimeError('There were missing files, see print statements above this one') + + # now we need to collect where these files are + file_mapping = [] + for tr in list_of_trained_model_folders: + file_mapping.append({}) + for f in folds: + for fi in files_per_folder[tr][f]: + # check for duplicates + assert fi not in file_mapping[-1].keys(), f"Duplicate detected. Case {fi} is present in more than " \ + f"one fold of model {tr}." + file_mapping[-1][fi] = join(tr, f'fold_{f}', 'validation', fi) + + lists_of_lists_of_files = [[fm[i] for fm in file_mapping] for i in unique_filenames] + output_files_truncated = [join(output_folder, fi[:-4]) for fi in unique_filenames] + + image_reader_writer = plans_manager.image_reader_writer_class() + maybe_mkdir_p(output_folder) + label_manager = plans_manager.get_label_manager(dataset_json) + + if not overwrite: + tmp = [isfile(i + dataset_json['file_ending']) for i in output_files_truncated] + lists_of_lists_of_files = [lists_of_lists_of_files[i] for i in range(len(tmp)) if not tmp[i]] + output_files_truncated = [output_files_truncated[i] for i in range(len(tmp)) if not tmp[i]] + + with multiprocessing.get_context("spawn").Pool(num_processes) as pool: + num_preds = len(lists_of_lists_of_files) + _ = pool.starmap( + merge_files, + zip( + lists_of_lists_of_files, + output_files_truncated, + [dataset_json['file_ending']] * num_preds, + [image_reader_writer] * num_preds, + [label_manager] * num_preds, + [False] * num_preds + ) + ) + + shutil.copy(join(list_of_trained_model_folders[0], 'plans.json'), join(output_folder, 'plans.json')) + shutil.copy(join(list_of_trained_model_folders[0], 'dataset.json'), join(output_folder, 'dataset.json')) diff --git a/docker/template/src/nnunetv2/evaluation/__init__.py b/docker/template/src/nnunetv2/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/evaluation/accumulate_cv_results.py b/docker/template/src/nnunetv2/evaluation/accumulate_cv_results.py new file mode 100644 index 0000000..f1a79f0 --- /dev/null +++ b/docker/template/src/nnunetv2/evaluation/accumulate_cv_results.py @@ -0,0 +1,58 @@ +import shutil +from typing import Union, List, Tuple + +from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, maybe_mkdir_p, subfiles, isfile + +from nnunetv2.configuration import default_num_processes +from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + +def accumulate_cv_results(trained_model_folder, + merged_output_folder: str, + folds: Union[List[int], Tuple[int, ...]], + num_processes: int = default_num_processes, + overwrite: bool = True): + """ + There are a lot of things that can get fucked up, so the simplest way to deal with potential problems is to + collect the cv results into a separate folder and then evaluate them again. No messing with summary_json files! + """ + + if overwrite and isdir(merged_output_folder): + shutil.rmtree(merged_output_folder) + maybe_mkdir_p(merged_output_folder) + + dataset_json = load_json(join(trained_model_folder, 'dataset.json')) + plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) + rw = plans_manager.image_reader_writer_class() + shutil.copy(join(trained_model_folder, 'dataset.json'), join(merged_output_folder, 'dataset.json')) + shutil.copy(join(trained_model_folder, 'plans.json'), join(merged_output_folder, 'plans.json')) + + did_we_copy_something = False + for f in folds: + expected_validation_folder = join(trained_model_folder, f'fold_{f}', 'validation') + if not isdir(expected_validation_folder): + raise RuntimeError(f"fold {f} of model {trained_model_folder} is missing. Please train it!") + predicted_files = subfiles(expected_validation_folder, suffix=dataset_json['file_ending'], join=False) + for pf in predicted_files: + if overwrite and isfile(join(merged_output_folder, pf)): + raise RuntimeError(f'More than one of your folds has a prediction for case {pf}') + if overwrite or not isfile(join(merged_output_folder, pf)): + shutil.copy(join(expected_validation_folder, pf), join(merged_output_folder, pf)) + did_we_copy_something = True + + if did_we_copy_something or not isfile(join(merged_output_folder, 'summary.json')): + label_manager = plans_manager.get_label_manager(dataset_json) + gt_folder = join(nnUNet_raw, plans_manager.dataset_name, 'labelsTr') + if not isdir(gt_folder): + gt_folder = join(nnUNet_preprocessed, plans_manager.dataset_name, 'gt_segmentations') + compute_metrics_on_folder(gt_folder, + merged_output_folder, + join(merged_output_folder, 'summary.json'), + rw, + dataset_json['file_ending'], + label_manager.foreground_regions if label_manager.has_regions else + label_manager.foreground_labels, + label_manager.ignore_label, + num_processes) diff --git a/docker/template/src/nnunetv2/evaluation/evaluate_predictions.py b/docker/template/src/nnunetv2/evaluation/evaluate_predictions.py new file mode 100644 index 0000000..a2c342a --- /dev/null +++ b/docker/template/src/nnunetv2/evaluation/evaluate_predictions.py @@ -0,0 +1,263 @@ +import multiprocessing +import os +from copy import deepcopy +from typing import Tuple, List, Union + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \ + isfile +from nnunetv2.configuration import default_num_processes +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json, \ + determine_reader_writer_from_file_ending +from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO +# the Evaluator class of the previous nnU-Net was great and all but man was it overengineered. Keep it simple +from nnunetv2.utilities.json_export import recursive_fix_for_json_export +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + +def label_or_region_to_key(label_or_region: Union[int, Tuple[int]]): + return str(label_or_region) + + +def key_to_label_or_region(key: str): + try: + return int(key) + except ValueError: + key = key.replace('(', '') + key = key.replace(')', '') + split = key.split(',') + return tuple([int(i) for i in split if len(i) > 0]) + + +def save_summary_json(results: dict, output_file: str): + """ + stupid json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit + ourselves + """ + results_converted = deepcopy(results) + # convert keys in mean metrics + results_converted['mean'] = {label_or_region_to_key(k): results['mean'][k] for k in results['mean'].keys()} + # convert metric_per_case + for i in range(len(results_converted["metric_per_case"])): + results_converted["metric_per_case"][i]['metrics'] = \ + {label_or_region_to_key(k): results["metric_per_case"][i]['metrics'][k] + for k in results["metric_per_case"][i]['metrics'].keys()} + # sort_keys=True will make foreground_mean the first entry and thus easy to spot + save_json(results_converted, output_file, sort_keys=True) + + +def load_summary_json(filename: str): + results = load_json(filename) + # convert keys in mean metrics + results['mean'] = {key_to_label_or_region(k): results['mean'][k] for k in results['mean'].keys()} + # convert metric_per_case + for i in range(len(results["metric_per_case"])): + results["metric_per_case"][i]['metrics'] = \ + {key_to_label_or_region(k): results["metric_per_case"][i]['metrics'][k] + for k in results["metric_per_case"][i]['metrics'].keys()} + return results + + +def labels_to_list_of_regions(labels: List[int]): + return [(i,) for i in labels] + + +def region_or_label_to_mask(segmentation: np.ndarray, region_or_label: Union[int, Tuple[int, ...]]) -> np.ndarray: + if np.isscalar(region_or_label): + return segmentation == region_or_label + else: + mask = np.zeros_like(segmentation, dtype=bool) + for r in region_or_label: + mask[segmentation == r] = True + return mask + + +def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None): + if ignore_mask is None: + use_mask = np.ones_like(mask_ref, dtype=bool) + else: + use_mask = ~ignore_mask + tp = np.sum((mask_ref & mask_pred) & use_mask) + fp = np.sum(((~mask_ref) & mask_pred) & use_mask) + fn = np.sum((mask_ref & (~mask_pred)) & use_mask) + tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask) + return tp, fp, fn, tn + + +def compute_metrics(reference_file: str, prediction_file: str, image_reader_writer: BaseReaderWriter, + labels_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]], + ignore_label: int = None) -> dict: + # load images + seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file) + seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file) + # spacing = seg_ref_dict['spacing'] + + ignore_mask = seg_ref == ignore_label if ignore_label is not None else None + + results = {} + results['reference_file'] = reference_file + results['prediction_file'] = prediction_file + results['metrics'] = {} + for r in labels_or_regions: + results['metrics'][r] = {} + mask_ref = region_or_label_to_mask(seg_ref, r) + mask_pred = region_or_label_to_mask(seg_pred, r) + tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask) + if tp + fp + fn == 0: + results['metrics'][r]['Dice'] = np.nan + results['metrics'][r]['IoU'] = np.nan + else: + results['metrics'][r]['Dice'] = 2 * tp / (2 * tp + fp + fn) + results['metrics'][r]['IoU'] = tp / (tp + fp + fn) + results['metrics'][r]['FP'] = fp + results['metrics'][r]['TP'] = tp + results['metrics'][r]['FN'] = fn + results['metrics'][r]['TN'] = tn + results['metrics'][r]['n_pred'] = fp + tp + results['metrics'][r]['n_ref'] = fn + tp + return results + + +def compute_metrics_on_folder(folder_ref: str, folder_pred: str, output_file: str, + image_reader_writer: BaseReaderWriter, + file_ending: str, + regions_or_labels: Union[List[int], List[Union[int, Tuple[int, ...]]]], + ignore_label: int = None, + num_processes: int = default_num_processes, + chill: bool = True) -> dict: + """ + output_file must end with .json; can be None + """ + if output_file is not None: + assert output_file.endswith('.json'), 'output_file should end with .json' + files_pred = subfiles(folder_pred, suffix=file_ending, join=False) + files_ref = subfiles(folder_ref, suffix=file_ending, join=False) + if not chill: + present = [isfile(join(folder_pred, i)) for i in files_ref] + assert all(present), "Not all files in folder_pred exist in folder_ref" + files_ref = [join(folder_ref, i) for i in files_pred] + files_pred = [join(folder_pred, i) for i in files_pred] + with multiprocessing.get_context("spawn").Pool(num_processes) as pool: + # for i in list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), [ignore_label] * len(files_pred))): + # compute_metrics(*i) + results = pool.starmap( + compute_metrics, + list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), + [ignore_label] * len(files_pred))) + ) + + # mean metric per class + metric_list = list(results[0]['metrics'][regions_or_labels[0]].keys()) + means = {} + for r in regions_or_labels: + means[r] = {} + for m in metric_list: + means[r][m] = np.nanmean([i['metrics'][r][m] for i in results]) + + # foreground mean + foreground_mean = {} + for m in metric_list: + values = [] + for k in means.keys(): + if k == 0 or k == '0': + continue + values.append(means[k][m]) + foreground_mean[m] = np.mean(values) + + [recursive_fix_for_json_export(i) for i in results] + recursive_fix_for_json_export(means) + recursive_fix_for_json_export(foreground_mean) + result = {'metric_per_case': results, 'mean': means, 'foreground_mean': foreground_mean} + if output_file is not None: + save_summary_json(result, output_file) + return result + # print('DONE') + + +def compute_metrics_on_folder2(folder_ref: str, folder_pred: str, dataset_json_file: str, plans_file: str, + output_file: str = None, + num_processes: int = default_num_processes, + chill: bool = False): + dataset_json = load_json(dataset_json_file) + # get file ending + file_ending = dataset_json['file_ending'] + + # get reader writer class + example_file = subfiles(folder_ref, suffix=file_ending, join=True)[0] + rw = determine_reader_writer_from_dataset_json(dataset_json, example_file)() + + # maybe auto set output file + if output_file is None: + output_file = join(folder_pred, 'summary.json') + + lm = PlansManager(plans_file).get_label_manager(dataset_json) + compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending, + lm.foreground_regions if lm.has_regions else lm.foreground_labels, lm.ignore_label, + num_processes, chill=chill) + + +def compute_metrics_on_folder_simple(folder_ref: str, folder_pred: str, labels: Union[Tuple[int, ...], List[int]], + output_file: str = None, + num_processes: int = default_num_processes, + ignore_label: int = None, + chill: bool = False): + example_file = subfiles(folder_ref, join=True)[0] + file_ending = os.path.splitext(example_file)[-1] + rw = determine_reader_writer_from_file_ending(file_ending, example_file, allow_nonmatching_filename=True, + verbose=False)() + # maybe auto set output file + if output_file is None: + output_file = join(folder_pred, 'summary.json') + compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending, + labels, ignore_label=ignore_label, num_processes=num_processes, chill=chill) + + +def evaluate_folder_entry_point(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('gt_folder', type=str, help='folder with gt segmentations') + parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations') + parser.add_argument('-djfile', type=str, required=True, + help='dataset.json file') + parser.add_argument('-pfile', type=str, required=True, + help='plans.json file') + parser.add_argument('-o', type=str, required=False, default=None, + help='Output file. Optional. Default: pred_folder/summary.json') + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f'number of processes used. Optional. Default: {default_num_processes}') + parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt') + args = parser.parse_args() + compute_metrics_on_folder2(args.gt_folder, args.pred_folder, args.djfile, args.pfile, args.o, args.np, chill=args.chill) + + +def evaluate_simple_entry_point(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('gt_folder', type=str, help='folder with gt segmentations') + parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations') + parser.add_argument('-l', type=int, nargs='+', required=True, + help='list of labels') + parser.add_argument('-il', type=int, required=False, default=None, + help='ignore label') + parser.add_argument('-o', type=str, required=False, default=None, + help='Output file. Optional. Default: pred_folder/summary.json') + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f'number of processes used. Optional. Default: {default_num_processes}') + parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt') + + args = parser.parse_args() + compute_metrics_on_folder_simple(args.gt_folder, args.pred_folder, args.l, args.o, args.np, args.il, chill=args.chill) + + +if __name__ == '__main__': + folder_ref = '/media/fabian/data/nnUNet_raw/Dataset004_Hippocampus/labelsTr' + folder_pred = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation' + output_file = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation/summary.json' + image_reader_writer = SimpleITKIO() + file_ending = '.nii.gz' + regions = labels_to_list_of_regions([1, 2]) + ignore_label = None + num_processes = 12 + compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, regions, ignore_label, + num_processes) diff --git a/docker/template/src/nnunetv2/evaluation/find_best_configuration.py b/docker/template/src/nnunetv2/evaluation/find_best_configuration.py new file mode 100644 index 0000000..7e9f774 --- /dev/null +++ b/docker/template/src/nnunetv2/evaluation/find_best_configuration.py @@ -0,0 +1,333 @@ +import argparse +import os.path +from copy import deepcopy +from typing import Union, List, Tuple + +from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, save_json + +from nnunetv2.configuration import default_num_processes +from nnunetv2.ensembling.ensemble import ensemble_crossvalidations +from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results +from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder, load_summary_json +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results +from nnunetv2.postprocessing.remove_connected_components import determine_postprocessing +from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name, get_output_folder, \ + convert_identifier_to_trainer_plans_config, get_ensemble_name, folds_tuple_to_string +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + +default_trained_models = tuple([ + {'plans': 'nnUNetPlans', 'configuration': '2d', 'trainer': 'nnUNetTrainer'}, + {'plans': 'nnUNetPlans', 'configuration': '3d_fullres', 'trainer': 'nnUNetTrainer'}, + {'plans': 'nnUNetPlans', 'configuration': '3d_lowres', 'trainer': 'nnUNetTrainer'}, + {'plans': 'nnUNetPlans', 'configuration': '3d_cascade_fullres', 'trainer': 'nnUNetTrainer'}, +]) + + +def filter_available_models(model_dict: Union[List[dict], Tuple[dict, ...]], dataset_name_or_id: Union[str, int]): + valid = [] + for trained_model in model_dict: + plans_manager = PlansManager(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id), + trained_model['plans'] + '.json')) + # check if configuration exists + # 3d_cascade_fullres and 3d_lowres do not exist for each dataset so we allow them to be absent IF they are not + # specified in the plans file + if trained_model['configuration'] not in plans_manager.available_configurations: + print(f"Configuration {trained_model['configuration']} not found in plans {trained_model['plans']}.\n" + f"Inferred plans file: {join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id), trained_model['plans'] + '.json')}.") + continue + + # check if trained model output folder exists. This is a requirement. No mercy here. + expected_output_folder = get_output_folder(dataset_name_or_id, trained_model['trainer'], trained_model['plans'], + trained_model['configuration'], fold=None) + if not isdir(expected_output_folder): + raise RuntimeError(f"Trained model {trained_model} does not have an output folder. " + f"Expected: {expected_output_folder}. Please run the training for this model! (don't forget " + f"the --npz flag if you want to ensemble multiple configurations)") + + valid.append(trained_model) + return valid + + +def generate_inference_command(dataset_name_or_id: Union[int, str], configuration_name: str, + plans_identifier: str = 'nnUNetPlans', trainer_name: str = 'nnUNetTrainer', + folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4), + folder_with_segs_from_prev_stage: str = None, + input_folder: str = 'INPUT_FOLDER', + output_folder: str = 'OUTPUT_FOLDER', + save_npz: bool = False): + fold_str = '' + for f in folds: + fold_str += f' {f}' + + predict_command = '' + trained_model_folder = get_output_folder(dataset_name_or_id, trainer_name, plans_identifier, configuration_name, fold=None) + plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) + configuration_manager = plans_manager.get_configuration(configuration_name) + if 'previous_stage' in plans_manager.available_configurations: + prev_stage = configuration_manager.previous_stage_name + predict_command += generate_inference_command(dataset_name_or_id, prev_stage, plans_identifier, trainer_name, + folds, None, output_folder='OUTPUT_FOLDER_PREV_STAGE') + '\n' + folder_with_segs_from_prev_stage = 'OUTPUT_FOLDER_PREV_STAGE' + + predict_command = f'nnUNetv2_predict -d {dataset_name_or_id} -i {input_folder} -o {output_folder} -f {fold_str} ' \ + f'-tr {trainer_name} -c {configuration_name} -p {plans_identifier}' + if folder_with_segs_from_prev_stage is not None: + predict_command += f' -prev_stage_predictions {folder_with_segs_from_prev_stage}' + if save_npz: + predict_command += ' --save_probabilities' + return predict_command + + +def find_best_configuration(dataset_name_or_id, + allowed_trained_models: Union[List[dict], Tuple[dict, ...]] = default_trained_models, + allow_ensembling: bool = True, + num_processes: int = default_num_processes, + overwrite: bool = True, + folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4), + strict: bool = False): + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + all_results = {} + + allowed_trained_models = filter_available_models(deepcopy(allowed_trained_models), dataset_name_or_id) + + for m in allowed_trained_models: + output_folder = get_output_folder(dataset_name_or_id, m['trainer'], m['plans'], m['configuration'], fold=None) + if not isdir(output_folder) and strict: + raise RuntimeError(f'{dataset_name}: The output folder of plans {m["plans"]} configuration ' + f'{m["configuration"]} is missing. Please train the model (all requested folds!) first!') + identifier = os.path.basename(output_folder) + merged_output_folder = join(output_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}') + accumulate_cv_results(output_folder, merged_output_folder, folds, num_processes, overwrite) + all_results[identifier] = { + 'source': merged_output_folder, + 'result': load_summary_json(join(merged_output_folder, 'summary.json'))['foreground_mean']['Dice'] + } + + if allow_ensembling: + for i in range(len(allowed_trained_models)): + for j in range(i + 1, len(allowed_trained_models)): + m1, m2 = allowed_trained_models[i], allowed_trained_models[j] + + output_folder_1 = get_output_folder(dataset_name_or_id, m1['trainer'], m1['plans'], m1['configuration'], fold=None) + output_folder_2 = get_output_folder(dataset_name_or_id, m2['trainer'], m2['plans'], m2['configuration'], fold=None) + identifier = get_ensemble_name(output_folder_1, output_folder_2, folds) + + output_folder_ensemble = join(nnUNet_results, dataset_name, 'ensembles', identifier) + + ensemble_crossvalidations([output_folder_1, output_folder_2], output_folder_ensemble, folds, + num_processes, overwrite=overwrite) + + # evaluate ensembled predictions + plans_manager = PlansManager(join(output_folder_1, 'plans.json')) + dataset_json = load_json(join(output_folder_1, 'dataset.json')) + label_manager = plans_manager.get_label_manager(dataset_json) + rw = plans_manager.image_reader_writer_class() + + compute_metrics_on_folder(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'), + output_folder_ensemble, + join(output_folder_ensemble, 'summary.json'), + rw, + dataset_json['file_ending'], + label_manager.foreground_regions if label_manager.has_regions else + label_manager.foreground_labels, + label_manager.ignore_label, + num_processes) + all_results[identifier] = \ + { + 'source': output_folder_ensemble, + 'result': load_summary_json(join(output_folder_ensemble, 'summary.json'))['foreground_mean']['Dice'] + } + + # pick best and report inference command + best_score = max([i['result'] for i in all_results.values()]) + best_keys = [k for k in all_results.keys() if all_results[k]['result'] == best_score] # may never happen but theoretically + # there can be a tie. Let's pick the first model in this case because it's going to be the simpler one (ensembles + # come after single configs) + best_key = best_keys[0] + + print() + print('***All results:***') + for k, v in all_results.items(): + print(f'{k}: {v["result"]}') + print(f'\n*Best*: {best_key}: {all_results[best_key]["result"]}') + print() + + print('***Determining postprocessing for best model/ensemble***') + determine_postprocessing(all_results[best_key]['source'], join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'), + plans_file_or_dict=join(all_results[best_key]['source'], 'plans.json'), + dataset_json_file_or_dict=join(all_results[best_key]['source'], 'dataset.json'), + num_processes=num_processes, keep_postprocessed_files=True) + + # in addition to just reading the console output (how it was previously) we should return the information + # needed to run the full inference via API + return_dict = { + 'folds': folds, + 'dataset_name_or_id': dataset_name_or_id, + 'considered_models': allowed_trained_models, + 'ensembling_allowed': allow_ensembling, + 'all_results': {i: j['result'] for i, j in all_results.items()}, + 'best_model_or_ensemble': { + 'result_on_crossval_pre_pp': all_results[best_key]["result"], + 'result_on_crossval_post_pp': load_json(join(all_results[best_key]['source'], 'postprocessed', 'summary.json'))['foreground_mean']['Dice'], + 'postprocessing_file': join(all_results[best_key]['source'], 'postprocessing.pkl'), + 'some_plans_file': join(all_results[best_key]['source'], 'plans.json'), + # just needed for label handling, can + # come from any of the ensemble members (if any) + 'selected_model_or_models': [] + } + } + # convert best key to inference command: + if best_key.startswith('ensemble___'): + prefix, m1, m2, folds_string = best_key.split('___') + tr1, pl1, c1 = convert_identifier_to_trainer_plans_config(m1) + tr2, pl2, c2 = convert_identifier_to_trainer_plans_config(m2) + return_dict['best_model_or_ensemble']['selected_model_or_models'].append( + { + 'configuration': c1, + 'trainer': tr1, + 'plans_identifier': pl1, + }) + return_dict['best_model_or_ensemble']['selected_model_or_models'].append( + { + 'configuration': c2, + 'trainer': tr2, + 'plans_identifier': pl2, + }) + else: + tr, pl, c = convert_identifier_to_trainer_plans_config(best_key) + return_dict['best_model_or_ensemble']['selected_model_or_models'].append( + { + 'configuration': c, + 'trainer': tr, + 'plans_identifier': pl, + }) + + save_json(return_dict, join(nnUNet_results, dataset_name, 'inference_information.json')) # save this so that we don't have to run this + # everything someone wants to be reminded of the inference commands. They can just load this and give it to + # print_inference_instructions + + # print it + print_inference_instructions(return_dict, instructions_file=join(nnUNet_results, dataset_name, 'inference_instructions.txt')) + return return_dict + + +def print_inference_instructions(inference_info_dict: dict, instructions_file: str = None): + def _print_and_maybe_write_to_file(string): + print(string) + if f_handle is not None: + f_handle.write(f'{string}\n') + + f_handle = open(instructions_file, 'w') if instructions_file is not None else None + print() + _print_and_maybe_write_to_file('***Run inference like this:***\n') + output_folders = [] + + dataset_name_or_id = inference_info_dict['dataset_name_or_id'] + if len(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']) > 1: + is_ensemble = True + _print_and_maybe_write_to_file('An ensemble won! What a surprise! Run the following commands to run predictions with the ensemble members:\n') + else: + is_ensemble = False + + for j, i in enumerate(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']): + tr, c, pl = i['trainer'], i['configuration'], i['plans_identifier'] + if is_ensemble: + output_folder_name = f"OUTPUT_FOLDER_MODEL_{j+1}" + else: + output_folder_name = f"OUTPUT_FOLDER" + output_folders.append(output_folder_name) + + _print_and_maybe_write_to_file(generate_inference_command(dataset_name_or_id, c, pl, tr, inference_info_dict['folds'], + save_npz=is_ensemble, output_folder=output_folder_name)) + + if is_ensemble: + output_folder_str = output_folders[0] + for o in output_folders[1:]: + output_folder_str += f' {o}' + output_ensemble = f"OUTPUT_FOLDER" + _print_and_maybe_write_to_file('\nThe run ensembling with:\n') + _print_and_maybe_write_to_file(f"nnUNetv2_ensemble -i {output_folder_str} -o {output_ensemble} -np {default_num_processes}") + + _print_and_maybe_write_to_file("\n***Once inference is completed, run postprocessing like this:***\n") + _print_and_maybe_write_to_file(f"nnUNetv2_apply_postprocessing -i OUTPUT_FOLDER -o OUTPUT_FOLDER_PP " + f"-pp_pkl_file {inference_info_dict['best_model_or_ensemble']['postprocessing_file']} -np {default_num_processes} " + f"-plans_json {inference_info_dict['best_model_or_ensemble']['some_plans_file']}") + + +def dumb_trainer_config_plans_to_trained_models_dict(trainers: List[str], configs: List[str], plans: List[str]): + """ + function is called dumb because it's dumb + """ + ret = [] + for t in trainers: + for c in configs: + for p in plans: + ret.append( + {'plans': p, 'configuration': c, 'trainer': t} + ) + return tuple(ret) + + +def find_best_configuration_entry_point(): + parser = argparse.ArgumentParser() + parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id') + parser.add_argument('-p', nargs='+', required=False, default=['nnUNetPlans'], + help='List of plan identifiers. Default: nnUNetPlans') + parser.add_argument('-c', nargs='+', required=False, default=['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres'], + help="List of configurations. Default: ['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres']") + parser.add_argument('-tr', nargs='+', required=False, default=['nnUNetTrainer'], + help='List of trainers. Default: nnUNetTrainer') + parser.add_argument('-np', required=False, default=default_num_processes, type=int, + help='Number of processes to use for ensembling, postprocessing etc') + parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4), + help='Folds to use. Default: 0 1 2 3 4') + parser.add_argument('--disable_ensembling', action='store_true', required=False, + help='Set this flag to disable ensembling') + parser.add_argument('--no_overwrite', action='store_true', + help='If set we will not overwrite already ensembled files etc. May speed up concecutive ' + 'runs of this command (why would you want to do that?) at the risk of not updating ' + 'outdated results.') + args = parser.parse_args() + + model_dict = dumb_trainer_config_plans_to_trained_models_dict(args.tr, args.c, args.p) + dataset_name = maybe_convert_to_dataset_name(args.dataset_name_or_id) + + find_best_configuration(dataset_name, model_dict, allow_ensembling=not args.disable_ensembling, + num_processes=args.np, overwrite=not args.no_overwrite, folds=args.f, + strict=False) + + +def accumulate_crossval_results_entry_point(): + parser = argparse.ArgumentParser('Copies all predicted segmentations from the individual folds into one joint ' + 'folder and evaluates them') + parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id') + parser.add_argument('-c', type=str, required=True, + default='3d_fullres', + help="Configuration") + parser.add_argument('-o', type=str, required=False, default=None, + help="Output folder. If not specified, the output folder will be located in the trained " \ + "model directory (named crossval_results_folds_XXX).") + parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4), + help='Folds to use. Default: 0 1 2 3 4') + parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', + help='Plan identifier in which to search for the specified configuration. Default: nnUNetPlans') + parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', + help='Trainer class. Default: nnUNetTrainer') + args = parser.parse_args() + trained_model_folder = get_output_folder(args.dataset_name_or_id, args.tr, args.p, args.c) + + if args.o is None: + merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}') + else: + merged_output_folder = args.o + + accumulate_cv_results(trained_model_folder, merged_output_folder, args.f) + + +if __name__ == '__main__': + find_best_configuration(4, + default_trained_models, + True, + 8, + False, + (0, 1, 2, 3, 4)) diff --git a/docker/template/src/nnunetv2/experiment_planning/__init__.py b/docker/template/src/nnunetv2/experiment_planning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/experiment_planning/dataset_fingerprint/__init__.py b/docker/template/src/nnunetv2/experiment_planning/dataset_fingerprint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py b/docker/template/src/nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py new file mode 100644 index 0000000..a4bec96 --- /dev/null +++ b/docker/template/src/nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py @@ -0,0 +1,199 @@ +import multiprocessing +import os +from time import sleep +from typing import List, Type, Union + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p +from tqdm import tqdm + +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + + +class DatasetFingerprintExtractor(object): + def __init__(self, dataset_name_or_id: Union[str, int], num_processes: int = 8, verbose: bool = False): + """ + extracts the dataset fingerprint used for experiment planning. The dataset fingerprint will be saved as a + json file in the input_folder + + Philosophy here is to do only what we really need. Don't store stuff that we can easily read from somewhere + else. Don't compute stuff we don't need (except for intensity_statistics_per_channel) + """ + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + self.verbose = verbose + + self.dataset_name = dataset_name + self.input_folder = join(nnUNet_raw, dataset_name) + self.num_processes = num_processes + self.dataset_json = load_json(join(self.input_folder, 'dataset.json')) + self.dataset = get_filenames_of_train_images_and_targets(self.input_folder, self.dataset_json) + + # We don't want to use all foreground voxels because that can accumulate a lot of data (out of memory). It is + # also not critically important to get all pixels as long as there are enough. Let's use 10e7 voxels in total + # (for the entire dataset) + self.num_foreground_voxels_for_intensitystats = 10e7 + + @staticmethod + def collect_foreground_intensities(segmentation: np.ndarray, images: np.ndarray, seed: int = 1234, + num_samples: int = 10000): + """ + images=image with multiple channels = shape (c, x, y(, z)) + """ + assert images.ndim == 4 + assert segmentation.ndim == 4 + + assert not np.any(np.isnan(segmentation)), "Segmentation contains NaN values. grrrr.... :-(" + assert not np.any(np.isnan(images)), "Images contains NaN values. grrrr.... :-(" + + rs = np.random.RandomState(seed) + + intensities_per_channel = [] + # we don't use the intensity_statistics_per_channel at all, it's just something that might be nice to have + intensity_statistics_per_channel = [] + + # segmentation is 4d: 1,x,y,z. We need to remove the empty dimension for the following code to work + foreground_mask = segmentation[0] > 0 + + for i in range(len(images)): + foreground_pixels = images[i][foreground_mask] + num_fg = len(foreground_pixels) + # sample with replacement so that we don't get issues with cases that have less than num_samples + # foreground_pixels. We could also just sample less in those cases but that would than cause these + # training cases to be underrepresented + intensities_per_channel.append( + rs.choice(foreground_pixels, num_samples, replace=True) if num_fg > 0 else []) + intensity_statistics_per_channel.append({ + 'mean': np.mean(foreground_pixels) if num_fg > 0 else np.nan, + 'median': np.median(foreground_pixels) if num_fg > 0 else np.nan, + 'min': np.min(foreground_pixels) if num_fg > 0 else np.nan, + 'max': np.max(foreground_pixels) if num_fg > 0 else np.nan, + 'percentile_99_5': np.percentile(foreground_pixels, 99.5) if num_fg > 0 else np.nan, + 'percentile_00_5': np.percentile(foreground_pixels, 0.5) if num_fg > 0 else np.nan, + + }) + + return intensities_per_channel, intensity_statistics_per_channel + + @staticmethod + def analyze_case(image_files: List[str], segmentation_file: str, reader_writer_class: Type[BaseReaderWriter], + num_samples: int = 10000): + rw = reader_writer_class() + images, properties_images = rw.read_images(image_files) + segmentation, properties_seg = rw.read_seg(segmentation_file) + + # we no longer crop and save the cropped images before this is run. Instead we run the cropping on the fly. + # Downside is that we need to do this twice (once here and once during preprocessing). Upside is that we don't + # need to save the cropped data anymore. Given that cropping is not too expensive it makes sense to do it this + # way. This is only possible because we are now using our new input/output interface. + data_cropped, seg_cropped, bbox = crop_to_nonzero(images, segmentation) + + foreground_intensities_per_channel, foreground_intensity_stats_per_channel = \ + DatasetFingerprintExtractor.collect_foreground_intensities(seg_cropped, data_cropped, + num_samples=num_samples) + + spacing = properties_images['spacing'] + + shape_before_crop = images.shape[1:] + shape_after_crop = data_cropped.shape[1:] + relative_size_after_cropping = np.prod(shape_after_crop) / np.prod(shape_before_crop) + return shape_after_crop, spacing, foreground_intensities_per_channel, foreground_intensity_stats_per_channel, \ + relative_size_after_cropping + + def run(self, overwrite_existing: bool = False) -> dict: + # we do not save the properties file in self.input_folder because that folder might be read-only. We can only + # reliably write in nnUNet_preprocessed and nnUNet_results, so nnUNet_preprocessed it is + preprocessed_output_folder = join(nnUNet_preprocessed, self.dataset_name) + maybe_mkdir_p(preprocessed_output_folder) + properties_file = join(preprocessed_output_folder, 'dataset_fingerprint.json') + + if not isfile(properties_file) or overwrite_existing: + reader_writer_class = determine_reader_writer_from_dataset_json(self.dataset_json, + # yikes. Rip the following line + self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0]) + + # determine how many foreground voxels we need to sample per training case + num_foreground_samples_per_case = int(self.num_foreground_voxels_for_intensitystats // + len(self.dataset)) + + r = [] + with multiprocessing.get_context("spawn").Pool(self.num_processes) as p: + for k in self.dataset.keys(): + r.append(p.starmap_async(DatasetFingerprintExtractor.analyze_case, + ((self.dataset[k]['images'], self.dataset[k]['label'], reader_writer_class, + num_foreground_samples_per_case),))) + remaining = list(range(len(self.dataset))) + # p is pretty nifti. If we kill workers they just respawn but don't do any work. + # So we need to store the original pool of workers. + workers = [j for j in p._pool] + with tqdm(desc=None, total=len(self.dataset), disable=self.verbose) as pbar: + while len(remaining) > 0: + all_alive = all([j.is_alive() for j in workers]) + if not all_alive: + raise RuntimeError('Some background worker is 6 feet under. Yuck. \n' + 'OK jokes aside.\n' + 'One of your background processes is missing. This could be because of ' + 'an error (look for an error message) or because it was killed ' + 'by your OS due to running out of RAM. If you don\'t see ' + 'an error message, out of RAM is likely the problem. In that case ' + 'reducing the number of workers might help') + done = [i for i in remaining if r[i].ready()] + for _ in done: + pbar.update() + remaining = [i for i in remaining if i not in done] + sleep(0.1) + + # results = ptqdm(DatasetFingerprintExtractor.analyze_case, + # (training_images_per_case, training_labels_per_case), + # processes=self.num_processes, zipped=True, reader_writer_class=reader_writer_class, + # num_samples=num_foreground_samples_per_case, disable=self.verbose) + results = [i.get()[0] for i in r] + + shapes_after_crop = [r[0] for r in results] + spacings = [r[1] for r in results] + foreground_intensities_per_channel = [np.concatenate([r[2][i] for r in results]) for i in + range(len(results[0][2]))] + # we drop this so that the json file is somewhat human readable + # foreground_intensity_stats_by_case_and_modality = [r[3] for r in results] + median_relative_size_after_cropping = np.median([r[4] for r in results], 0) + + num_channels = len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()) + intensity_statistics_per_channel = {} + for i in range(num_channels): + intensity_statistics_per_channel[i] = { + 'mean': float(np.mean(foreground_intensities_per_channel[i])), + 'median': float(np.median(foreground_intensities_per_channel[i])), + 'std': float(np.std(foreground_intensities_per_channel[i])), + 'min': float(np.min(foreground_intensities_per_channel[i])), + 'max': float(np.max(foreground_intensities_per_channel[i])), + 'percentile_99_5': float(np.percentile(foreground_intensities_per_channel[i], 99.5)), + 'percentile_00_5': float(np.percentile(foreground_intensities_per_channel[i], 0.5)), + } + + fingerprint = { + "spacings": spacings, + "shapes_after_crop": shapes_after_crop, + 'foreground_intensity_properties_per_channel': intensity_statistics_per_channel, + "median_relative_size_after_cropping": median_relative_size_after_cropping + } + + try: + save_json(fingerprint, properties_file) + except Exception as e: + if isfile(properties_file): + os.remove(properties_file) + raise e + else: + fingerprint = load_json(properties_file) + return fingerprint + + +if __name__ == '__main__': + dfe = DatasetFingerprintExtractor(2, 8) + dfe.run(overwrite_existing=False) diff --git a/docker/template/src/nnunetv2/experiment_planning/experiment_planners/__init__.py b/docker/template/src/nnunetv2/experiment_planning/experiment_planners/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/docker/template/src/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py new file mode 100644 index 0000000..1055170 --- /dev/null +++ b/docker/template/src/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -0,0 +1,542 @@ +import shutil +from copy import deepcopy +from functools import lru_cache +from typing import List, Union, Tuple, Type + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm + +from nnunetv2.configuration import ANISO_THRESHOLD +from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme +from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.json_export import recursive_fix_for_json_export +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + + +class ExperimentPlanner(object): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + """ + overwrite_target_spacing only affects 3d_fullres! (but by extension 3d_lowres which starts with fullres may + also be affected + """ + + self.dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + self.suppress_transpose = suppress_transpose + self.raw_dataset_folder = join(nnUNet_raw, self.dataset_name) + preprocessed_folder = join(nnUNet_preprocessed, self.dataset_name) + self.dataset_json = load_json(join(self.raw_dataset_folder, 'dataset.json')) + self.dataset = get_filenames_of_train_images_and_targets(self.raw_dataset_folder, self.dataset_json) + + # load dataset fingerprint + if not isfile(join(preprocessed_folder, 'dataset_fingerprint.json')): + raise RuntimeError('Fingerprint missing for this dataset. Please run nnUNet_extract_dataset_fingerprint') + + self.dataset_fingerprint = load_json(join(preprocessed_folder, 'dataset_fingerprint.json')) + + self.anisotropy_threshold = ANISO_THRESHOLD + + self.UNet_base_num_features = 32 + self.UNet_class = PlainConvUNet + # the following two numbers are really arbitrary and were set to reproduce nnU-Net v1's configurations as + # much as possible + self.UNet_reference_val_3d = 560000000 # 455600128 550000000 + self.UNet_reference_val_2d = 85000000 # 83252480 + self.UNet_reference_com_nfeatures = 32 + self.UNet_reference_val_corresp_GB = 8 + self.UNet_reference_val_corresp_bs_2d = 12 + self.UNet_reference_val_corresp_bs_3d = 2 + self.UNet_vram_target_GB = gpu_memory_target_in_gb + self.UNet_featuremap_min_edge_length = 4 + self.UNet_blocks_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + self.UNet_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + self.UNet_min_batch_size = 2 + self.UNet_max_features_2d = 512 + self.UNet_max_features_3d = 320 + + self.lowres_creation_threshold = 0.25 # if the patch size of fullres is less than 25% of the voxels in the + # median shape then we need a lowres config as well + + self.preprocessor_name = preprocessor_name + self.plans_identifier = plans_name + self.overwrite_target_spacing = overwrite_target_spacing + assert overwrite_target_spacing is None or len(overwrite_target_spacing), 'if overwrite_target_spacing is ' \ + 'used then three floats must be ' \ + 'given (as list or tuple)' + assert overwrite_target_spacing is None or all([isinstance(i, float) for i in overwrite_target_spacing]), \ + 'if overwrite_target_spacing is used then three floats must be given (as list or tuple)' + + self.plans = None + + def determine_reader_writer(self): + example_image = self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0] + return determine_reader_writer_from_dataset_json(self.dataset_json, example_image) + + @staticmethod + @lru_cache(maxsize=None) + def static_estimate_VRAM_usage(patch_size: Tuple[int], + n_stages: int, + strides: Union[int, List[int], Tuple[int, ...]], + UNet_class: Union[Type[PlainConvUNet], Type[ResidualEncoderUNet]], + num_input_channels: int, + features_per_stage: Tuple[int], + blocks_per_stage_encoder: Union[int, Tuple[int]], + blocks_per_stage_decoder: Union[int, Tuple[int]], + num_labels: int): + """ + Works for PlainConvUNet, ResidualEncoderUNet + """ + dim = len(patch_size) + conv_op = convert_dim_to_conv_op(dim) + norm_op = get_matching_instancenorm(conv_op) + net = UNet_class(num_input_channels, n_stages, + features_per_stage, + conv_op, + 3, + strides, + blocks_per_stage_encoder, + num_labels, + blocks_per_stage_decoder, + norm_op=norm_op) + return net.compute_conv_feature_map_size(patch_size) + + def determine_resampling(self, *args, **kwargs): + """ + returns what functions to use for resampling data and seg, respectively. Also returns kwargs + resampling function must be callable(data, current_spacing, new_spacing, **kwargs) + + determine_resampling is called within get_plans_for_configuration to allow for different functions for each + configuration + """ + resampling_data = resample_data_or_seg_to_shape + resampling_data_kwargs = { + "is_seg": False, + "order": 3, + "order_z": 0, + "force_separate_z": None, + } + resampling_seg = resample_data_or_seg_to_shape + resampling_seg_kwargs = { + "is_seg": True, + "order": 1, + "order_z": 0, + "force_separate_z": None, + } + return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs + + def determine_segmentation_softmax_export_fn(self, *args, **kwargs): + """ + function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be + used as target. current_spacing and new_spacing are merely there in case we want to use it somehow + + determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different + functions for each configuration + + """ + resampling_fn = resample_data_or_seg_to_shape + resampling_fn_kwargs = { + "is_seg": False, + "order": 1, + "order_z": 0, + "force_separate_z": None, + } + return resampling_fn, resampling_fn_kwargs + + def determine_fullres_target_spacing(self) -> np.ndarray: + """ + per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data + and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training + + For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic + (for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a spacing of 5 or 6 mm in the low + resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially + impact performance (due to the low number of slices). + """ + if self.overwrite_target_spacing is not None: + return np.array(self.overwrite_target_spacing) + + spacings = self.dataset_fingerprint['spacings'] + sizes = self.dataset_fingerprint['shapes_after_crop'] + + target = np.percentile(np.vstack(spacings), 50, 0) + + # todo sizes_after_resampling = [compute_new_shape(j, i, target) for i, j in zip(spacings, sizes)] + + target_size = np.percentile(np.vstack(sizes), 50, 0) + # we need to identify datasets for which a different target spacing could be beneficial. These datasets have + # the following properties: + # - one axis which much lower resolution than the others + # - the lowres axis has much less voxels than the others + # - (the size in mm of the lowres axis is also reduced) + worst_spacing_axis = np.argmax(target) + other_axes = [i for i in range(len(target)) if i != worst_spacing_axis] + other_spacings = [target[i] for i in other_axes] + other_sizes = [target_size[i] for i in other_axes] + + has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings)) + has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes) + + if has_aniso_spacing and has_aniso_voxels: + spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis] + target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10) + # don't let the spacing of that axis get higher than the other axes + if target_spacing_of_that_axis < max(other_spacings): + target_spacing_of_that_axis = max(max(other_spacings), target_spacing_of_that_axis) + 1e-5 + target[worst_spacing_axis] = target_spacing_of_that_axis + return target + + def determine_normalization_scheme_and_whether_mask_is_used_for_norm(self) -> Tuple[List[str], List[bool]]: + if 'channel_names' not in self.dataset_json.keys(): + print('WARNING: "modalities" should be renamed to "channel_names" in dataset.json. This will be ' + 'enforced soon!') + modalities = self.dataset_json['channel_names'] if 'channel_names' in self.dataset_json.keys() else \ + self.dataset_json['modality'] + normalization_schemes = [get_normalization_scheme(m) for m in modalities.values()] + if self.dataset_fingerprint['median_relative_size_after_cropping'] < (3 / 4.): + use_nonzero_mask_for_norm = [i.leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true for i in + normalization_schemes] + else: + use_nonzero_mask_for_norm = [False] * len(normalization_schemes) + assert all([i in (True, False) for i in use_nonzero_mask_for_norm]), 'use_nonzero_mask_for_norm must be ' \ + 'True or False and cannot be None' + normalization_schemes = [i.__name__ for i in normalization_schemes] + return normalization_schemes, use_nonzero_mask_for_norm + + def determine_transpose(self): + if self.suppress_transpose: + return [0, 1, 2], [0, 1, 2] + + # todo we should use shapes for that as well. Not quite sure how yet + target_spacing = self.determine_fullres_target_spacing() + + max_spacing_axis = np.argmax(target_spacing) + remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis] + transpose_forward = [max_spacing_axis] + remaining_axes + transpose_backward = [np.argwhere(np.array(transpose_forward) == i)[0][0] for i in range(3)] + return transpose_forward, transpose_backward + + def get_plans_for_configuration(self, + spacing: Union[np.ndarray, Tuple[float, ...], List[float]], + median_shape: Union[np.ndarray, Tuple[int, ...], List[int]], + data_identifier: str, + approximate_n_voxels_dataset: float) -> dict: + assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" + # print(spacing, median_shape, approximate_n_voxels_dataset) + # find an initial patch size + # we first use the spacing to get an aspect ratio + tmp = 1 / np.array(spacing) + + # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same + # volume as a patch of size 256 ** 3) + # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be + # ideal because large initial patch sizes increase computation time because more iterations in the while loop + # further down may be required. + if len(spacing) == 3: + initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] + elif len(spacing) == 2: + initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] + else: + raise RuntimeError() + + # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that + # this is different from how nnU-Net v1 does it! + # todo patch size can still get too large because we pad the patch size to a multiple of 2**n + initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) + + # use that to get the network topology. Note that this changes the patch_size depending on the number of + # pooling operations (must be divisible by 2**num_pool in each axis) + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + + # now estimate vram consumption + num_stages = len(pool_op_kernel_sizes) + estimate = self.static_estimate_VRAM_usage(tuple(patch_size), + num_stages, + tuple([tuple(i) for i in pool_op_kernel_sizes]), + self.UNet_class, + len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()), + tuple([min(self.UNet_max_features_2d if len(patch_size) == 2 else + self.UNet_max_features_3d, + self.UNet_reference_com_nfeatures * 2 ** i) for + i in range(len(pool_op_kernel_sizes))]), + self.UNet_blocks_per_stage_encoder[:num_stages], + self.UNet_blocks_per_stage_decoder[:num_stages - 1], + len(self.dataset_json['labels'].keys())) + + # how large is the reference for us here (batch size etc)? + # adapt for our vram target + reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ + (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) + + while estimate > reference: + # print(patch_size) + # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the + # aspect ratio the most (that is the largest relative to median shape) + axis_to_be_reduced = np.argsort(patch_size / median_shape[:len(spacing)])[-1] + + # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this + # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. + # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size + # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first + # subtract shape_must_be_divisible_by, then recompute it and then subtract the + # recomputed shape_must_be_divisible_by. Annoying. + tmp = deepcopy(patch_size) + tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + _, _, _, _, shape_must_be_divisible_by = \ + get_pool_and_conv_props(spacing, tmp, + self.UNet_featuremap_min_edge_length, + 999999) + patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + + # now recompute topology + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + + num_stages = len(pool_op_kernel_sizes) + estimate = self.static_estimate_VRAM_usage(tuple(patch_size), + num_stages, + tuple([tuple(i) for i in pool_op_kernel_sizes]), + self.UNet_class, + len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()), + tuple([min(self.UNet_max_features_2d if len(patch_size) == 2 else + self.UNet_max_features_3d, + self.UNet_reference_com_nfeatures * 2 ** i) for + i in range(len(pool_op_kernel_sizes))]), + self.UNet_blocks_per_stage_encoder[:num_stages], + self.UNet_blocks_per_stage_decoder[:num_stages - 1], + len(self.dataset_json['labels'].keys())) + + # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was + # executed. If not, additional vram headroom is used to increase batch size + ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d + batch_size = round((reference / estimate) * ref_bs) + + # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot + # go smaller than self.UNet_min_batch_size though + bs_corresponding_to_5_percent = round( + approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) + + resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() + resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() + + normalization_schemes, mask_is_used_for_norm = \ + self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() + num_stages = len(pool_op_kernel_sizes) + plan = { + 'data_identifier': data_identifier, + 'preprocessor_name': self.preprocessor_name, + 'batch_size': batch_size, + 'patch_size': patch_size, + 'median_image_size_in_voxels': median_shape, + 'spacing': spacing, + 'normalization_schemes': normalization_schemes, + 'use_mask_for_norm': mask_is_used_for_norm, + 'UNet_class_name': self.UNet_class.__name__, + 'UNet_base_num_features': self.UNet_base_num_features, + 'n_conv_per_stage_encoder': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'num_pool_per_axis': network_num_pool_per_axis, + 'pool_op_kernel_sizes': pool_op_kernel_sizes, + 'conv_kernel_sizes': conv_kernel_sizes, + 'unet_max_num_features': self.UNet_max_features_3d if len(spacing) == 3 else self.UNet_max_features_2d, + 'resampling_fn_data': resampling_data.__name__, + 'resampling_fn_seg': resampling_seg.__name__, + 'resampling_fn_data_kwargs': resampling_data_kwargs, + 'resampling_fn_seg_kwargs': resampling_seg_kwargs, + 'resampling_fn_probabilities': resampling_softmax.__name__, + 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, + } + return plan + + def plan_experiment(self): + """ + MOVE EVERYTHING INTO THE PLANS. MAXIMUM FLEXIBILITY + + Ideally I would like to move transpose_forward/backward into the configurations so that this can also be done + differently for each configuration but this would cause problems with identifying the correct axes for 2d. There + surely is a way around that but eh. I'm feeling lazy and featuritis must also not be pushed to the extremes. + + So for now if you want a different transpose_forward/backward you need to create a new planner. Also not too + hard. + """ + + # first get transpose + transpose_forward, transpose_backward = self.determine_transpose() + + # get fullres spacing and transpose it + fullres_spacing = self.determine_fullres_target_spacing() + fullres_spacing_transposed = fullres_spacing[transpose_forward] + + # get transposed new median shape (what we would have after resampling) + new_shapes = [compute_new_shape(j, i, fullres_spacing) for i, j in + zip(self.dataset_fingerprint['spacings'], self.dataset_fingerprint['shapes_after_crop'])] + new_median_shape = np.median(new_shapes, 0) + new_median_shape_transposed = new_median_shape[transpose_forward] + + approximate_n_voxels_dataset = float(np.prod(new_median_shape_transposed, dtype=np.float64) * + self.dataset_json['numTraining']) + # only run 3d if this is a 3d dataset + if new_median_shape_transposed[0] != 1: + plan_3d_fullres = self.get_plans_for_configuration(fullres_spacing_transposed, + new_median_shape_transposed, + self.generate_data_identifier('3d_fullres'), + approximate_n_voxels_dataset) + # maybe add 3d_lowres as well + patch_size_fullres = plan_3d_fullres['patch_size'] + median_num_voxels = np.prod(new_median_shape_transposed, dtype=np.float64) + num_voxels_in_patch = np.prod(patch_size_fullres, dtype=np.float64) + + plan_3d_lowres = None + lowres_spacing = deepcopy(plan_3d_fullres['spacing']) + + spacing_increase_factor = 1.03 # used to be 1.01 but that is slow with new GPU memory estimation! + + while num_voxels_in_patch / median_num_voxels < self.lowres_creation_threshold: + # we incrementally increase the target spacing. We start with the anisotropic axis/axes until it/they + # is/are similar (factor 2) to the other ax(i/e)s. + max_spacing = max(lowres_spacing) + if np.any((max_spacing / lowres_spacing) > 2): + lowres_spacing[(max_spacing / lowres_spacing) > 2] *= spacing_increase_factor + else: + lowres_spacing *= spacing_increase_factor + median_num_voxels = np.prod(plan_3d_fullres['spacing'] / lowres_spacing * new_median_shape_transposed, + dtype=np.float64) + # print(lowres_spacing) + plan_3d_lowres = self.get_plans_for_configuration(lowres_spacing, + [round(i) for i in plan_3d_fullres['spacing'] / + lowres_spacing * new_median_shape_transposed], + self.generate_data_identifier('3d_lowres'), + float(np.prod(median_num_voxels) * + self.dataset_json['numTraining'])) + num_voxels_in_patch = np.prod(plan_3d_lowres['patch_size'], dtype=np.int64) + print(f'Attempting to find 3d_lowres config. ' + f'\nCurrent spacing: {lowres_spacing}. ' + f'\nCurrent patch size: {plan_3d_lowres["patch_size"]}. ' + f'\nCurrent median shape: {plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed}') + if plan_3d_lowres is not None: + plan_3d_lowres['batch_dice'] = False + plan_3d_fullres['batch_dice'] = True + else: + plan_3d_fullres['batch_dice'] = False + else: + plan_3d_fullres = None + plan_3d_lowres = None + + # 2D configuration + plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:], + new_median_shape_transposed[1:], + self.generate_data_identifier('2d'), approximate_n_voxels_dataset) + plan_2d['batch_dice'] = True + + print('2D U-Net configuration:') + print(plan_2d) + print() + + # median spacing and shape, just for reference when printing the plans + median_spacing = np.median(self.dataset_fingerprint['spacings'], 0)[transpose_forward] + median_shape = np.median(self.dataset_fingerprint['shapes_after_crop'], 0)[transpose_forward] + + # instead of writing all that into the plans we just copy the original file. More files, but less crowded + # per file. + shutil.copy(join(self.raw_dataset_folder, 'dataset.json'), + join(nnUNet_preprocessed, self.dataset_name, 'dataset.json')) + + # json is stupid and I hate it... "Object of type int64 is not JSON serializable" -> my ass + plans = { + 'dataset_name': self.dataset_name, + 'plans_name': self.plans_identifier, + 'original_median_spacing_after_transp': [float(i) for i in median_spacing], + 'original_median_shape_after_transp': [int(round(i)) for i in median_shape], + 'image_reader_writer': self.determine_reader_writer().__name__, + 'transpose_forward': [int(i) for i in transpose_forward], + 'transpose_backward': [int(i) for i in transpose_backward], + 'configurations': {'2d': plan_2d}, + 'experiment_planner_used': self.__class__.__name__, + 'label_manager': 'LabelManager', + 'foreground_intensity_properties_per_channel': self.dataset_fingerprint[ + 'foreground_intensity_properties_per_channel'] + } + + if plan_3d_lowres is not None: + plans['configurations']['3d_lowres'] = plan_3d_lowres + if plan_3d_fullres is not None: + plans['configurations']['3d_lowres']['next_stage'] = '3d_cascade_fullres' + print('3D lowres U-Net configuration:') + print(plan_3d_lowres) + print() + if plan_3d_fullres is not None: + plans['configurations']['3d_fullres'] = plan_3d_fullres + print('3D fullres U-Net configuration:') + print(plan_3d_fullres) + print() + if plan_3d_lowres is not None: + plans['configurations']['3d_cascade_fullres'] = { + 'inherits_from': '3d_fullres', + 'previous_stage': '3d_lowres' + } + + plans['configurations']['2d_p256'] = { + 'inherits_from': '2d', + 'patch_size': [256, 256] + } + + plans['configurations']['2d_p512'] = { + 'inherits_from': '2d', + 'patch_size': [512, 512] + } + + self.plans = plans + self.save_plans(plans) + return plans + + def save_plans(self, plans): + recursive_fix_for_json_export(plans) + + plans_file = join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json') + + # we don't want to overwrite potentially existing custom configurations every time this is executed. So let's + # read the plans file if it already exists and keep any non-default configurations + if isfile(plans_file): + old_plans = load_json(plans_file) + old_configurations = old_plans['configurations'] + for c in plans['configurations'].keys(): + if c in old_configurations.keys(): + del (old_configurations[c]) + plans['configurations'].update(old_configurations) + + maybe_mkdir_p(join(nnUNet_preprocessed, self.dataset_name)) + save_json(plans, plans_file, sort_keys=False) + print(f"Plans were saved to {join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')}") + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + return self.plans_identifier + '_' + configuration_name + + def load_plans(self, fname: str): + self.plans = load_json(fname) + + +if __name__ == '__main__': + ExperimentPlanner(2, 8).plan_experiment() diff --git a/docker/template/src/nnunetv2/experiment_planning/experiment_planners/network_topology.py b/docker/template/src/nnunetv2/experiment_planning/experiment_planners/network_topology.py new file mode 100644 index 0000000..1ce6a46 --- /dev/null +++ b/docker/template/src/nnunetv2/experiment_planning/experiment_planners/network_topology.py @@ -0,0 +1,105 @@ +from copy import deepcopy +import numpy as np + + +def get_shape_must_be_divisible_by(net_numpool_per_axis): + return 2 ** np.array(net_numpool_per_axis) + + +def pad_shape(shape, must_be_divisible_by): + """ + pads shape so that it is divisible by must_be_divisible_by + :param shape: + :param must_be_divisible_by: + :return: + """ + if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)): + must_be_divisible_by = [must_be_divisible_by] * len(shape) + else: + assert len(must_be_divisible_by) == len(shape) + + new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))] + + for i in range(len(shape)): + if shape[i] % must_be_divisible_by[i] == 0: + new_shp[i] -= must_be_divisible_by[i] + new_shp = np.array(new_shp).astype(int) + return new_shp + + +def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool): + """ + this is the same as get_pool_and_conv_props_v2 from old nnunet + + :param spacing: + :param patch_size: + :param min_feature_map_size: min edge length of feature maps in bottleneck + :param max_numpool: + :return: + """ + # todo review this code + dim = len(spacing) + + current_spacing = deepcopy(list(spacing)) + current_size = deepcopy(list(patch_size)) + + pool_op_kernel_sizes = [[1] * len(spacing)] + conv_kernel_sizes = [] + + num_pool_per_axis = [0] * dim + kernel_size = [1] * dim + + while True: + # exclude axes that we cannot pool further because of min_feature_map_size constraint + valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size] + if len(valid_axes_for_pool) < 1: + break + + spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool] + + # find axis that are within factor of 2 within smallest spacing + min_spacing_of_valid = min(spacings_of_axes) + valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2] + + # max_numpool constraint + valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool] + + if len(valid_axes_for_pool) == 1: + if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size: + pass + else: + break + if len(valid_axes_for_pool) < 1: + break + + # now we need to find kernel sizes + # kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within + # factor 2 of min_spacing. Once they are 3 they remain 3 + for d in range(dim): + if kernel_size[d] == 3: + continue + else: + if current_spacing[d] / min(current_spacing) < 2: + kernel_size[d] = 3 + + other_axes = [i for i in range(dim) if i not in valid_axes_for_pool] + + pool_kernel_sizes = [0] * dim + for v in valid_axes_for_pool: + pool_kernel_sizes[v] = 2 + num_pool_per_axis[v] += 1 + current_spacing[v] *= 2 + current_size[v] = np.ceil(current_size[v] / 2) + for nv in other_axes: + pool_kernel_sizes[nv] = 1 + + pool_op_kernel_sizes.append(pool_kernel_sizes) + conv_kernel_sizes.append(deepcopy(kernel_size)) + #print(conv_kernel_sizes) + + must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) + patch_size = pad_shape(patch_size, must_be_divisible_by) + + # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here + conv_kernel_sizes.append([3]*dim) + return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by diff --git a/docker/template/src/nnunetv2/experiment_planning/experiment_planners/readme.md b/docker/template/src/nnunetv2/experiment_planning/experiment_planners/readme.md new file mode 100644 index 0000000..e2e4e18 --- /dev/null +++ b/docker/template/src/nnunetv2/experiment_planning/experiment_planners/readme.md @@ -0,0 +1,38 @@ +What do experiment planners need to do (these are notes for myself while rewriting nnU-Net, they are provided as is +without further explanations. These notes also include new features): +- (done) preprocessor name should be configurable via cli +- (done) gpu memory target should be configurable via cli +- (done) plans name should be configurable via cli +- (done) data name should be specified in plans (plans specify the data they want to use, this will allow us to manually + edit plans files without having to copy the data folders) +- plans must contain: + - (done) transpose forward/backward + - (done) preprocessor name (can differ for each config) + - (done) spacing + - (done) normalization scheme + - (done) target spacing + - (done) conv and pool op kernel sizes + - (done) base num features for architecture + - (done) data identifier + - num conv per stage? + - (done) use mask for norm + - [NO. Handled by LabelManager & dataset.json] num segmentation outputs + - [NO. Handled by LabelManager & dataset.json] ignore class + - [NO. Handled by LabelManager & dataset.json] list of regions or classes + - [NO. Handled by LabelManager & dataset.json] regions class order, if applicable + - (done) resampling function to be used + - (done) the image reader writer class that should be used + + +dataset.json +mandatory: +- numTraining +- labels (value 'ignore' has special meaning. Cannot have more than one ignore_label) +- modalities +- file_ending + +optional +- overwrite_image_reader_writer (if absent, auto) +- regions +- region_class_order +- \ No newline at end of file diff --git a/docker/template/src/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py b/docker/template/src/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py new file mode 100644 index 0000000..52ca938 --- /dev/null +++ b/docker/template/src/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py @@ -0,0 +1,54 @@ +from typing import Union, List, Tuple + +from torch import nn + +from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner +from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet + + +class ResEncUNetPlanner(ExperimentPlanner): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + self.UNet_base_num_features = 32 + self.UNet_class = ResidualEncoderUNet + # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as + # much as possible + self.UNet_reference_val_3d = 680000000 + self.UNet_reference_val_2d = 135000000 + self.UNet_reference_com_nfeatures = 32 + self.UNet_reference_val_corresp_GB = 8 + self.UNet_reference_val_corresp_bs_2d = 12 + self.UNet_reference_val_corresp_bs_3d = 2 + self.UNet_featuremap_min_edge_length = 4 + self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) + self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) + self.UNet_min_batch_size = 2 + self.UNet_max_features_2d = 512 + self.UNet_max_features_3d = 320 + + +if __name__ == '__main__': + # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively + net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), + conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2), + n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3, + n_conv_per_stage_decoder=(1, 1, 1, 1, 1), + conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None, + nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) + print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned + # from this one to match the regular nnunetplans more closely + + net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512), + conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2), + n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3, + n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1), + conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, + nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) + print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 + diff --git a/docker/template/src/nnunetv2/experiment_planning/plan_and_preprocess_api.py b/docker/template/src/nnunetv2/experiment_planning/plan_and_preprocess_api.py new file mode 100644 index 0000000..7748572 --- /dev/null +++ b/docker/template/src/nnunetv2/experiment_planning/plan_and_preprocess_api.py @@ -0,0 +1,137 @@ +from typing import List, Type, Optional, Tuple, Union + +import nnunetv2 +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, load_json + +from nnunetv2.experiment_planning.dataset_fingerprint.fingerprint_extractor import DatasetFingerprintExtractor +from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner +from nnunetv2.experiment_planning.verify_dataset_integrity import verify_dataset_integrity +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager +from nnunetv2.configuration import default_num_processes +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + + +def extract_fingerprint_dataset(dataset_id: int, + fingerprint_extractor_class: Type[ + DatasetFingerprintExtractor] = DatasetFingerprintExtractor, + num_processes: int = default_num_processes, check_dataset_integrity: bool = False, + clean: bool = True, verbose: bool = True): + """ + Returns the fingerprint as a dictionary (additionally to saving it) + """ + dataset_name = convert_id_to_dataset_name(dataset_id) + print(dataset_name) + + if check_dataset_integrity: + verify_dataset_integrity(join(nnUNet_raw, dataset_name), num_processes) + + fpe = fingerprint_extractor_class(dataset_id, num_processes, verbose=verbose) + return fpe.run(overwrite_existing=clean) + + +def extract_fingerprints(dataset_ids: List[int], fingerprint_extractor_class_name: str = 'DatasetFingerprintExtractor', + num_processes: int = default_num_processes, check_dataset_integrity: bool = False, + clean: bool = True, verbose: bool = True): + """ + clean = False will not actually run this. This is just a switch for use with nnUNetv2_plan_and_preprocess where + we don't want to rerun fingerprint extraction every time. + """ + fingerprint_extractor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), + fingerprint_extractor_class_name, + current_module="nnunetv2.experiment_planning") + for d in dataset_ids: + extract_fingerprint_dataset(d, fingerprint_extractor_class, num_processes, check_dataset_integrity, clean, + verbose) + + +def plan_experiment_dataset(dataset_id: int, + experiment_planner_class: Type[ExperimentPlanner] = ExperimentPlanner, + gpu_memory_target_in_gb: float = 8, preprocess_class_name: str = 'DefaultPreprocessor', + overwrite_target_spacing: Optional[Tuple[float, ...]] = None, + overwrite_plans_name: Optional[str] = None) -> dict: + """ + overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres! + """ + kwargs = {} + if overwrite_plans_name is not None: + kwargs['plans_name'] = overwrite_plans_name + return experiment_planner_class(dataset_id, + gpu_memory_target_in_gb=gpu_memory_target_in_gb, + preprocessor_name=preprocess_class_name, + overwrite_target_spacing=[float(i) for i in overwrite_target_spacing] if + overwrite_target_spacing is not None else overwrite_target_spacing, + suppress_transpose=False, # might expose this later, + **kwargs + ).plan_experiment() + + +def plan_experiments(dataset_ids: List[int], experiment_planner_class_name: str = 'ExperimentPlanner', + gpu_memory_target_in_gb: float = 8, preprocess_class_name: str = 'DefaultPreprocessor', + overwrite_target_spacing: Optional[Tuple[float, ...]] = None, + overwrite_plans_name: Optional[str] = None): + """ + overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres! + """ + experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), + experiment_planner_class_name, + current_module="nnunetv2.experiment_planning") + for d in dataset_ids: + plan_experiment_dataset(d, experiment_planner, gpu_memory_target_in_gb, preprocess_class_name, + overwrite_target_spacing, overwrite_plans_name) + + +def preprocess_dataset(dataset_id: int, + plans_identifier: str = 'nnUNetPlans', + configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'), + num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8), + verbose: bool = False) -> None: + if not isinstance(num_processes, list): + num_processes = list(num_processes) + if len(num_processes) == 1: + num_processes = num_processes * len(configurations) + if len(num_processes) != len(configurations): + raise RuntimeError( + f'The list provided with num_processes must either have len 1 or as many elements as there are ' + f'configurations (see --help). Number of configurations: {len(configurations)}, length ' + f'of num_processes: ' + f'{len(num_processes)}') + + dataset_name = convert_id_to_dataset_name(dataset_id) + print(f'Preprocessing dataset {dataset_name}') + plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json') + plans_manager = PlansManager(plans_file) + for n, c in zip(num_processes, configurations): + print(f'Configuration: {c}...') + if c not in plans_manager.available_configurations: + print( + f"INFO: Configuration {c} not found in plans file {plans_identifier + '.json'} of " + f"dataset {dataset_name}. Skipping.") + continue + configuration_manager = plans_manager.get_configuration(c) + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + preprocessor.run(dataset_id, c, plans_identifier, num_processes=n) + + # copy the gt to a folder in the nnUNet_preprocessed so that we can do validation even if the raw data is no + # longer there (useful for compute cluster where only the preprocessed data is available) + from distutils.file_util import copy_file + maybe_mkdir_p(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations')) + dataset_json = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) + dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json) + # only copy files that are newer than the ones already present + for k in dataset: + copy_file(dataset[k]['label'], + join(nnUNet_preprocessed, dataset_name, 'gt_segmentations', k + dataset_json['file_ending']), + update=True) + + + +def preprocess(dataset_ids: List[int], + plans_identifier: str = 'nnUNetPlans', + configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'), + num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8), + verbose: bool = False): + for d in dataset_ids: + preprocess_dataset(d, plans_identifier, configurations, num_processes, verbose) diff --git a/docker/template/src/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py b/docker/template/src/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py new file mode 100644 index 0000000..556f04a --- /dev/null +++ b/docker/template/src/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py @@ -0,0 +1,201 @@ +from nnunetv2.configuration import default_num_processes +from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, plan_experiments, preprocess + + +def extract_fingerprint_entry(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-d', nargs='+', type=int, + help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " + "planning and preprocessing for these datasets. Can of course also be just one dataset") + parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor', + help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is ' + '\'DatasetFingerprintExtractor\'.') + parser.add_argument('-np', type=int, default=default_num_processes, required=False, + help=f'[OPTIONAL] Number of processes used for fingerprint extraction. ' + f'Default: {default_num_processes}') + parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true", + help="[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for " + "each dataset!") + parser.add_argument("--clean", required=False, default=False, action="store_true", + help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a ' + 'fingerprint already exists, the fingerprint extractor will not run.') + parser.add_argument('--verbose', required=False, action='store_true', + help='Set this to print a lot of stuff. Useful for debugging. Will disable progress bar! ' + 'Recommended for cluster environments') + args, unrecognized_args = parser.parse_known_args() + extract_fingerprints(args.d, args.fpe, args.np, args.verify_dataset_integrity, args.clean, args.verbose) + + +def plan_experiment_entry(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-d', nargs='+', type=int, + help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " + "planning and preprocessing for these datasets. Can of course also be just one dataset") + parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False, + help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is ' + '\'ExperimentPlanner\'. Note: There is no longer a distinction between 2d and 3d planner. ' + 'It\'s an all in one solution now. Wuch. Such amazing.') + parser.add_argument('-gpu_memory_target', default=8, type=float, required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target. Default: 8 [GB]. Changing this will ' + 'affect patch and batch size and will ' + 'definitely affect your models performance! Only use this if you really know what you ' + 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).') + parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in ' + 'nnunetv2.preprocessing. Default: \'DefaultPreprocessor\'. Changing this may affect your ' + 'models performance! Only use this if you really know what you ' + 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).') + parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres ' + 'configurations. Default: None [no changes]. Changing this will affect image size and ' + 'potentially patch and batch ' + 'size. This will definitely affect your models performance! Only use this if you really ' + 'know what you are doing and NEVER use this without running the default nnU-Net first ' + '(as a baseline). Changing the target spacing for the other configurations is currently ' + 'not implemented. New target spacing must be a list of three numbers!') + parser.add_argument('-overwrite_plans_name', default=None, required=False, + help='[OPTIONAL] DANGER ZONE! If you used -gpu_memory_target, -preprocessor_name or ' + '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a ' + 'differently named plans file such that the nnunet default plans are not ' + 'overwritten. You will then need to specify your custom plans file with -p whenever ' + 'running other nnunet commands (training, inference etc)') + args, unrecognized_args = parser.parse_known_args() + plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, + args.overwrite_plans_name) + + +def preprocess_entry(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-d', nargs='+', type=int, + help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " + "planning and preprocessing for these datasets. Can of course also be just one dataset") + parser.add_argument('-plans_name', default='nnUNetPlans', required=False, + help='[OPTIONAL] You can use this to specify a custom plans file that you may have generated') + parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+', + help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3d_fullres ' + '3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data ' + 'from 3d_fullres. Configurations that do not exist for some dataset will be skipped.') + parser.add_argument('-np', type=int, nargs='+', default=[8, 4, 8], required=False, + help="[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then " + "this number of processes is used for all configurations specified with -c. If it's a " + "list of numbers this list must have as many elements as there are configurations. We " + "then iterate over zip(configs, num_processes) to determine then umber of processes " + "used for each configuration. More processes is always faster (up to the number of " + "threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't " + "know what that is then dont touch it, or at least don't increase it!). DANGER: More " + "often than not the number of processes that can be used is limited by the amount of " + "RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND " + "DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 " + "for 3d_fullres, 8 for 3d_lowres and 4 for everything else") + parser.add_argument('--verbose', required=False, action='store_true', + help='Set this to print a lot of stuff. Useful for debugging. Will disable progress bar! ' + 'Recommended for cluster environments') + args, unrecognized_args = parser.parse_known_args() + if args.np is None: + default_np = { + '2d': 4, + '3d_lowres': 8, + '3d_fullres': 4 + } + np = {default_np[c] if c in default_np.keys() else 4 for c in args.c} + else: + np = args.np + preprocess(args.d, args.plans_name, configurations=args.c, num_processes=np, verbose=args.verbose) + + +def plan_and_preprocess_entry(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-d', nargs='+', type=int, + help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " + "planning and preprocessing for these datasets. Can of course also be just one dataset") + parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor', + help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is ' + '\'DatasetFingerprintExtractor\'.') + parser.add_argument('-npfp', type=int, default=8, required=False, + help='[OPTIONAL] Number of processes used for fingerprint extraction. Default: 8') + parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true", + help="[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for " + "each dataset!") + parser.add_argument('--no_pp', default=False, action='store_true', required=False, + help='[OPTIONAL] Set this to only run fingerprint extraction and experiment planning (no ' + 'preprocesing). Useful for debugging.') + parser.add_argument("--clean", required=False, default=False, action="store_true", + help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a ' + 'fingerprint already exists, the fingerprint extractor will not run. REQUIRED IF YOU ' + 'CHANGE THE DATASET FINGERPRINT EXTRACTOR OR MAKE CHANGES TO THE DATASET!') + parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False, + help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is ' + '\'ExperimentPlanner\'. Note: There is no longer a distinction between 2d and 3d planner. ' + 'It\'s an all in one solution now. Wuch. Such amazing.') + parser.add_argument('-gpu_memory_target', default=8, type=int, required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target. Default: 8 [GB]. Changing this will ' + 'affect patch and batch size and will ' + 'definitely affect your models performance! Only use this if you really know what you ' + 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).') + parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in ' + 'nnunetv2.preprocessing. Default: \'DefaultPreprocessor\'. Changing this may affect your ' + 'models performance! Only use this if you really know what you ' + 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).') + parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres ' + 'configurations. Default: None [no changes]. Changing this will affect image size and ' + 'potentially patch and batch ' + 'size. This will definitely affect your models performance! Only use this if you really ' + 'know what you are doing and NEVER use this without running the default nnU-Net first ' + '(as a baseline). Changing the target spacing for the other configurations is currently ' + 'not implemented. New target spacing must be a list of three numbers!') + parser.add_argument('-overwrite_plans_name', default='nnUNetPlans', required=False, + help='[OPTIONAL] uSE A CUSTOM PLANS IDENTIFIER. If you used -gpu_memory_target, ' + '-preprocessor_name or ' + '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a ' + 'differently named plans file such that the nnunet default plans are not ' + 'overwritten. You will then need to specify your custom plans file with -p whenever ' + 'running other nnunet commands (training, inference etc)') + parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+', + help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3d_fullres ' + '3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data ' + 'from 3d_fullres. Configurations that do not exist for some dataset will be skipped.') + parser.add_argument('-np', type=int, nargs='+', default=None, required=False, + help="[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then " + "this number of processes is used for all configurations specified with -c. If it's a " + "list of numbers this list must have as many elements as there are configurations. We " + "then iterate over zip(configs, num_processes) to determine then umber of processes " + "used for each configuration. More processes is always faster (up to the number of " + "threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't " + "know what that is then dont touch it, or at least don't increase it!). DANGER: More " + "often than not the number of processes that can be used is limited by the amount of " + "RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND " + "DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 " + "for 3d_fullres, 8 for 3d_lowres and 4 for everything else") + parser.add_argument('--verbose', required=False, action='store_true', + help='Set this to print a lot of stuff. Useful for debugging. Will disable progress bar! ' + 'Recommended for cluster environments') + args = parser.parse_args() + + # fingerprint extraction + print("Fingerprint extraction...") + extract_fingerprints(args.d, args.fpe, args.npfp, args.verify_dataset_integrity, args.clean, args.verbose) + + # experiment planning + print('Experiment planning...') + plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, args.overwrite_plans_name) + + # manage default np + if args.np is None: + default_np = {"2d": 8, "3d_fullres": 4, "3d_lowres": 8} + np = [default_np[c] if c in default_np.keys() else 4 for c in args.c] + else: + np = args.np + # preprocessing + if not args.no_pp: + print('Preprocessing...') + preprocess(args.d, args.overwrite_plans_name, args.c, np, args.verbose) + + +if __name__ == '__main__': + plan_and_preprocess_entry() diff --git a/docker/template/src/nnunetv2/experiment_planning/plans_for_pretraining/__init__.py b/docker/template/src/nnunetv2/experiment_planning/plans_for_pretraining/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py b/docker/template/src/nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py new file mode 100644 index 0000000..7219ddc --- /dev/null +++ b/docker/template/src/nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py @@ -0,0 +1,82 @@ +import argparse +from typing import Union + +from batchgenerators.utilities.file_and_folder_operations import join, isdir, isfile, load_json, save_json + +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw +from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + + +def move_plans_between_datasets( + source_dataset_name_or_id: Union[int, str], + target_dataset_name_or_id: Union[int, str], + source_plans_identifier: str, + target_plans_identifier: str = None): + source_dataset_name = maybe_convert_to_dataset_name(source_dataset_name_or_id) + target_dataset_name = maybe_convert_to_dataset_name(target_dataset_name_or_id) + + if target_plans_identifier is None: + target_plans_identifier = source_plans_identifier + + source_folder = join(nnUNet_preprocessed, source_dataset_name) + assert isdir(source_folder), f"Cannot move plans because preprocessed directory of source dataset is missing. " \ + f"Run nnUNetv2_plan_and_preprocess for source dataset first!" + + source_plans_file = join(source_folder, source_plans_identifier + '.json') + assert isfile(source_plans_file), f"Source plans are missing. Run the corresponding experiment planning first! " \ + f"Expected file: {source_plans_file}" + + source_plans = load_json(source_plans_file) + source_plans['dataset_name'] = target_dataset_name + + # we need to change data_identifier to use target_plans_identifier + if target_plans_identifier != source_plans_identifier: + for c in source_plans['configurations'].keys(): + if 'data_identifier' in source_plans['configurations'][c].keys(): + old_identifier = source_plans['configurations'][c]["data_identifier"] + if old_identifier.startswith(source_plans_identifier): + new_identifier = target_plans_identifier + old_identifier[len(source_plans_identifier):] + else: + new_identifier = target_plans_identifier + '_' + old_identifier + source_plans['configurations'][c]["data_identifier"] = new_identifier + + # we need to change the reader writer class! + target_raw_data_dir = join(nnUNet_raw, target_dataset_name) + target_dataset_json = load_json(join(target_raw_data_dir, 'dataset.json')) + + # we may need to change the reader/writer + # pick any file from the source dataset + dataset = get_filenames_of_train_images_and_targets(target_raw_data_dir, target_dataset_json) + example_image = dataset[dataset.keys().__iter__().__next__()]['images'][0] + rw = determine_reader_writer_from_dataset_json(target_dataset_json, example_image, allow_nonmatching_filename=True, + verbose=False) + + source_plans["image_reader_writer"] = rw.__name__ + if target_plans_identifier is not None: + source_plans["plans_name"] = target_plans_identifier + + save_json(source_plans, join(nnUNet_preprocessed, target_dataset_name, target_plans_identifier + '.json'), + sort_keys=False) + + +def entry_point_move_plans_between_datasets(): + parser = argparse.ArgumentParser() + parser.add_argument('-s', type=str, required=True, + help='Source dataset name or id') + parser.add_argument('-t', type=str, required=True, + help='Target dataset name or id') + parser.add_argument('-sp', type=str, required=True, + help='Source plans identifier. If your plans are named "nnUNetPlans.json" then the ' + 'identifier would be nnUNetPlans') + parser.add_argument('-tp', type=str, required=False, default=None, + help='Target plans identifier. Default is None meaning the source plans identifier will ' + 'be kept. Not recommended if the source plans identifier is a default nnU-Net identifier ' + 'such as nnUNetPlans!!!') + args = parser.parse_args() + move_plans_between_datasets(args.s, args.t, args.sp, args.tp) + + +if __name__ == '__main__': + move_plans_between_datasets(2, 4, 'nnUNetPlans', 'nnUNetPlansFrom2') diff --git a/docker/template/src/nnunetv2/experiment_planning/verify_dataset_integrity.py b/docker/template/src/nnunetv2/experiment_planning/verify_dataset_integrity.py new file mode 100644 index 0000000..8d646a2 --- /dev/null +++ b/docker/template/src/nnunetv2/experiment_planning/verify_dataset_integrity.py @@ -0,0 +1,231 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +from typing import Type + +import numpy as np +import pandas as pd +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.label_handling.label_handling import LabelManager +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + + +def verify_labels(label_file: str, readerclass: Type[BaseReaderWriter], expected_labels: List[int]) -> bool: + rw = readerclass() + seg, properties = rw.read_seg(label_file) + found_labels = np.sort(pd.unique(seg.ravel())) # np.unique(seg) + unexpected_labels = [i for i in found_labels if i not in expected_labels] + if len(found_labels) == 0 and found_labels[0] == 0: + print('WARNING: File %s only has label 0 (which should be background). This may be intentional or not, ' + 'up to you.' % label_file) + if len(unexpected_labels) > 0: + print("Error: Unexpected labels found in file %s.\nExpected: %s\nFound: %s" % (label_file, expected_labels, + found_labels)) + return False + return True + + +def check_cases(image_files: List[str], label_file: str, expected_num_channels: int, + readerclass: Type[BaseReaderWriter]) -> bool: + rw = readerclass() + ret = True + + images, properties_image = rw.read_images(image_files) + segmentation, properties_seg = rw.read_seg(label_file) + + # check for nans + if np.any(np.isnan(images)): + print(f'Images contain NaN pixel values. You need to fix that by ' + f'replacing NaN values with something that makes sense for your images!\nImages:\n{image_files}') + ret = False + if np.any(np.isnan(segmentation)): + print(f'Segmentation contains NaN pixel values. You need to fix that.\nSegmentation:\n{label_file}') + ret = False + + # check shapes + shape_image = images.shape[1:] + shape_seg = segmentation.shape[1:] + if shape_image != shape_seg: + print('Error: Shape mismatch between segmentation and corresponding images. \nShape images: %s. ' + '\nShape seg: %s. \nImage files: %s. \nSeg file: %s\n' % + (shape_image, shape_seg, image_files, label_file)) + ret = False + + # check spacings + spacing_images = properties_image['spacing'] + spacing_seg = properties_seg['spacing'] + if not np.allclose(spacing_seg, spacing_images): + print('Error: Spacing mismatch between segmentation and corresponding images. \nSpacing images: %s. ' + '\nSpacing seg: %s. \nImage files: %s. \nSeg file: %s\n' % + (shape_image, shape_seg, image_files, label_file)) + ret = False + + # check modalities + if not len(images) == expected_num_channels: + print('Error: Unexpected number of modalities. \nExpected: %d. \nGot: %d. \nImages: %s\n' + % (expected_num_channels, len(images), image_files)) + ret = False + + # nibabel checks + if 'nibabel_stuff' in properties_image.keys(): + # this image was read with NibabelIO + affine_image = properties_image['nibabel_stuff']['original_affine'] + affine_seg = properties_seg['nibabel_stuff']['original_affine'] + if not np.allclose(affine_image, affine_seg): + print('WARNING: Affine is not the same for image and seg! \nAffine image: %s \nAffine seg: %s\n' + 'Image files: %s. \nSeg file: %s.\nThis can be a problem but doesn\'t have to be. Please run ' + 'nnUNet_plot_dataset_pngs to verify if everything is OK!\n' + % (affine_image, affine_seg, image_files, label_file)) + + # sitk checks + if 'sitk_stuff' in properties_image.keys(): + # this image was read with SimpleITKIO + # spacing has already been checked, only check direction and origin + origin_image = properties_image['sitk_stuff']['origin'] + origin_seg = properties_seg['sitk_stuff']['origin'] + if not np.allclose(origin_image, origin_seg): + print('Warning: Origin mismatch between segmentation and corresponding images. \nOrigin images: %s. ' + '\nOrigin seg: %s. \nImage files: %s. \nSeg file: %s\n' % + (origin_image, origin_seg, image_files, label_file)) + direction_image = properties_image['sitk_stuff']['direction'] + direction_seg = properties_seg['sitk_stuff']['direction'] + if not np.allclose(direction_image, direction_seg): + print('Warning: Direction mismatch between segmentation and corresponding images. \nDirection images: %s. ' + '\nDirection seg: %s. \nImage files: %s. \nSeg file: %s\n' % + (direction_image, direction_seg, image_files, label_file)) + + return ret + + +def verify_dataset_integrity(folder: str, num_processes: int = 8) -> None: + """ + folder needs the imagesTr, imagesTs and labelsTr subfolders. There also needs to be a dataset.json + checks if the expected number of training cases and labels are present + for each case, if possible, checks whether the pixel grids are aligned + checks whether the labels really only contain values they should + :param folder: + :return: + """ + assert isfile(join(folder, "dataset.json")), f"There needs to be a dataset.json file in folder, folder={folder}" + dataset_json = load_json(join(folder, "dataset.json")) + + if not 'dataset' in dataset_json.keys(): + assert isdir(join(folder, "imagesTr")), f"There needs to be a imagesTr subfolder in folder, folder={folder}" + assert isdir(join(folder, "labelsTr")), f"There needs to be a labelsTr subfolder in folder, folder={folder}" + + # make sure all required keys are there + dataset_keys = list(dataset_json.keys()) + required_keys = ['labels', "channel_names", "numTraining", "file_ending"] + assert all([i in dataset_keys for i in required_keys]), 'not all required keys are present in dataset.json.' \ + '\n\nRequired: \n%s\n\nPresent: \n%s\n\nMissing: ' \ + '\n%s\n\nUnused by nnU-Net:\n%s' % \ + (str(required_keys), + str(dataset_keys), + str([i for i in required_keys if i not in dataset_keys]), + str([i for i in dataset_keys if i not in required_keys])) + + expected_num_training = dataset_json['numTraining'] + num_modalities = len(dataset_json['channel_names'].keys() + if 'channel_names' in dataset_json.keys() + else dataset_json['modality'].keys()) + file_ending = dataset_json['file_ending'] + + dataset = get_filenames_of_train_images_and_targets(folder, dataset_json) + + # check if the right number of training cases is present + assert len(dataset) == expected_num_training, 'Did not find the expected number of training cases ' \ + '(%d). Found %d instead.\nExamples: %s' % \ + (expected_num_training, len(dataset), + list(dataset.keys())[:5]) + + # check if corresponding labels are present + if 'dataset' in dataset_json.keys(): + # just check if everything is there + ok = True + missing_images = [] + missing_labels = [] + for k in dataset: + for i in dataset[k]['images']: + if not isfile(i): + missing_images.append(i) + ok = False + if not isfile(dataset[k]['label']): + missing_labels.append(dataset[k]['label']) + ok = False + if not ok: + raise FileNotFoundError(f"Some expected files were missing. Make sure you are properly referencing them " + f"in the dataset.json. Or use imagesTr & labelsTr folders!\nMissing images:" + f"\n{missing_images}\n\nMissing labels:\n{missing_labels}") + else: + # old code that uses imagestr and labelstr folders + labelfiles = subfiles(join(folder, 'labelsTr'), suffix=file_ending, join=False) + label_identifiers = [i[:-len(file_ending)] for i in labelfiles] + labels_present = [i in label_identifiers for i in dataset.keys()] + missing = [i for j, i in enumerate(dataset.keys()) if not labels_present[j]] + assert all(labels_present), f'not all training cases have a label file in labelsTr. Fix that. Missing: {missing}' + + labelfiles = [v['label'] for v in dataset.values()] + image_files = [v['images'] for v in dataset.values()] + + # no plans exist yet, so we can't use PlansManager and gotta roll with the default. It's unlikely to cause + # problems anyway + label_manager = LabelManager(dataset_json['labels'], regions_class_order=dataset_json.get('regions_class_order')) + expected_labels = label_manager.all_labels + if label_manager.has_ignore_label: + expected_labels.append(label_manager.ignore_label) + labels_valid_consecutive = np.ediff1d(expected_labels) == 1 + assert all( + labels_valid_consecutive), f'Labels must be in consecutive order (0, 1, 2, ...). The labels {np.array(expected_labels)[1:][~labels_valid_consecutive]} do not satisfy this restriction' + + # determine reader/writer class + reader_writer_class = determine_reader_writer_from_dataset_json(dataset_json, dataset[dataset.keys().__iter__().__next__()]['images'][0]) + + # check whether only the desired labels are present + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + result = p.starmap( + verify_labels, + zip([join(folder, 'labelsTr', i) for i in labelfiles], [reader_writer_class] * len(labelfiles), + [expected_labels] * len(labelfiles)) + ) + if not all(result): + raise RuntimeError( + 'Some segmentation images contained unexpected labels. Please check text output above to see which one(s).') + + # check whether shapes and spacings match between images and labels + result = p.starmap( + check_cases, + zip(image_files, labelfiles, [num_modalities] * expected_num_training, + [reader_writer_class] * expected_num_training) + ) + if not all(result): + raise RuntimeError( + 'Some images have errors. Please check text output above to see which one(s) and what\'s going on.') + + # check for nans + # check all same orientation nibabel + print('\n####################') + print('verify_dataset_integrity Done. \nIf you didn\'t see any error messages then your dataset is most likely OK!') + print('####################\n') + + +if __name__ == "__main__": + # investigate geometry issues + example_folder = join(nnUNet_raw, 'Dataset250_COMPUTING_it0') + num_processes = 6 + verify_dataset_integrity(example_folder, num_processes) diff --git a/docker/template/src/nnunetv2/imageio/__init__.py b/docker/template/src/nnunetv2/imageio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/imageio/base_reader_writer.py b/docker/template/src/nnunetv2/imageio/base_reader_writer.py new file mode 100644 index 0000000..2847478 --- /dev/null +++ b/docker/template/src/nnunetv2/imageio/base_reader_writer.py @@ -0,0 +1,107 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Tuple, Union, List +import numpy as np + + +class BaseReaderWriter(ABC): + @staticmethod + def _check_all_same(input_list): + # compare all entries to the first + for i in input_list[1:]: + if i != input_list[0]: + return False + return True + + @staticmethod + def _check_all_same_array(input_list): + # compare all entries to the first + for i in input_list[1:]: + if i.shape != input_list[0].shape or not np.allclose(i, input_list[0]): + return False + return True + + @abstractmethod + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + """ + Reads a sequence of images and returns a 4d (!) np.ndarray along with a dictionary. The 4d array must have the + modalities (or color channels, or however you would like to call them) in its first axis, followed by the + spatial dimensions (so shape must be c,x,y,z where c is the number of modalities (can be 1)). + Use the dictionary to store necessary meta information that is lost when converting to numpy arrays, for + example the Spacing, Orientation and Direction of the image. This dictionary will be handed over to write_seg + for exporting the predicted segmentations, so make sure you have everything you need in there! + + IMPORTANT: dict MUST have a 'spacing' key with a tuple/list of length 3 with the voxel spacing of the np.ndarray. + Example: my_dict = {'spacing': (3, 0.5, 0.5), ...}. This is needed for planning and + preprocessing. The ordering of the numbers must correspond to the axis ordering in the returned numpy array. So + if the array has shape c,x,y,z and the spacing is (a,b,c) then a must be the spacing of x, b the spacing of y + and c the spacing of z. + + In the case of 2D images, the returned array should have shape (c, 1, x, y) and the spacing should be + (999, sp_x, sp_y). Make sure 999 is larger than sp_x and sp_y! Example: shape=(3, 1, 224, 224), + spacing=(999, 1, 1) + + For images that don't have a spacing, set the spacing to 1 (2d exception with 999 for the first axis still applies!) + + :param image_fnames: + :return: + 1) a np.ndarray of shape (c, x, y, z) where c is the number of image channels (can be 1) and x, y, z are + the spatial dimensions (set x=1 for 2D! Example: (3, 1, 224, 224) for RGB image). + 2) a dictionary with metadata. This can be anything. BUT it HAS to include a {'spacing': (a, b, c)} where a + is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set + a=999 (largest spacing value! Make it larger than b and c) + + """ + pass + + @abstractmethod + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + """ + Same requirements as BaseReaderWriter.read_image. Returned segmentations must have shape 1,x,y,z. Multiple + segmentations are not (yet?) allowed + + If images and segmentations can be read the same way you can just `return self.read_image((image_fname,))` + :param seg_fname: + :return: + 1) a np.ndarray of shape (1, x, y, z) where x, y, z are + the spatial dimensions (set x=1 for 2D! Example: (1, 1, 224, 224) for 2D segmentation). + 2) a dictionary with metadata. This can be anything. BUT it HAS to include a {'spacing': (a, b, c)} where a + is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set + a=999 (largest spacing value! Make it larger than b and c) + """ + pass + + @abstractmethod + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + """ + Export the predicted segmentation to the desired file format. The given seg array will have the same shape and + orientation as the corresponding image data, so you don't need to do any resampling or whatever. Just save :-) + + properties is the same dictionary you created during read_images/read_seg so you can use the information here + to restore metadata + + IMPORTANT: Segmentations are always 3D! If your input images were 2d then the segmentation will have shape + 1,x,y. You need to catch that and export accordingly (for 2d images you need to convert the 3d segmentation + to 2d via seg = seg[0])! + + :param seg: A segmentation (np.ndarray, integer) of shape (x, y, z). For 2D segmentations this will be (1, y, z)! + :param output_fname: + :param properties: the dictionary that you created in read_images (the ones this segmentation is based on). + Use this to restore metadata + :return: + """ + pass \ No newline at end of file diff --git a/docker/template/src/nnunetv2/imageio/natural_image_reader_writer.py b/docker/template/src/nnunetv2/imageio/natural_image_reader_writer.py new file mode 100644 index 0000000..11946c3 --- /dev/null +++ b/docker/template/src/nnunetv2/imageio/natural_image_reader_writer.py @@ -0,0 +1,73 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union, List +import numpy as np +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from skimage import io + + +class NaturalImage2DIO(BaseReaderWriter): + """ + ONLY SUPPORTS 2D IMAGES!!! + """ + + # there are surely more we could add here. Everything that can be read by skimage.io should be supported + supported_file_endings = [ + '.png', + # '.jpg', + # '.jpeg', # jpg not supported because we cannot allow lossy compression! segmentation maps! + '.bmp', + '.tif' + ] + + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + images = [] + for f in image_fnames: + npy_img = io.imread(f) + if npy_img.ndim == 3: + # rgb image, last dimension should be the color channel and the size of that channel should be 3 + # (or 4 if we have alpha) + assert npy_img.shape[-1] == 3 or npy_img.shape[-1] == 4, "If image has three dimensions then the last " \ + "dimension must have shape 3 or 4 " \ + f"(RGB or RGBA). Image shape here is {npy_img.shape}" + # move RGB(A) to front, add additional dim so that we have shape (1, c, X, Y), where c is either 3 or 4 + images.append(npy_img.transpose((2, 0, 1))[:, None]) + elif npy_img.ndim == 2: + # grayscale image + images.append(npy_img[None, None]) + + if not self._check_all_same([i.shape for i in images]): + print('ERROR! Not all input images have the same shape!') + print('Shapes:') + print([i.shape for i in images]) + print('Image files:') + print(image_fnames) + raise RuntimeError() + return np.vstack(images).astype(np.float32), {'spacing': (999, 1, 1)} + + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + return self.read_images((seg_fname, )) + + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + io.imsave(output_fname, seg[0].astype(np.uint8), check_contrast=False) + + +if __name__ == '__main__': + images = ('/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/imagesTr/img-11_0000.png',) + segmentation = '/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/labelsTr/img-11.png' + imgio = NaturalImage2DIO() + img, props = imgio.read_images(images) + seg, segprops = imgio.read_seg(segmentation) \ No newline at end of file diff --git a/docker/template/src/nnunetv2/imageio/nibabel_reader_writer.py b/docker/template/src/nnunetv2/imageio/nibabel_reader_writer.py new file mode 100644 index 0000000..8faafb7 --- /dev/null +++ b/docker/template/src/nnunetv2/imageio/nibabel_reader_writer.py @@ -0,0 +1,204 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union, List +import numpy as np +from nibabel import io_orientation + +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +import nibabel + + +class NibabelIO(BaseReaderWriter): + """ + Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be + consistent. This is of course considered properly in segmentation export as well. + + IMPORTANT: Run nnUNet_plot_dataset_pngs to verify that this did not destroy the alignment of data and seg! + """ + supported_file_endings = [ + '.nii.gz', + '.nrrd', + '.mha' + ] + + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + images = [] + original_affines = [] + + spacings_for_nnunet = [] + for f in image_fnames: + nib_image = nibabel.load(f) + assert nib_image.ndim == 3, 'only 3d images are supported by NibabelIO' + original_affine = nib_image.affine + + original_affines.append(original_affine) + + # spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...) + spacings_for_nnunet.append( + [float(i) for i in nib_image.header.get_zooms()[::-1]] + ) + + # transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying. + images.append(nib_image.get_fdata().transpose((2, 1, 0))[None]) + + if not self._check_all_same([i.shape for i in images]): + print('ERROR! Not all input images have the same shape!') + print('Shapes:') + print([i.shape for i in images]) + print('Image files:') + print(image_fnames) + raise RuntimeError() + if not self._check_all_same_array(original_affines): + print('WARNING! Not all input images have the same original_affines!') + print('Affines:') + print(original_affines) + print('Image files:') + print(image_fnames) + print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' + 'that segmentations and data overlap.') + if not self._check_all_same(spacings_for_nnunet): + print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not ' + 'having the same affine') + print('spacings_for_nnunet:') + print(spacings_for_nnunet) + print('Image files:') + print(image_fnames) + raise RuntimeError() + + stacked_images = np.vstack(images) + dict = { + 'nibabel_stuff': { + 'original_affine': original_affines[0], + }, + 'spacing': spacings_for_nnunet[0] + } + return stacked_images.astype(np.float32), dict + + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + return self.read_images((seg_fname, )) + + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + # revert transpose + seg = seg.transpose((2, 1, 0)).astype(np.uint8) + seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['original_affine']) + nibabel.save(seg_nib, output_fname) + + +class NibabelIOWithReorient(BaseReaderWriter): + """ + Reorients images to RAS + + Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be + consistent. This is of course considered properly in segmentation export as well. + + IMPORTANT: Run nnUNet_plot_dataset_pngs to verify that this did not destroy the alignment of data and seg! + """ + supported_file_endings = [ + '.nii.gz', + '.nrrd', + '.mha' + ] + + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + images = [] + original_affines = [] + reoriented_affines = [] + + spacings_for_nnunet = [] + for f in image_fnames: + nib_image = nibabel.load(f) + assert nib_image.ndim == 3, 'only 3d images are supported by NibabelIO' + original_affine = nib_image.affine + reoriented_image = nib_image.as_reoriented(io_orientation(original_affine)) + reoriented_affine = reoriented_image.affine + + original_affines.append(original_affine) + reoriented_affines.append(reoriented_affine) + + # spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...) + spacings_for_nnunet.append( + [float(i) for i in reoriented_image.header.get_zooms()[::-1]] + ) + + # transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying. + images.append(reoriented_image.get_fdata().transpose((2, 1, 0))[None]) + + if not self._check_all_same([i.shape for i in images]): + print('ERROR! Not all input images have the same shape!') + print('Shapes:') + print([i.shape for i in images]) + print('Image files:') + print(image_fnames) + raise RuntimeError() + if not self._check_all_same_array(reoriented_affines): + print('WARNING! Not all input images have the same reoriented_affines!') + print('Affines:') + print(reoriented_affines) + print('Image files:') + print(image_fnames) + print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' + 'that segmentations and data overlap.') + if not self._check_all_same(spacings_for_nnunet): + print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not ' + 'having the same affine') + print('spacings_for_nnunet:') + print(spacings_for_nnunet) + print('Image files:') + print(image_fnames) + raise RuntimeError() + + stacked_images = np.vstack(images) + dict = { + 'nibabel_stuff': { + 'original_affine': original_affines[0], + 'reoriented_affine': reoriented_affines[0], + }, + 'spacing': spacings_for_nnunet[0] + } + return stacked_images.astype(np.float32), dict + + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + return self.read_images((seg_fname, )) + + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + # revert transpose + seg = seg.transpose((2, 1, 0)).astype(np.uint8) + + seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['reoriented_affine']) + seg_nib_reoriented = seg_nib.as_reoriented(io_orientation(properties['nibabel_stuff']['original_affine'])) + assert np.allclose(properties['nibabel_stuff']['original_affine'], seg_nib_reoriented.affine), \ + 'restored affine does not match original affine' + nibabel.save(seg_nib_reoriented, output_fname) + + +if __name__ == '__main__': + img_file = 'patient028_frame01_0000.nii.gz' + seg_file = 'patient028_frame01.nii.gz' + + nibio = NibabelIO() + images, dct = nibio.read_images([img_file]) + seg, dctseg = nibio.read_seg(seg_file) + + nibio_r = NibabelIOWithReorient() + images_r, dct_r = nibio_r.read_images([img_file]) + seg_r, dctseg_r = nibio_r.read_seg(seg_file) + + nibio.write_seg(seg[0], '/home/isensee/seg_nibio.nii.gz', dctseg) + nibio_r.write_seg(seg_r[0], '/home/isensee/seg_nibio_r.nii.gz', dctseg_r) + + s_orig = nibabel.load(seg_file).get_fdata() + s_nibio = nibabel.load('/home/isensee/seg_nibio.nii.gz').get_fdata() + s_nibio_r = nibabel.load('/home/isensee/seg_nibio_r.nii.gz').get_fdata() diff --git a/docker/template/src/nnunetv2/imageio/reader_writer_registry.py b/docker/template/src/nnunetv2/imageio/reader_writer_registry.py new file mode 100644 index 0000000..606334c --- /dev/null +++ b/docker/template/src/nnunetv2/imageio/reader_writer_registry.py @@ -0,0 +1,79 @@ +import traceback +from typing import Type + +from batchgenerators.utilities.file_and_folder_operations import join + +import nnunetv2 +from nnunetv2.imageio.natural_image_reader_writer import NaturalImage2DIO +from nnunetv2.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient +from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO +from nnunetv2.imageio.tif_reader_writer import Tiff3DIO +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class + +LIST_OF_IO_CLASSES = [ + NaturalImage2DIO, + SimpleITKIO, + Tiff3DIO, + NibabelIO, + NibabelIOWithReorient +] + + +def determine_reader_writer_from_dataset_json(dataset_json_content: dict, example_file: str = None, + allow_nonmatching_filename: bool = False, verbose: bool = True + ) -> Type[BaseReaderWriter]: + if 'overwrite_image_reader_writer' in dataset_json_content.keys() and \ + dataset_json_content['overwrite_image_reader_writer'] != 'None': + ioclass_name = dataset_json_content['overwrite_image_reader_writer'] + # trying to find that class in the nnunetv2.imageio module + try: + ret = recursive_find_reader_writer_by_name(ioclass_name) + if verbose: print(f'Using {ret} reader/writer') + return ret + except RuntimeError: + if verbose: print(f'Warning: Unable to find ioclass specified in dataset.json: {ioclass_name}') + if verbose: print('Trying to automatically determine desired class') + return determine_reader_writer_from_file_ending(dataset_json_content['file_ending'], example_file, + allow_nonmatching_filename, verbose) + + +def determine_reader_writer_from_file_ending(file_ending: str, example_file: str = None, allow_nonmatching_filename: bool = False, + verbose: bool = True): + for rw in LIST_OF_IO_CLASSES: + if file_ending.lower() in rw.supported_file_endings: + if example_file is not None: + # if an example file is provided, try if we can actually read it. If not move on to the next reader + try: + tmp = rw() + _ = tmp.read_images((example_file,)) + if verbose: print(f'Using {rw} as reader/writer') + return rw + except: + if verbose: print(f'Failed to open file {example_file} with reader {rw}:') + traceback.print_exc() + pass + else: + if verbose: print(f'Using {rw} as reader/writer') + return rw + else: + if allow_nonmatching_filename and example_file is not None: + try: + tmp = rw() + _ = tmp.read_images((example_file,)) + if verbose: print(f'Using {rw} as reader/writer') + return rw + except: + if verbose: print(f'Failed to open file {example_file} with reader {rw}:') + if verbose: traceback.print_exc() + pass + raise RuntimeError(f"Unable to determine a reader for file ending {file_ending} and file {example_file} (file None means no file provided).") + + +def recursive_find_reader_writer_by_name(rw_class_name: str) -> Type[BaseReaderWriter]: + ret = recursive_find_python_class(join(nnunetv2.__path__[0], "imageio"), rw_class_name, 'nnunetv2.imageio') + if ret is None: + raise RuntimeError("Unable to find reader writer class '%s'. Please make sure this class is located in the " + "nnunetv2.imageio module." % rw_class_name) + else: + return ret diff --git a/docker/template/src/nnunetv2/imageio/readme.md b/docker/template/src/nnunetv2/imageio/readme.md new file mode 100644 index 0000000..7819425 --- /dev/null +++ b/docker/template/src/nnunetv2/imageio/readme.md @@ -0,0 +1,7 @@ +- Derive your adapter from `BaseReaderWriter`. +- Reimplement all abstractmethods. +- make sure to support 2d and 3d input images (or raise some error). +- place it in this folder or nnU-Net won't find it! +- add it to LIST_OF_IO_CLASSES in `reader_writer_registry.py` + +Bam, you're done! \ No newline at end of file diff --git a/docker/template/src/nnunetv2/imageio/simpleitk_reader_writer.py b/docker/template/src/nnunetv2/imageio/simpleitk_reader_writer.py new file mode 100644 index 0000000..6a9afc2 --- /dev/null +++ b/docker/template/src/nnunetv2/imageio/simpleitk_reader_writer.py @@ -0,0 +1,129 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union, List +import numpy as np +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +import SimpleITK as sitk + + +class SimpleITKIO(BaseReaderWriter): + supported_file_endings = [ + '.nii.gz', + '.nrrd', + '.mha' + ] + + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + images = [] + spacings = [] + origins = [] + directions = [] + + spacings_for_nnunet = [] + for f in image_fnames: + itk_image = sitk.ReadImage(f) + spacings.append(itk_image.GetSpacing()) + origins.append(itk_image.GetOrigin()) + directions.append(itk_image.GetDirection()) + npy_image = sitk.GetArrayFromImage(itk_image) + if npy_image.ndim == 2: + # 2d + npy_image = npy_image[None, None] + max_spacing = max(spacings[-1]) + spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1])) + elif npy_image.ndim == 3: + # 3d, as in original nnunet + npy_image = npy_image[None] + spacings_for_nnunet.append(list(spacings[-1])[::-1]) + elif npy_image.ndim == 4: + # 4d, multiple modalities in one file + spacings_for_nnunet.append(list(spacings[-1])[::-1][1:]) + pass + else: + raise RuntimeError(f"Unexpected number of dimensions: {npy_image.ndim} in file {f}") + + images.append(npy_image) + spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1])) + + if not self._check_all_same([i.shape for i in images]): + print('ERROR! Not all input images have the same shape!') + print('Shapes:') + print([i.shape for i in images]) + print('Image files:') + print(image_fnames) + raise RuntimeError() + if not self._check_all_same(spacings): + print('ERROR! Not all input images have the same spacing!') + print('Spacings:') + print(spacings) + print('Image files:') + print(image_fnames) + raise RuntimeError() + if not self._check_all_same(origins): + print('WARNING! Not all input images have the same origin!') + print('Origins:') + print(origins) + print('Image files:') + print(image_fnames) + print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' + 'that segmentations and data overlap.') + if not self._check_all_same(directions): + print('WARNING! Not all input images have the same direction!') + print('Directions:') + print(directions) + print('Image files:') + print(image_fnames) + print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' + 'that segmentations and data overlap.') + if not self._check_all_same(spacings_for_nnunet): + print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a ' + 'bug. Please report!') + print('spacings_for_nnunet:') + print(spacings_for_nnunet) + print('Image files:') + print(image_fnames) + raise RuntimeError() + + stacked_images = np.vstack(images) + dict = { + 'sitk_stuff': { + # this saves the sitk geometry information. This part is NOT used by nnU-Net! + 'spacing': spacings[0], + 'origin': origins[0], + 'direction': directions[0] + }, + # the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays + # are returned x,y,z but spacing is returned z,y,x. Duh. + 'spacing': spacings_for_nnunet[0] + } + return stacked_images.astype(np.float32), dict + + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + return self.read_images((seg_fname, )) + + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + assert seg.ndim == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y' + output_dimension = len(properties['sitk_stuff']['spacing']) + assert 1 < output_dimension < 4 + if output_dimension == 2: + seg = seg[0] + + itk_image = sitk.GetImageFromArray(seg.astype(np.uint8)) + itk_image.SetSpacing(properties['sitk_stuff']['spacing']) + itk_image.SetOrigin(properties['sitk_stuff']['origin']) + itk_image.SetDirection(properties['sitk_stuff']['direction']) + + sitk.WriteImage(itk_image, output_fname, True) diff --git a/docker/template/src/nnunetv2/imageio/tif_reader_writer.py b/docker/template/src/nnunetv2/imageio/tif_reader_writer.py new file mode 100644 index 0000000..19ad882 --- /dev/null +++ b/docker/template/src/nnunetv2/imageio/tif_reader_writer.py @@ -0,0 +1,100 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path +from typing import Tuple, Union, List +import numpy as np +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +import tifffile +from batchgenerators.utilities.file_and_folder_operations import isfile, load_json, save_json, split_path, join + + +class Tiff3DIO(BaseReaderWriter): + """ + reads and writes 3D tif(f) images. Uses tifffile package. Ignores metadata (for now)! + + If you have 2D tiffs, use NaturalImage2DIO + + Supports the use of auxiliary files for spacing information. If used, the auxiliary files are expected to end + with .json and omit the channel identifier. So, for example, the corresponding of image image1_0000.tif is + expected to be image1.json)! + """ + supported_file_endings = [ + '.tif', + '.tiff', + ] + + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + # figure out file ending used here + ending = '.' + image_fnames[0].split('.')[-1] + assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}' + ending_length = len(ending) + truncate_length = ending_length + 5 # 5 comes from len(_0000) + + images = [] + for f in image_fnames: + image = tifffile.imread(f) + if image.ndim != 3: + raise RuntimeError(f"Only 3D images are supported! File: {f}") + images.append(image[None]) + + # see if aux file can be found + expected_aux_file = image_fnames[0][:-truncate_length] + '.json' + if isfile(expected_aux_file): + spacing = load_json(expected_aux_file)['spacing'] + assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}' + else: + print(f'WARNING no spacing file found for images {image_fnames}\nAssuming spacing (1, 1, 1).') + spacing = (1, 1, 1) + + if not self._check_all_same([i.shape for i in images]): + print('ERROR! Not all input images have the same shape!') + print('Shapes:') + print([i.shape for i in images]) + print('Image files:') + print(image_fnames) + raise RuntimeError() + + return np.vstack(images).astype(np.float32), {'spacing': spacing} + + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + # not ideal but I really have no clue how to set spacing/resolution information properly in tif files haha + tifffile.imwrite(output_fname, data=seg.astype(np.uint8), compression='zlib') + file = os.path.basename(output_fname) + out_dir = os.path.dirname(output_fname) + ending = file.split('.')[-1] + save_json({'spacing': properties['spacing']}, join(out_dir, file[:-(len(ending) + 1)] + '.json')) + + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + # figure out file ending used here + ending = '.' + seg_fname.split('.')[-1] + assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}' + ending_length = len(ending) + + seg = tifffile.imread(seg_fname) + if seg.ndim != 3: + raise RuntimeError(f"Only 3D images are supported! File: {seg_fname}") + seg = seg[None] + + # see if aux file can be found + expected_aux_file = seg_fname[:-ending_length] + '.json' + if isfile(expected_aux_file): + spacing = load_json(expected_aux_file)['spacing'] + assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}' + assert all([i > 0 for i in spacing]), f"Spacing must be > 0, spacing: {spacing}" + else: + print(f'WARNING no spacing file found for segmentation {seg_fname}\nAssuming spacing (1, 1, 1).') + spacing = (1, 1, 1) + + return seg.astype(np.float32), {'spacing': spacing} \ No newline at end of file diff --git a/docker/template/src/nnunetv2/inference/__init__.py b/docker/template/src/nnunetv2/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/inference/data_iterators.py b/docker/template/src/nnunetv2/inference/data_iterators.py new file mode 100644 index 0000000..a35e330 --- /dev/null +++ b/docker/template/src/nnunetv2/inference/data_iterators.py @@ -0,0 +1,316 @@ +import multiprocessing +import queue +from torch.multiprocessing import Event, Queue, Manager + +from time import sleep +from typing import Union, List + +import numpy as np +import torch +from batchgenerators.dataloading.data_loader import DataLoader + +from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor +from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager + + +def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]], + list_of_segs_from_prev_stage_files: Union[None, List[str]], + output_filenames_truncated: Union[None, List[str]], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + target_queue: Queue, + done_event: Event, + abort_event: Event, + verbose: bool = False): + try: + label_manager = plans_manager.get_label_manager(dataset_json) + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + for idx in range(len(list_of_lists)): + data, seg, data_properties = preprocessor.run_case(list_of_lists[idx], + list_of_segs_from_prev_stage_files[ + idx] if list_of_segs_from_prev_stage_files is not None else None, + plans_manager, + configuration_manager, + dataset_json) + if list_of_segs_from_prev_stage_files is not None and list_of_segs_from_prev_stage_files[idx] is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + + data = torch.from_numpy(data).contiguous().float() + + item = {'data': data, 'data_properties': data_properties, + 'ofile': output_filenames_truncated[idx] if output_filenames_truncated is not None else None} + success = False + while not success: + try: + if abort_event.is_set(): + return + target_queue.put(item, timeout=0.01) + success = True + except queue.Full: + pass + done_event.set() + except Exception as e: + abort_event.set() + raise e + + +def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], + list_of_segs_from_prev_stage_files: Union[None, List[str]], + output_filenames_truncated: Union[None, List[str]], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + num_processes: int, + pin_memory: bool = False, + verbose: bool = False): + context = multiprocessing.get_context('spawn') + manager = Manager() + num_processes = min(len(list_of_lists), num_processes) + assert num_processes >= 1 + processes = [] + done_events = [] + target_queues = [] + abort_event = manager.Event() + for i in range(num_processes): + event = manager.Event() + queue = Manager().Queue(maxsize=1) + pr = context.Process(target=preprocess_fromfiles_save_to_queue, + args=( + list_of_lists[i::num_processes], + list_of_segs_from_prev_stage_files[ + i::num_processes] if list_of_segs_from_prev_stage_files is not None else None, + output_filenames_truncated[ + i::num_processes] if output_filenames_truncated is not None else None, + plans_manager, + dataset_json, + configuration_manager, + queue, + event, + abort_event, + verbose + ), daemon=True) + pr.start() + target_queues.append(queue) + done_events.append(event) + processes.append(pr) + + worker_ctr = 0 + while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): + if not target_queues[worker_ctr].empty(): + item = target_queues[worker_ctr].get() + worker_ctr = (worker_ctr + 1) % num_processes + else: + all_ok = all( + [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() + if not all_ok: + raise RuntimeError('Background workers died. Look for the error message further up! If there is ' + 'none then your RAM was full and the worker was killed by the OS. Use fewer ' + 'workers or get more RAM in that case!') + sleep(0.01) + continue + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item + [p.join() for p in processes] + +class PreprocessAdapter(DataLoader): + def __init__(self, list_of_lists: List[List[str]], + list_of_segs_from_prev_stage_files: Union[None, List[str]], + preprocessor: DefaultPreprocessor, + output_filenames_truncated: Union[None, List[str]], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + num_threads_in_multithreaded: int = 1): + self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json = \ + preprocessor, plans_manager, configuration_manager, dataset_json + + self.label_manager = plans_manager.get_label_manager(dataset_json) + + if list_of_segs_from_prev_stage_files is None: + list_of_segs_from_prev_stage_files = [None] * len(list_of_lists) + if output_filenames_truncated is None: + output_filenames_truncated = [None] * len(list_of_lists) + + super().__init__(list(zip(list_of_lists, list_of_segs_from_prev_stage_files, output_filenames_truncated)), + 1, num_threads_in_multithreaded, + seed_for_shuffle=1, return_incomplete=True, + shuffle=False, infinite=False, sampling_probabilities=None) + + self.indices = list(range(len(list_of_lists))) + + def generate_train_batch(self): + idx = self.get_indices()[0] + files = self._data[idx][0] + seg_prev_stage = self._data[idx][1] + ofile = self._data[idx][2] + # if we have a segmentation from the previous stage we have to process it together with the images so that we + # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after + # preprocessing and then there might be misalignments + data, seg, data_properties = self.preprocessor.run_case(files, seg_prev_stage, self.plans_manager, + self.configuration_manager, + self.dataset_json) + if seg_prev_stage is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + + data = torch.from_numpy(data) + + return {'data': data, 'data_properties': data_properties, 'ofile': ofile} + + +class PreprocessAdapterFromNpy(DataLoader): + def __init__(self, list_of_images: List[np.ndarray], + list_of_segs_from_prev_stage: Union[List[np.ndarray], None], + list_of_image_properties: List[dict], + truncated_ofnames: Union[List[str], None], + plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager, + num_threads_in_multithreaded: int = 1, verbose: bool = False): + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json, self.truncated_ofnames = \ + preprocessor, plans_manager, configuration_manager, dataset_json, truncated_ofnames + + self.label_manager = plans_manager.get_label_manager(dataset_json) + + if list_of_segs_from_prev_stage is None: + list_of_segs_from_prev_stage = [None] * len(list_of_images) + if truncated_ofnames is None: + truncated_ofnames = [None] * len(list_of_images) + + super().__init__( + list(zip(list_of_images, list_of_segs_from_prev_stage, list_of_image_properties, truncated_ofnames)), + 1, num_threads_in_multithreaded, + seed_for_shuffle=1, return_incomplete=True, + shuffle=False, infinite=False, sampling_probabilities=None) + + self.indices = list(range(len(list_of_images))) + + def generate_train_batch(self): + idx = self.get_indices()[0] + image = self._data[idx][0] + seg_prev_stage = self._data[idx][1] + props = self._data[idx][2] + ofname = self._data[idx][3] + # if we have a segmentation from the previous stage we have to process it together with the images so that we + # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after + # preprocessing and then there might be misalignments + data, seg = self.preprocessor.run_case_npy(image, seg_prev_stage, props, + self.plans_manager, + self.configuration_manager, + self.dataset_json) + if seg_prev_stage is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + + data = torch.from_numpy(data) + + return {'data': data, 'data_properties': props, 'ofile': ofname} + + +def preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray], + list_of_segs_from_prev_stage: Union[List[np.ndarray], None], + list_of_image_properties: List[dict], + truncated_ofnames: Union[List[str], None], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + target_queue: Queue, + done_event: Event, + abort_event: Event, + verbose: bool = False): + try: + label_manager = plans_manager.get_label_manager(dataset_json) + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + for idx in range(len(list_of_images)): + data, seg = preprocessor.run_case_npy(list_of_images[idx], + list_of_segs_from_prev_stage[ + idx] if list_of_segs_from_prev_stage is not None else None, + list_of_image_properties[idx], + plans_manager, + configuration_manager, + dataset_json) + if list_of_segs_from_prev_stage is not None and list_of_segs_from_prev_stage[idx] is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + + data = torch.from_numpy(data).contiguous().float() + + item = {'data': data, 'data_properties': list_of_image_properties[idx], + 'ofile': truncated_ofnames[idx] if truncated_ofnames is not None else None} + success = False + while not success: + try: + if abort_event.is_set(): + return + target_queue.put(item, timeout=0.01) + success = True + except queue.Full: + pass + done_event.set() + except Exception as e: + abort_event.set() + raise e + + +def preprocessing_iterator_fromnpy(list_of_images: List[np.ndarray], + list_of_segs_from_prev_stage: Union[List[np.ndarray], None], + list_of_image_properties: List[dict], + truncated_ofnames: Union[List[str], None], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + num_processes: int, + pin_memory: bool = False, + verbose: bool = False): + context = multiprocessing.get_context('spawn') + manager = Manager() + num_processes = min(len(list_of_images), num_processes) + assert num_processes >= 1 + target_queues = [] + processes = [] + done_events = [] + abort_event = manager.Event() + for i in range(num_processes): + event = manager.Event() + queue = manager.Queue(maxsize=1) + pr = context.Process(target=preprocess_fromnpy_save_to_queue, + args=( + list_of_images[i::num_processes], + list_of_segs_from_prev_stage[ + i::num_processes] if list_of_segs_from_prev_stage is not None else None, + list_of_image_properties[i::num_processes], + truncated_ofnames[i::num_processes] if truncated_ofnames is not None else None, + plans_manager, + dataset_json, + configuration_manager, + queue, + event, + abort_event, + verbose + ), daemon=True) + pr.start() + done_events.append(event) + processes.append(pr) + target_queues.append(queue) + + worker_ctr = 0 + while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): + if not target_queues[worker_ctr].empty(): + item = target_queues[worker_ctr].get() + worker_ctr = (worker_ctr + 1) % num_processes + else: + all_ok = all( + [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() + if not all_ok: + raise RuntimeError('Background workers died. Look for the error message further up! If there is ' + 'none then your RAM was full and the worker was killed by the OS. Use fewer ' + 'workers or get more RAM in that case!') + sleep(0.01) + continue + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item + [p.join() for p in processes] diff --git a/docker/template/src/nnunetv2/inference/examples.py b/docker/template/src/nnunetv2/inference/examples.py new file mode 100644 index 0000000..a66d98f --- /dev/null +++ b/docker/template/src/nnunetv2/inference/examples.py @@ -0,0 +1,102 @@ +if __name__ == '__main__': + from nnunetv2.paths import nnUNet_results, nnUNet_raw + import torch + from batchgenerators.utilities.file_and_folder_operations import join + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + + # nnUNetv2_predict -d 3 -f 0 -c 3d_lowres -i imagesTs -o imagesTs_predlowres --continue_prediction + + # instantiate the nnUNetPredictor + predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=True, + perform_everything_on_device=True, + device=torch.device('cuda', 0), + verbose=False, + verbose_preprocessing=False, + allow_tqdm=True + ) + # initializes the network architecture, loads the checkpoint + predictor.initialize_from_trained_model_folder( + join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), + use_folds=(0,), + checkpoint_name='checkpoint_final.pth', + ) + # variant 1: give input and output folders + predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), + join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'), + save_probabilities=False, overwrite=False, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) + + # variant 2, use list of files as inputs. Note how we use nested lists!!! + indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') + outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres') + predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], + [join(indir, 'liver_142_0000.nii.gz')]], + [join(outdir, 'liver_152.nii.gz'), + join(outdir, 'liver_142.nii.gz')], + save_probabilities=False, overwrite=True, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) + + # variant 2.5, returns segmentations + indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') + predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], + [join(indir, 'liver_142_0000.nii.gz')]], + None, + save_probabilities=True, overwrite=True, + num_processes_preprocessing=2, + num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, + part_id=0) + + # predict several npy images + from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) + img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) + img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) + img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) + # we do not set output files so that the segmentations will be returned. You can of course also specify output + # files instead (no return value on that case) + ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4], + None, + [props, props2, props3, props4], + None, 2, save_probabilities=False, + num_processes_segmentation_export=2) + + # predict a single numpy array + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) + ret = predictor.predict_single_npy_array(img, props, None, None, True) + + # custom iterator + + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) + img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) + img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) + img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) + + + # each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys! + # If 'ofile' is None, the result will be returned instead of written to a file + # the iterator is responsible for performing the correct preprocessing! + # note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread! + # take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays + # (they both use predictor.predict_from_data_iterator) for inspiration! + def my_iterator(list_of_input_arrs, list_of_input_props): + preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose) + for a, p in zip(list_of_input_arrs, list_of_input_props): + data, seg = preprocessor.run_case_npy(a, + None, + p, + predictor.plans_manager, + predictor.configuration_manager, + predictor.dataset_json) + yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properties': p, 'ofile': None} + + + ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]), + save_probabilities=False, num_processes_segmentation_export=3) diff --git a/docker/template/src/nnunetv2/inference/export_prediction.py b/docker/template/src/nnunetv2/inference/export_prediction.py new file mode 100644 index 0000000..3303567 --- /dev/null +++ b/docker/template/src/nnunetv2/inference/export_prediction.py @@ -0,0 +1,145 @@ +import os +from copy import deepcopy +from typing import Union, List + +import numpy as np +import torch +from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice +from batchgenerators.utilities.file_and_folder_operations import load_json, isfile, save_pickle + +from nnunetv2.configuration import default_num_processes +from nnunetv2.utilities.label_handling.label_handling import LabelManager +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager + + +def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray], + plans_manager: PlansManager, + configuration_manager: ConfigurationManager, + label_manager: LabelManager, + properties_dict: dict, + return_probabilities: bool = False, + num_threads_torch: int = default_num_processes): + old_threads = torch.get_num_threads() + torch.set_num_threads(num_threads_torch) + + # resample to original shape + current_spacing = configuration_manager.spacing if \ + len(configuration_manager.spacing) == \ + len(properties_dict['shape_after_cropping_and_before_resampling']) else \ + [properties_dict['spacing'][0], *configuration_manager.spacing] + predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits, + properties_dict['shape_after_cropping_and_before_resampling'], + current_spacing, + properties_dict['spacing']) + # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because + # apply_inference_nonlin will convert to torch + predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits) + del predicted_logits + segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities) + + # segmentation may be torch.Tensor but we continue with numpy + if isinstance(segmentation, torch.Tensor): + segmentation = segmentation.cpu().numpy() + + # put segmentation in bbox (revert cropping) + segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'], + dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) + slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) + segmentation_reverted_cropping[slicer] = segmentation + del segmentation + + # revert transpose + segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward) + if return_probabilities: + # revert cropping + predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities, + properties_dict[ + 'bbox_used_for_cropping'], + properties_dict[ + 'shape_before_cropping']) + predicted_probabilities = predicted_probabilities.cpu().numpy() + # revert transpose + predicted_probabilities = predicted_probabilities.transpose([0] + [i + 1 for i in + plans_manager.transpose_backward]) + torch.set_num_threads(old_threads) + return segmentation_reverted_cropping, predicted_probabilities + else: + torch.set_num_threads(old_threads) + return segmentation_reverted_cropping + + +def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict, + configuration_manager: ConfigurationManager, + plans_manager: PlansManager, + dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str, + save_probabilities: bool = False): + # if isinstance(predicted_array_or_file, str): + # tmp = deepcopy(predicted_array_or_file) + # if predicted_array_or_file.endswith('.npy'): + # predicted_array_or_file = np.load(predicted_array_or_file) + # elif predicted_array_or_file.endswith('.npz'): + # predicted_array_or_file = np.load(predicted_array_or_file)['softmax'] + # os.remove(tmp) + + if isinstance(dataset_json_dict_or_file, str): + dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) + + label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) + ret = convert_predicted_logits_to_segmentation_with_correct_shape( + predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict, + return_probabilities=save_probabilities + ) + del predicted_array_or_file + + # save + if save_probabilities: + segmentation_final, probabilities_final = ret + np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final) + save_pickle(properties_dict, output_file_truncated + '.pkl') + del probabilities_final, ret + else: + segmentation_final = ret + del ret + + rw = plans_manager.image_reader_writer_class() + rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'], + properties_dict) + + +def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str, + plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict, + dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes) \ + -> None: + # # needed for cascade + # if isinstance(predicted, str): + # assert isfile(predicted), "If isinstance(segmentation_softmax, str) then " \ + # "isfile(segmentation_softmax) must be True" + # del_file = deepcopy(predicted) + # predicted = np.load(predicted) + # os.remove(del_file) + old_threads = torch.get_num_threads() + torch.set_num_threads(num_threads_torch) + + if isinstance(dataset_json_dict_or_file, str): + dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) + + # resample to original shape + current_spacing = configuration_manager.spacing if \ + len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \ + [properties_dict['spacing'][0], *configuration_manager.spacing] + target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \ + len(properties_dict['shape_after_cropping_and_before_resampling']) else \ + [properties_dict['spacing'][0], *configuration_manager.spacing] + predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted, + target_shape, + current_spacing, + target_spacing) + + # create segmentation (argmax, regions, etc) + label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) + segmentation = label_manager.convert_logits_to_segmentation(predicted_array_or_file) + # segmentation may be torch.Tensor but we continue with numpy + if isinstance(segmentation, torch.Tensor): + segmentation = segmentation.cpu().numpy() + np.savez_compressed(output_file, seg=segmentation.astype(np.uint8)) + torch.set_num_threads(old_threads) diff --git a/docker/template/src/nnunetv2/inference/predict_from_raw_data.py b/docker/template/src/nnunetv2/inference/predict_from_raw_data.py new file mode 100644 index 0000000..1b6543e --- /dev/null +++ b/docker/template/src/nnunetv2/inference/predict_from_raw_data.py @@ -0,0 +1,917 @@ +import inspect +import itertools +import multiprocessing +import os +from copy import deepcopy +from time import sleep +from typing import Tuple, Union, List, Optional + +import numpy as np +import torch +from acvl_utils.cropping_and_padding.padding import pad_nd_image +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \ + save_json +from torch import nn +from torch._dynamo import OptimizedModule +from torch.nn.parallel import DistributedDataParallel +from tqdm import tqdm + +import nnunetv2 +from nnunetv2.configuration import default_num_processes +from nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy, preprocessing_iterator_fromfiles, \ + preprocessing_iterator_fromnpy +from nnunetv2.inference.export_prediction import export_prediction_from_logits, \ + convert_predicted_logits_to_segmentation_with_correct_shape +from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \ + compute_steps_for_sliding_window +from nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.helpers import empty_cache, dummy_context +from nnunetv2.utilities.json_export import recursive_fix_for_json_export +from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder + + +class nnUNetPredictor(object): + def __init__(self, + tile_step_size: float = 0.5, + use_gaussian: bool = True, + use_mirroring: bool = True, + perform_everything_on_device: bool = True, + device: torch.device = torch.device('cuda'), + verbose: bool = False, + verbose_preprocessing: bool = False, + allow_tqdm: bool = True): + self.verbose = verbose + self.verbose_preprocessing = verbose_preprocessing + self.allow_tqdm = allow_tqdm + + self.plans_manager, self.configuration_manager, self.list_of_parameters, self.network, self.dataset_json, \ + self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None, None, None + + self.tile_step_size = tile_step_size + self.use_gaussian = use_gaussian + self.use_mirroring = use_mirroring + if device.type == 'cuda': + # device = torch.device(type='cuda', index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES! + # why would I ever want to do that. Stupid dobby. This kills DDP inference... + pass + if device.type != 'cuda': + print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False') + perform_everything_on_device = False + self.device = device + self.perform_everything_on_device = perform_everything_on_device + + def initialize_from_trained_model_folder(self, model_training_output_dir: str, + use_folds: Union[Tuple[Union[int, str]], None], + checkpoint_name: str = 'checkpoint_final.pth'): + """ + This is used when making predictions with a trained model + """ + if use_folds is None: + use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name) + + dataset_json = load_json(join(model_training_output_dir, 'dataset.json')) + plans = load_json(join(model_training_output_dir, 'plans.json')) + plans_manager = PlansManager(plans) + + if isinstance(use_folds, str): + use_folds = [use_folds] + + parameters = [] + for i, f in enumerate(use_folds): + f = int(f) if f != 'all' else f + checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name), + map_location=torch.device('cpu')) + if i == 0: + trainer_name = checkpoint['trainer_name'] + configuration_name = checkpoint['init_args']['configuration'] + inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \ + 'inference_allowed_mirroring_axes' in checkpoint.keys() else None + + parameters.append(checkpoint['network_weights']) + + configuration_manager = plans_manager.get_configuration(configuration_name) + # restore network + num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) + trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), + trainer_name, 'nnunetv2.training.nnUNetTrainer') + network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager, + num_input_channels, enable_deep_supervision=False) + self.plans_manager = plans_manager + self.configuration_manager = configuration_manager + self.list_of_parameters = parameters + self.network = network + self.dataset_json = dataset_json + self.trainer_name = trainer_name + self.allowed_mirroring_axes = inference_allowed_mirroring_axes + self.label_manager = plans_manager.get_label_manager(dataset_json) + if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ + and not isinstance(self.network, OptimizedModule): + print('Using torch.compile') + self.network = torch.compile(self.network) + + def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, + configuration_manager: ConfigurationManager, parameters: Optional[List[dict]], + dataset_json: dict, trainer_name: str, + inference_allowed_mirroring_axes: Optional[Tuple[int, ...]]): + """ + This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation + """ + self.plans_manager = plans_manager + self.configuration_manager = configuration_manager + self.list_of_parameters = parameters + self.network = network + self.dataset_json = dataset_json + self.trainer_name = trainer_name + self.allowed_mirroring_axes = inference_allowed_mirroring_axes + self.label_manager = plans_manager.get_label_manager(dataset_json) + allow_compile = True + allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) + allow_compile = allow_compile and not isinstance(self.network, OptimizedModule) + if isinstance(self.network, DistributedDataParallel): + allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule) + if allow_compile: + print('Using torch.compile') + self.network = torch.compile(self.network) + + @staticmethod + def auto_detect_available_folds(model_training_output_dir, checkpoint_name): + print('use_folds is None, attempting to auto detect available folds') + fold_folders = subdirs(model_training_output_dir, prefix='fold_', join=False) + fold_folders = [i for i in fold_folders if i != 'fold_all'] + fold_folders = [i for i in fold_folders if isfile(join(model_training_output_dir, i, checkpoint_name))] + use_folds = [int(i.split('_')[-1]) for i in fold_folders] + print(f'found the following folds: {use_folds}') + return use_folds + + def _manage_input_and_output_lists(self, list_of_lists_or_source_folder: Union[str, List[List[str]]], + output_folder_or_list_of_truncated_output_files: Union[None, str, List[str]], + folder_with_segs_from_prev_stage: str = None, + overwrite: bool = True, + part_id: int = 0, + num_parts: int = 1, + save_probabilities: bool = False): + if isinstance(list_of_lists_or_source_folder, str): + list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(list_of_lists_or_source_folder, + self.dataset_json['file_ending']) + print(f'There are {len(list_of_lists_or_source_folder)} cases in the source folder') + list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts] + caseids = [os.path.basename(i[0])[:-(len(self.dataset_json['file_ending']) + 5)] for i in + list_of_lists_or_source_folder] + print( + f'I am process {part_id} out of {num_parts} (max process ID is {num_parts - 1}, we start counting with 0!)') + print(f'There are {len(caseids)} cases that I would like to predict') + + if isinstance(output_folder_or_list_of_truncated_output_files, str): + output_filename_truncated = [join(output_folder_or_list_of_truncated_output_files, i) for i in caseids] + else: + output_filename_truncated = output_folder_or_list_of_truncated_output_files + + seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + self.dataset_json['file_ending']) if + folder_with_segs_from_prev_stage is not None else None for i in caseids] + # remove already predicted files form the lists + if not overwrite and output_filename_truncated is not None: + tmp = [isfile(i + self.dataset_json['file_ending']) for i in output_filename_truncated] + if save_probabilities: + tmp2 = [isfile(i + '.npz') for i in output_filename_truncated] + tmp = [i and j for i, j in zip(tmp, tmp2)] + not_existing_indices = [i for i, j in enumerate(tmp) if not j] + + output_filename_truncated = [output_filename_truncated[i] for i in not_existing_indices] + list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in not_existing_indices] + seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in not_existing_indices] + print(f'overwrite was set to {overwrite}, so I am only working on cases that haven\'t been predicted yet. ' + f'That\'s {len(not_existing_indices)} cases.') + return list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files + + def predict_from_files(self, + list_of_lists_or_source_folder: Union[str, List[List[str]]], + output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]], + save_probabilities: bool = False, + overwrite: bool = True, + num_processes_preprocessing: int = default_num_processes, + num_processes_segmentation_export: int = default_num_processes, + folder_with_segs_from_prev_stage: str = None, + num_parts: int = 1, + part_id: int = 0): + """ + This is nnU-Net's default function for making predictions. It works best for batch predictions + (predicting many images at once). + """ + if isinstance(output_folder_or_list_of_truncated_output_files, str): + output_folder = output_folder_or_list_of_truncated_output_files + elif isinstance(output_folder_or_list_of_truncated_output_files, list): + output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0]) + else: + output_folder = None + + ######################## + # let's store the input arguments so that its clear what was used to generate the prediction + if output_folder is not None: + my_init_kwargs = {} + for k in inspect.signature(self.predict_from_files).parameters.keys(): + my_init_kwargs[k] = locals()[k] + my_init_kwargs = deepcopy( + my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a + recursive_fix_for_json_export(my_init_kwargs) + maybe_mkdir_p(output_folder) + save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json')) + + # we need these two if we want to do things with the predictions like for example apply postprocessing + save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False) + save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False) + ####################### + + # check if we need a prediction from the previous stage + if self.configuration_manager.previous_stage_name is not None: + assert folder_with_segs_from_prev_stage is not None, \ + f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \ + f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \ + f' they are located via folder_with_segs_from_prev_stage' + + # sort out input and output filenames + list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \ + self._manage_input_and_output_lists(list_of_lists_or_source_folder, + output_folder_or_list_of_truncated_output_files, + folder_with_segs_from_prev_stage, overwrite, part_id, num_parts, + save_probabilities) + if len(list_of_lists_or_source_folder) == 0: + return + + data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder, + seg_from_prev_stage_files, + output_filename_truncated, + num_processes_preprocessing) + + return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export) + + def _internal_get_data_iterator_from_lists_of_filenames(self, + input_list_of_lists: List[List[str]], + seg_from_prev_stage_files: Union[List[str], None], + output_filenames_truncated: Union[List[str], None], + num_processes: int): + return preprocessing_iterator_fromfiles(input_list_of_lists, seg_from_prev_stage_files, + output_filenames_truncated, self.plans_manager, self.dataset_json, + self.configuration_manager, num_processes, self.device.type == 'cuda', + self.verbose_preprocessing) + # preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing) + # # hijack batchgenerators, yo + # # we use the multiprocessing of the batchgenerators dataloader to handle all the background worker stuff. This + # # way we don't have to reinvent the wheel here. + # num_processes = max(1, min(num_processes, len(input_list_of_lists))) + # ppa = PreprocessAdapter(input_list_of_lists, seg_from_prev_stage_files, preprocessor, + # output_filenames_truncated, self.plans_manager, self.dataset_json, + # self.configuration_manager, num_processes) + # if num_processes == 0: + # mta = SingleThreadedAugmenter(ppa, None) + # else: + # mta = MultiThreadedAugmenter(ppa, None, num_processes, 1, None, pin_memory=pin_memory) + # return mta + + def get_data_iterator_from_raw_npy_data(self, + image_or_list_of_images: Union[np.ndarray, List[np.ndarray]], + segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None, + np.ndarray, + List[ + np.ndarray]], + properties_or_list_of_properties: Union[dict, List[dict]], + truncated_ofname: Union[str, List[str], None], + num_processes: int = 3): + + list_of_images = [image_or_list_of_images] if not isinstance(image_or_list_of_images, list) else \ + image_or_list_of_images + + if isinstance(segs_from_prev_stage_or_list_of_segs_from_prev_stage, np.ndarray): + segs_from_prev_stage_or_list_of_segs_from_prev_stage = [ + segs_from_prev_stage_or_list_of_segs_from_prev_stage] + + if isinstance(truncated_ofname, str): + truncated_ofname = [truncated_ofname] + + if isinstance(properties_or_list_of_properties, dict): + properties_or_list_of_properties = [properties_or_list_of_properties] + + num_processes = min(num_processes, len(list_of_images)) + pp = preprocessing_iterator_fromnpy( + list_of_images, + segs_from_prev_stage_or_list_of_segs_from_prev_stage, + properties_or_list_of_properties, + truncated_ofname, + self.plans_manager, + self.dataset_json, + self.configuration_manager, + num_processes, + self.device.type == 'cuda', + self.verbose_preprocessing + ) + + return pp + + def predict_from_list_of_npy_arrays(self, + image_or_list_of_images: Union[np.ndarray, List[np.ndarray]], + segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None, + np.ndarray, + List[ + np.ndarray]], + properties_or_list_of_properties: Union[dict, List[dict]], + truncated_ofname: Union[str, List[str], None], + num_processes: int = 3, + save_probabilities: bool = False, + num_processes_segmentation_export: int = default_num_processes): + iterator = self.get_data_iterator_from_raw_npy_data(image_or_list_of_images, + segs_from_prev_stage_or_list_of_segs_from_prev_stage, + properties_or_list_of_properties, + truncated_ofname, + num_processes) + return self.predict_from_data_iterator(iterator, save_probabilities, num_processes_segmentation_export) + + def predict_from_data_iterator(self, + data_iterator, + save_probabilities: bool = False, + num_processes_segmentation_export: int = default_num_processes): + """ + each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys! + If 'ofile' is None, the result will be returned instead of written to a file + """ + with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool: + worker_list = [i for i in export_pool._pool] + r = [] + for preprocessed in data_iterator: + data = preprocessed['data'] + if isinstance(data, str): + delfile = data + data = torch.from_numpy(np.load(data)) + os.remove(delfile) + + ofile = preprocessed['ofile'] + if ofile is not None: + print(f'\nPredicting {os.path.basename(ofile)}:') + else: + print(f'\nPredicting image of shape {data.shape}:') + + print(f'perform_everything_on_device: {self.perform_everything_on_device}') + + properties = preprocessed['data_properties'] + + # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with + # npy files + proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) + while not proceed: + # print('sleeping') + sleep(0.1) + proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) + + prediction = self.predict_logits_from_preprocessed_data(data).cpu() + + if ofile is not None: + # this needs to go into background processes + # export_prediction_from_logits(prediction, properties, configuration_manager, plans_manager, + # dataset_json, ofile, save_probabilities) + print('sending off prediction to background worker for resampling and export') + r.append( + export_pool.starmap_async( + export_prediction_from_logits, + ((prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, ofile, save_probabilities),) + ) + ) + else: + # convert_predicted_logits_to_segmentation_with_correct_shape(prediction, plans_manager, + # configuration_manager, label_manager, + # properties, + # save_probabilities) + print('sending off prediction to background worker for resampling') + r.append( + export_pool.starmap_async( + convert_predicted_logits_to_segmentation_with_correct_shape, ( + (prediction, self.plans_manager, + self.configuration_manager, self.label_manager, + properties, + save_probabilities),) + ) + ) + if ofile is not None: + print(f'done with {os.path.basename(ofile)}') + else: + print(f'\nDone with image of shape {data.shape}:') + ret = [i.get()[0] for i in r] + + if isinstance(data_iterator, MultiThreadedAugmenter): + data_iterator._finish() + + # clear lru cache + compute_gaussian.cache_clear() + # clear device cache + empty_cache(self.device) + return ret + + def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict, + segmentation_previous_stage: np.ndarray = None, + output_file_truncated: str = None, + save_or_return_probabilities: bool = False): + """ + image_properties must only have a 'spacing' key! + """ + ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties], + [output_file_truncated], + self.plans_manager, self.dataset_json, self.configuration_manager, + num_threads_in_multithreaded=1, verbose=self.verbose) + if self.verbose: + print('preprocessing') + dct = next(ppa) + + if self.verbose: + print('predicting') + predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu() + + if self.verbose: + print('resampling to original shape') + if output_file_truncated is not None: + export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager, + self.plans_manager, self.dataset_json, output_file_truncated, + save_or_return_probabilities) + else: + ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager, + self.configuration_manager, + self.label_manager, + dct['data_properties'], + return_probabilities= + save_or_return_probabilities) + if save_or_return_probabilities: + return ret[0], ret[1] + else: + return ret + + def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor: + """ + IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON + TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE! + + RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE. + SEE convert_predicted_logits_to_segmentation_with_correct_shape + """ + n_threads = torch.get_num_threads() + torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads) + with torch.no_grad(): + prediction = None + + for params in self.list_of_parameters: + + # messing with state dict names... + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(params) + else: + self.network._orig_mod.load_state_dict(params) + + # why not leave prediction on device if perform_everything_on_device? Because this may cause the + # second iteration to crash due to OOM. Grabbing tha twith try except cause way more bloated code than + # this actually saves computation time + if prediction is None: + prediction = self.predict_sliding_window_return_logits(data).to('cpu') + else: + prediction += self.predict_sliding_window_return_logits(data).to('cpu') + + if len(self.list_of_parameters) > 1: + prediction /= len(self.list_of_parameters) + + if self.verbose: print('Prediction done') + prediction = prediction.to('cpu') + torch.set_num_threads(n_threads) + return prediction + + def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]): + slicers = [] + if len(self.configuration_manager.patch_size) < len(image_size): + assert len(self.configuration_manager.patch_size) == len( + image_size) - 1, 'if tile_size has less entries than image_size, ' \ + 'len(tile_size) ' \ + 'must be one shorter than len(image_size) ' \ + '(only dimension ' \ + 'discrepancy of 1 allowed).' + steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size, + self.tile_step_size) + if self.verbose: print(f'n_steps {image_size[0] * len(steps[0]) * len(steps[1])}, image size is' + f' {image_size}, tile_size {self.configuration_manager.patch_size}, ' + f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}') + for d in range(image_size[0]): + for sx in steps[0]: + for sy in steps[1]: + slicers.append( + tuple([slice(None), d, *[slice(si, si + ti) for si, ti in + zip((sx, sy), self.configuration_manager.patch_size)]])) + else: + steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size, + self.tile_step_size) + if self.verbose: print( + f'n_steps {np.prod([len(i) for i in steps])}, image size is {image_size}, tile_size {self.configuration_manager.patch_size}, ' + f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}') + for sx in steps[0]: + for sy in steps[1]: + for sz in steps[2]: + slicers.append( + tuple([slice(None), *[slice(si, si + ti) for si, ti in + zip((sx, sy, sz), self.configuration_manager.patch_size)]])) + return slicers + + def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: + mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None + if 'SAMed' in self.trainer_name: + prediction = self.network(x, True, self.configuration_manager.patch_size[-1])['masks'] + + else: + prediction = self.network(x) + + if mirror_axes is not None: + # check for invalid numbers in mirror_axes + # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 + assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' + + axes_combinations = [ + c for i in range(len(mirror_axes)) for c in itertools.combinations([m + 2 for m in mirror_axes], i + 1) + ] + for axes in axes_combinations: + if 'SAMed' in self.trainer_name: + prediction += torch.flip(self.network(torch.flip(x, (*axes,)),True, self.configuration_manager.patch_size[-1])['masks'], (*axes,)) + else: + prediction += torch.flip(self.network(torch.flip(x, (*axes,))), (*axes,)) + prediction /= (len(axes_combinations) + 1) + return prediction + + def _internal_predict_sliding_window_return_logits(self, + data: torch.Tensor, + slicers, + do_on_device: bool = True, + ): + results_device = self.device if do_on_device else torch.device('cpu') + empty_cache(self.device) + + # move data to device + if self.verbose: + print(f'move image to device {results_device}') + data = data.to(results_device) + + # preallocate arrays + if self.verbose: + print(f'preallocating results arrays on device {results_device}') + predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), + dtype=torch.half, + device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + if self.use_gaussian: + gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, + value_scaling_factor=10, + device=results_device) + + if self.verbose: print('running prediction') + if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps') + for sl in tqdm(slicers, disable=not self.allow_tqdm): + workon = data[sl][None] + workon = workon.to(self.device, non_blocking=False) + + prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) + + predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction) + n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1) + + predicted_logits /= n_predictions + # check for infs + if torch.any(torch.isinf(predicted_logits)): + raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' + 'reduce value_scaling_factor in compute_gaussian or increase the dtype of ' + 'predicted_logits to fp32') + return predicted_logits + + def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ + -> Union[np.ndarray, torch.Tensor]: + assert isinstance(input_image, torch.Tensor) + self.network = self.network.to(self.device) + self.network.eval() + + empty_cache(self.device) + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection) + # and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False + # is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with torch.no_grad(): + with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)' + + if self.verbose: print(f'Input shape: {input_image.shape}') + if self.verbose: print("step_size:", self.tile_step_size) + if self.verbose: print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None) + + # if input_image is smaller than tile_size we need to pad it to tile_size. + data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size, + 'constant', {'value': 0}, True, + None) + + slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) + + if self.perform_everything_on_device and self.device != 'cpu': + # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device + try: + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device) + except RuntimeError: + print('Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU') + empty_cache(self.device) + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False) + else: + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device) + + empty_cache(self.device) + # revert padding + predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + return predicted_logits + + +def predict_entry_point_modelfolder(): + import argparse + parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when ' + 'you want to manually specify a folder containing a trained nnU-Net ' + 'model. This is useful when the nnunet environment variables ' + '(nnUNet_results) are not set.') + parser.add_argument('-i', type=str, required=True, + help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). ' + 'File endings must be the same as the training dataset!') + parser.add_argument('-o', type=str, required=True, + help='Output folder. If it does not exist it will be created. Predicted segmentations will ' + 'have the same name as their source images.') + parser.add_argument('-m', type=str, required=True, + help='Folder in which the trained model is. Must have subfolders fold_X for the different ' + 'folds you trained') + parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4), + help='Specify the folds of the trained model that should be used for prediction. ' + 'Default: (0, 1, 2, 3, 4)') + parser.add_argument('-step_size', type=float, required=False, default=0.5, + help='Step size for sliding window prediction. The larger it is the faster but less accurate ' + 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.') + parser.add_argument('--disable_tta', action='store_true', required=False, default=False, + help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' + 'but less accurate inference. Not recommended.') + parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " + "to be a good listener/reader.") + parser.add_argument('--save_probabilities', action='store_true', + help='Set this to export predicted class "probabilities". Required if you want to ensemble ' + 'multiple configurations.') + parser.add_argument('--continue_prediction', '--c', action='store_true', + help='Continue an aborted previous prediction (will not overwrite existing files)') + parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', + help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') + parser.add_argument('-npp', type=int, required=False, default=3, + help='Number of processes used for preprocessing. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-nps', type=int, required=False, default=3, + help='Number of processes used for segmentation export. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, + help='Folder containing the predictions of the previous stage. Required for cascaded models.') + parser.add_argument('-device', type=str, default='cuda', required=False, + help="Use this to set the device the inference should run with. Available options are 'cuda' " + "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " + "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, + help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' + 'jobs)') + + + print( + "\n#######################################################################\nPlease cite the following paper " + "when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n#######################################################################\n") + + args = parser.parse_args() + args.f = [i if i == 'all' else int(i) for i in args.f] + + if not isdir(args.o): + maybe_mkdir_p(args.o) + + assert args.device in ['cpu', 'cuda', + 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' + if args.device == 'cpu': + # let's allow torch to use hella threads + import multiprocessing + torch.set_num_threads(multiprocessing.cpu_count()) + device = torch.device('cpu') + elif args.device == 'cuda': + # multithreading in torch doesn't help nnU-Net if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device('cuda') + else: + device = torch.device('mps') + + predictor = nnUNetPredictor(tile_step_size=args.step_size, + use_gaussian=True, + use_mirroring=not args.disable_tta, + perform_everything_on_device=True, + device=device, + verbose=args.verbose, + allow_tqdm=not args.disable_progress_bar) + predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk) + predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, + overwrite=not args.continue_prediction, + num_processes_preprocessing=args.npp, + num_processes_segmentation_export=args.nps, + folder_with_segs_from_prev_stage=args.prev_stage_predictions, + num_parts=1, part_id=0) + + +def predict_entry_point(): + import argparse + parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when ' + 'you want to manually specify a folder containing a trained nnU-Net ' + 'model. This is useful when the nnunet environment variables ' + '(nnUNet_results) are not set.') + parser.add_argument('-i', type=str, required=True, + help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). ' + 'File endings must be the same as the training dataset!') + parser.add_argument('-o', type=str, required=True, + help='Output folder. If it does not exist it will be created. Predicted segmentations will ' + 'have the same name as their source images.') + parser.add_argument('-d', type=str, required=True, + help='Dataset with which you would like to predict. You can specify either dataset name or id') + parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', + help='Plans identifier. Specify the plans in which the desired configuration is located. ' + 'Default: nnUNetPlans') + parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', + help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer') + parser.add_argument('-c', type=str, required=True, + help='nnU-Net configuration that should be used for prediction. Config must be located ' + 'in the plans specified with -p') + parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4), + help='Specify the folds of the trained model that should be used for prediction. ' + 'Default: (0, 1, 2, 3, 4)') + parser.add_argument('-step_size', type=float, required=False, default=0.5, + help='Step size for sliding window prediction. The larger it is the faster but less accurate ' + 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.') + parser.add_argument('--disable_tta', action='store_true', required=False, default=False, + help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' + 'but less accurate inference. Not recommended.') + parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " + "to be a good listener/reader.") + parser.add_argument('--save_probabilities', action='store_true', + help='Set this to export predicted class "probabilities". Required if you want to ensemble ' + 'multiple configurations.') + parser.add_argument('--continue_prediction', action='store_true', + help='Continue an aborted previous prediction (will not overwrite existing files)') + parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', + help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') + parser.add_argument('-npp', type=int, required=False, default=3, + help='Number of processes used for preprocessing. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-nps', type=int, required=False, default=3, + help='Number of processes used for segmentation export. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, + help='Folder containing the predictions of the previous stage. Required for cascaded models.') + parser.add_argument('-num_parts', type=int, required=False, default=1, + help='Number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one ' + 'call predicts everything)') + parser.add_argument('-part_id', type=int, required=False, default=0, + help='If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with ' + 'num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts ' + '5 and use -part_id 0, 1, 2, 3 and 4. Simple, right? Note: You are yourself responsible ' + 'to make these run on separate GPUs! Use CUDA_VISIBLE_DEVICES (google, yo!)') + parser.add_argument('-device', type=str, default='cuda', required=False, + help="Use this to set the device the inference should run with. Available options are 'cuda' " + "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " + "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, + help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' + 'jobs)') + + print( + "\n#######################################################################\nPlease cite the following paper " + "when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n#######################################################################\n") + + args = parser.parse_args() + args.f = [i if i == 'all' else int(i) for i in args.f] + + model_folder = get_output_folder(args.d, args.tr, args.p, args.c) + + if not isdir(args.o): + maybe_mkdir_p(args.o) + + # slightly passive aggressive haha + assert args.part_id < args.num_parts, 'Do you even read the documentation? See nnUNetv2_predict -h.' + + assert args.device in ['cpu', 'cuda', + 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' + if args.device == 'cpu': + # let's allow torch to use hella threads + import multiprocessing + torch.set_num_threads(multiprocessing.cpu_count()) + device = torch.device('cpu') + elif args.device == 'cuda': + # multithreading in torch doesn't help nnU-Net if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device('cuda') + else: + device = torch.device('mps') + + predictor = nnUNetPredictor(tile_step_size=args.step_size, + use_gaussian=True, + use_mirroring=not args.disable_tta, + perform_everything_on_device=True, + device=device, + verbose=args.verbose, + verbose_preprocessing=False, + allow_tqdm=not args.disable_progress_bar) + predictor.initialize_from_trained_model_folder( + model_folder, + args.f, + checkpoint_name=args.chk + ) + predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, + overwrite=not args.continue_prediction, + num_processes_preprocessing=args.npp, + num_processes_segmentation_export=args.nps, + folder_with_segs_from_prev_stage=args.prev_stage_predictions, + num_parts=args.num_parts, + part_id=args.part_id) + # r = predict_from_raw_data(args.i, + # args.o, + # model_folder, + # args.f, + # args.step_size, + # use_gaussian=True, + # use_mirroring=not args.disable_tta, + # perform_everything_on_device=True, + # verbose=args.verbose, + # save_probabilities=args.save_probabilities, + # overwrite=not args.continue_prediction, + # checkpoint_name=args.chk, + # num_processes_preprocessing=args.npp, + # num_processes_segmentation_export=args.nps, + # folder_with_segs_from_prev_stage=args.prev_stage_predictions, + # num_parts=args.num_parts, + # part_id=args.part_id, + # device=device) + + +if __name__ == '__main__': + # predict a bunch of files + from nnunetv2.paths import nnUNet_results, nnUNet_raw + predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=True, + perform_everything_on_device=True, + device=torch.device('cuda', 0), + verbose=False, + verbose_preprocessing=False, + allow_tqdm=True + ) + predictor.initialize_from_trained_model_folder( + join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), + use_folds=(0, ), + checkpoint_name='checkpoint_final.pth', + ) + predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), + join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'), + save_probabilities=False, overwrite=False, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) + + # predict a numpy array + from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')]) + ret = predictor.predict_single_npy_array(img, props, None, None, False) + + iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1) + ret = predictor.predict_from_data_iterator(iterator, False, 1) + + + # predictor = nnUNetPredictor( + # tile_step_size=0.5, + # use_gaussian=True, + # use_mirroring=True, + # perform_everything_on_device=True, + # device=torch.device('cuda', 0), + # verbose=False, + # allow_tqdm=True + # ) + # predictor.initialize_from_trained_model_folder( + # join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_cascade_fullres'), + # use_folds=(0,), + # checkpoint_name='checkpoint_final.pth', + # ) + # predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), + # join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predCascade'), + # save_probabilities=False, overwrite=False, + # num_processes_preprocessing=2, num_processes_segmentation_export=2, + # folder_with_segs_from_prev_stage='/media/isensee/data/nnUNet_raw/Dataset003_Liver/imagesTs_predlowres', + # num_parts=1, part_id=0) + diff --git a/docker/template/src/nnunetv2/inference/readme.md b/docker/template/src/nnunetv2/inference/readme.md new file mode 100644 index 0000000..4f832a1 --- /dev/null +++ b/docker/template/src/nnunetv2/inference/readme.md @@ -0,0 +1,205 @@ +The nnU-Net inference is now much more dynamic than before, allowing you to more seamlessly integrate nnU-Net into +your existing workflows. +This readme will give you a quick rundown of your options. This is not a complete guide. Look into the code to learn +all the details! + +# Preface +In terms of speed, the most efficient inference strategy is the one done by the nnU-Net defaults! Images are read on +the fly and preprocessed in background workers. The main process takes the preprocessed images, predicts them and +sends the prediction off to another set of background workers which will resize the resulting logits, convert +them to a segmentation and export the segmentation. + +The reason the default setup is the best option is because + +1) loading and preprocessing as well as segmentation export are interlaced with the prediction. The main process can +focus on communicating with the compute device (i.e. your GPU) and does not have to do any other processing. +This uses your resources as well as possible! +2) only the images and segmentation that are currently being needed are stored in RAM! Imaging predicting many images +and having to store all of them + the results in your system memory + +# nnUNetPredictor +The new nnUNetPredictor class encapsulates the inferencing code and makes it simple to switch between modes. Your +code can hold a nnUNetPredictor instance and perform prediction on the fly. Previously this was not possible and each +new prediction request resulted in reloading the parameters and reinstantiating the network architecture. Not ideal. + +The nnUNetPredictor must be ininitialized manually! You will want to use the +`predictor.initialize_from_trained_model_folder` function for 99% of use cases! + +New feature: If you do not specify an output folder / output files then the predicted segmentations will be +returned + + +## Recommended nnU-Net default: predict from source files + +tldr: +- loads images on the fly +- performs preprocessing in background workers +- main process focuses only on making predictions +- results are again given to background workers for resampling and (optional) export + +pros: +- best suited for predicting a large number of images +- nicer to your RAM + +cons: +- not ideal when single images are to be predicted +- requires images to be present as files + +Example: +```python + from nnunetv2.paths import nnUNet_results, nnUNet_raw + import torch + from batchgenerators.utilities.file_and_folder_operations import join + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + + # instantiate the nnUNetPredictor + predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=True, + perform_everything_on_device=True, + device=torch.device('cuda', 0), + verbose=False, + verbose_preprocessing=False, + allow_tqdm=True + ) + # initializes the network architecture, loads the checkpoint + predictor.initialize_from_trained_model_folder( + join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), + use_folds=(0,), + checkpoint_name='checkpoint_final.pth', + ) + # variant 1: give input and output folders + predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), + join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'), + save_probabilities=False, overwrite=False, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) +``` + +Instead if giving input and output folders you can also give concrete files. If you give concrete files, there is no +need for the _0000 suffix anymore! This can be useful in situations where you have no control over the filenames! +Remember that the files must be given as 'list of lists' where each entry in the outer list is a case to be predicted +and the inner list contains all the files belonging to that case. There is just one file for datasets with just one +input modality (such as CT) but may be more files for others (such as MRI where there is sometimes T1, T2, Flair etc). +IMPORTANT: the order in which the files for each case are given must match the order of the channels as defined in the +dataset.json! + +If you give files as input, you need to give individual output files as output! + +```python + # variant 2, use list of files as inputs. Note how we use nested lists!!! + indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') + outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres') + predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], + [join(indir, 'liver_142_0000.nii.gz')]], + [join(outdir, 'liver_152.nii.gz'), + join(outdir, 'liver_142.nii.gz')], + save_probabilities=False, overwrite=False, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) +``` + +Did you know? If you do not specify output files, the predicted segmentations will be returned: +```python + # variant 2.5, returns segmentations + indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') + predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], + [join(indir, 'liver_142_0000.nii.gz')]], + None, + save_probabilities=False, overwrite=True, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) +``` + +## Prediction from npy arrays +tldr: +- you give images as a list of npy arrays +- performs preprocessing in background workers +- main process focuses only on making predictions +- results are again given to background workers for resampling and (optional) export + +pros: +- the correct variant for when you have images in RAM already +- well suited for predicting multiple images + +cons: +- uses more ram than the default +- unsuited for large number of images as all images must be held in RAM + +```python + from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) + img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) + img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) + img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) + # we do not set output files so that the segmentations will be returned. You can of course also specify output + # files instead (no return value on that case) + ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4], + None, + [props, props2, props3, props4], + None, 2, save_probabilities=False, + num_processes_segmentation_export=2) +``` + +## Predicting a single npy array + +tldr: +- you give one image as npy array +- everything is done in the main process: preprocessing, prediction, resampling, (export) +- no interlacing, slowest variant! +- ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON + +pros: +- no messing with multiprocessing +- no messing with data iterator blabla + +cons: +- slows as heck, yo +- never the right choice unless you can only give a single image at a time to nnU-Net + +```python + # predict a single numpy array + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')]) + ret = predictor.predict_single_npy_array(img, props, None, None, False) +``` + +## Predicting with a custom data iterator +tldr: +- highly flexible +- not for newbies + +pros: +- you can do everything yourself +- you have all the freedom you want +- really fast if you remember to use multiprocessing in your iterator + +cons: +- you need to do everything yourself +- harder than you might think + +```python + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) + img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) + img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) + img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) + # each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys! + # If 'ofile' is None, the result will be returned instead of written to a file + # the iterator is responsible for performing the correct preprocessing! + # note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread! + # take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays + # (they both use predictor.predict_from_data_iterator) for inspiration! + def my_iterator(list_of_input_arrs, list_of_input_props): + preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose) + for a, p in zip(list_of_input_arrs, list_of_input_props): + data, seg = preprocessor.run_case_npy(a, + None, + p, + predictor.plans_manager, + predictor.configuration_manager, + predictor.dataset_json) + yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properties': p, 'ofile': None} + ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]), + save_probabilities=False, num_processes_segmentation_export=3) +``` \ No newline at end of file diff --git a/docker/template/src/nnunetv2/inference/sliding_window_prediction.py b/docker/template/src/nnunetv2/inference/sliding_window_prediction.py new file mode 100644 index 0000000..a6f8ebb --- /dev/null +++ b/docker/template/src/nnunetv2/inference/sliding_window_prediction.py @@ -0,0 +1,67 @@ +from functools import lru_cache + +import numpy as np +import torch +from typing import Union, Tuple, List +from acvl_utils.cropping_and_padding.padding import pad_nd_image +from scipy.ndimage import gaussian_filter + + +@lru_cache(maxsize=2) +def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8, + value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \ + -> torch.Tensor: + tmp = np.zeros(tile_size) + center_coords = [i // 2 for i in tile_size] + sigmas = [i * sigma_scale for i in tile_size] + tmp[tuple(center_coords)] = 1 + gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) + + gaussian_importance_map = torch.from_numpy(gaussian_importance_map) + + gaussian_importance_map = gaussian_importance_map / torch.max(gaussian_importance_map) * value_scaling_factor + gaussian_importance_map = gaussian_importance_map.type(dtype).to(device) + + # gaussian_importance_map cannot be 0, otherwise we may end up with nans! + gaussian_importance_map[gaussian_importance_map == 0] = torch.min( + gaussian_importance_map[gaussian_importance_map != 0]) + + return gaussian_importance_map + + +def compute_steps_for_sliding_window(image_size: Tuple[int, ...], tile_size: Tuple[int, ...], tile_step_size: float) -> \ + List[List[int]]: + assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size" + assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1' + + # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of + # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46 + target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size] + + num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)] + + steps = [] + for dim in range(len(tile_size)): + # the highest step value for this dimension is + max_step_value = image_size[dim] - tile_size[dim] + if num_steps[dim] > 1: + actual_step_size = max_step_value / (num_steps[dim] - 1) + else: + actual_step_size = 99999999999 # does not matter because there is only one step at 0 + + steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])] + + steps.append(steps_here) + + return steps + + +if __name__ == '__main__': + a = torch.rand((4, 2, 32, 23)) + a_npy = a.numpy() + + a_padded = pad_nd_image(a, new_shape=(48, 27)) + a_npy_padded = pad_nd_image(a_npy, new_shape=(48, 27)) + assert all([i == j for i, j in zip(a_padded.shape, (4, 2, 48, 27))]) + assert all([i == j for i, j in zip(a_npy_padded.shape, (4, 2, 48, 27))]) + assert np.all(a_padded.numpy() == a_npy_padded) diff --git a/docker/template/src/nnunetv2/model_sharing/__init__.py b/docker/template/src/nnunetv2/model_sharing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/model_sharing/entry_points.py b/docker/template/src/nnunetv2/model_sharing/entry_points.py new file mode 100644 index 0000000..1ab7c93 --- /dev/null +++ b/docker/template/src/nnunetv2/model_sharing/entry_points.py @@ -0,0 +1,61 @@ +from nnunetv2.model_sharing.model_download import download_and_install_from_url +from nnunetv2.model_sharing.model_export import export_pretrained_model +from nnunetv2.model_sharing.model_import import install_model_from_zip_file + + +def print_license_warning(): + print('') + print('######################################################') + print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!') + print('######################################################') + print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some " + "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use " + "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!") + print('######################################################') + print('') + + +def download_by_url(): + import argparse + parser = argparse.ArgumentParser( + description="Use this to download pretrained models. This script is intended to download models via url only. " + "CAREFUL: This script will overwrite " + "existing models (if they share the same trainer class and plans as " + "the pretrained model.") + parser.add_argument("url", type=str, help='URL of the pretrained model') + args = parser.parse_args() + url = args.url + download_and_install_from_url(url) + + +def install_from_zip_entry_point(): + import argparse + parser = argparse.ArgumentParser( + description="Use this to install a zip file containing a pretrained model.") + parser.add_argument("zip", type=str, help='zip file') + args = parser.parse_args() + zip = args.zip + install_model_from_zip_file(zip) + + +def export_pretrained_model_entry(): + import argparse + parser = argparse.ArgumentParser( + description="Use this to export a trained model as a zip file.") + parser.add_argument('-d', type=str, required=True, help='Dataset name or id') + parser.add_argument('-o', type=str, required=True, help='Output file name') + parser.add_argument('-c', nargs='+', type=str, required=False, + default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'), + help="List of configuration names") + parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class') + parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier') + parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids') + parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ), + help='Lis tof checkpoint names to export. Default: checkpoint_final.pth') + parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations') + parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well') + args = parser.parse_args() + + export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr, + plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk, + export_crossval_predictions=args.exp_cv_preds) diff --git a/docker/template/src/nnunetv2/model_sharing/model_download.py b/docker/template/src/nnunetv2/model_sharing/model_download.py new file mode 100644 index 0000000..02dac5f --- /dev/null +++ b/docker/template/src/nnunetv2/model_sharing/model_download.py @@ -0,0 +1,47 @@ +from typing import Optional + +import requests +from batchgenerators.utilities.file_and_folder_operations import * +from time import time +from nnunetv2.model_sharing.model_import import install_model_from_zip_file +from nnunetv2.paths import nnUNet_results +from tqdm import tqdm + + +def download_and_install_from_url(url): + assert nnUNet_results is not None, "Cannot install model because network_training_output_dir is not " \ + "set (RESULTS_FOLDER missing as environment variable, see " \ + "Installation instructions)" + print('Downloading pretrained model from url:', url) + import http.client + http.client.HTTPConnection._http_vsn = 10 + http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0' + + import os + home = os.path.expanduser('~') + random_number = int(time() * 1e7) + tempfile = join(home, f'.nnunetdownload_{str(random_number)}') + + try: + download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16) + print("Download finished. Extracting...") + install_model_from_zip_file(tempfile) + print("Done") + except Exception as e: + raise e + finally: + if isfile(tempfile): + os.remove(tempfile) + + +def download_file(url: str, local_filename: str, chunk_size: Optional[int] = 8192 * 16) -> str: + # borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests + # NOTE the stream=True parameter below + with requests.get(url, stream=True, timeout=100) as r: + r.raise_for_status() + with tqdm.wrapattr(open(local_filename, 'wb'), "write", total=int(r.headers.get("Content-Length"))) as f: + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + return local_filename + + diff --git a/docker/template/src/nnunetv2/model_sharing/model_export.py b/docker/template/src/nnunetv2/model_sharing/model_export.py new file mode 100644 index 0000000..51eb455 --- /dev/null +++ b/docker/template/src/nnunetv2/model_sharing/model_export.py @@ -0,0 +1,124 @@ +import zipfile + +from nnunetv2.utilities.file_path_utilities import * + + +def export_pretrained_model(dataset_name_or_id: Union[int, str], output_file: str, + configurations: Tuple[str] = ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + trainer: str = 'nnUNetTrainer', + plans_identifier: str = 'nnUNetPlans', + folds: Tuple[int, ...] = (0, 1, 2, 3, 4), + strict: bool = True, + save_checkpoints: Tuple[str, ...] = ('checkpoint_final.pth',), + export_crossval_predictions: bool = False) -> None: + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + with(zipfile.ZipFile(output_file, 'w', zipfile.ZIP_DEFLATED)) as zipf: + for c in configurations: + print(f"Configuration {c}") + trainer_output_dir = get_output_folder(dataset_name, trainer, plans_identifier, c) + + if not isdir(trainer_output_dir): + if strict: + raise RuntimeError(f"{dataset_name} is missing the trained model of configuration {c}") + else: + continue + + expected_fold_folder = [f"fold_{i}" if i != 'all' else 'fold_all' for i in folds] + assert all([isdir(join(trainer_output_dir, i)) for i in expected_fold_folder]), \ + f"not all requested folds are present; {dataset_name} {c}; requested folds: {folds}" + + assert isfile(join(trainer_output_dir, "plans.json")), f"plans.json missing, {dataset_name} {c}" + + for fold_folder in expected_fold_folder: + print(f"Exporting {fold_folder}") + # debug.json, does not exist yet + source_file = join(trainer_output_dir, fold_folder, "debug.json") + if isfile(source_file): + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + # all requested checkpoints + for chk in save_checkpoints: + source_file = join(trainer_output_dir, fold_folder, chk) + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + # progress.png + source_file = join(trainer_output_dir, fold_folder, "progress.png") + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + # if it exists, network architecture.png + source_file = join(trainer_output_dir, fold_folder, "network_architecture.pdf") + if isfile(source_file): + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + # validation folder with all predicted segmentations etc + if export_crossval_predictions: + source_folder = join(trainer_output_dir, fold_folder, "validation") + files = [i for i in subfiles(source_folder, join=False) if not i.endswith('.npz') and not i.endswith('.pkl')] + for f in files: + zipf.write(join(source_folder, f), os.path.relpath(join(source_folder, f), nnUNet_results)) + # just the summary.json file from the validation + else: + source_file = join(trainer_output_dir, fold_folder, "validation", "summary.json") + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + source_folder = join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}') + if isdir(source_folder): + if export_crossval_predictions: + source_files = subfiles(source_folder, join=True) + else: + source_files = [ + join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}', i) for i in + ['summary.json', 'postprocessing.pkl', 'postprocessing.json'] + ] + for s in source_files: + if isfile(s): + zipf.write(s, os.path.relpath(s, nnUNet_results)) + # plans + source_file = join(trainer_output_dir, "plans.json") + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + # fingerprint + source_file = join(trainer_output_dir, "dataset_fingerprint.json") + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + # dataset + source_file = join(trainer_output_dir, "dataset.json") + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + ensemble_dir = join(nnUNet_results, dataset_name, 'ensembles') + + if not isdir(ensemble_dir): + print("No ensemble directory found for task", dataset_name_or_id) + return + subd = subdirs(ensemble_dir, join=False) + # figure out whether the models in the ensemble are all within the exported models here + for ens in subd: + identifiers, folds = convert_ensemble_folder_to_model_identifiers_and_folds(ens) + ok = True + for i in identifiers: + tr, pl, c = convert_identifier_to_trainer_plans_config(i) + if tr == trainer and pl == plans_identifier and c in configurations: + pass + else: + ok = False + if ok: + print(f'found matching ensemble: {ens}') + source_folder = join(ensemble_dir, ens) + if export_crossval_predictions: + source_files = subfiles(source_folder, join=True) + else: + source_files = [ + join(source_folder, i) for i in + ['summary.json', 'postprocessing.pkl', 'postprocessing.json'] if isfile(join(source_folder, i)) + ] + for s in source_files: + zipf.write(s, os.path.relpath(s, nnUNet_results)) + inference_information_file = join(nnUNet_results, dataset_name, 'inference_information.json') + if isfile(inference_information_file): + zipf.write(inference_information_file, os.path.relpath(inference_information_file, nnUNet_results)) + inference_information_txt_file = join(nnUNet_results, dataset_name, 'inference_information.txt') + if isfile(inference_information_txt_file): + zipf.write(inference_information_txt_file, os.path.relpath(inference_information_txt_file, nnUNet_results)) + print('Done') + + +if __name__ == '__main__': + export_pretrained_model(2, '/home/fabian/temp/dataset2.zip', strict=False, export_crossval_predictions=True, folds=(0, )) diff --git a/docker/template/src/nnunetv2/model_sharing/model_import.py b/docker/template/src/nnunetv2/model_sharing/model_import.py new file mode 100644 index 0000000..0356e90 --- /dev/null +++ b/docker/template/src/nnunetv2/model_sharing/model_import.py @@ -0,0 +1,8 @@ +import zipfile + +from nnunetv2.paths import nnUNet_results + + +def install_model_from_zip_file(zip_file: str): + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(nnUNet_results) \ No newline at end of file diff --git a/docker/template/src/nnunetv2/nets/LightMUNet.py b/docker/template/src/nnunetv2/nets/LightMUNet.py new file mode 100644 index 0000000..a26c029 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/LightMUNet.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.convolutions import Convolution +from monai.networks.blocks.segresnet_block import ResBlock, get_conv_layer, get_upsample_layer +from monai.networks.layers.factories import Dropout +from monai.networks.layers.utils import get_act_layer, get_norm_layer +from monai.utils import UpsampleMode + +from mamba_ssm import Mamba + + +def get_dwconv_layer( + spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, bias: bool = False +): + depth_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, + strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True, groups=in_channels) + point_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, + strides=stride, kernel_size=1, bias=bias, conv_only=True, groups=1) + return torch.nn.Sequential(depth_conv, point_conv) + +class RVMLayer(nn.Module): + def __init__(self, input_dim, output_dim, d_state = 16, d_conv = 4, expand = 2): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.norm = nn.LayerNorm(input_dim) + self.mamba = Mamba( + d_model=input_dim, # Model dimension d_model + d_state=d_state, # SSM state expansion factor + d_conv=d_conv, # Local convolution width + expand=expand, # Block expansion factor + ) + self.proj = nn.Linear(input_dim, output_dim) + self.skip_scale= nn.Parameter(torch.ones(1)) + + def forward(self, x): + if x.dtype == torch.float16: + x = x.type(torch.float32) + B, C = x.shape[:2] + assert C == self.input_dim + n_tokens = x.shape[2:].numel() + img_dims = x.shape[2:] + x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) + x_norm = self.norm(x_flat) + x_mamba = self.mamba(x_norm) + self.skip_scale * x_flat + x_mamba = self.norm(x_mamba) + x_mamba = self.proj(x_mamba) + out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims) + return out + + +def get_rvm_layer( + spatial_dims: int, in_channels: int, out_channels: int, stride: int = 1 +): + mamba_layer = RVMLayer(input_dim=in_channels, output_dim=out_channels) + if stride != 1: + if spatial_dims==2: + return nn.Sequential(mamba_layer, nn.MaxPool2d(kernel_size=stride, stride=stride)) + if spatial_dims==3: + return nn.Sequential(mamba_layer, nn.MaxPool3d(kernel_size=stride, stride=stride)) + return mamba_layer + + +class ResMambaBlock(nn.Module): + + def __init__( + self, + spatial_dims: int, + in_channels: int, + norm: tuple | str, + kernel_size: int = 3, + act: tuple | str = ("RELU", {"inplace": True}), + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions, could be 1, 2 or 3. + in_channels: number of input channels. + norm: feature normalization type and arguments. + kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3. + act: activation type and arguments. Defaults to ``RELU``. + """ + + super().__init__() + + if kernel_size % 2 != 1: + raise AssertionError("kernel_size should be an odd number.") + + self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) + self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) + self.act = get_act_layer(act) + self.conv1 = get_rvm_layer( + spatial_dims, in_channels=in_channels, out_channels=in_channels + ) + self.conv2 = get_rvm_layer( + spatial_dims, in_channels=in_channels, out_channels=in_channels + ) + + def forward(self, x): + identity = x + + x = self.norm1(x) + x = self.act(x) + x = self.conv1(x) + + x = self.norm2(x) + x = self.act(x) + x = self.conv2(x) + + x += identity + + return x + + +class ResUpBlock(nn.Module): + + def __init__( + self, + spatial_dims: int, + in_channels: int, + norm: tuple | str, + kernel_size: int = 3, + act: tuple | str = ("RELU", {"inplace": True}), + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions, could be 1, 2 or 3. + in_channels: number of input channels. + norm: feature normalization type and arguments. + kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3. + act: activation type and arguments. Defaults to ``RELU``. + """ + + super().__init__() + + if kernel_size % 2 != 1: + raise AssertionError("kernel_size should be an odd number.") + + self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) + self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) + self.act = get_act_layer(act) + self.conv = get_dwconv_layer( + spatial_dims, in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size + ) + self.skip_scale= nn.Parameter(torch.ones(1)) + + def forward(self, x): + identity = x + + x = self.norm1(x) + x = self.act(x) + x = self.conv(x) + self.skip_scale * identity + x = self.norm2(x) + x = self.act(x) + return x + + +class LightMUNet(nn.Module): + + def __init__( + self, + spatial_dims: int = 3, + init_filters: int = 8, + in_channels: int = 1, + out_channels: int = 2, + dropout_prob: float | None = None, + act: tuple | str = ("RELU", {"inplace": True}), + norm: tuple | str = ("GROUP", {"num_groups": 8}), + norm_name: str = "", + num_groups: int = 8, + use_conv_final: bool = True, + blocks_down: list = [1, 2, 2, 4], + blocks_up: list = [1, 1, 1], + upsample_mode: UpsampleMode | str = UpsampleMode.NONTRAINABLE, + ): + super().__init__() + + if spatial_dims not in (2, 3): + raise ValueError("`spatial_dims` can only be 2 or 3.") + + self.spatial_dims = spatial_dims + self.init_filters = init_filters + self.in_channels = in_channels + self.blocks_down = blocks_down + self.blocks_up = blocks_up + self.dropout_prob = dropout_prob + self.act = act # input options + self.act_mod = get_act_layer(act) + if norm_name: + if norm_name.lower() != "group": + raise ValueError(f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead.") + norm = ("group", {"num_groups": num_groups}) + self.norm = norm + self.upsample_mode = UpsampleMode(upsample_mode) + self.use_conv_final = use_conv_final + self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters) + self.down_layers = self._make_down_layers() + self.up_layers, self.up_samples = self._make_up_layers() + self.conv_final = self._make_final_conv(out_channels) + + if dropout_prob is not None: + self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob) + + def _make_down_layers(self): + down_layers = nn.ModuleList() + blocks_down, spatial_dims, filters, norm = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm) + for i, item in enumerate(blocks_down): + layer_in_channels = filters * 2**i + downsample_mamba = ( + get_rvm_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2) + if i > 0 + else nn.Identity() + ) + down_layer = nn.Sequential( + downsample_mamba, *[ResMambaBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act) for _ in range(item)] + ) + down_layers.append(down_layer) + return down_layers + + def _make_up_layers(self): + up_layers, up_samples = nn.ModuleList(), nn.ModuleList() + upsample_mode, blocks_up, spatial_dims, filters, norm = ( + self.upsample_mode, + self.blocks_up, + self.spatial_dims, + self.init_filters, + self.norm, + ) + n_up = len(blocks_up) + for i in range(n_up): + sample_in_channels = filters * 2 ** (n_up - i) + up_layers.append( + nn.Sequential( + *[ + ResBlock(spatial_dims, sample_in_channels // 2, norm=norm, act=self.act) + for _ in range(blocks_up[i]) + ] + ) + ) + up_samples.append( + nn.Sequential( + *[ + get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1), + get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode), + ] + ) + ) + return up_layers, up_samples + + def _make_final_conv(self, out_channels: int): + return nn.Sequential( + get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters), + self.act_mod, + get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True), + ) + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: + x = self.convInit(x) + if self.dropout_prob is not None: + x = self.dropout(x) + down_x = [] + + for down in self.down_layers: + x = down(x) + down_x.append(x) + + return x, down_x + + def decode(self, x: torch.Tensor, down_x: list[torch.Tensor]) -> torch.Tensor: + for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): + x = up(x) + down_x[i + 1] + x = upl(x) + + if self.use_conv_final: + x = self.conv_final(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, down_x = self.encode(x) + down_x.reverse() + + x = self.decode(x, down_x) + return x \ No newline at end of file diff --git a/docker/template/src/nnunetv2/nets/UMambaBot.py b/docker/template/src/nnunetv2/nets/UMambaBot.py new file mode 100644 index 0000000..863ee05 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/UMambaBot.py @@ -0,0 +1,269 @@ +import numpy as np +import torch +from torch import nn +from typing import Union, Type, List, Tuple + +from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp +from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder +from dynamic_network_architectures.building_blocks.residual import StackedResidualBlocks +from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder +from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd +from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim + +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from nnunetv2.utilities.network_initialization import InitWeights_He +from mamba_ssm import Mamba + +class UNetResDecoder(nn.Module): + def __init__(self, + encoder: Union[PlainConvEncoder, ResidualEncoder], + num_classes: int, + n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], + deep_supervision, nonlin_first: bool = False): + """ + This class needs the skips of the encoder as input in its forward. + + the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder + are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck + features and the lowest skip as inputs + the decoder has two (three) parts in each stage: + 1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage) + 2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge + 3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits? + :param encoder: + :param num_classes: + :param n_conv_per_stage: + :param deep_supervision: + """ + super().__init__() + self.deep_supervision = deep_supervision + self.encoder = encoder + self.num_classes = num_classes + n_stages_encoder = len(encoder.output_channels) + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) + assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ + "resolution stages - 1 (n_stages in encoder - 1), " \ + "here: %d" % n_stages_encoder + + transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op) + + # we start with the bottleneck and work out way up + stages = [] + transpconvs = [] + seg_layers = [] + for s in range(1, n_stages_encoder): + input_features_below = encoder.output_channels[-s] + input_features_skip = encoder.output_channels[-(s + 1)] + stride_for_transpconv = encoder.strides[-s] + transpconvs.append(transpconv_op( + input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, + bias=encoder.conv_bias + )) + # input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output) + stages.append(StackedResidualBlocks( + n_blocks = n_conv_per_stage[s-1], + conv_op = encoder.conv_op, + input_channels = 2 * input_features_skip, + output_channels = input_features_skip, + kernel_size = encoder.kernel_sizes[-(s + 1)], + initial_stride = 1, + conv_bias = encoder.conv_bias, + norm_op = encoder.norm_op, + norm_op_kwargs = encoder.norm_op_kwargs, + dropout_op = encoder.dropout_op, + dropout_op_kwargs = encoder.dropout_op_kwargs, + nonlin = encoder.nonlin, + nonlin_kwargs = encoder.nonlin_kwargs, + )) + + # we always build the deep supervision outputs so that we can always load parameters. If we don't do this + # then a model trained with deep_supervision=True could not easily be loaded at inference time where + # deep supervision is not needed. It's just a convenience thing + seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True)) + + self.stages = nn.ModuleList(stages) + self.transpconvs = nn.ModuleList(transpconvs) + self.seg_layers = nn.ModuleList(seg_layers) + + def forward(self, skips): + """ + we expect to get the skips in the order they were computed, so the bottleneck should be the last entry + :param skips: + :return: + """ + lres_input = skips[-1] + seg_outputs = [] + for s in range(len(self.stages)): + x = self.transpconvs[s](lres_input) + x = torch.cat((x, skips[-(s+2)]), 1) + x = self.stages[s](x) + if self.deep_supervision: + seg_outputs.append(self.seg_layers[s](x)) + elif s == (len(self.stages) - 1): + seg_outputs.append(self.seg_layers[-1](x)) + lres_input = x + + # invert seg outputs so that the largest segmentation prediction is returned first + seg_outputs = seg_outputs[::-1] + + if not self.deep_supervision: + r = seg_outputs[0] + else: + r = seg_outputs + return r + + def compute_conv_feature_map_size(self, input_size): + """ + IMPORTANT: input_size is the input_size of the encoder! + :param input_size: + :return: + """ + # first we need to compute the skip sizes. Skip bottleneck because all output feature maps of our ops will at + # least have the size of the skip above that (therefore -1) + skip_sizes = [] + for s in range(len(self.encoder.strides) - 1): + skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])]) + input_size = skip_sizes[-1] + # print(skip_sizes) + + assert len(skip_sizes) == len(self.stages) + + # our ops are the other way around, so let's match things up + output = np.int64(0) + for s in range(len(self.stages)): + # print(skip_sizes[-(s+1)], self.encoder.output_channels[-(s+2)]) + # conv blocks + output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)]) + # trans conv + output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64) + # segmentation + if self.deep_supervision or (s == (len(self.stages) - 1)): + output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64) + return output + +class UMambaBot(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...]], + n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + num_classes: int, + n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + deep_supervision: bool = False, + block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, + bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None, + stem_channels: int = None + ): + super().__init__() + n_blocks_per_stage = n_conv_per_stage + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_blocks_per_stage: {n_blocks_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + self.encoder = ResidualEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides, + n_blocks_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op, + dropout_op_kwargs, nonlin, nonlin_kwargs, block, bottleneck_channels, + return_skips=True, disable_default_stem=False, stem_channels=stem_channels) + # layer norm + self.ln = nn.LayerNorm(features_per_stage[-1]) + self.mamba = Mamba( + d_model=features_per_stage[-1], + d_state=16, + d_conv=4, + expand=2, + ) + self.decoder = UNetResDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision) + + def forward(self, x): + skips = self.encoder(x) + middle_feature = skips[-1] + B, C = middle_feature.shape[:2] + n_tokens = middle_feature.shape[2:].numel() + img_dims = middle_feature.shape[2:] + middle_feature_flat = middle_feature.view(B, C, n_tokens).transpose(-1, -2) + middle_feature_flat = self.ln(middle_feature_flat) + out = self.mamba(middle_feature_flat) + out = out.transpose(-1, -2).view(B, C, *img_dims) + skips[-1] = out + + return self.decoder(skips) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + + +def get_umamba_bot_from_plans(plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + num_input_channels: int, + deep_supervision: bool = True): + """ + we may have to change this in the future to accommodate other plans -> network mappings + + num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the + trainer rather than inferring it again from the plans here. + """ + num_stages = len(configuration_manager.conv_kernel_sizes) + + dim = len(configuration_manager.conv_kernel_sizes[0]) + conv_op = convert_dim_to_conv_op(dim) + + label_manager = plans_manager.get_label_manager(dataset_json) + + segmentation_network_class_name = 'UMambaBot' + network_class = UMambaBot + kwargs = { + 'UMambaBot': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + + conv_or_blocks_per_stage = { + 'n_conv_per_stage': configuration_manager.n_conv_per_stage_encoder, + 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder + } + + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, + configuration_manager.unet_max_num_features) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=configuration_manager.conv_kernel_sizes, + strides=configuration_manager.pool_op_kernel_sizes, + num_classes=label_manager.num_segmentation_heads, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + + return model diff --git a/docker/template/src/nnunetv2/nets/UMambaEnc.py b/docker/template/src/nnunetv2/nets/UMambaEnc.py new file mode 100644 index 0000000..a5ec3c4 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/UMambaEnc.py @@ -0,0 +1,414 @@ +import numpy as np +import torch +from torch import nn +from typing import Union, Type, List, Tuple + +from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp +from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder + +from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks +from dynamic_network_architectures.building_blocks.residual import StackedResidualBlocks + +from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op +from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd +from torch.cuda.amp import autocast +from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 +from nnunetv2.utilities.network_initialization import InitWeights_He +from mamba_ssm import Mamba + +class MambaLayer(nn.Module): + def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2): + super().__init__() + self.dim = dim + self.norm = nn.LayerNorm(dim) + self.mamba = Mamba( + d_model=dim, # Model dimension d_model + d_state=d_state, # SSM state expansion factor + d_conv=d_conv, # Local convolution width + expand=expand, # Block expansion factor + ) + + @autocast(enabled=False) + def forward(self, x): + if x.dtype == torch.float16: + x = x.type(torch.float32) + B, C = x.shape[:2] + assert C == self.dim + n_tokens = x.shape[2:].numel() + img_dims = x.shape[2:] + x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) + x_norm = self.norm(x_flat) + x_mamba = self.mamba(x_norm) + out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims) + + return out + + +class ResidualMambaEncoder(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, + bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None, + return_skips: bool = False, + disable_default_stem: bool = False, + stem_channels: int = None, + pool_type: str = 'conv', + stochastic_depth_p: float = 0.0, + squeeze_excitation: bool = False, + squeeze_excitation_reduction_ratio: float = 1. / 16 + ): + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * n_stages + if isinstance(features_per_stage, int): + features_per_stage = [features_per_stage] * n_stages + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(strides, int): + strides = [strides] * n_stages + if bottleneck_channels is None or isinstance(bottleneck_channels, int): + bottleneck_channels = [bottleneck_channels] * n_stages + assert len( + bottleneck_channels) == n_stages, "bottleneck_channels must be None or have as many entries as we have resolution stages (n_stages)" + assert len( + kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" + assert len( + n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len( + features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ + "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" + + pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None + + # build a stem, Todo maybe we need more flexibility for this in the future. For now, if you need a custom + # stem you can just disable the stem and build your own. + # THE STEM DOES NOT DO STRIDE/POOLING IN THIS IMPLEMENTATION + if not disable_default_stem: + if stem_channels is None: + stem_channels = features_per_stage[0] + self.stem = StackedConvBlocks(1, conv_op, input_channels, stem_channels, kernel_sizes[0], 1, conv_bias, + norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs) + input_channels = stem_channels + else: + self.stem = None + + # now build the network + stages = [] + mamba_layers = [] + for s in range(n_stages): + stride_for_conv = strides[s] if pool_op is None else 1 + + stage = StackedResidualBlocks( + n_blocks_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], stride_for_conv, + conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, + block=block, bottleneck_channels=bottleneck_channels[s], stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, + squeeze_excitation_reduction_ratio=squeeze_excitation_reduction_ratio + ) + + if pool_op is not None: + stage = nn.Sequential(pool_op(strides[s]), stage) + + stages.append(stage) + input_channels = features_per_stage[s] + + mamba_layers.append(MambaLayer(input_channels)) + + #self.stages = nn.Sequential(*stages) + self.stages = nn.ModuleList(stages) + self.output_channels = features_per_stage + self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] + self.return_skips = return_skips + + # we store some things that a potential decoder needs + self.conv_op = conv_op + self.norm_op = norm_op + self.norm_op_kwargs = norm_op_kwargs + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + self.dropout_op = dropout_op + self.dropout_op_kwargs = dropout_op_kwargs + self.conv_bias = conv_bias + self.kernel_sizes = kernel_sizes + + self.mamba_layers = nn.ModuleList(mamba_layers) + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + ret = [] + #for s in self.stages: + for s in range(len(self.stages)): + #x = s(x) + x = self.stages[s](x) + x = self.mamba_layers[s](x) + ret.append(x) + if self.return_skips: + return ret + else: + return ret[-1] + + def compute_conv_feature_map_size(self, input_size): + if self.stem is not None: + output = self.stem.compute_conv_feature_map_size(input_size) + else: + output = np.int64(0) + + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(input_size) + input_size = [i // j for i, j in zip(input_size, self.strides[s])] + + return output + +class UNetResDecoder(nn.Module): + def __init__(self, + encoder: Union[PlainConvEncoder, ResidualMambaEncoder], + num_classes: int, + n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], + deep_supervision, nonlin_first: bool = False): + """ + This class needs the skips of the encoder as input in its forward. + + the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder + are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck + features and the lowest skip as inputs + the decoder has two (three) parts in each stage: + 1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage) + 2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge + 3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits? + :param encoder: + :param num_classes: + :param n_conv_per_stage: + :param deep_supervision: + """ + super().__init__() + self.deep_supervision = deep_supervision + self.encoder = encoder + self.num_classes = num_classes + n_stages_encoder = len(encoder.output_channels) + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) + assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ + "resolution stages - 1 (n_stages in encoder - 1), " \ + "here: %d" % n_stages_encoder + + transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op) + + # we start with the bottleneck and work out way up + stages = [] + transpconvs = [] + seg_layers = [] + for s in range(1, n_stages_encoder): + input_features_below = encoder.output_channels[-s] + input_features_skip = encoder.output_channels[-(s + 1)] + stride_for_transpconv = encoder.strides[-s] + transpconvs.append(transpconv_op( + input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, + bias=encoder.conv_bias + )) + # input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output) + stages.append(StackedResidualBlocks( + n_blocks = n_conv_per_stage[s-1], + conv_op = encoder.conv_op, + input_channels = 2 * input_features_skip, + output_channels = input_features_skip, + kernel_size = encoder.kernel_sizes[-(s + 1)], + initial_stride = 1, + conv_bias = encoder.conv_bias, + norm_op = encoder.norm_op, + norm_op_kwargs = encoder.norm_op_kwargs, + dropout_op = encoder.dropout_op, + dropout_op_kwargs = encoder.dropout_op_kwargs, + nonlin = encoder.nonlin, + nonlin_kwargs = encoder.nonlin_kwargs, + )) + # we always build the deep supervision outputs so that we can always load parameters. If we don't do this + # then a model trained with deep_supervision=True could not easily be loaded at inference time where + # deep supervision is not needed. It's just a convenience thing + seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True)) + + self.stages = nn.ModuleList(stages) + self.transpconvs = nn.ModuleList(transpconvs) + self.seg_layers = nn.ModuleList(seg_layers) + + def forward(self, skips): + """ + we expect to get the skips in the order they were computed, so the bottleneck should be the last entry + :param skips: + :return: + """ + lres_input = skips[-1] + seg_outputs = [] + for s in range(len(self.stages)): + x = self.transpconvs[s](lres_input) + x = torch.cat((x, skips[-(s+2)]), 1) + x = self.stages[s](x) + if self.deep_supervision: + seg_outputs.append(self.seg_layers[s](x)) + elif s == (len(self.stages) - 1): + seg_outputs.append(self.seg_layers[-1](x)) + lres_input = x + + # invert seg outputs so that the largest segmentation prediction is returned first + seg_outputs = seg_outputs[::-1] + + if not self.deep_supervision: + r = seg_outputs[0] + else: + r = seg_outputs + return r + + def compute_conv_feature_map_size(self, input_size): + """ + IMPORTANT: input_size is the input_size of the encoder! + :param input_size: + :return: + """ + # first we need to compute the skip sizes. Skip bottleneck because all output feature maps of our ops will at + # least have the size of the skip above that (therefore -1) + skip_sizes = [] + for s in range(len(self.encoder.strides) - 1): + skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])]) + input_size = skip_sizes[-1] + # print(skip_sizes) + + assert len(skip_sizes) == len(self.stages) + + # our ops are the other way around, so let's match things up + output = np.int64(0) + for s in range(len(self.stages)): + # print(skip_sizes[-(s+1)], self.encoder.output_channels[-(s+2)]) + # conv blocks + output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)]) + # trans conv + output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64) + # segmentation + if self.deep_supervision or (s == (len(self.stages) - 1)): + output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64) + return output + +class UMambaEnc(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...]], + n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + num_classes: int, + n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + deep_supervision: bool = False, + block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, + bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None, + stem_channels: int = None + ): + super().__init__() + n_blocks_per_stage = n_conv_per_stage + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_blocks_per_stage: {n_blocks_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + self.encoder = ResidualMambaEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides, + n_blocks_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op, + dropout_op_kwargs, nonlin, nonlin_kwargs, block, bottleneck_channels, + return_skips=True, disable_default_stem=False, stem_channels=stem_channels) + self.decoder = UNetResDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision) + + def forward(self, x): + skips = self.encoder(x) + return self.decoder(skips) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + + +def get_umamba_enc_from_plans(plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + num_input_channels: int, + deep_supervision: bool = True): + """ + we may have to change this in the future to accommodate other plans -> network mappings + + num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the + trainer rather than inferring it again from the plans here. + """ + num_stages = len(configuration_manager.conv_kernel_sizes) + + dim = len(configuration_manager.conv_kernel_sizes[0]) + conv_op = convert_dim_to_conv_op(dim) + + label_manager = plans_manager.get_label_manager(dataset_json) + + segmentation_network_class_name = 'UMambaEnc' + network_class = UMambaEnc + kwargs = { + 'UMambaEnc': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + + conv_or_blocks_per_stage = { + 'n_conv_per_stage': configuration_manager.n_conv_per_stage_encoder, + 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder + } + + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, + configuration_manager.unet_max_num_features) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=configuration_manager.conv_kernel_sizes, + strides=configuration_manager.pool_op_kernel_sizes, + num_classes=label_manager.num_segmentation_heads, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + if network_class == UMambaEnc: + model.apply(init_last_bn_before_add_to_0) + + return model diff --git a/docker/template/src/nnunetv2/nets/__init__.py b/docker/template/src/nnunetv2/nets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/nets/mednextv1/MedNextV1.py b/docker/template/src/nnunetv2/nets/mednextv1/MedNextV1.py new file mode 100644 index 0000000..3f8cb83 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/mednextv1/MedNextV1.py @@ -0,0 +1,432 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from nnunetv2.nets.mednextv1.blocks import * + +class MedNeXt(nn.Module): + + def __init__(self, + in_channels: int, + n_channels: int, + n_classes: int, + exp_r: int = 4, # Expansion ratio as in Swin Transformers + kernel_size: int = 7, # Ofcourse can test kernel_size + enc_kernel_size: int = None, + dec_kernel_size: int = None, + deep_supervision: bool = False, # Can be used to test deep supervision + do_res: bool = False, # Can be used to individually test residual connection + do_res_up_down: bool = False, # Additional 'res' connection on up and down convs + checkpoint_style: bool = None, # Either inside block or outside block + block_counts: list = [2,2,2,2,2,2,2,2,2], # Can be used to test staging ratio: + # [3,3,9,3] in Swin as opposed to [2,2,2,2,2] in nnUNet + norm_type = 'group', + dim = '3d', # 2d or 3d + grn = False + ): + + super().__init__() + + self.do_ds = deep_supervision + assert checkpoint_style in [None, 'outside_block'] + self.inside_block_checkpointing = False + self.outside_block_checkpointing = False + if checkpoint_style == 'outside_block': + self.outside_block_checkpointing = True + assert dim in ['2d', '3d'] + + if kernel_size is not None: + enc_kernel_size = kernel_size + dec_kernel_size = kernel_size + + if dim == '2d': + conv = nn.Conv2d + elif dim == '3d': + conv = nn.Conv3d + + self.stem = conv(in_channels, n_channels, kernel_size=1) + if type(exp_r) == int: + exp_r = [exp_r for i in range(len(block_counts))] + + self.enc_block_0 = nn.Sequential(*[ + MedNeXtBlock( + in_channels=n_channels, + out_channels=n_channels, + exp_r=exp_r[0], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[0])] + ) + + self.down_0 = MedNeXtDownBlock( + in_channels=n_channels, + out_channels=2*n_channels, + exp_r=exp_r[1], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim + ) + + self.enc_block_1 = nn.Sequential(*[ + MedNeXtBlock( + in_channels=n_channels*2, + out_channels=n_channels*2, + exp_r=exp_r[1], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[1])] + ) + + self.down_1 = MedNeXtDownBlock( + in_channels=2*n_channels, + out_channels=4*n_channels, + exp_r=exp_r[2], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.enc_block_2 = nn.Sequential(*[ + MedNeXtBlock( + in_channels=n_channels*4, + out_channels=n_channels*4, + exp_r=exp_r[2], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[2])] + ) + + self.down_2 = MedNeXtDownBlock( + in_channels=4*n_channels, + out_channels=8*n_channels, + exp_r=exp_r[3], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.enc_block_3 = nn.Sequential(*[ + MedNeXtBlock( + in_channels=n_channels*8, + out_channels=n_channels*8, + exp_r=exp_r[3], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[3])] + ) + + self.down_3 = MedNeXtDownBlock( + in_channels=8*n_channels, + out_channels=16*n_channels, + exp_r=exp_r[4], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.bottleneck = nn.Sequential(*[ + MedNeXtBlock( + in_channels=n_channels*16, + out_channels=n_channels*16, + exp_r=exp_r[4], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[4])] + ) + + self.up_3 = MedNeXtUpBlock( + in_channels=16*n_channels, + out_channels=8*n_channels, + exp_r=exp_r[5], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.dec_block_3 = nn.Sequential(*[ + MedNeXtBlock( + in_channels=n_channels*8, + out_channels=n_channels*8, + exp_r=exp_r[5], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[5])] + ) + + self.up_2 = MedNeXtUpBlock( + in_channels=8*n_channels, + out_channels=4*n_channels, + exp_r=exp_r[6], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.dec_block_2 = nn.Sequential(*[ + MedNeXtBlock( + in_channels=n_channels*4, + out_channels=n_channels*4, + exp_r=exp_r[6], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[6])] + ) + + self.up_1 = MedNeXtUpBlock( + in_channels=4*n_channels, + out_channels=2*n_channels, + exp_r=exp_r[7], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.dec_block_1 = nn.Sequential(*[ + MedNeXtBlock( + in_channels=n_channels*2, + out_channels=n_channels*2, + exp_r=exp_r[7], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[7])] + ) + + self.up_0 = MedNeXtUpBlock( + in_channels=2*n_channels, + out_channels=n_channels, + exp_r=exp_r[8], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.dec_block_0 = nn.Sequential(*[ + MedNeXtBlock( + in_channels=n_channels, + out_channels=n_channels, + exp_r=exp_r[8], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[8])] + ) + + self.out_0 = OutBlock(in_channels=n_channels, n_classes=n_classes, dim=dim) + + # Used to fix PyTorch checkpointing bug + self.dummy_tensor = nn.Parameter(torch.tensor([1.]), requires_grad=True) + + if deep_supervision: + self.out_1 = OutBlock(in_channels=n_channels*2, n_classes=n_classes, dim=dim) + self.out_2 = OutBlock(in_channels=n_channels*4, n_classes=n_classes, dim=dim) + self.out_3 = OutBlock(in_channels=n_channels*8, n_classes=n_classes, dim=dim) + self.out_4 = OutBlock(in_channels=n_channels*16, n_classes=n_classes, dim=dim) + + self.block_counts = block_counts + + + def iterative_checkpoint(self, sequential_block, x): + """ + This simply forwards x through each block of the sequential_block while + using gradient_checkpointing. This implementation is designed to bypass + the following issue in PyTorch's gradient checkpointing: + https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/9 + """ + for l in sequential_block: + x = checkpoint.checkpoint(l, x, self.dummy_tensor) + return x + + + def forward(self, x): + + x = self.stem(x) + if self.outside_block_checkpointing: + x_res_0 = self.iterative_checkpoint(self.enc_block_0, x) + x = checkpoint.checkpoint(self.down_0, x_res_0, self.dummy_tensor) + x_res_1 = self.iterative_checkpoint(self.enc_block_1, x) + x = checkpoint.checkpoint(self.down_1, x_res_1, self.dummy_tensor) + x_res_2 = self.iterative_checkpoint(self.enc_block_2, x) + x = checkpoint.checkpoint(self.down_2, x_res_2, self.dummy_tensor) + x_res_3 = self.iterative_checkpoint(self.enc_block_3, x) + x = checkpoint.checkpoint(self.down_3, x_res_3, self.dummy_tensor) + + x = self.iterative_checkpoint(self.bottleneck, x) + if self.do_ds: + x_ds_4 = checkpoint.checkpoint(self.out_4, x, self.dummy_tensor) + + x_up_3 = checkpoint.checkpoint(self.up_3, x, self.dummy_tensor) + dec_x = x_res_3 + x_up_3 + x = self.iterative_checkpoint(self.dec_block_3, dec_x) + if self.do_ds: + x_ds_3 = checkpoint.checkpoint(self.out_3, x, self.dummy_tensor) + del x_res_3, x_up_3 + + x_up_2 = checkpoint.checkpoint(self.up_2, x, self.dummy_tensor) + dec_x = x_res_2 + x_up_2 + x = self.iterative_checkpoint(self.dec_block_2, dec_x) + if self.do_ds: + x_ds_2 = checkpoint.checkpoint(self.out_2, x, self.dummy_tensor) + del x_res_2, x_up_2 + + x_up_1 = checkpoint.checkpoint(self.up_1, x, self.dummy_tensor) + dec_x = x_res_1 + x_up_1 + x = self.iterative_checkpoint(self.dec_block_1, dec_x) + if self.do_ds: + x_ds_1 = checkpoint.checkpoint(self.out_1, x, self.dummy_tensor) + del x_res_1, x_up_1 + + x_up_0 = checkpoint.checkpoint(self.up_0, x, self.dummy_tensor) + dec_x = x_res_0 + x_up_0 + x = self.iterative_checkpoint(self.dec_block_0, dec_x) + del x_res_0, x_up_0, dec_x + + x = checkpoint.checkpoint(self.out_0, x, self.dummy_tensor) + + else: + x_res_0 = self.enc_block_0(x) + x = self.down_0(x_res_0) + x_res_1 = self.enc_block_1(x) + x = self.down_1(x_res_1) + x_res_2 = self.enc_block_2(x) + x = self.down_2(x_res_2) + x_res_3 = self.enc_block_3(x) + x = self.down_3(x_res_3) + + x = self.bottleneck(x) + if self.do_ds: + x_ds_4 = self.out_4(x) + + x_up_3 = self.up_3(x) + dec_x = x_res_3 + x_up_3 + x = self.dec_block_3(dec_x) + + if self.do_ds: + x_ds_3 = self.out_3(x) + del x_res_3, x_up_3 + + x_up_2 = self.up_2(x) + dec_x = x_res_2 + x_up_2 + x = self.dec_block_2(dec_x) + if self.do_ds: + x_ds_2 = self.out_2(x) + del x_res_2, x_up_2 + + x_up_1 = self.up_1(x) + dec_x = x_res_1 + x_up_1 + x = self.dec_block_1(dec_x) + if self.do_ds: + x_ds_1 = self.out_1(x) + del x_res_1, x_up_1 + + x_up_0 = self.up_0(x) + dec_x = x_res_0 + x_up_0 + x = self.dec_block_0(dec_x) + del x_res_0, x_up_0, dec_x + + x = self.out_0(x) + + if self.do_ds: + return [x, x_ds_1, x_ds_2, x_ds_3, x_ds_4] + else: + return x + + +if __name__ == "__main__": + + network = MedNeXt( + in_channels = 1, + n_channels = 32, + n_classes = 13, + exp_r=[2,3,4,4,4,4,4,3,2], # Expansion ratio as in Swin Transformers + # exp_r = 2, + kernel_size=3, # Can test kernel_size + deep_supervision=True, # Can be used to test deep supervision + do_res=True, # Can be used to individually test residual connection + do_res_up_down = True, + # block_counts = [2,2,2,2,2,2,2,2,2], + block_counts = [3,4,8,8,8,8,8,4,3], + checkpoint_style = None, + dim = '2d', + grn=True + + ).cuda() + + # network = MedNeXt_RegularUpDown( + # in_channels = 1, + # n_channels = 32, + # n_classes = 13, + # exp_r=[2,3,4,4,4,4,4,3,2], # Expansion ratio as in Swin Transformers + # kernel_size=3, # Can test kernel_size + # deep_supervision=True, # Can be used to test deep supervision + # do_res=True, # Can be used to individually test residual connection + # block_counts = [2,2,2,2,2,2,2,2,2], + # + # ).cuda() + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + print(count_parameters(network)) + + from fvcore.nn import FlopCountAnalysis + from fvcore.nn import parameter_count_table + + # model = ResTranUnet(img_size=128, in_channels=1, num_classes=14, dummy=False).cuda() + x = torch.zeros((1,1,64,64,64), requires_grad=False).cuda() + flops = FlopCountAnalysis(network, x) + print(flops.total()) + + with torch.no_grad(): + print(network) + x = torch.zeros((1, 1, 128, 128, 128)).cuda() + print(network(x)[0].shape) diff --git a/docker/template/src/nnunetv2/nets/mednextv1/__init__.py b/docker/template/src/nnunetv2/nets/mednextv1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/nets/mednextv1/blocks.py b/docker/template/src/nnunetv2/nets/mednextv1/blocks.py new file mode 100644 index 0000000..f8fd4d7 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/mednextv1/blocks.py @@ -0,0 +1,265 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MedNeXtBlock(nn.Module): + + def __init__(self, + in_channels:int, + out_channels:int, + exp_r:int=4, + kernel_size:int=7, + do_res:int=True, + norm_type:str = 'group', + n_groups:int or None = None, + dim = '3d', + grn = False + ): + + super().__init__() + + self.do_res = do_res + + assert dim in ['2d', '3d'] + self.dim = dim + if self.dim == '2d': + conv = nn.Conv2d + elif self.dim == '3d': + conv = nn.Conv3d + + # First convolution layer with DepthWise Convolutions + self.conv1 = conv( + in_channels = in_channels, + out_channels = in_channels, + kernel_size = kernel_size, + stride = 1, + padding = kernel_size//2, + groups = in_channels if n_groups is None else n_groups, + ) + + # Normalization Layer. GroupNorm is used by default. + if norm_type=='group': + self.norm = nn.GroupNorm( + num_groups=in_channels, + num_channels=in_channels + ) + elif norm_type=='layer': + self.norm = LayerNorm( + normalized_shape=in_channels, + data_format='channels_first' + ) + + # Second convolution (Expansion) layer with Conv3D 1x1x1 + self.conv2 = conv( + in_channels = in_channels, + out_channels = exp_r*in_channels, + kernel_size = 1, + stride = 1, + padding = 0 + ) + + # GeLU activations + self.act = nn.GELU() + + # Third convolution (Compression) layer with Conv3D 1x1x1 + self.conv3 = conv( + in_channels = exp_r*in_channels, + out_channels = out_channels, + kernel_size = 1, + stride = 1, + padding = 0 + ) + + self.grn = grn + if grn: + if dim == '3d': + self.grn_beta = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1,1), requires_grad=True) + self.grn_gamma = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1,1), requires_grad=True) + elif dim == '2d': + self.grn_beta = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1), requires_grad=True) + self.grn_gamma = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1), requires_grad=True) + + + def forward(self, x, dummy_tensor=None): + + x1 = x + x1 = self.conv1(x1) + x1 = self.act(self.conv2(self.norm(x1))) + if self.grn: + # gamma, beta: learnable affine transform parameters + # X: input of shape (N,C,H,W,D) + if self.dim == '3d': + gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True) + elif self.dim == '2d': + gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True) + nx = gx / (gx.mean(dim=1, keepdim=True)+1e-6) + x1 = self.grn_gamma * (x1 * nx) + self.grn_beta + x1 + x1 = self.conv3(x1) + if self.do_res: + x1 = x + x1 + return x1 + + +class MedNeXtDownBlock(MedNeXtBlock): + + def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7, + do_res=False, norm_type = 'group', dim='3d', grn=False): + + super().__init__(in_channels, out_channels, exp_r, kernel_size, + do_res = False, norm_type = norm_type, dim=dim, + grn=grn) + + if dim == '2d': + conv = nn.Conv2d + elif dim == '3d': + conv = nn.Conv3d + self.resample_do_res = do_res + if do_res: + self.res_conv = conv( + in_channels = in_channels, + out_channels = out_channels, + kernel_size = 1, + stride = 2 + ) + + self.conv1 = conv( + in_channels = in_channels, + out_channels = in_channels, + kernel_size = kernel_size, + stride = 2, + padding = kernel_size//2, + groups = in_channels, + ) + + def forward(self, x, dummy_tensor=None): + + x1 = super().forward(x) + + if self.resample_do_res: + res = self.res_conv(x) + x1 = x1 + res + + return x1 + + +class MedNeXtUpBlock(MedNeXtBlock): + + def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7, + do_res=False, norm_type = 'group', dim='3d', grn = False): + super().__init__(in_channels, out_channels, exp_r, kernel_size, + do_res=False, norm_type = norm_type, dim=dim, + grn=grn) + + self.resample_do_res = do_res + + self.dim = dim + if dim == '2d': + conv = nn.ConvTranspose2d + elif dim == '3d': + conv = nn.ConvTranspose3d + if do_res: + self.res_conv = conv( + in_channels = in_channels, + out_channels = out_channels, + kernel_size = 1, + stride = 2 + ) + + self.conv1 = conv( + in_channels = in_channels, + out_channels = in_channels, + kernel_size = kernel_size, + stride = 2, + padding = kernel_size//2, + groups = in_channels, + ) + + + def forward(self, x, dummy_tensor=None): + + x1 = super().forward(x) + # Asymmetry but necessary to match shape + + if self.dim == '2d': + x1 = torch.nn.functional.pad(x1, (1,0,1,0)) + elif self.dim == '3d': + x1 = torch.nn.functional.pad(x1, (1,0,1,0,1,0)) + + if self.resample_do_res: + res = self.res_conv(x) + if self.dim == '2d': + res = torch.nn.functional.pad(res, (1,0,1,0)) + elif self.dim == '3d': + res = torch.nn.functional.pad(res, (1,0,1,0,1,0)) + x1 = x1 + res + + return x1 + + +class OutBlock(nn.Module): + + def __init__(self, in_channels, n_classes, dim): + super().__init__() + + if dim == '2d': + conv = nn.ConvTranspose2d + elif dim == '3d': + conv = nn.ConvTranspose3d + self.conv_out = conv(in_channels, n_classes, kernel_size=1) + + def forward(self, x, dummy_tensor=None): + return self.conv_out(x) + + +class LayerNorm(nn.Module): + """ LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) # beta + self.bias = nn.Parameter(torch.zeros(normalized_shape)) # gamma + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x, dummy_tensor=False): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] + return x + + +if __name__ == "__main__": + + + # network = nnUNeXtBlock(in_channels=12, out_channels=12, do_res=False).cuda() + + # with torch.no_grad(): + # print(network) + # x = torch.zeros((2, 12, 8, 8, 8)).cuda() + # print(network(x).shape) + + # network = DownsampleBlock(in_channels=12, out_channels=24, do_res=False) + + # with torch.no_grad(): + # print(network) + # x = torch.zeros((2, 12, 128, 128, 128)) + # print(network(x).shape) + + network = MedNeXtBlock(in_channels=12, out_channels=12, do_res=True, grn=True, norm_type='group').cuda() + # network = LayerNorm(normalized_shape=12, data_format='channels_last').cuda() + # network.eval() + with torch.no_grad(): + print(network) + x = torch.zeros((2, 12, 64, 64, 64)).cuda() + print(network(x).shape) diff --git a/docker/template/src/nnunetv2/nets/mednextv1/create_mednext_v1.py b/docker/template/src/nnunetv2/nets/mednextv1/create_mednext_v1.py new file mode 100644 index 0000000..84d619c --- /dev/null +++ b/docker/template/src/nnunetv2/nets/mednextv1/create_mednext_v1.py @@ -0,0 +1,83 @@ +from nnunetv2.nets.mednextv1.MedNextV1 import MedNeXt + +def create_mednextv1_small(num_input_channels, num_classes, kernel_size=3, ds=False): + + return MedNeXt( + in_channels = num_input_channels, + n_channels = 32, + n_classes = num_classes, + exp_r=2, + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down = True, + block_counts = [2,2,2,2,2,2,2,2,2] + ) + + +def create_mednextv1_base(num_input_channels, num_classes, kernel_size=3, ds=False): + + return MedNeXt( + in_channels = num_input_channels, + n_channels = 32, + n_classes = num_classes, + exp_r=[2,3,4,4,4,4,4,3,2], + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down = True, + block_counts = [2,2,2,2,2,2,2,2,2] + ) + + +def create_mednextv1_medium(num_input_channels, num_classes, kernel_size=3, ds=False): + + return MedNeXt( + in_channels = num_input_channels, + n_channels = 32, + n_classes = num_classes, + exp_r=[2,3,4,4,4,4,4,3,2], + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down = True, + block_counts = [3,4,4,4,4,4,4,4,3], + checkpoint_style = 'outside_block' + ) + + +def create_mednextv1_large(num_input_channels, num_classes, kernel_size=3, ds=False): + + return MedNeXt( + in_channels = num_input_channels, + n_channels = 32, + n_classes = num_classes, + exp_r=[3,4,8,8,8,8,8,4,3], + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down = True, + block_counts = [3,4,8,8,8,8,8,4,3], + checkpoint_style = 'outside_block' + ) + + +def create_mednext_v1(num_input_channels, num_classes, model_id, kernel_size=3, + deep_supervision=False): + + model_dict = { + 'S': create_mednextv1_small, + 'B': create_mednextv1_base, + 'M': create_mednextv1_medium, + 'L': create_mednextv1_large, + } + + return model_dict[model_id]( + num_input_channels, num_classes, kernel_size, deep_supervision + ) + + +if __name__ == "__main__": + + model = create_mednextv1_large(1, 3, 3, False) + print(model) \ No newline at end of file diff --git a/docker/template/src/nnunetv2/nets/sam_lora_image_encoder.py b/docker/template/src/nnunetv2/nets/sam_lora_image_encoder.py new file mode 100644 index 0000000..6f1331a --- /dev/null +++ b/docker/template/src/nnunetv2/nets/sam_lora_image_encoder.py @@ -0,0 +1,206 @@ +from typing import Mapping, Any + +from nnunetv2.nets.segment_anything import sam_model_registry + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.parameter import Parameter +# from segment_anything.modeling import Sam +from safetensors import safe_open +from safetensors.torch import save_file +# from icecream import ic +from nnunetv2.nets.segment_anything.modeling import Sam +from torch._dynamo import OptimizedModule + +# from segment_anything import build_sam, SamPredictor + +class _LoRA_qkv(nn.Module): + """In Sam it is implemented as + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + """ + + def __init__( + self, + qkv: nn.Module, + linear_a_q: nn.Module, + linear_b_q: nn.Module, + linear_a_v: nn.Module, + linear_b_v: nn.Module, + ): + super().__init__() + self.qkv = qkv + self.linear_a_q = linear_a_q + self.linear_b_q = linear_b_q + self.linear_a_v = linear_a_v + self.linear_b_v = linear_b_v + self.dim = qkv.in_features + self.w_identity = torch.eye(qkv.in_features) + + def forward(self, x): + qkv = self.qkv(x) # B,N,N,3*org_C + new_q = self.linear_b_q(self.linear_a_q(x)) + new_v = self.linear_b_v(self.linear_a_v(x)) + qkv[:, :, :, : self.dim] += new_q + qkv[:, :, :, -self.dim:] += new_v + return qkv + + +class LoRA_Sam(nn.Module): + """Applies low-rank adaptation to a Sam model's image encoder. + + Args: + sam_model: a vision transformer model, see base_vit.py + r: rank of LoRA + num_classes: how many classes the model output, default to the vit model + lora_layer: which layer we apply LoRA. + + Examples:: + >>> model = ViT('B_16_imagenet1k') + >>> lora_model = LoRA_ViT(model, r=4) + >>> preds = lora_model(img) + >>> print(preds.shape) + torch.Size([1, 1000]) + """ + + def __init__(self, sam_model: Sam, r: int, lora_layer=None): + super(LoRA_Sam, self).__init__() + + assert r > 0 + # base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels + # dim = base_vit_dim + if lora_layer: + self.lora_layer = lora_layer + else: + self.lora_layer = list( + range(len(sam_model.image_encoder.blocks))) # Only apply lora to the image encoder by default + # create for storage, then we can init them or load weights + self.w_As = [] # These are linear layers + self.w_Bs = [] + + # lets freeze first + for param in sam_model.image_encoder.parameters(): + param.requires_grad = False + + # Here, we do the surgery + for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks): + # If we only want few lora layer instead of all + if t_layer_i not in self.lora_layer: + continue + w_qkv_linear = blk.attn.qkv + self.dim = w_qkv_linear.in_features + w_a_linear_q = nn.Linear(self.dim, r, bias=False) + w_b_linear_q = nn.Linear(r, self.dim, bias=False) + w_a_linear_v = nn.Linear(self.dim, r, bias=False) + w_b_linear_v = nn.Linear(r, self.dim, bias=False) + self.w_As.append(w_a_linear_q) + self.w_Bs.append(w_b_linear_q) + self.w_As.append(w_a_linear_v) + self.w_Bs.append(w_b_linear_v) + blk.attn.qkv = _LoRA_qkv( + w_qkv_linear, + w_a_linear_q, + w_b_linear_q, + w_a_linear_v, + w_b_linear_v, + ) + self.reset_parameters() + self.sam = sam_model + + def get_lora_parameters(self) -> None: + r"""Only safetensors is supported now. + + pip install safetensor if you do not have one installed yet. + + save both lora and fc parameters. + """ + + # assert filename.endswith(".pt") or filename.endswith('.pth') + + num_layer = len(self.w_As) # actually, it is half + a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)} + b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)} + prompt_encoder_tensors = {} + mask_decoder_tensors = {} + + # save prompt encoder, only `state_dict`, the `named_parameter` is not permitted + if isinstance(self.sam, torch.nn.DataParallel) or isinstance(self.sam, torch.nn.parallel.DistributedDataParallel): + state_dict = self.sam.module.state_dict() + else: + state_dict = self.sam.state_dict() + for key, value in state_dict.items(): + if 'prompt_encoder' in key: + prompt_encoder_tensors[key] = value + if 'mask_decoder' in key: + mask_decoder_tensors[key] = value + + merged_dict = {**a_tensors, **b_tensors, **prompt_encoder_tensors, **mask_decoder_tensors} + # torch.save(merged_dict, filename) + return merged_dict + def load_state_dict(self, state_dict: Mapping[str, Any], + strict: bool = True, assign: bool = False): + self.load_lora_parameters(state_dict) + + def load_lora_parameters(self, state_dict) -> None: + r"""Only safetensors is supported now. + + pip install safetensor if you do not have one installed yet.\ + + load both lora and fc parameters. + """ + + # assert filename.endswith(".pt") or filename.endswith('.pth') + # if torch.cuda.is_available(): + # state_dict = torch.load(filename, map_location='cuda') + # else: + # state_dict = torch.load(filename, map_location='cpu') + # + for i, w_A_linear in enumerate(self.w_As): + saved_key = f"w_a_{i:03d}" + saved_tensor = state_dict[saved_key] + w_A_linear.weight = Parameter(saved_tensor) + + for i, w_B_linear in enumerate(self.w_Bs): + saved_key = f"w_b_{i:03d}" + saved_tensor = state_dict[saved_key] + w_B_linear.weight = Parameter(saved_tensor) + + sam_dict = self.sam.state_dict() + sam_keys = sam_dict.keys() + + # load prompt encoder + prompt_encoder_keys = [k for k in sam_keys if 'prompt_encoder' in k] + prompt_encoder_values = [state_dict[k] for k in prompt_encoder_keys] + prompt_encoder_new_state_dict = {k: v for k, v in zip(prompt_encoder_keys, prompt_encoder_values)} + sam_dict.update(prompt_encoder_new_state_dict) + + # load mask decoder + mask_decoder_keys = [k for k in sam_keys if 'mask_decoder' in k] + mask_decoder_values = [state_dict[k] for k in mask_decoder_keys] + mask_decoder_new_state_dict = {k: v for k, v in zip(mask_decoder_keys, mask_decoder_values)} + sam_dict.update(mask_decoder_new_state_dict) + self.sam.load_state_dict(sam_dict) + + def reset_parameters(self) -> None: + for w_A in self.w_As: + nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5)) + for w_B in self.w_Bs: + nn.init.zeros_(w_B.weight) + + def forward(self, batched_input, multimask_output, image_size): + return self.sam(batched_input, multimask_output, image_size) + + + # def forward(self, x: Tensor) -> Tensor: + # return self.lora_vit(x) + + +if __name__ == "__main__": + sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth") + lora_sam = LoRA_Sam(sam, 4) + lora_sam.sam.image_encoder(torch.rand(size=(1, 3, 1024, 1024))) diff --git a/docker/template/src/nnunetv2/nets/segment_anything/__init__.py b/docker/template/src/nnunetv2/nets/segment_anything/__init__.py new file mode 100644 index 0000000..34383d8 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .predictor import SamPredictor +from .automatic_mask_generator import SamAutomaticMaskGenerator diff --git a/docker/template/src/nnunetv2/nets/segment_anything/automatic_mask_generator.py b/docker/template/src/nnunetv2/nets/segment_anything/automatic_mask_generator.py new file mode 100644 index 0000000..2326497 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/automatic_mask_generator.py @@ -0,0 +1,372 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from typing import Any, Dict, List, Optional, Tuple + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crops_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crops_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros(len(data["boxes"])), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros(len(data["boxes"])), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros(len(boxes)), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/docker/template/src/nnunetv2/nets/segment_anything/build_sam.py b/docker/template/src/nnunetv2/nets/segment_anything/build_sam.py new file mode 100644 index 0000000..9ba5bb3 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/build_sam.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.nn import functional as F +# from icecream import ic + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(image_size, num_classes, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], + checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + num_classes=num_classes, + image_size=image_size, + pixel_mean=pixel_mean, + pixel_std=pixel_std + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(image_size, num_classes, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], + checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + num_classes=num_classes, + image_size=image_size, + pixel_mean=pixel_mean, + pixel_std=pixel_std + ) + + +def build_sam_vit_b(image_size, num_classes, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], + checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + # adopt global attention at [3, 6, 9, 12] transform layer, else window attention layer + checkpoint=checkpoint, + num_classes=num_classes, + image_size=image_size, + pixel_mean=pixel_mean, + pixel_std=pixel_std + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + num_classes, + image_size, + pixel_mean, + pixel_std, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = image_size + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size # Divide by 16 here + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + # num_multimask_outputs=3, + num_multimask_outputs=num_classes, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + # pixel_mean=[123.675, 116.28, 103.53], + # pixel_std=[58.395, 57.12, 57.375], + pixel_mean=pixel_mean, + pixel_std=pixel_std + ) + # sam.eval() + sam.train() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + try: + sam.load_state_dict(state_dict) + except: + new_state_dict = load_from(sam, state_dict, image_size, vit_patch_size, encoder_global_attn_indexes) + sam.load_state_dict(new_state_dict) + return sam, image_embedding_size + + +def load_from(sam, state_dict, image_size, vit_patch_size, encoder_global_attn_indexes): + ega = encoder_global_attn_indexes + sam_dict = sam.state_dict() + except_keys = ['mask_tokens', 'output_hypernetworks_mlps', 'iou_prediction_head'] + new_state_dict = {k: v for k, v in state_dict.items() if + k in sam_dict.keys() and except_keys[0] not in k and except_keys[1] not in k and except_keys[2] not in k} + pos_embed = new_state_dict['image_encoder.pos_embed'] + token_size = int(image_size // vit_patch_size) + if pos_embed.shape[1] != token_size: + # resize pos embedding, which may sacrifice the performance, but I have no better idea + pos_embed = pos_embed.permute(0, 3, 1, 2) # [b, c, h, w] + pos_embed = F.interpolate(pos_embed, (token_size, token_size), mode='bilinear', align_corners=False) + pos_embed = pos_embed.permute(0, 2, 3, 1) # [b, h, w, c] + new_state_dict['image_encoder.pos_embed'] = pos_embed + rel_pos_keys = [k for k in sam_dict.keys() if 'rel_pos' in k] + global_rel_pos_keys = [] + for rel_pos_key in rel_pos_keys: + num = int(rel_pos_key.split('.')[2]) + if num in encoder_global_attn_indexes: + global_rel_pos_keys.append(rel_pos_key) + # global_rel_pos_keys = [k for k in rel_pos_keys if '2' in k or '5' in k or '8' in k or '11' in k] + for k in global_rel_pos_keys: + rel_pos_params = new_state_dict[k] + h, w = rel_pos_params.shape + rel_pos_params = rel_pos_params.unsqueeze(0).unsqueeze(0) + rel_pos_params = F.interpolate(rel_pos_params, (token_size * 2 - 1, w), mode='bilinear', align_corners=False) + new_state_dict[k] = rel_pos_params[0, 0, ...] + sam_dict.update(new_state_dict) + return sam_dict diff --git a/docker/template/src/nnunetv2/nets/segment_anything/modeling/__init__.py b/docker/template/src/nnunetv2/nets/segment_anything/modeling/__init__.py new file mode 100644 index 0000000..38e9062 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/modeling/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/docker/template/src/nnunetv2/nets/segment_anything/modeling/common.py b/docker/template/src/nnunetv2/nets/segment_anything/modeling/common.py new file mode 100644 index 0000000..2bf1523 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/docker/template/src/nnunetv2/nets/segment_anything/modeling/image_encoder.py b/docker/template/src/nnunetv2/nets/segment_anything/modeling/image_encoder.py new file mode 100644 index 0000000..9e382c1 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/modeling/image_encoder.py @@ -0,0 +1,396 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +# from icecream import ic + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) # pre embed: [1, 3, 1024, 1024], post embed: [1, 64, 64, 768] + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) # [b, c, h, w], [1, 256, 64, 64] + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) # [B * num_windows, window_size, window_size, C] + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/docker/template/src/nnunetv2/nets/segment_anything/modeling/mask_decoder.py b/docker/template/src/nnunetv2/nets/segment_anything/modeling/mask_decoder.py new file mode 100644 index 0000000..2f4f184 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/modeling/mask_decoder.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F +# from icecream import ic + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + tranformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + # if multimask_output: + # mask_slice = slice(1, None) + # else: + # mask_slice = slice(0, 1) + # masks = masks[:, mask_slice, :, :] + # iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) # [b, c, token_num] + + b, c, h, w = upscaled_embedding.shape # [h, token_num, h, w] + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # [1, 4, 256, 256], 256 = 4 * 64, the size of image embeddings + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/docker/template/src/nnunetv2/nets/segment_anything/modeling/prompt_encoder.py b/docker/template/src/nnunetv2/nets/segment_anything/modeling/prompt_encoder.py new file mode 100644 index 0000000..5989635 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) # downsample to 1/4 + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/docker/template/src/nnunetv2/nets/segment_anything/modeling/sam.py b/docker/template/src/nnunetv2/nets/segment_anything/modeling/sam.py new file mode 100644 index 0000000..50f5088 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/modeling/sam.py @@ -0,0 +1,208 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F +# from icecream import ic + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + def forward(self, batched_input, multimask_output, image_size): + if isinstance(batched_input, list): + outputs = self.forward_test(batched_input, multimask_output) + else: + outputs = self.forward_train(batched_input, multimask_output, image_size) + return outputs + + def forward_train(self, batched_input, multimask_output, image_size): + input_images = self.preprocess(batched_input) + image_embeddings = self.image_encoder(input_images) + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=None, boxes=None, masks=None + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=image_embeddings, + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=(image_size, image_size), + original_size=(image_size, image_size) + ) + outputs = { + 'masks': masks, + 'iou_predictions': iou_predictions, + 'low_res_logits': low_res_masks + } + return outputs + + @torch.no_grad() + def forward_test( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input promts, + C is determiend by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + diff --git a/docker/template/src/nnunetv2/nets/segment_anything/modeling/transformer.py b/docker/template/src/nnunetv2/nets/segment_anything/modeling/transformer.py new file mode 100644 index 0000000..f1a2812 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attenion layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/docker/template/src/nnunetv2/nets/segment_anything/predictor.py b/docker/template/src/nnunetv2/nets/segment_anything/predictor.py new file mode 100644 index 0000000..5af7540 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/predictor.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from nnunetv2.nets.segment_anything.modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks = masks[0].detach().cpu().numpy() + iou_predictions = iou_predictions[0].detach().cpu().numpy() + low_res_masks = low_res_masks[0].detach().cpu().numpy() + return masks, iou_predictions, low_res_masks + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/docker/template/src/nnunetv2/nets/segment_anything/utils/__init__.py b/docker/template/src/nnunetv2/nets/segment_anything/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/docker/template/src/nnunetv2/nets/segment_anything/utils/amg.py b/docker/template/src/nnunetv2/nets/segment_anything/utils/amg.py new file mode 100644 index 0000000..3a13777 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecesary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/docker/template/src/nnunetv2/nets/segment_anything/utils/onnx.py b/docker/template/src/nnunetv2/nets/segment_anything/utils/onnx.py new file mode 100644 index 0000000..4297b31 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) + masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/docker/template/src/nnunetv2/nets/segment_anything/utils/transforms.py b/docker/template/src/nnunetv2/nets/segment_anything/utils/transforms.py new file mode 100644 index 0000000..3ad3466 --- /dev/null +++ b/docker/template/src/nnunetv2/nets/segment_anything/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/docker/template/src/nnunetv2/paths.py b/docker/template/src/nnunetv2/paths.py new file mode 100644 index 0000000..f2b65bc --- /dev/null +++ b/docker/template/src/nnunetv2/paths.py @@ -0,0 +1,63 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +join = os.path.join +""" +Please make sure your data is organized as follows: + +data/ +├── nnUNet_raw/ +│ ├── Dataset701_AbdomenCT/ +│ │ ├── imagesTr +│ │ │ ├── FLARE22_Tr_0001_0000.nii.gz +│ │ │ ├── FLARE22_Tr_0002_0000.nii.gz +│ │ │ ├── ... +│ │ ├── labelsTr +│ │ │ ├── FLARE22_Tr_0001.nii.gz +│ │ │ ├── FLARE22_Tr_0002.nii.gz +│ │ │ ├── ... +│ │ ├── dataset.json +│ ├── Dataset702_AbdomenMR/ +│ │ ├── imagesTr +│ │ │ ├── amos_0507_0000.nii.gz +│ │ │ ├── amos_0508_0000.nii.gz +│ │ │ ├── ... +│ │ ├── labelsTr +│ │ │ ├── amos_0507.nii.gz +│ │ │ ├── amos_0508.nii.gz +│ │ │ ├── ... +│ │ ├── dataset.json +│ ├── ... +""" +base = join(os.sep.join(__file__.split(os.sep)[:-3]), 'data') +nnUNet_raw = os.environ.get('nnUNet_raw') +nnUNet_preprocessed = os.environ.get('nnUNet_preprocessed') +nnUNet_results = os.environ.get('nnUNet_results') + +if nnUNet_raw is None: + print("nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files " + "are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like " + "this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set " + "this up properly.") + +if nnUNet_preprocessed is None: + print("nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing " + "or training. If this is not intended, please read documentation/setting_up_paths.md for information on how " + "to set this up.") + +if nnUNet_results is None: + print("nnUNet_results is not defined and nnU-Net cannot be used for training or " + "inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information " + "on how to set this up.") diff --git a/docker/template/src/nnunetv2/postprocessing/__init__.py b/docker/template/src/nnunetv2/postprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/postprocessing/remove_connected_components.py b/docker/template/src/nnunetv2/postprocessing/remove_connected_components.py new file mode 100644 index 0000000..c8021ac --- /dev/null +++ b/docker/template/src/nnunetv2/postprocessing/remove_connected_components.py @@ -0,0 +1,361 @@ +import argparse +import multiprocessing +import shutil +from typing import Union, Tuple, List, Callable + +import numpy as np +from acvl_utils.morphology.morphology_helper import remove_all_but_largest_component +from batchgenerators.utilities.file_and_folder_operations import load_json, subfiles, maybe_mkdir_p, join, isfile, \ + isdir, save_pickle, load_pickle, save_json +from nnunetv2.configuration import default_num_processes +from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results +from nnunetv2.evaluation.evaluate_predictions import region_or_label_to_mask, compute_metrics_on_folder, \ + load_summary_json, label_or_region_to_key +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.file_path_utilities import folds_tuple_to_string +from nnunetv2.utilities.json_export import recursive_fix_for_json_export +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + +def remove_all_but_largest_component_from_segmentation(segmentation: np.ndarray, + labels_or_regions: Union[int, Tuple[int, ...], + List[Union[int, Tuple[int, ...]]]], + background_label: int = 0) -> np.ndarray: + mask = np.zeros_like(segmentation, dtype=bool) + if not isinstance(labels_or_regions, list): + labels_or_regions = [labels_or_regions] + for l_or_r in labels_or_regions: + mask |= region_or_label_to_mask(segmentation, l_or_r) + mask_keep = remove_all_but_largest_component(mask) + ret = np.copy(segmentation) # do not modify the input! + ret[mask & ~mask_keep] = background_label + return ret + + +def apply_postprocessing(segmentation: np.ndarray, pp_fns: List[Callable], pp_fn_kwargs: List[dict]): + for fn, kwargs in zip(pp_fns, pp_fn_kwargs): + segmentation = fn(segmentation, **kwargs) + return segmentation + + +def load_postprocess_save(segmentation_file: str, + output_fname: str, + image_reader_writer: BaseReaderWriter, + pp_fns: List[Callable], + pp_fn_kwargs: List[dict]): + seg, props = image_reader_writer.read_seg(segmentation_file) + seg = apply_postprocessing(seg[0], pp_fns, pp_fn_kwargs) + image_reader_writer.write_seg(seg, output_fname, props) + + +def determine_postprocessing(folder_predictions: str, + folder_ref: str, + plans_file_or_dict: Union[str, dict], + dataset_json_file_or_dict: Union[str, dict], + num_processes: int = default_num_processes, + keep_postprocessed_files: bool = True): + """ + Determines nnUNet postprocessing. Its output is a postprocessing.pkl file in folder_predictions which can be + used with apply_postprocessing_to_folder. + + Postprocessed files are saved in folder_predictions/postprocessed. Set + keep_postprocessed_files=False to delete these files after this function is done (temp files will eb created + and deleted regardless). + + If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder + """ + output_folder = join(folder_predictions, 'postprocessed') + + if plans_file_or_dict is None: + expected_plans_file = join(folder_predictions, 'plans.json') + if not isfile(expected_plans_file): + raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans files should have been " + f"created while running nnUNetv2_predict. Sadge.") + plans_file_or_dict = load_json(expected_plans_file) + plans_manager = PlansManager(plans_file_or_dict) + + if dataset_json_file_or_dict is None: + expected_dataset_json_file = join(folder_predictions, 'dataset.json') + if not isfile(expected_dataset_json_file): + raise RuntimeError( + f"Expected plans file missing: {expected_dataset_json_file}. The plans files should have been " + f"created while running nnUNetv2_predict. Sadge.") + dataset_json_file_or_dict = load_json(expected_dataset_json_file) + + if not isinstance(dataset_json_file_or_dict, dict): + dataset_json = load_json(dataset_json_file_or_dict) + else: + dataset_json = dataset_json_file_or_dict + + rw = plans_manager.image_reader_writer_class() + label_manager = plans_manager.get_label_manager(dataset_json) + labels_or_regions = label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels + + predicted_files = subfiles(folder_predictions, suffix=dataset_json['file_ending'], join=False) + ref_files = subfiles(folder_ref, suffix=dataset_json['file_ending'], join=False) + # we should print a warning if not all files from folder_ref are present in folder_predictions + if not all([i in predicted_files for i in ref_files]): + print(f'WARNING: Not all files in folder_ref were found in folder_predictions. Determining postprocessing ' + f'should always be done on the entire dataset!') + + # before we start we should evaluate the imaegs in the source folder + if not isfile(join(folder_predictions, 'summary.json')): + compute_metrics_on_folder(folder_ref, + folder_predictions, + join(folder_predictions, 'summary.json'), + rw, + dataset_json['file_ending'], + labels_or_regions, + label_manager.ignore_label, + num_processes) + + # we save the postprocessing functions in here + pp_fns = [] + pp_fn_kwargs = [] + + # pool party! + with multiprocessing.get_context("spawn").Pool(num_processes) as pool: + # now let's see whether removing all but the largest foreground region improves the scores + output_here = join(output_folder, 'temp', 'keep_largest_fg') + maybe_mkdir_p(output_here) + pp_fn = remove_all_but_largest_component_from_segmentation + kwargs = { + 'labels_or_regions': label_manager.foreground_labels, + } + + pool.starmap( + load_postprocess_save, + zip( + [join(folder_predictions, i) for i in predicted_files], + [join(output_here, i) for i in predicted_files], + [rw] * len(predicted_files), + [[pp_fn]] * len(predicted_files), + [[kwargs]] * len(predicted_files) + ) + ) + compute_metrics_on_folder(folder_ref, + output_here, + join(output_here, 'summary.json'), + rw, + dataset_json['file_ending'], + labels_or_regions, + label_manager.ignore_label, + num_processes) + # now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far + # that if a single class got worse as a result we won't do this. We can change this in the future but right now I + # prefer to do it this way + baseline_results = load_summary_json(join(folder_predictions, 'summary.json')) + pp_results = load_summary_json(join(output_here, 'summary.json')) + do_this = pp_results['foreground_mean']['Dice'] > baseline_results['foreground_mean']['Dice'] + if do_this: + for class_id in pp_results['mean'].keys(): + if pp_results['mean'][class_id]['Dice'] < baseline_results['mean'][class_id]['Dice']: + do_this = False + break + if do_this: + print(f'Results were improved by removing all but the largest foreground region. ' + f'Mean dice before: {round(baseline_results["foreground_mean"]["Dice"], 5)} ' + f'after: {round(pp_results["foreground_mean"]["Dice"], 5)}') + source = output_here + pp_fns.append(pp_fn) + pp_fn_kwargs.append(kwargs) + else: + print(f'Removing all but the largest foreground region did not improve results!') + source = folder_predictions + + # in the old nnU-Net we could just apply all-but-largest component removal to all classes at the same time and + # then evaluate for each class whether this improved results. This is no longer possible because we now support + # region-based predictions and regions can overlap, causing interactions + # in principle the order with which the postprocessing is applied to the regions matter as well and should be + # investigated, but due to some things that I am too lazy to explain right now it's going to be alright (I think) + # to stick to the order in which they are declared in dataset.json (if you want to think about it then think about + # region_class_order) + # 2023_02_06: I hate myself for the comment above. Thanks past me + if len(labels_or_regions) > 1: + for label_or_region in labels_or_regions: + pp_fn = remove_all_but_largest_component_from_segmentation + kwargs = { + 'labels_or_regions': label_or_region, + } + + output_here = join(output_folder, 'temp', 'keep_largest_perClassOrRegion') + maybe_mkdir_p(output_here) + + pool.starmap( + load_postprocess_save, + zip( + [join(source, i) for i in predicted_files], + [join(output_here, i) for i in predicted_files], + [rw] * len(predicted_files), + [[pp_fn]] * len(predicted_files), + [[kwargs]] * len(predicted_files) + ) + ) + compute_metrics_on_folder(folder_ref, + output_here, + join(output_here, 'summary.json'), + rw, + dataset_json['file_ending'], + labels_or_regions, + label_manager.ignore_label, + num_processes) + baseline_results = load_summary_json(join(source, 'summary.json')) + pp_results = load_summary_json(join(output_here, 'summary.json')) + do_this = pp_results['mean'][label_or_region]['Dice'] > baseline_results['mean'][label_or_region]['Dice'] + if do_this: + print(f'Results were improved by removing all but the largest component for {label_or_region}. ' + f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} ' + f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}') + if isdir(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')): + shutil.rmtree(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')) + shutil.move(output_here, join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'), ) + source = join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest') + pp_fns.append(pp_fn) + pp_fn_kwargs.append(kwargs) + else: + print(f'Removing all but the largest component for {label_or_region} did not improve results! ' + f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} ' + f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}') + [shutil.copy(join(source, i), join(output_folder, i)) for i in subfiles(source, join=False)] + save_pickle((pp_fns, pp_fn_kwargs), join(folder_predictions, 'postprocessing.pkl')) + + baseline_results = load_summary_json(join(folder_predictions, 'summary.json')) + final_results = load_summary_json(join(output_folder, 'summary.json')) + tmp = { + 'input_folder': {i: baseline_results[i] for i in ['foreground_mean', 'mean']}, + 'postprocessed': {i: final_results[i] for i in ['foreground_mean', 'mean']}, + 'postprocessing_fns': [i.__name__ for i in pp_fns], + 'postprocessing_kwargs': pp_fn_kwargs, + } + # json is a very annoying little bi###. Can't handle tuples as dict keys. + tmp['input_folder']['mean'] = {label_or_region_to_key(k): tmp['input_folder']['mean'][k] for k in + tmp['input_folder']['mean'].keys()} + tmp['postprocessed']['mean'] = {label_or_region_to_key(k): tmp['postprocessed']['mean'][k] for k in + tmp['postprocessed']['mean'].keys()} + # did I already say that I hate json? "TypeError: Object of type int64 is not JSON serializable" You retarded bro? + recursive_fix_for_json_export(tmp) + save_json(tmp, join(folder_predictions, 'postprocessing.json')) + + shutil.rmtree(join(output_folder, 'temp')) + + if not keep_postprocessed_files: + shutil.rmtree(output_folder) + return pp_fns, pp_fn_kwargs + + +def apply_postprocessing_to_folder(input_folder: str, + output_folder: str, + pp_fns: List[Callable], + pp_fn_kwargs: List[dict], + plans_file_or_dict: Union[str, dict] = None, + dataset_json_file_or_dict: Union[str, dict] = None, + num_processes=8) -> None: + """ + If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder + """ + if plans_file_or_dict is None: + expected_plans_file = join(input_folder, 'plans.json') + if not isfile(expected_plans_file): + raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans file should have been " + f"created while running nnUNetv2_predict. Sadge. If the folder you want to apply " + f"postprocessing to was create from an ensemble then just specify one of the " + f"plans files of the ensemble members in plans_file_or_dict") + plans_file_or_dict = load_json(expected_plans_file) + plans_manager = PlansManager(plans_file_or_dict) + + if dataset_json_file_or_dict is None: + expected_dataset_json_file = join(input_folder, 'dataset.json') + if not isfile(expected_dataset_json_file): + raise RuntimeError( + f"Expected plans file missing: {expected_dataset_json_file}. The dataset.json should have been " + f"copied while running nnUNetv2_predict/nnUNetv2_ensemble. Sadge.") + dataset_json_file_or_dict = load_json(expected_dataset_json_file) + + if not isinstance(dataset_json_file_or_dict, dict): + dataset_json = load_json(dataset_json_file_or_dict) + else: + dataset_json = dataset_json_file_or_dict + + rw = plans_manager.image_reader_writer_class() + + maybe_mkdir_p(output_folder) + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + files = subfiles(input_folder, suffix=dataset_json['file_ending'], join=False) + + _ = p.starmap(load_postprocess_save, + zip( + [join(input_folder, i) for i in files], + [join(output_folder, i) for i in files], + [rw] * len(files), + [pp_fns] * len(files), + [pp_fn_kwargs] * len(files) + ) + ) + + +def entry_point_determine_postprocessing_folder(): + parser = argparse.ArgumentParser('Writes postprocessing.pkl and postprocessing.json in input_folder.') + parser.add_argument('-i', type=str, required=True, help='Input folder') + parser.add_argument('-ref', type=str, required=True, help='Folder with gt labels') + parser.add_argument('-plans_json', type=str, required=False, default=None, + help="plans file to use. If not specified we will look for the plans.json file in the " + "input folder (input_folder/plans.json)") + parser.add_argument('-dataset_json', type=str, required=False, default=None, + help="dataset.json file to use. If not specified we will look for the dataset.json file in the " + "input folder (input_folder/dataset.json)") + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f"number of processes to use. Default: {default_num_processes}") + parser.add_argument('--remove_postprocessed', action='store_true', required=False, + help='set this is you don\'t want to keep the postprocessed files') + + args = parser.parse_args() + determine_postprocessing(args.i, args.ref, args.plans_json, args.dataset_json, args.np, + not args.remove_postprocessed) + + +def entry_point_apply_postprocessing(): + parser = argparse.ArgumentParser('Apples postprocessing specified in pp_pkl_file to input folder.') + parser.add_argument('-i', type=str, required=True, help='Input folder') + parser.add_argument('-o', type=str, required=True, help='Output folder') + parser.add_argument('-pp_pkl_file', type=str, required=True, help='postprocessing.pkl file') + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f"number of processes to use. Default: {default_num_processes}") + parser.add_argument('-plans_json', type=str, required=False, default=None, + help="plans file to use. If not specified we will look for the plans.json file in the " + "input folder (input_folder/plans.json)") + parser.add_argument('-dataset_json', type=str, required=False, default=None, + help="dataset.json file to use. If not specified we will look for the dataset.json file in the " + "input folder (input_folder/dataset.json)") + args = parser.parse_args() + pp_fns, pp_fn_kwargs = load_pickle(args.pp_pkl_file) + apply_postprocessing_to_folder(args.i, args.o, pp_fns, pp_fn_kwargs, args.plans_json, args.dataset_json, args.np) + + +if __name__ == '__main__': + trained_model_folder = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__3d_fullres' + labelstr = join(nnUNet_raw, 'Dataset004_Hippocampus', 'labelsTr') + plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) + dataset_json = load_json(join(trained_model_folder, 'dataset.json')) + folds = (0, 1, 2, 3, 4) + label_manager = plans_manager.get_label_manager(dataset_json) + + merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}') + accumulate_cv_results(trained_model_folder, merged_output_folder, folds, 8, False) + + fns, kwargs = determine_postprocessing(merged_output_folder, labelstr, plans_manager.plans, + dataset_json, 8, keep_postprocessed_files=True) + save_pickle((fns, kwargs), join(trained_model_folder, 'postprocessing.pkl')) + fns, kwargs = load_pickle(join(trained_model_folder, 'postprocessing.pkl')) + + apply_postprocessing_to_folder(merged_output_folder, merged_output_folder + '_pp', fns, kwargs, + plans_manager.plans, dataset_json, + 8) + compute_metrics_on_folder(labelstr, + merged_output_folder + '_pp', + join(merged_output_folder + '_pp', 'summary.json'), + plans_manager.image_reader_writer_class(), + dataset_json['file_ending'], + label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels, + label_manager.ignore_label, + 8) diff --git a/docker/template/src/nnunetv2/preprocessing/__init__.py b/docker/template/src/nnunetv2/preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/preprocessing/cropping/__init__.py b/docker/template/src/nnunetv2/preprocessing/cropping/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/preprocessing/cropping/cropping.py b/docker/template/src/nnunetv2/preprocessing/cropping/cropping.py new file mode 100644 index 0000000..96fe7b7 --- /dev/null +++ b/docker/template/src/nnunetv2/preprocessing/cropping/cropping.py @@ -0,0 +1,51 @@ +import numpy as np + + +# Hello! crop_to_nonzero is the function you are looking for. Ignore the rest. +from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, crop_to_bbox, bounding_box_to_slice + + +def create_nonzero_mask(data): + """ + + :param data: + :return: the mask is True where the data is nonzero + """ + from scipy.ndimage import binary_fill_holes + assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)" + nonzero_mask = np.zeros(data.shape[1:], dtype=bool) + for c in range(data.shape[0]): + this_mask = data[c] != 0 + nonzero_mask = nonzero_mask | this_mask + nonzero_mask = binary_fill_holes(nonzero_mask) + return nonzero_mask + + +def crop_to_nonzero(data, seg=None, nonzero_label=-1): + """ + + :param data: + :param seg: + :param nonzero_label: this will be written into the segmentation map + :return: + """ + nonzero_mask = create_nonzero_mask(data) + bbox = get_bbox_from_mask(nonzero_mask) + + slicer = bounding_box_to_slice(bbox) + data = data[tuple([slice(None), *slicer])] + + if seg is not None: + seg = seg[tuple([slice(None), *slicer])] + + nonzero_mask = nonzero_mask[slicer][None] + if seg is not None: + seg[(seg == 0) & (~nonzero_mask)] = nonzero_label + else: + nonzero_mask = nonzero_mask.astype(np.int8) + nonzero_mask[nonzero_mask == 0] = nonzero_label + nonzero_mask[nonzero_mask > 0] = 0 + seg = nonzero_mask + return data, seg, bbox + + diff --git a/docker/template/src/nnunetv2/preprocessing/normalization/__init__.py b/docker/template/src/nnunetv2/preprocessing/normalization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/preprocessing/normalization/default_normalization_schemes.py b/docker/template/src/nnunetv2/preprocessing/normalization/default_normalization_schemes.py new file mode 100644 index 0000000..3c90a91 --- /dev/null +++ b/docker/template/src/nnunetv2/preprocessing/normalization/default_normalization_schemes.py @@ -0,0 +1,95 @@ +from abc import ABC, abstractmethod +from typing import Type + +import numpy as np +from numpy import number + + +class ImageNormalization(ABC): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None + + def __init__(self, use_mask_for_norm: bool = None, intensityproperties: dict = None, + target_dtype: Type[number] = np.float32): + assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool) + self.use_mask_for_norm = use_mask_for_norm + assert isinstance(intensityproperties, dict) + self.intensityproperties = intensityproperties + self.target_dtype = target_dtype + + @abstractmethod + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + """ + Image and seg must have the same shape. Seg is not always used + """ + pass + + +class ZScoreNormalization(ImageNormalization): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = True + + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + """ + here seg is used to store the zero valued region. The value for that region in the segmentation is -1 by + default. + """ + image = image.astype(self.target_dtype) + if self.use_mask_for_norm is not None and self.use_mask_for_norm: + # negative values in the segmentation encode the 'outside' region (think zero values around the brain as + # in BraTS). We want to run the normalization only in the brain region, so we need to mask the image. + # The default nnU-net sets use_mask_for_norm to True if cropping to the nonzero region substantially + # reduced the image size. + mask = seg >= 0 + mean = image[mask].mean() + std = image[mask].std() + image[mask] = (image[mask] - mean) / (max(std, 1e-8)) + else: + mean = image.mean() + std = image.std() + image = (image - mean) / (max(std, 1e-8)) + return image + + +class CTNormalization(ImageNormalization): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False + + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + assert self.intensityproperties is not None, "CTNormalization requires intensity properties" + image = image.astype(self.target_dtype) + mean_intensity = self.intensityproperties['mean'] + std_intensity = self.intensityproperties['std'] + lower_bound = self.intensityproperties['percentile_00_5'] + upper_bound = self.intensityproperties['percentile_99_5'] + image = np.clip(image, lower_bound, upper_bound) + image = (image - mean_intensity) / max(std_intensity, 1e-8) + return image + + +class NoNormalization(ImageNormalization): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False + + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + return image.astype(self.target_dtype) + + +class RescaleTo01Normalization(ImageNormalization): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False + + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + image = image.astype(self.target_dtype) + image = image - image.min() + image = image / np.clip(image.max(), a_min=1e-8, a_max=None) + return image + + +class RGBTo01Normalization(ImageNormalization): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False + + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + assert image.min() >= 0, "RGB images are uint 8, for whatever reason I found pixel values smaller than 0. " \ + "Your images do not seem to be RGB images" + assert image.max() <= 255, "RGB images are uint 8, for whatever reason I found pixel values greater than 255" \ + ". Your images do not seem to be RGB images" + image = image.astype(self.target_dtype) + image = image / 255. + return image + diff --git a/docker/template/src/nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py b/docker/template/src/nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py new file mode 100644 index 0000000..18f027b --- /dev/null +++ b/docker/template/src/nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py @@ -0,0 +1,24 @@ +from typing import Type + +from nnunetv2.preprocessing.normalization.default_normalization_schemes import CTNormalization, NoNormalization, \ + ZScoreNormalization, RescaleTo01Normalization, RGBTo01Normalization, ImageNormalization + +channel_name_to_normalization_mapping = { + 'CT': CTNormalization, + 'noNorm': NoNormalization, + 'zscore': ZScoreNormalization, + 'rescale_to_0_1': RescaleTo01Normalization, + 'rgb_to_0_1': RGBTo01Normalization +} + + +def get_normalization_scheme(channel_name: str) -> Type[ImageNormalization]: + """ + If we find the channel_name in channel_name_to_normalization_mapping return the corresponding normalization. If it is + not found, use the default (ZScoreNormalization) + """ + norm_scheme = channel_name_to_normalization_mapping.get(channel_name) + if norm_scheme is None: + norm_scheme = ZScoreNormalization + # print('Using %s for image normalization' % norm_scheme.__name__) + return norm_scheme diff --git a/docker/template/src/nnunetv2/preprocessing/normalization/readme.md b/docker/template/src/nnunetv2/preprocessing/normalization/readme.md new file mode 100644 index 0000000..7b54396 --- /dev/null +++ b/docker/template/src/nnunetv2/preprocessing/normalization/readme.md @@ -0,0 +1,5 @@ +The channel_names entry in dataset.json only determines the normlaization scheme. So if you want to use something different +then you can just +- create a new subclass of ImageNormalization +- map your custom channel identifier to that subclass in channel_name_to_normalization_mapping +- run plan and preprocess again with your custom normlaization scheme \ No newline at end of file diff --git a/docker/template/src/nnunetv2/preprocessing/preprocessors/__init__.py b/docker/template/src/nnunetv2/preprocessing/preprocessors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/preprocessing/preprocessors/default_preprocessor.py b/docker/template/src/nnunetv2/preprocessing/preprocessors/default_preprocessor.py new file mode 100644 index 0000000..ae71059 --- /dev/null +++ b/docker/template/src/nnunetv2/preprocessing/preprocessors/default_preprocessor.py @@ -0,0 +1,295 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +import shutil +from time import sleep +from typing import Union, Tuple + +import nnunetv2 +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw +from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero +from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets +from tqdm import tqdm + + +class DefaultPreprocessor(object): + def __init__(self, verbose: bool = True): + self.verbose = verbose + """ + Everything we need is in the plans. Those are given when run() is called + """ + + def run_case_npy(self, data: np.ndarray, seg: Union[np.ndarray, None], properties: dict, + plans_manager: PlansManager, configuration_manager: ConfigurationManager, + dataset_json: Union[dict, str]): + # let's not mess up the inputs! + data = np.copy(data) + if seg is not None: + assert data.shape[1:] == seg.shape[1:], "Shape mismatch between image and segmentation. Please fix your dataset and make use of the --verify_dataset_integrity flag to ensure everything is correct" + seg = np.copy(seg) + + has_seg = seg is not None + + # apply transpose_forward, this also needs to be applied to the spacing! + data = data.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]]) + if seg is not None: + seg = seg.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]]) + original_spacing = [properties['spacing'][i] for i in plans_manager.transpose_forward] + + # crop, remember to store size before cropping! + shape_before_cropping = data.shape[1:] + properties['shape_before_cropping'] = shape_before_cropping + # this command will generate a segmentation. This is important because of the nonzero mask which we may need + data, seg, bbox = crop_to_nonzero(data, seg) + properties['bbox_used_for_cropping'] = bbox + # print(data.shape, seg.shape) + properties['shape_after_cropping_and_before_resampling'] = data.shape[1:] + + # resample + target_spacing = configuration_manager.spacing # this should already be transposed + + if len(target_spacing) < len(data.shape[1:]): + # target spacing for 2d has 2 entries but the data and original_spacing have three because everything is 3d + # in 2d configuration we do not change the spacing between slices + target_spacing = [original_spacing[0]] + target_spacing + new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing) + + # normalize + # normalization MUST happen before resampling or we get huge problems with resampled nonzero masks no + # longer fitting the images perfectly! + data = self._normalize(data, seg, configuration_manager, + plans_manager.foreground_intensity_properties_per_channel) + + # print('current shape', data.shape[1:], 'current_spacing', original_spacing, + # '\ntarget shape', new_shape, 'target_spacing', target_spacing) + old_shape = data.shape[1:] + data = configuration_manager.resampling_fn_data(data, new_shape, original_spacing, target_spacing) + seg = configuration_manager.resampling_fn_seg(seg, new_shape, original_spacing, target_spacing) + if self.verbose: + print(f'old shape: {old_shape}, new_shape: {new_shape}, old_spacing: {original_spacing}, ' + f'new_spacing: {target_spacing}, fn_data: {configuration_manager.resampling_fn_data}') + + # if we have a segmentation, sample foreground locations for oversampling and add those to properties + if has_seg: + # reinstantiating LabelManager for each case is not ideal. We could replace the dataset_json argument + # with a LabelManager Instance in this function because that's all its used for. Dunno what's better. + # LabelManager is pretty light computation-wise. + label_manager = plans_manager.get_label_manager(dataset_json) + collect_for_this = label_manager.foreground_regions if label_manager.has_regions \ + else label_manager.foreground_labels + + # when using the ignore label we want to sample only from annotated regions. Therefore we also need to + # collect samples uniformly from all classes (incl background) + if label_manager.has_ignore_label: + collect_for_this.append(label_manager.all_labels) + + # no need to filter background in regions because it is already filtered in handle_labels + # print(all_labels, regions) + properties['class_locations'] = self._sample_foreground_locations(seg, collect_for_this, + verbose=self.verbose) + seg = self.modify_seg_fn(seg, plans_manager, dataset_json, configuration_manager) + if np.max(seg) > 127: + seg = seg.astype(np.int16) + else: + seg = seg.astype(np.int8) + return data, seg + + def run_case(self, image_files: List[str], seg_file: Union[str, None], plans_manager: PlansManager, + configuration_manager: ConfigurationManager, + dataset_json: Union[dict, str]): + """ + seg file can be none (test cases) + + order of operations is: transpose -> crop -> resample + so when we export we need to run the following order: resample -> crop -> transpose (we could also run + transpose at a different place, but reverting the order of operations done during preprocessing seems cleaner) + """ + if isinstance(dataset_json, str): + dataset_json = load_json(dataset_json) + + rw = plans_manager.image_reader_writer_class() + + # load image(s) + data, data_properties = rw.read_images(image_files) + + # if possible, load seg + if seg_file is not None: + seg, _ = rw.read_seg(seg_file) + else: + seg = None + + data, seg = self.run_case_npy(data, seg, data_properties, plans_manager, configuration_manager, + dataset_json) + return data, seg, data_properties + + def run_case_save(self, output_filename_truncated: str, image_files: List[str], seg_file: str, + plans_manager: PlansManager, configuration_manager: ConfigurationManager, + dataset_json: Union[dict, str]): + data, seg, properties = self.run_case(image_files, seg_file, plans_manager, configuration_manager, dataset_json) + # print('dtypes', data.dtype, seg.dtype) + np.savez_compressed(output_filename_truncated + '.npz', data=data, seg=seg) + write_pickle(properties, output_filename_truncated + '.pkl') + + @staticmethod + def _sample_foreground_locations(seg: np.ndarray, classes_or_regions: Union[List[int], List[Tuple[int, ...]]], + seed: int = 1234, verbose: bool = False): + num_samples = 10000 + min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too + # sparse + rndst = np.random.RandomState(seed) + class_locs = {} + for c in classes_or_regions: + k = c if not isinstance(c, list) else tuple(c) + if isinstance(c, (tuple, list)): + mask = seg == c[0] + for cc in c[1:]: + mask = mask | (seg == cc) + all_locs = np.argwhere(mask) + else: + all_locs = np.argwhere(seg == c) + if len(all_locs) == 0: + class_locs[k] = [] + continue + target_num_samples = min(num_samples, len(all_locs)) + target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage))) + + selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)] + class_locs[k] = selected + if verbose: + print(c, target_num_samples) + return class_locs + + def _normalize(self, data: np.ndarray, seg: np.ndarray, configuration_manager: ConfigurationManager, + foreground_intensity_properties_per_channel: dict) -> np.ndarray: + for c in range(data.shape[0]): + scheme = configuration_manager.normalization_schemes[c] + normalizer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "normalization"), + scheme, + 'nnunetv2.preprocessing.normalization') + if normalizer_class is None: + raise RuntimeError(f'Unable to locate class \'{scheme}\' for normalization') + normalizer = normalizer_class(use_mask_for_norm=configuration_manager.use_mask_for_norm[c], + intensityproperties=foreground_intensity_properties_per_channel[str(c)]) + data[c] = normalizer.run(data[c], seg[0]) + return data + + def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plans_identifier: str, + num_processes: int): + """ + data identifier = configuration name in plans. EZ. + """ + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + + assert isdir(join(nnUNet_raw, dataset_name)), "The requested dataset could not be found in nnUNet_raw" + + plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json') + assert isfile(plans_file), "Expected plans file (%s) not found. Run corresponding nnUNet_plan_experiment " \ + "first." % plans_file + plans = load_json(plans_file) + plans_manager = PlansManager(plans) + configuration_manager = plans_manager.get_configuration(configuration_name) + + if self.verbose: + print(f'Preprocessing the following configuration: {configuration_name}') + if self.verbose: + print(configuration_manager) + + dataset_json_file = join(nnUNet_preprocessed, dataset_name, 'dataset.json') + dataset_json = load_json(dataset_json_file) + + output_directory = join(nnUNet_preprocessed, dataset_name, configuration_manager.data_identifier) + + if isdir(output_directory): + shutil.rmtree(output_directory) + + maybe_mkdir_p(output_directory) + + dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json) + + # identifiers = [os.path.basename(i[:-len(dataset_json['file_ending'])]) for i in seg_fnames] + # output_filenames_truncated = [join(output_directory, i) for i in identifiers] + + # multiprocessing magic. + r = [] + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + for k in dataset.keys(): + r.append(p.starmap_async(self.run_case_save, + ((join(output_directory, k), dataset[k]['images'], dataset[k]['label'], + plans_manager, configuration_manager, + dataset_json),))) + remaining = list(range(len(dataset))) + # p is pretty nifti. If we kill workers they just respawn but don't do any work. + # So we need to store the original pool of workers. + workers = [j for j in p._pool] + with tqdm(desc=None, total=len(dataset), disable=self.verbose) as pbar: + while len(remaining) > 0: + all_alive = all([j.is_alive() for j in workers]) + if not all_alive: + raise RuntimeError('Some background worker is 6 feet under. Yuck. \n' + 'OK jokes aside.\n' + 'One of your background processes is missing. This could be because of ' + 'an error (look for an error message) or because it was killed ' + 'by your OS due to running out of RAM. If you don\'t see ' + 'an error message, out of RAM is likely the problem. In that case ' + 'reducing the number of workers might help') + done = [i for i in remaining if r[i].ready()] + for _ in done: + pbar.update() + remaining = [i for i in remaining if i not in done] + sleep(0.1) + + def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict, + configuration_manager: ConfigurationManager) -> np.ndarray: + # this function will be called at the end of self.run_case. Can be used to change the segmentation + # after resampling. Useful for experimenting with sparse annotations: I can introduce sparsity after resampling + # and don't have to create a new dataset each time I modify my experiments + return seg + + +def example_test_case_preprocessing(): + # (paths to files may need adaptations) + plans_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/nnUNetPlans.json' + dataset_json_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/dataset.json' + input_images = ['/home/isensee/drives/e132-rohdaten/nnUNetv2/Dataset219_AMOS2022_postChallenge_task2/imagesTr/amos_0600_0000.nii.gz', ] # if you only have one channel, you still need a list: ['case000_0000.nii.gz'] + + configuration = '3d_fullres' + pp = DefaultPreprocessor() + + # _ because this position would be the segmentation if seg_file was not None (training case) + # even if you have the segmentation, don't put the file there! You should always evaluate in the original + # resolution. What comes out of the preprocessor might have been resampled to some other image resolution (as + # specified by plans) + plans_manager = PlansManager(plans_file) + data, _, properties = pp.run_case(input_images, seg_file=None, plans_manager=plans_manager, + configuration_manager=plans_manager.get_configuration(configuration), + dataset_json=dataset_json_file) + + # voila. Now plug data into your prediction function of choice. We of course recommend nnU-Net's default (TODO) + return data + + +if __name__ == '__main__': + example_test_case_preprocessing() + # pp = DefaultPreprocessor() + # pp.run(2, '2d', 'nnUNetPlans', 8) + + ########################################################################################################### + # how to process a test cases? This is an example: + # example_test_case_preprocessing() diff --git a/docker/template/src/nnunetv2/preprocessing/resampling/__init__.py b/docker/template/src/nnunetv2/preprocessing/resampling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/preprocessing/resampling/default_resampling.py b/docker/template/src/nnunetv2/preprocessing/resampling/default_resampling.py new file mode 100644 index 0000000..e83f614 --- /dev/null +++ b/docker/template/src/nnunetv2/preprocessing/resampling/default_resampling.py @@ -0,0 +1,216 @@ +from collections import OrderedDict +from typing import Union, Tuple, List + +import numpy as np +import pandas as pd +import torch +from batchgenerators.augmentations.utils import resize_segmentation +from scipy.ndimage.interpolation import map_coordinates +from skimage.transform import resize +from nnunetv2.configuration import ANISO_THRESHOLD + + +def get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD): + do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold + return do_separate_z + + +def get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]): + axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic + return axis + + +def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray], + old_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray: + assert len(old_spacing) == len(old_shape) + assert len(old_shape) == len(new_spacing) + new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)]) + return new_shape + + +def resample_data_or_seg_to_spacing(data: np.ndarray, + current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, + order: int = 3, order_z: int = 0, + force_separate_z: Union[bool, None] = False, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD): + if force_separate_z is not None: + do_separate_z = force_separate_z + if force_separate_z: + axis = get_lowres_axis(current_spacing) + else: + axis = None + else: + if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold): + do_separate_z = True + axis = get_lowres_axis(current_spacing) + elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold): + do_separate_z = True + axis = get_lowres_axis(new_spacing) + else: + do_separate_z = False + axis = None + + if axis is not None: + if len(axis) == 3: + # every axis has the same spacing, this should never happen, why is this code here? + do_separate_z = False + elif len(axis) == 2: + # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample + # separately in the out of plane axis + do_separate_z = False + else: + pass + + if data is not None: + assert data.ndim == 4, "data must be c x y z" + + shape = np.array(data[0].shape) + new_shape = compute_new_shape(shape[1:], current_spacing, new_spacing) + + data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z) + return data_reshaped + + +def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray], + new_shape: Union[Tuple[int, ...], List[int], np.ndarray], + current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, + order: int = 3, order_z: int = 0, + force_separate_z: Union[bool, None] = False, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD): + """ + needed for segmentation export. Stupid, I know. Maybe we can fix that with Leos new resampling functions + """ + if isinstance(data, torch.Tensor): + data = data.cpu().numpy() + if force_separate_z is not None: + do_separate_z = force_separate_z + if force_separate_z: + axis = get_lowres_axis(current_spacing) + else: + axis = None + else: + if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold): + do_separate_z = True + axis = get_lowres_axis(current_spacing) + elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold): + do_separate_z = True + axis = get_lowres_axis(new_spacing) + else: + do_separate_z = False + axis = None + + if axis is not None: + if len(axis) == 3: + # every axis has the same spacing, this should never happen, why is this code here? + do_separate_z = False + elif len(axis) == 2: + # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample + # separately in the out of plane axis + do_separate_z = False + else: + pass + + if data is not None: + assert data.ndim == 4, "data must be c x y z" + + data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z) + return data_reshaped + + +def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, axis: Union[None, int] = None, order: int = 3, + do_separate_z: bool = False, order_z: int = 0): + """ + separate_z=True will resample with order 0 along z + :param data: + :param new_shape: + :param is_seg: + :param axis: + :param order: + :param do_separate_z: + :param order_z: only applies if do_separate_z is True + :return: + """ + assert data.ndim == 4, "data must be (c, x, y, z)" + assert len(new_shape) == data.ndim - 1 + + if is_seg: + resize_fn = resize_segmentation + kwargs = OrderedDict() + else: + resize_fn = resize + kwargs = {'mode': 'edge', 'anti_aliasing': False} + dtype_data = data.dtype + shape = np.array(data[0].shape) + new_shape = np.array(new_shape) + if np.any(shape != new_shape): + data = data.astype(float) + if do_separate_z: + # print("separate z, order in z is", order_z, "order inplane is", order) + assert len(axis) == 1, "only one anisotropic axis supported" + axis = axis[0] + if axis == 0: + new_shape_2d = new_shape[1:] + elif axis == 1: + new_shape_2d = new_shape[[0, 2]] + else: + new_shape_2d = new_shape[:-1] + + reshaped_final_data = [] + for c in range(data.shape[0]): + reshaped_data = [] + for slice_id in range(shape[axis]): + if axis == 0: + reshaped_data.append(resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)) + elif axis == 1: + reshaped_data.append(resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)) + else: + reshaped_data.append(resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)) + reshaped_data = np.stack(reshaped_data, axis) + if shape[axis] != new_shape[axis]: + + # The following few lines are blatantly copied and modified from sklearn's resize() + rows, cols, dim = new_shape[0], new_shape[1], new_shape[2] + orig_rows, orig_cols, orig_dim = reshaped_data.shape + + row_scale = float(orig_rows) / rows + col_scale = float(orig_cols) / cols + dim_scale = float(orig_dim) / dim + + map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim] + map_rows = row_scale * (map_rows + 0.5) - 0.5 + map_cols = col_scale * (map_cols + 0.5) - 0.5 + map_dims = dim_scale * (map_dims + 0.5) - 0.5 + + coord_map = np.array([map_rows, map_cols, map_dims]) + if not is_seg or order_z == 0: + reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, order=order_z, + mode='nearest')[None]) + else: + unique_labels = np.sort(pd.unique(reshaped_data.ravel())) # np.unique(reshaped_data) + reshaped = np.zeros(new_shape, dtype=dtype_data) + + for i, cl in enumerate(unique_labels): + reshaped_multihot = np.round( + map_coordinates((reshaped_data == cl).astype(float), coord_map, order=order_z, + mode='nearest')) + reshaped[reshaped_multihot > 0.5] = cl + reshaped_final_data.append(reshaped[None]) + else: + reshaped_final_data.append(reshaped_data[None]) + reshaped_final_data = np.vstack(reshaped_final_data) + else: + # print("no separate z, order", order) + reshaped = [] + for c in range(data.shape[0]): + reshaped.append(resize_fn(data[c], new_shape, order, **kwargs)[None]) + reshaped_final_data = np.vstack(reshaped) + return reshaped_final_data.astype(dtype_data) + else: + # print("no resampling necessary") + return data diff --git a/docker/template/src/nnunetv2/preprocessing/resampling/utils.py b/docker/template/src/nnunetv2/preprocessing/resampling/utils.py new file mode 100644 index 0000000..0bff719 --- /dev/null +++ b/docker/template/src/nnunetv2/preprocessing/resampling/utils.py @@ -0,0 +1,15 @@ +from typing import Callable + +import nnunetv2 +from batchgenerators.utilities.file_and_folder_operations import join +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class + + +def recursive_find_resampling_fn_by_name(resampling_fn: str) -> Callable: + ret = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "resampling"), resampling_fn, + 'nnunetv2.preprocessing.resampling') + if ret is None: + raise RuntimeError("Unable to find resampling function named '%s'. Please make sure this fn is located in the " + "nnunetv2.preprocessing.resampling module." % resampling_fn) + else: + return ret diff --git a/docker/template/src/nnunetv2/run/__init__.py b/docker/template/src/nnunetv2/run/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/run/load_pretrained_weights.py b/docker/template/src/nnunetv2/run/load_pretrained_weights.py new file mode 100644 index 0000000..bb26e41 --- /dev/null +++ b/docker/template/src/nnunetv2/run/load_pretrained_weights.py @@ -0,0 +1,66 @@ +import torch +from torch._dynamo import OptimizedModule +from torch.nn.parallel import DistributedDataParallel as DDP + + +def load_pretrained_weights(network, fname, verbose=False): + """ + Transfers all weights between matching keys in state_dicts. matching is done by name and we only transfer if the + shape is also the same. Segmentation layers (the 1x1(x1) layers that produce the segmentation maps) + identified by keys ending with '.seg_layers') are not transferred! + + If the pretrained weights were obtained with a training outside nnU-Net and DDP or torch.optimize was used, + you need to change the keys of the pretrained state_dict. DDP adds a 'module.' prefix and torch.optim adds + '_orig_mod'. You DO NOT need to worry about this if pretraining was done with nnU-Net as + nnUNetTrainer.save_checkpoint takes care of that! + + """ + saved_model = torch.load(fname) + pretrained_dict = saved_model['network_weights'] + + skip_strings_in_pretrained = [ + '.seg_layers.', + ] + + if isinstance(network, DDP): + mod = network.module + else: + mod = network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + model_dict = mod.state_dict() + # verify that all but the segmentation layers have the same shape + for key, _ in model_dict.items(): + if all([i not in key for i in skip_strings_in_pretrained]): + assert key in pretrained_dict, \ + f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \ + f"compatible with your network." + assert model_dict[key].shape == pretrained_dict[key].shape, \ + f"The shape of the parameters of key {key} is not the same. Pretrained model: " \ + f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \ + f"does not seem to be compatible with your network." + + # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained + # encoders. Not supported by this function though (see assertions above) + + # commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do' + # pretrained_dict = {'module.' + k if is_ddp else k: v + # for k, v in pretrained_dict.items() + # if (('module.' + k if is_ddp else k) in model_dict) and + # all([i not in k for i in skip_strings_in_pretrained])} + + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])} + + model_dict.update(pretrained_dict) + + print("################### Loading pretrained weights from file ", fname, '###################') + if verbose: + print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:") + for key, value in pretrained_dict.items(): + print(key, 'shape', value.shape) + print("################### Done ###################") + mod.load_state_dict(model_dict) + + diff --git a/docker/template/src/nnunetv2/run/run_training.py b/docker/template/src/nnunetv2/run/run_training.py new file mode 100644 index 0000000..93dd759 --- /dev/null +++ b/docker/template/src/nnunetv2/run/run_training.py @@ -0,0 +1,274 @@ +import os +import socket +from typing import Union, Optional + +import nnunetv2 +import torch.cuda +import torch.distributed as dist +import torch.multiprocessing as mp +from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json +from nnunetv2.paths import nnUNet_preprocessed +from nnunetv2.run.load_pretrained_weights import load_pretrained_weights +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from torch.backends import cudnn + + +def find_free_network_port() -> int: + """Finds a free port on localhost. + + It is useful in single-node training when we don't want to connect to a real main node but have to set the + `MASTER_PORT` environment variable. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + + +def get_trainer_from_args(dataset_name_or_id: Union[int, str], + configuration: str, + fold: int, + trainer_name: str = 'nnUNetTrainer', + plans_identifier: str = 'nnUNetPlans', + use_compressed: bool = False, + device: torch.device = torch.device('cuda')): + # load nnunet class and do sanity checks + nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), + trainer_name, 'nnunetv2.training.nnUNetTrainer') + if nnunet_trainer is None: + raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in ' + f'nnunetv2.training.nnUNetTrainer (' + f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere ' + f'else, please move it there.') + assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \ + 'nnUNetTrainer' + + # handle dataset input. If it's an ID we need to convert to int from string + if dataset_name_or_id.startswith('Dataset'): + pass + else: + try: + dataset_name_or_id = int(dataset_name_or_id) + except ValueError: + raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern ' + f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your ' + f'input: {dataset_name_or_id}') + + # initialize nnunet trainer + preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id)) + plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json') + plans = load_json(plans_file) + dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json')) + nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold, + dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device) + return nnunet_trainer + + +def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool, + pretrained_weights_file: str = None): + if continue_training and pretrained_weights_file is not None: + raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only ' + 'be used at the beginning of the training.') + if continue_training: + expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth') + if not isfile(expected_checkpoint_file): + expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth') + # special case where --c is used to run a previously aborted validation + if not isfile(expected_checkpoint_file): + expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth') + if not isfile(expected_checkpoint_file): + print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to " + f"continue from. Starting a new training...") + expected_checkpoint_file = None + elif validation_only: + expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth') + if not isfile(expected_checkpoint_file): + raise RuntimeError(f"Cannot run validation because the training is not finished yet!") + else: + if pretrained_weights_file is not None: + if not nnunet_trainer.was_initialized: + nnunet_trainer.initialize() + load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True) + expected_checkpoint_file = None + + if expected_checkpoint_file is not None: + nnunet_trainer.load_checkpoint(expected_checkpoint_file) + + +def setup_ddp(rank, world_size): + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup_ddp(): + dist.destroy_process_group() + + +def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, use_compressed, disable_checkpointing, c, val, + pretrained_weights, npz, val_with_best, world_size): + setup_ddp(rank, world_size) + torch.cuda.set_device(torch.device('cuda', dist.get_rank())) + + nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p, + use_compressed) + + if disable_checkpointing: + nnunet_trainer.disable_checkpointing = disable_checkpointing + + assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.' + + maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights) + + if torch.cuda.is_available(): + cudnn.deterministic = False + cudnn.benchmark = True + + if not val: + nnunet_trainer.run_training() + + if val_with_best: + nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth')) + nnunet_trainer.perform_actual_validation(npz) + cleanup_ddp() + + +def run_training(dataset_name_or_id: Union[str, int], + configuration: str, fold: Union[int, str], + trainer_class_name: str = 'nnUNetTrainer', + plans_identifier: str = 'nnUNetPlans', + pretrained_weights: Optional[str] = None, + num_gpus: int = 1, + use_compressed_data: bool = False, + export_validation_probabilities: bool = False, + continue_training: bool = False, + only_run_validation: bool = False, + disable_checkpointing: bool = False, + val_with_best: bool = False, + device: torch.device = torch.device('cuda')): + if isinstance(fold, str): + if fold != 'all': + try: + fold = int(fold) + except ValueError as e: + print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!') + raise e + + if val_with_best: + assert not disable_checkpointing, '--val_best is not compatible with --disable_checkpointing' + + if num_gpus > 1: + assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}" + + os.environ['MASTER_ADDR'] = 'localhost' + if 'MASTER_PORT' not in os.environ.keys(): + port = str(find_free_network_port()) + print(f"using port {port}") + os.environ['MASTER_PORT'] = port # str(port) + + mp.spawn(run_ddp, + args=( + dataset_name_or_id, + configuration, + fold, + trainer_class_name, + plans_identifier, + use_compressed_data, + disable_checkpointing, + continue_training, + only_run_validation, + pretrained_weights, + export_validation_probabilities, + val_with_best, + num_gpus), + nprocs=num_gpus, + join=True) + else: + nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name, + plans_identifier, use_compressed_data, device=device) + + if disable_checkpointing: + nnunet_trainer.disable_checkpointing = disable_checkpointing + + assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.' + + maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) + + if torch.cuda.is_available(): + cudnn.deterministic = False + cudnn.benchmark = True + + if not only_run_validation: + nnunet_trainer.run_training() + + if val_with_best: + nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth')) + nnunet_trainer.perform_actual_validation(export_validation_probabilities) + + +def run_training_entry(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('dataset_name_or_id', type=str, + help="Dataset name or ID to train with") + parser.add_argument('configuration', type=str, + help="Configuration that should be trained") + parser.add_argument('fold', type=str, + help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.') + parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', + help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer') + parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', + help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans') + parser.add_argument('-pretrained_weights', type=str, required=False, default=None, + help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only ' + 'be used when actually training. Beta. Use with caution.') + parser.add_argument('-num_gpus', type=int, default=1, required=False, + help='Specify the number of GPUs to use for training') + parser.add_argument("--use_compressed", default=False, action="store_true", required=False, + help="[OPTIONAL] If you set this flag the training cases will not be decompressed. Reading compressed " + "data is much more CPU and (potentially) RAM intensive and should only be used if you " + "know what you are doing") + parser.add_argument('--npz', action='store_true', required=False, + help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted ' + 'segmentations). Needed for finding the best ensemble.') + parser.add_argument('--c', action='store_true', required=False, + help='[OPTIONAL] Continue training from latest checkpoint') + parser.add_argument('--val', action='store_true', required=False, + help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.') + parser.add_argument('--val_best', action='store_true', required=False, + help='[OPTIONAL] If set, the validation will be performed with the checkpoint_best instead ' + 'of checkpoint_final. NOT COMPATIBLE with --disable_checkpointing! ' + 'WARNING: This will use the same \'validation\' folder as the regular validation ' + 'with no way of distinguishing the two!') + parser.add_argument('--disable_checkpointing', action='store_true', required=False, + help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and ' + 'you dont want to flood your hard drive with checkpoints.') + parser.add_argument('-device', type=str, default='cuda', required=False, + help="Use this to set the device the training should run with. Available options are 'cuda' " + "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " + "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!") + args = parser.parse_args() + + assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' + if args.device == 'cpu': + # let's allow torch to use hella threads + import multiprocessing + torch.set_num_threads(multiprocessing.cpu_count()) + device = torch.device('cpu') + elif args.device == 'cuda': + # multithreading in torch doesn't help nnU-Net if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device('cuda') + else: + device = torch.device('mps') + + run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights, + args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best, + device=device) + + +if __name__ == '__main__': + run_training_entry() diff --git a/docker/template/src/nnunetv2/tests/__init__.py b/docker/template/src/nnunetv2/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/tests/integration_tests/__init__.py b/docker/template/src/nnunetv2/tests/integration_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/tests/integration_tests/add_lowres_and_cascade.py b/docker/template/src/nnunetv2/tests/integration_tests/add_lowres_and_cascade.py new file mode 100644 index 0000000..a1b4df1 --- /dev/null +++ b/docker/template/src/nnunetv2/tests/integration_tests/add_lowres_and_cascade.py @@ -0,0 +1,33 @@ +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.paths import nnUNet_preprocessed +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('-d', nargs='+', type=int, help='List of dataset ids') + args = parser.parse_args() + + for d in args.d: + dataset_name = maybe_convert_to_dataset_name(d) + plans = load_json(join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json')) + plans['configurations']['3d_lowres'] = { + "data_identifier": "nnUNetPlans_3d_lowres", # do not be a dumbo and forget this. I was a dumbo. And I paid dearly with ~10 min debugging time + 'inherits_from': '3d_fullres', + "patch_size": [20, 28, 20], + "median_image_size_in_voxels": [18.0, 25.0, 18.0], + "spacing": [2.0, 2.0, 2.0], + "n_conv_per_stage_encoder": [2, 2, 2], + "n_conv_per_stage_decoder": [2, 2], + "num_pool_per_axis": [2, 2, 2], + "pool_op_kernel_sizes": [[1, 1, 1], [2, 2, 2], [2, 2, 2]], + "conv_kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3]], + "next_stage": "3d_cascade_fullres" + } + plans['configurations']['3d_cascade_fullres'] = { + 'inherits_from': '3d_fullres', + "previous_stage": "3d_lowres" + } + save_json(plans, join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json'), sort_keys=False) \ No newline at end of file diff --git a/docker/template/src/nnunetv2/tests/integration_tests/cleanup_integration_test.py b/docker/template/src/nnunetv2/tests/integration_tests/cleanup_integration_test.py new file mode 100644 index 0000000..c9fca95 --- /dev/null +++ b/docker/template/src/nnunetv2/tests/integration_tests/cleanup_integration_test.py @@ -0,0 +1,19 @@ +import shutil + +from batchgenerators.utilities.file_and_folder_operations import isdir, join + +from nnunetv2.paths import nnUNet_raw, nnUNet_results, nnUNet_preprocessed + +if __name__ == '__main__': + # deletes everything! + dataset_names = [ + 'Dataset996_IntegrationTest_Hippocampus_regions_ignore', + 'Dataset997_IntegrationTest_Hippocampus_regions', + 'Dataset998_IntegrationTest_Hippocampus_ignore', + 'Dataset999_IntegrationTest_Hippocampus', + ] + for fld in [nnUNet_raw, nnUNet_preprocessed, nnUNet_results]: + for d in dataset_names: + if isdir(join(fld, d)): + shutil.rmtree(join(fld, d)) + diff --git a/docker/template/src/nnunetv2/tests/integration_tests/lsf_commands.sh b/docker/template/src/nnunetv2/tests/integration_tests/lsf_commands.sh new file mode 100644 index 0000000..3888c1a --- /dev/null +++ b/docker/template/src/nnunetv2/tests/integration_tests/lsf_commands.sh @@ -0,0 +1,10 @@ +bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 996" +bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 997" +bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 998" +bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 999" + + +bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 996" +bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 997" +bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 998" +bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 999" diff --git a/docker/template/src/nnunetv2/tests/integration_tests/prepare_integration_tests.sh b/docker/template/src/nnunetv2/tests/integration_tests/prepare_integration_tests.sh new file mode 100644 index 0000000..b5dda42 --- /dev/null +++ b/docker/template/src/nnunetv2/tests/integration_tests/prepare_integration_tests.sh @@ -0,0 +1,18 @@ +# assumes you are in the nnunet repo! + +# prepare raw datasets +python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py +python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py +python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py +python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py + +# now run experiment planning without preprocessing +nnUNetv2_plan_and_preprocess -d 996 997 998 999 --no_pp + +# now add 3d lowres and cascade +python nnunetv2/tests/integration_tests/add_lowres_and_cascade.py -d 996 997 998 999 + +# now preprocess everything +nnUNetv2_preprocess -d 996 997 998 999 -c 2d 3d_lowres 3d_fullres -np 8 8 8 # no need to preprocess cascade as its the same data as 3d_fullres + +# done \ No newline at end of file diff --git a/docker/template/src/nnunetv2/tests/integration_tests/readme.md b/docker/template/src/nnunetv2/tests/integration_tests/readme.md new file mode 100644 index 0000000..2a44f13 --- /dev/null +++ b/docker/template/src/nnunetv2/tests/integration_tests/readme.md @@ -0,0 +1,58 @@ +# Preface + +I am just a mortal with many tasks and limited time. Aint nobody got time for unittests. + +HOWEVER, at least some integration tests should be performed testing nnU-Net from start to finish. + +# Introduction - What the heck is happening? +This test covers all possible labeling scenarios (standard labels, regions, ignore labels and regions with +ignore labels). It runs the entire nnU-Net pipeline from start to finish: + +- fingerprint extraction +- experiment planning +- preprocessing +- train all 4 configurations (2d, 3d_lowres, 3d_fullres, 3d_cascade_fullres) as 5-fold CV +- automatically find the best model or ensemble +- determine the postprocessing used for this +- predict some test set +- apply postprocessing to the test set + +To speed things up, we do the following: +- pick Dataset004_Hippocampus because it is quadratisch praktisch gut. MNIST of medical image segmentation +- by default this dataset does not have 3d_lowres or cascade. We just manually add them (cool new feature, eh?). See `add_lowres_and_cascade.py` to learn more! +- we use nnUNetTrainer_5epochs for a short training + +# How to run it? + +Set your pwd to be the nnunet repo folder (the one where the `nnunetv2` folder and the `setup.py` are located!) + +Now generate the 4 dummy datasets (ids 996, 997, 998, 999) from dataset 4. This will crash if you don't have Dataset004! +```commandline +bash nnunetv2/tests/integration_tests/prepare_integration_tests.sh +``` + +Now you can run the integration test for each of the datasets: +```commandline +bash nnunetv2/tests/integration_tests/run_integration_test.sh DATSET_ID +``` +use DATSET_ID 996, 997, 998 and 999. You can run these independently on different GPUs/systems to speed things up. +This will take i dunno like 10-30 Minutes!? + +Also run +```commandline +bash nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh DATSET_ID +``` +to verify DDP is working (needs 2 GPUs!) + +# How to check if the test was successful? +If I was not as lazy as I am I would have programmed some automatism that checks if Dice scores etc are in an acceptable range. +So you need to do the following: +1) check that none of your runs crashed (duh) +2) for each run, navigate to `nnUNet_results/DATASET_NAME` and take a look at the `inference_information.json` file. +Does it make sense? If so: NICE! + +Once the integration test is completed you can delete all the temporary files associated with it by running: + +```commandline +python nnunetv2/tests/integration_tests/cleanup_integration_test.py +``` \ No newline at end of file diff --git a/docker/template/src/nnunetv2/tests/integration_tests/run_integration_test.sh b/docker/template/src/nnunetv2/tests/integration_tests/run_integration_test.sh new file mode 100644 index 0000000..ff0426c --- /dev/null +++ b/docker/template/src/nnunetv2/tests/integration_tests/run_integration_test.sh @@ -0,0 +1,27 @@ + + +nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_fullres 1 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_fullres 2 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_fullres 3 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_fullres 4 -tr nnUNetTrainer_5epochs --npz + +nnUNetv2_train $1 2d 0 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 2d 1 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 2d 2 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 2d 3 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 2d 4 -tr nnUNetTrainer_5epochs --npz + +nnUNetv2_train $1 3d_lowres 0 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_lowres 1 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_lowres 2 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_lowres 3 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_lowres 4 -tr nnUNetTrainer_5epochs --npz + +nnUNetv2_train $1 3d_cascade_fullres 0 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_cascade_fullres 1 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_cascade_fullres 2 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_cascade_fullres 3 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_cascade_fullres 4 -tr nnUNetTrainer_5epochs --npz + +python nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py -d $1 \ No newline at end of file diff --git a/docker/template/src/nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py b/docker/template/src/nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py new file mode 100644 index 0000000..89e783e --- /dev/null +++ b/docker/template/src/nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py @@ -0,0 +1,75 @@ +import argparse + +import torch +from batchgenerators.utilities.file_and_folder_operations import join, load_pickle + +from nnunetv2.ensembling.ensemble import ensemble_folders +from nnunetv2.evaluation.find_best_configuration import find_best_configuration, \ + dumb_trainer_config_plans_to_trained_models_dict +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +from nnunetv2.paths import nnUNet_raw, nnUNet_results +from nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.file_path_utilities import get_output_folder + + +if __name__ == '__main__': + """ + Predicts the imagesTs folder with the best configuration and applies postprocessing + """ + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + + parser = argparse.ArgumentParser() + parser.add_argument('-d', type=int, help='dataset id') + args = parser.parse_args() + d = args.d + + dataset_name = maybe_convert_to_dataset_name(d) + source_dir = join(nnUNet_raw, dataset_name, 'imagesTs') + target_dir_base = join(nnUNet_results, dataset_name) + + models = dumb_trainer_config_plans_to_trained_models_dict(['nnUNetTrainer_5epochs'], + ['2d', + '3d_lowres', + '3d_cascade_fullres', + '3d_fullres'], + ['nnUNetPlans']) + ret = find_best_configuration(d, models, allow_ensembling=True, num_processes=8, overwrite=True, + folds=(0, 1, 2, 3, 4), strict=True) + + has_ensemble = len(ret['best_model_or_ensemble']['selected_model_or_models']) > 1 + + # we don't use all folds to speed stuff up + used_folds = (0, 3) + output_folders = [] + for im in ret['best_model_or_ensemble']['selected_model_or_models']: + output_dir = join(target_dir_base, f"pred_{im['configuration']}") + model_folder = get_output_folder(d, im['trainer'], im['plans_identifier'], im['configuration']) + # note that if the best model is the enseble of 3d_lowres and 3d cascade then 3d_lowres will be predicted + # twice (once standalone and once to generate the predictions for the cascade) because we don't reuse the + # prediction here. Proper way would be to check for that and + # then give the output of 3d_lowres inference to the folder_with_segs_from_prev_stage kwarg in + # predict_from_raw_data. Since we allow for + # dynamically setting 'previous_stage' in the plans I am too lazy to implement this here. This is just an + # integration test after all. Take a closer look at how this in handled in predict_from_raw_data + predictor = nnUNetPredictor(verbose=False, allow_tqdm=False) + predictor.initialize_from_trained_model_folder(model_folder, used_folds) + predictor.predict_from_files(source_dir, output_dir, has_ensemble, overwrite=True) + # predict_from_raw_data(list_of_lists_or_source_folder=source_dir, output_folder=output_dir, + # model_training_output_dir=model_folder, use_folds=used_folds, + # save_probabilities=has_ensemble, verbose=False, overwrite=True) + output_folders.append(output_dir) + + # if we have an ensemble, we need to ensemble the results + if has_ensemble: + ensemble_folders(output_folders, join(target_dir_base, 'ensemble_predictions'), save_merged_probabilities=False) + folder_for_pp = join(target_dir_base, 'ensemble_predictions') + else: + folder_for_pp = output_folders[0] + + # apply postprocessing + pp_fns, pp_fn_kwargs = load_pickle(ret['best_model_or_ensemble']['postprocessing_file']) + apply_postprocessing_to_folder(folder_for_pp, join(target_dir_base, 'ensemble_predictions_postprocessed'), + pp_fns, + pp_fn_kwargs, plans_file_or_dict=ret['best_model_or_ensemble']['some_plans_file']) diff --git a/docker/template/src/nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh b/docker/template/src/nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh new file mode 100644 index 0000000..5199247 --- /dev/null +++ b/docker/template/src/nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh @@ -0,0 +1 @@ +nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_10epochs -num_gpus 2 diff --git a/docker/template/src/nnunetv2/training/__init__.py b/docker/template/src/nnunetv2/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/data_augmentation/__init__.py b/docker/template/src/nnunetv2/training/data_augmentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/data_augmentation/compute_initial_patch_size.py b/docker/template/src/nnunetv2/training/data_augmentation/compute_initial_patch_size.py new file mode 100644 index 0000000..a772bc2 --- /dev/null +++ b/docker/template/src/nnunetv2/training/data_augmentation/compute_initial_patch_size.py @@ -0,0 +1,24 @@ +import numpy as np + + +def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range): + if isinstance(rot_x, (tuple, list)): + rot_x = max(np.abs(rot_x)) + if isinstance(rot_y, (tuple, list)): + rot_y = max(np.abs(rot_y)) + if isinstance(rot_z, (tuple, list)): + rot_z = max(np.abs(rot_z)) + rot_x = min(90 / 360 * 2. * np.pi, rot_x) + rot_y = min(90 / 360 * 2. * np.pi, rot_y) + rot_z = min(90 / 360 * 2. * np.pi, rot_z) + from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d + coords = np.array(final_patch_size) + final_shape = np.copy(coords) + if len(coords) == 3: + final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0) + final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0) + final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0) + elif len(coords) == 2: + final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0) + final_shape /= min(scale_range) + return final_shape.astype(int) diff --git a/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/__init__.py b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py new file mode 100644 index 0000000..378bab2 --- /dev/null +++ b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py @@ -0,0 +1,136 @@ +from typing import Union, List, Tuple, Callable + +import numpy as np +from acvl_utils.morphology.morphology_helper import label_with_component_sizes +from batchgenerators.transforms.abstract_transforms import AbstractTransform +from skimage.morphology import ball +from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening + + +class MoveSegAsOneHotToData(AbstractTransform): + def __init__(self, index_in_origin: int, all_labels: Union[Tuple[int, ...], List[int]], + key_origin="seg", key_target="data", remove_from_origin=True): + """ + Takes data_dict[seg][:, index_in_origin], converts it to one hot encoding and appends it to + data_dict[key_target]. Optionally removes index_in_origin from data_dict[seg]. + """ + self.remove_from_origin = remove_from_origin + self.all_labels = all_labels + self.key_target = key_target + self.key_origin = key_origin + self.index_in_origin = index_in_origin + + def __call__(self, **data_dict): + seg = data_dict[self.key_origin][:, self.index_in_origin:self.index_in_origin+1] + + seg_onehot = np.zeros((seg.shape[0], len(self.all_labels), *seg.shape[2:]), + dtype=data_dict[self.key_target].dtype) + for i, l in enumerate(self.all_labels): + seg_onehot[:, i][seg[:, 0] == l] = 1 + + data_dict[self.key_target] = np.concatenate((data_dict[self.key_target], seg_onehot), 1) + + if self.remove_from_origin: + remaining_channels = [i for i in range(data_dict[self.key_origin].shape[1]) if i != self.index_in_origin] + data_dict[self.key_origin] = data_dict[self.key_origin][:, remaining_channels] + + return data_dict + + +class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform): + def __init__(self, channel_idx: Union[int, List[int]], key: str = "data", p_per_sample: float = 0.2, + fill_with_other_class_p: float = 0.25, + dont_do_if_covers_more_than_x_percent: float = 0.25, p_per_label: float = 1): + """ + Randomly removes connected components in the specified channel_idx of data_dict[key]. Only considers components + smaller than dont_do_if_covers_more_than_X_percent of the sample. Also has the option of simulating + misclassification as another class (fill_with_other_class_p) + """ + self.p_per_label = p_per_label + self.dont_do_if_covers_more_than_x_percent = dont_do_if_covers_more_than_x_percent + self.fill_with_other_class_p = fill_with_other_class_p + self.p_per_sample = p_per_sample + self.key = key + if not isinstance(channel_idx, (list, tuple)): + channel_idx = [channel_idx] + self.channel_idx = channel_idx + + def __call__(self, **data_dict): + data = data_dict.get(self.key) + for b in range(data.shape[0]): + if np.random.uniform() < self.p_per_sample: + for c in self.channel_idx: + if np.random.uniform() < self.p_per_label: + # print(np.unique(data[b, c])) ## should be [0, 1] + workon = data[b, c].astype(bool) + if not np.any(workon): + continue + num_voxels = np.prod(workon.shape, dtype=np.uint64) + lab, component_sizes = label_with_component_sizes(workon.astype(bool)) + if len(component_sizes) > 0: + valid_component_ids = [i for i, j in component_sizes.items() if j < + num_voxels*self.dont_do_if_covers_more_than_x_percent] + # print('RemoveRandomConnectedComponentFromOneHotEncodingTransform', c, + # np.unique(data[b, c]), len(component_sizes), valid_component_ids, + # len(valid_component_ids)) + if len(valid_component_ids) > 0: + random_component = np.random.choice(valid_component_ids) + data[b, c][lab == random_component] = 0 + if np.random.uniform() < self.fill_with_other_class_p: + other_ch = [i for i in self.channel_idx if i != c] + if len(other_ch) > 0: + other_class = np.random.choice(other_ch) + data[b, other_class][lab == random_component] = 1 + data_dict[self.key] = data + return data_dict + + +class ApplyRandomBinaryOperatorTransform(AbstractTransform): + def __init__(self, + channel_idx: Union[int, List[int], Tuple[int, ...]], + p_per_sample: float = 0.3, + any_of_these: Tuple[Callable] = (binary_dilation, binary_erosion, binary_closing, binary_opening), + key: str = "data", + strel_size: Tuple[int, int] = (1, 10), + p_per_label: float = 1): + """ + Applies random binary operations (specified by any_of_these) with random ball size (radius is uniformly sampled + from interval strel_size) to specified channels. Expects the channel_idx to correspond to a hone hot encoded + segmentation (see for example MoveSegAsOneHotToData) + """ + self.p_per_label = p_per_label + self.strel_size = strel_size + self.key = key + self.any_of_these = any_of_these + self.p_per_sample = p_per_sample + + if not isinstance(channel_idx, (list, tuple)): + channel_idx = [channel_idx] + self.channel_idx = channel_idx + + def __call__(self, **data_dict): + for b in range(data_dict[self.key].shape[0]): + if np.random.uniform() < self.p_per_sample: + # this needs to be applied in random order to the channels + np.random.shuffle(self.channel_idx) + for c in self.channel_idx: + if np.random.uniform() < self.p_per_label: + operation = np.random.choice(self.any_of_these) + selem = ball(np.random.uniform(*self.strel_size)) + workon = data_dict[self.key][b, c].astype(bool) + if not np.any(workon): + continue + # print(np.unique(workon)) + res = operation(workon, selem).astype(data_dict[self.key].dtype) + # print('ApplyRandomBinaryOperatorTransform', c, operation, np.sum(workon), np.sum(res)) + data_dict[self.key][b, c] = res + + # if class was added, we need to remove it in ALL other channels to keep one hot encoding + # properties + other_ch = [i for i in self.channel_idx if i != c] + if len(other_ch) > 0: + was_added_mask = (res - workon) > 0 + for oc in other_ch: + data_dict[self.key][b, oc][was_added_mask] = 0 + # if class was removed, leave it at background + return data_dict diff --git a/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py new file mode 100644 index 0000000..d31881f --- /dev/null +++ b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py @@ -0,0 +1,55 @@ +from typing import Tuple, Union, List + +from batchgenerators.augmentations.utils import resize_segmentation +from batchgenerators.transforms.abstract_transforms import AbstractTransform +import numpy as np + + +class DownsampleSegForDSTransform2(AbstractTransform): + ''' + data_dict['output_key'] will be a list of segmentations scaled according to ds_scales + ''' + def __init__(self, ds_scales: Union[List, Tuple], + order: int = 0, input_key: str = "seg", + output_key: str = "seg", axes: Tuple[int] = None): + """ + Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision + output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape. + ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling + for each axis independently + """ + self.axes = axes + self.output_key = output_key + self.input_key = input_key + self.order = order + self.ds_scales = ds_scales + + def __call__(self, **data_dict): + if self.axes is None: + axes = list(range(2, data_dict[self.input_key].ndim)) + else: + axes = self.axes + + output = [] + for s in self.ds_scales: + if not isinstance(s, (tuple, list)): + s = [s] * len(axes) + else: + assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \ + f'for each axis) then the number of entried in that tuple (here ' \ + f'{len(s)}) must be the same as the number of axes (here {len(axes)}).' + + if all([i == 1 for i in s]): + output.append(data_dict[self.input_key]) + else: + new_shape = np.array(data_dict[self.input_key].shape).astype(float) + for i, a in enumerate(axes): + new_shape[a] *= s[i] + new_shape = np.round(new_shape).astype(int) + out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype) + for b in range(data_dict[self.input_key].shape[0]): + for c in range(data_dict[self.input_key].shape[1]): + out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order) + output.append(out_seg) + data_dict[self.output_key] = output + return data_dict diff --git a/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/limited_length_multithreaded_augmenter.py b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/limited_length_multithreaded_augmenter.py new file mode 100644 index 0000000..dd8368c --- /dev/null +++ b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/limited_length_multithreaded_augmenter.py @@ -0,0 +1,10 @@ +from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter + + +class LimitedLenWrapper(NonDetMultiThreadedAugmenter): + def __init__(self, my_imaginary_length, *args, **kwargs): + super().__init__(*args, **kwargs) + self.len = my_imaginary_length + + def __len__(self): + return self.len diff --git a/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/manipulating_data_dict.py b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/manipulating_data_dict.py new file mode 100644 index 0000000..587acd7 --- /dev/null +++ b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/manipulating_data_dict.py @@ -0,0 +1,10 @@ +from batchgenerators.transforms.abstract_transforms import AbstractTransform + + +class RemoveKeyTransform(AbstractTransform): + def __init__(self, key_to_remove: str): + self.key_to_remove = key_to_remove + + def __call__(self, **data_dict): + _ = data_dict.pop(self.key_to_remove, None) + return data_dict diff --git a/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/masking.py b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/masking.py new file mode 100644 index 0000000..b009993 --- /dev/null +++ b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/masking.py @@ -0,0 +1,22 @@ +from typing import List + +from batchgenerators.transforms.abstract_transforms import AbstractTransform + + +class MaskTransform(AbstractTransform): + def __init__(self, apply_to_channels: List[int], mask_idx_in_seg: int = 0, set_outside_to: int = 0, + data_key: str = "data", seg_key: str = "seg"): + """ + Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!! + """ + self.apply_to_channels = apply_to_channels + self.seg_key = seg_key + self.data_key = data_key + self.set_outside_to = set_outside_to + self.mask_idx_in_seg = mask_idx_in_seg + + def __call__(self, **data_dict): + mask = data_dict[self.seg_key][:, self.mask_idx_in_seg] < 0 + for c in self.apply_to_channels: + data_dict[self.data_key][:, c][mask] = self.set_outside_to + return data_dict diff --git a/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py new file mode 100644 index 0000000..52d2fc0 --- /dev/null +++ b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py @@ -0,0 +1,38 @@ +from typing import List, Tuple, Union + +from batchgenerators.transforms.abstract_transforms import AbstractTransform +import numpy as np + + +class ConvertSegmentationToRegionsTransform(AbstractTransform): + def __init__(self, regions: Union[List, Tuple], + seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0): + """ + regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region, + example: + regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2 + :param regions: + :param seg_key: + :param output_key: + """ + self.seg_channel = seg_channel + self.output_key = output_key + self.seg_key = seg_key + self.regions = regions + + def __call__(self, **data_dict): + seg = data_dict.get(self.seg_key) + num_regions = len(self.regions) + if seg is not None: + seg_shp = seg.shape + output_shape = list(seg_shp) + output_shape[1] = num_regions + region_output = np.zeros(output_shape, dtype=seg.dtype) + for b in range(seg_shp[0]): + for region_id, region_source_labels in enumerate(self.regions): + if not isinstance(region_source_labels, (list, tuple)): + region_source_labels = (region_source_labels, ) + for label_value in region_source_labels: + region_output[b, region_id][seg[b, self.seg_channel] == label_value] = 1 + data_dict[self.output_key] = region_output + return data_dict diff --git a/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py new file mode 100644 index 0000000..340fce7 --- /dev/null +++ b/docker/template/src/nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py @@ -0,0 +1,45 @@ +from typing import Tuple, Union, List + +from batchgenerators.transforms.abstract_transforms import AbstractTransform + + +class Convert3DTo2DTransform(AbstractTransform): + def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')): + """ + Transforms a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel + """ + self.apply_to_keys = apply_to_keys + + def __call__(self, **data_dict): + for k in self.apply_to_keys: + shp = data_dict[k].shape + assert len(shp) == 5, 'This transform only works on 3D data, so expects 5D tensor (b, c, x, y, z) as input.' + data_dict[k] = data_dict[k].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4])) + shape_key = f'orig_shape_{k}' + assert shape_key not in data_dict.keys(), f'Convert3DTo2DTransform needs to store the original shape. ' \ + f'It does that using the {shape_key} key. That key is ' \ + f'already taken. Bummer.' + data_dict[shape_key] = shp + return data_dict + + +class Convert2DTo3DTransform(AbstractTransform): + def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')): + """ + Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D (b, c, x, y, z) + """ + self.apply_to_keys = apply_to_keys + + def __call__(self, **data_dict): + for k in self.apply_to_keys: + shape_key = f'orig_shape_{k}' + assert shape_key in data_dict.keys(), f'Did not find key {shape_key} in data_dict. Shitty. ' \ + f'Convert2DTo3DTransform only works in tandem with ' \ + f'Convert3DTo2DTransform and you probably forgot to add ' \ + f'Convert3DTo2DTransform to your pipeline. (Convert3DTo2DTransform ' \ + f'is where the missing key is generated)' + original_shape = data_dict[shape_key] + current_shape = data_dict[k].shape + data_dict[k] = data_dict[k].reshape((original_shape[0], original_shape[1], original_shape[2], + current_shape[-2], current_shape[-1])) + return data_dict diff --git a/docker/template/src/nnunetv2/training/dataloading/__init__.py b/docker/template/src/nnunetv2/training/dataloading/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/dataloading/base_data_loader.py b/docker/template/src/nnunetv2/training/dataloading/base_data_loader.py new file mode 100644 index 0000000..6a6a49f --- /dev/null +++ b/docker/template/src/nnunetv2/training/dataloading/base_data_loader.py @@ -0,0 +1,139 @@ +from typing import Union, Tuple + +from batchgenerators.dataloading.data_loader import DataLoader +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset +from nnunetv2.utilities.label_handling.label_handling import LabelManager + + +class nnUNetDataLoaderBase(DataLoader): + def __init__(self, + data: nnUNetDataset, + batch_size: int, + patch_size: Union[List[int], Tuple[int, ...], np.ndarray], + final_patch_size: Union[List[int], Tuple[int, ...], np.ndarray], + label_manager: LabelManager, + oversample_foreground_percent: float = 0.0, + sampling_probabilities: Union[List[int], Tuple[int, ...], np.ndarray] = None, + pad_sides: Union[List[int], Tuple[int, ...], np.ndarray] = None, + probabilistic_oversampling: bool = False): + super().__init__(data, batch_size, 1, None, True, False, True, sampling_probabilities) + assert isinstance(data, nnUNetDataset), 'nnUNetDataLoaderBase only supports dictionaries as data' + self.indices = list(data.keys()) + + self.oversample_foreground_percent = oversample_foreground_percent + self.final_patch_size = final_patch_size + self.patch_size = patch_size + self.list_of_keys = list(self._data.keys()) + # need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size + # (which is what the network will get) these patches will also cover the border of the images + self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int) + if pad_sides is not None: + if not isinstance(pad_sides, np.ndarray): + pad_sides = np.array(pad_sides) + self.need_to_pad += pad_sides + self.num_channels = None + self.pad_sides = pad_sides + self.data_shape, self.seg_shape = self.determine_shapes() + self.sampling_probabilities = sampling_probabilities + self.annotated_classes_key = tuple(label_manager.all_labels) + self.has_ignore = label_manager.has_ignore_label + self.get_do_oversample = self._oversample_last_XX_percent if not probabilistic_oversampling \ + else self._probabilistic_oversampling + + def _oversample_last_XX_percent(self, sample_idx: int) -> bool: + """ + determines whether sample sample_idx in a minibatch needs to be guaranteed foreground + """ + return not sample_idx < round(self.batch_size * (1 - self.oversample_foreground_percent)) + + def _probabilistic_oversampling(self, sample_idx: int) -> bool: + # print('YEAH BOIIIIII') + return np.random.uniform() < self.oversample_foreground_percent + + def determine_shapes(self): + # load one case + data, seg, properties = self._data.load_case(self.indices[0]) + num_color_channels = data.shape[0] + + data_shape = (self.batch_size, num_color_channels, *self.patch_size) + seg_shape = (self.batch_size, seg.shape[0], *self.patch_size) + return data_shape, seg_shape + + def get_bbox(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None], + overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False): + # in dataloader 2d we need to select the slice prior to this and also modify the class_locations to only have + # locations for the given slice + need_to_pad = self.need_to_pad.copy() + dim = len(data_shape) + + for d in range(dim): + # if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides + # always + if need_to_pad[d] + data_shape[d] < self.patch_size[d]: + need_to_pad[d] = self.patch_size[d] - data_shape[d] + + # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we + # define what the upper and lower bound can be to then sample form them with np.random.randint + lbs = [- need_to_pad[i] // 2 for i in range(dim)] + ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)] + + # if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get + # at least one of the foreground classes in the patch + if not force_fg and not self.has_ignore: + bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)] + # print('I want a random location') + else: + if not force_fg and self.has_ignore: + selected_class = self.annotated_classes_key + if len(class_locations[selected_class]) == 0: + # no annotated pixels in this case. Not good. But we can hardly skip it here + print('Warning! No annotated pixels in image!') + selected_class = None + # print(f'I have ignore labels and want to pick a labeled area. annotated_classes_key: {self.annotated_classes_key}') + elif force_fg: + assert class_locations is not None, 'if force_fg is set class_locations cannot be None' + if overwrite_class is not None: + assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \ + 'have class_locations (missing key)' + # this saves us a np.unique. Preprocessing already did that for all cases. Neat. + # class_locations keys can also be tuple + eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0] + + # if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list + # strange formulation needed to circumvent + # ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() + tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions] + if any(tmp): + if len(eligible_classes_or_regions) > 1: + eligible_classes_or_regions.pop(np.where(tmp)[0][0]) + + if len(eligible_classes_or_regions) == 0: + # this only happens if some image does not contain foreground voxels at all + selected_class = None + if verbose: + print('case does not contain any foreground classes') + else: + # I hate myself. Future me aint gonna be happy to read this + # 2022_11_25: had to read it today. Wasn't too bad + selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \ + (overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class + # print(f'I want to have foreground, selected class: {selected_class}') + else: + raise RuntimeError('lol what!?') + voxels_of_that_class = class_locations[selected_class] if selected_class is not None else None + + if voxels_of_that_class is not None and len(voxels_of_that_class) > 0: + selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))] + # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel. + # Make sure it is within the bounds of lb and ub + # i + 1 because we have first dimension 0! + bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)] + else: + # If the image does not contain any foreground classes, we fall back to random cropping + bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)] + + bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)] + + return bbox_lbs, bbox_ubs diff --git a/docker/template/src/nnunetv2/training/dataloading/data_loader_2d.py b/docker/template/src/nnunetv2/training/dataloading/data_loader_2d.py new file mode 100644 index 0000000..aab8438 --- /dev/null +++ b/docker/template/src/nnunetv2/training/dataloading/data_loader_2d.py @@ -0,0 +1,94 @@ +import numpy as np +from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset + + +class nnUNetDataLoader2D(nnUNetDataLoaderBase): + def generate_train_batch(self): + selected_keys = self.get_indices() + # preallocate memory for data and seg + data_all = np.zeros(self.data_shape, dtype=np.float32) + seg_all = np.zeros(self.seg_shape, dtype=np.int16) + case_properties = [] + + for j, current_key in enumerate(selected_keys): + # oversampling foreground will improve stability of model training, especially if many patches are empty + # (Lung for example) + force_fg = self.get_do_oversample(j) + data, seg, properties = self._data.load_case(current_key) + case_properties.append(properties) + + # select a class/region first, then a slice where this class is present, then crop to that area + if not force_fg: + if self.has_ignore: + selected_class_or_region = self.annotated_classes_key + else: + selected_class_or_region = None + else: + # filter out all classes that are not present here + eligible_classes_or_regions = [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) > 0] + + # if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list + # strange formulation needed to circumvent + # ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() + tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions] + if any(tmp): + if len(eligible_classes_or_regions) > 1: + eligible_classes_or_regions.pop(np.where(tmp)[0][0]) + + selected_class_or_region = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \ + len(eligible_classes_or_regions) > 0 else None + if selected_class_or_region is not None: + selected_slice = np.random.choice(properties['class_locations'][selected_class_or_region][:, 1]) + else: + selected_slice = np.random.choice(len(data[0])) + + data = data[:, selected_slice] + seg = seg[:, selected_slice] + + # the line of death lol + # this needs to be a separate variable because we could otherwise permanently overwrite + # properties['class_locations'] + # selected_class_or_region is: + # - None if we do not have an ignore label and force_fg is False OR if force_fg is True but there is no foreground in the image + # - A tuple of all (non-ignore) labels if there is an ignore label and force_fg is False + # - a class or region if force_fg is True + class_locations = { + selected_class_or_region: properties['class_locations'][selected_class_or_region][properties['class_locations'][selected_class_or_region][:, 1] == selected_slice][:, (0, 2, 3)] + } if (selected_class_or_region is not None) else None + + # print(properties) + shape = data.shape[1:] + dim = len(shape) + bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg if selected_class_or_region is not None else None, + class_locations, overwrite_class=selected_class_or_region) + + # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the + # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad. + # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size + # later + valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)] + valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)] + + # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage. + # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to + # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also + # remove label -1 in the data augmentation but this way it is less error prone) + this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) + data = data[this_slice] + + this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) + seg = seg[this_slice] + + padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)] + data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0) + seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1) + + return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys} + + +if __name__ == '__main__': + folder = '/media/fabian/data/nnUNet_preprocessed/Dataset004_Hippocampus/2d' + ds = nnUNetDataset(folder, None, 1000) # this should not load the properties! + dl = nnUNetDataLoader2D(ds, 366, (65, 65), (56, 40), 0.33, None, None) + a = next(dl) diff --git a/docker/template/src/nnunetv2/training/dataloading/data_loader_3d.py b/docker/template/src/nnunetv2/training/dataloading/data_loader_3d.py new file mode 100644 index 0000000..e8345f8 --- /dev/null +++ b/docker/template/src/nnunetv2/training/dataloading/data_loader_3d.py @@ -0,0 +1,56 @@ +import numpy as np +from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset + + +class nnUNetDataLoader3D(nnUNetDataLoaderBase): + def generate_train_batch(self): + selected_keys = self.get_indices() + # preallocate memory for data and seg + data_all = np.zeros(self.data_shape, dtype=np.float32) + seg_all = np.zeros(self.seg_shape, dtype=np.int16) + case_properties = [] + + for j, i in enumerate(selected_keys): + # oversampling foreground will improve stability of model training, especially if many patches are empty + # (Lung for example) + force_fg = self.get_do_oversample(j) + + data, seg, properties = self._data.load_case(i) + case_properties.append(properties) + + # If we are doing the cascade then the segmentation from the previous stage will already have been loaded by + # self._data.load_case(i) (see nnUNetDataset.load_case) + shape = data.shape[1:] + dim = len(shape) + bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg, properties['class_locations']) + + # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the + # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad. + # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size + # later + valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)] + valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)] + + # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage. + # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to + # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also + # remove label -1 in the data augmentation but this way it is less error prone) + this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) + data = data[this_slice] + + this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) + seg = seg[this_slice] + + padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)] + data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0) + seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1) + + return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys} + + +if __name__ == '__main__': + folder = '/media/fabian/data/nnUNet_preprocessed/Dataset002_Heart/3d_fullres' + ds = nnUNetDataset(folder, 0) # this should not load the properties! + dl = nnUNetDataLoader3D(ds, 5, (16, 16, 16), (16, 16, 16), 0.33, None, None) + a = next(dl) diff --git a/docker/template/src/nnunetv2/training/dataloading/nnunet_dataset.py b/docker/template/src/nnunetv2/training/dataloading/nnunet_dataset.py new file mode 100644 index 0000000..153a005 --- /dev/null +++ b/docker/template/src/nnunetv2/training/dataloading/nnunet_dataset.py @@ -0,0 +1,146 @@ +import os +from typing import List + +import numpy as np +import shutil + +from batchgenerators.utilities.file_and_folder_operations import join, load_pickle, isfile +from nnunetv2.training.dataloading.utils import get_case_identifiers + + +class nnUNetDataset(object): + def __init__(self, folder: str, case_identifiers: List[str] = None, + num_images_properties_loading_threshold: int = 0, + folder_with_segs_from_previous_stage: str = None): + """ + This does not actually load the dataset. It merely creates a dictionary where the keys are training case names and + the values are dictionaries containing the relevant information for that case. + dataset[training_case] -> info + Info has the following key:value pairs: + - dataset[case_identifier]['properties']['data_file'] -> the full path to the npz file associated with the training case + - dataset[case_identifier]['properties']['properties_file'] -> the pkl file containing the case properties + + In addition, if the total number of cases is < num_images_properties_loading_threshold we load all the pickle files + (containing auxiliary information). This is done for small datasets so that we don't spend too much CPU time on + reading pkl files on the fly during training. However, for large datasets storing all the aux info (which also + contains locations of foreground voxels in the images) can cause too much RAM utilization. In that + case is it better to load on the fly. + + If properties are loaded into the RAM, the info dicts each will have an additional entry: + - dataset[case_identifier]['properties'] -> pkl file content + + IMPORTANT! THIS CLASS ITSELF IS READ-ONLY. YOU CANNOT ADD KEY:VALUE PAIRS WITH nnUNetDataset[key] = value + USE THIS INSTEAD: + nnUNetDataset.dataset[key] = value + (not sure why you'd want to do that though. So don't do it) + """ + super().__init__() + # print('loading dataset') + if case_identifiers is None: + case_identifiers = get_case_identifiers(folder) + case_identifiers.sort() + + self.dataset = {} + for c in case_identifiers: + self.dataset[c] = {} + self.dataset[c]['data_file'] = join(folder, f"{c}.npz") + self.dataset[c]['properties_file'] = join(folder, f"{c}.pkl") + if folder_with_segs_from_previous_stage is not None: + self.dataset[c]['seg_from_prev_stage_file'] = join(folder_with_segs_from_previous_stage, f"{c}.npz") + + if len(case_identifiers) <= num_images_properties_loading_threshold: + for i in self.dataset.keys(): + self.dataset[i]['properties'] = load_pickle(self.dataset[i]['properties_file']) + + self.keep_files_open = ('nnUNet_keep_files_open' in os.environ.keys()) and \ + (os.environ['nnUNet_keep_files_open'].lower() in ('true', '1', 't')) + # print(f'nnUNetDataset.keep_files_open: {self.keep_files_open}') + + def __getitem__(self, key): + ret = {**self.dataset[key]} + if 'properties' not in ret.keys(): + ret['properties'] = load_pickle(ret['properties_file']) + return ret + + def __setitem__(self, key, value): + return self.dataset.__setitem__(key, value) + + def keys(self): + return self.dataset.keys() + + def __len__(self): + return self.dataset.__len__() + + def items(self): + return self.dataset.items() + + def values(self): + return self.dataset.values() + + def load_case(self, key): + entry = self[key] + if 'open_data_file' in entry.keys(): + data = entry['open_data_file'] + # print('using open data file') + elif isfile(entry['data_file'][:-4] + ".npy"): + data = np.load(entry['data_file'][:-4] + ".npy", 'r') + if self.keep_files_open: + self.dataset[key]['open_data_file'] = data + # print('saving open data file') + else: + data = np.load(entry['data_file'])['data'] + + if 'open_seg_file' in entry.keys(): + seg = entry['open_seg_file'] + # print('using open data file') + elif isfile(entry['data_file'][:-4] + "_seg.npy"): + seg = np.load(entry['data_file'][:-4] + "_seg.npy", 'r') + if self.keep_files_open: + self.dataset[key]['open_seg_file'] = seg + # print('saving open seg file') + else: + seg = np.load(entry['data_file'])['seg'] + + if 'seg_from_prev_stage_file' in entry.keys(): + if isfile(entry['seg_from_prev_stage_file'][:-4] + ".npy"): + seg_prev = np.load(entry['seg_from_prev_stage_file'][:-4] + ".npy", 'r') + else: + seg_prev = np.load(entry['seg_from_prev_stage_file'])['seg'] + seg = np.vstack((seg, seg_prev[None])) + + return data, seg, entry['properties'] + + +if __name__ == '__main__': + # this is a mini test. Todo: We can move this to tests in the future (requires simulated dataset) + + folder = '/media/fabian/data/nnUNet_preprocessed/Dataset003_Liver/3d_lowres' + ds = nnUNetDataset(folder, num_images_properties_loading_threshold=0) # this should not load the properties! + # this SHOULD HAVE the properties + ks = ds['liver_0'].keys() + assert 'properties' in ks + # amazing. I am the best. + + # this should have the properties + ds = nnUNetDataset(folder, num_images_properties_loading_threshold=1000) + # now rename the properties file so that it does not exist anymore + shutil.move(join(folder, 'liver_0.pkl'), join(folder, 'liver_XXX.pkl')) + # now we should still be able to access the properties because they have already been loaded + ks = ds['liver_0'].keys() + assert 'properties' in ks + # move file back + shutil.move(join(folder, 'liver_XXX.pkl'), join(folder, 'liver_0.pkl')) + + # this should not have the properties + ds = nnUNetDataset(folder, num_images_properties_loading_threshold=0) + # now rename the properties file so that it does not exist anymore + shutil.move(join(folder, 'liver_0.pkl'), join(folder, 'liver_XXX.pkl')) + # now this should crash + try: + ks = ds['liver_0'].keys() + raise RuntimeError('we should not have come here') + except FileNotFoundError: + print('all good') + # move file back + shutil.move(join(folder, 'liver_XXX.pkl'), join(folder, 'liver_0.pkl')) + diff --git a/docker/template/src/nnunetv2/training/dataloading/utils.py b/docker/template/src/nnunetv2/training/dataloading/utils.py new file mode 100644 index 0000000..352d182 --- /dev/null +++ b/docker/template/src/nnunetv2/training/dataloading/utils.py @@ -0,0 +1,128 @@ +from __future__ import annotations +import multiprocessing +import os +from typing import List +from pathlib import Path +from warnings import warn + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import isfile, subfiles +from nnunetv2.configuration import default_num_processes + + +def find_broken_image_and_labels( + path_to_data_dir: str | Path, +) -> tuple[set[str], set[str]]: + """ + Iterates through all numpys and tries to read them once to see if a ValueError is raised. + If so, the case id is added to the respective set and returned for potential fixing. + + :path_to_data_dir: Path/str to the preprocessed directory containing the npys and npzs. + :returns: Tuple of a set containing the case ids of the broken npy images and a set of the case ids of broken npy segmentations. + """ + content = os.listdir(path_to_data_dir) + unique_ids = [c[:-4] for c in content if c.endswith(".npz")] + failed_data_ids = set() + failed_seg_ids = set() + for unique_id in unique_ids: + # Try reading data + try: + np.load(path_to_data_dir / (unique_id + ".npy"), "r") + except ValueError: + failed_data_ids.add(unique_id) + # Try reading seg + try: + np.load(path_to_data_dir / (unique_id + "_seg.npy"), "r") + except ValueError: + failed_seg_ids.add(unique_id) + + return failed_data_ids, failed_seg_ids + + +def try_fix_broken_npy(path_do_data_dir: Path, case_ids: set[str], fix_image: bool): + """ + Receives broken case ids and tries to fix them by re-extracting the npz file (up to 5 times). + + :param case_ids: Set of case ids that are broken. + :param path_do_data_dir: Path to the preprocessed directory containing the npys and npzs. + :raises ValueError: If the npy file could not be unpacked after 5 tries. -- + """ + for case_id in case_ids: + for i in range(5): + try: + key = "data" if fix_image else "seg" + suffix = ".npy" if fix_image else "_seg.npy" + read_npz = np.load(path_do_data_dir / (case_id + ".npz"), "r")[key] + np.save(path_do_data_dir / (case_id + suffix), read_npz) + # Try loading the just saved image. + np.load(path_do_data_dir / (case_id + suffix), "r") + break + except ValueError: + if i == 4: + raise ValueError( + f"Could not unpack {case_id + suffix} after 5 tries!" + ) + continue + + +def verify_or_stratify_npys(path_to_data_dir: str | Path) -> None: + """ + This re-reads the npy files after unpacking. Should there be a loading issue with any, it will try to unpack this file again and overwrites the existing. + If the new file does not get saved correctly 5 times, it will raise an error with the file name to the user. Does the same for images and segmentations. + :param path_to_data_dir: Path to the preprocessed directory containing the npys and npzs. + :raises ValueError: If the npy file could not be unpacked after 5 tries. -- + Otherwise an obscured error will be raised later during training (depending when the broken file is sampled) + """ + path_to_data_dir = Path(path_to_data_dir) + # Check for broken image and segmentation npys + failed_data_ids, failed_seg_ids = find_broken_image_and_labels(path_to_data_dir) + + if len(failed_data_ids) != 0 or len(failed_seg_ids) != 0: + warn( + f"Found {len(failed_data_ids)} faulty data npys and {len(failed_seg_ids)}!\n" + + f"Faulty images: {failed_data_ids}; Faulty segmentations: {failed_seg_ids})\n" + + "Trying to fix them now." + ) + # Try to fix the broken npys by reextracting the npz. If that fails, raise error + try_fix_broken_npy(path_to_data_dir, failed_data_ids, fix_image=True) + try_fix_broken_npy(path_to_data_dir, failed_seg_ids, fix_image=False) + + +def _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False) -> None: + try: + a = np.load(npz_file) # inexpensive, no compression is done here. This just reads metadata + if overwrite_existing or not isfile(npz_file[:-3] + "npy"): + np.save(npz_file[:-3] + "npy", a['data']) + if unpack_segmentation and (overwrite_existing or not isfile(npz_file[:-4] + "_seg.npy")): + np.save(npz_file[:-4] + "_seg.npy", a['seg']) + except KeyboardInterrupt: + if isfile(npz_file[:-3] + "npy"): + os.remove(npz_file[:-3] + "npy") + if isfile(npz_file[:-4] + "_seg.npy"): + os.remove(npz_file[:-4] + "_seg.npy") + raise KeyboardInterrupt + + +def unpack_dataset(folder: str, unpack_segmentation: bool = True, overwrite_existing: bool = False, + num_processes: int = default_num_processes): + """ + all npz files in this folder belong to the dataset, unpack them all + """ + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + npz_files = subfiles(folder, True, None, ".npz", True) + p.starmap(_convert_to_npy, zip(npz_files, + [unpack_segmentation] * len(npz_files), + [overwrite_existing] * len(npz_files)) + ) + + +def get_case_identifiers(folder: str) -> List[str]: + """ + finds all npz files in the given folder and reconstructs the training case names from them + """ + case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz") and (i.find("segFromPrevStage") == -1)] + return case_identifiers + + +if __name__ == '__main__': + unpack_dataset('/media/fabian/data/nnUNet_preprocessed/Dataset002_Heart/2d') \ No newline at end of file diff --git a/docker/template/src/nnunetv2/training/logging/__init__.py b/docker/template/src/nnunetv2/training/logging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/logging/nnunet_logger.py b/docker/template/src/nnunetv2/training/logging/nnunet_logger.py new file mode 100644 index 0000000..8409738 --- /dev/null +++ b/docker/template/src/nnunetv2/training/logging/nnunet_logger.py @@ -0,0 +1,103 @@ +import matplotlib +from batchgenerators.utilities.file_and_folder_operations import join + +matplotlib.use('agg') +import seaborn as sns +import matplotlib.pyplot as plt + + +class nnUNetLogger(object): + """ + This class is really trivial. Don't expect cool functionality here. This is my makeshift solution to problems + arising from out-of-sync epoch numbers and numbers of logged loss values. It also simplifies the trainer class a + little + + YOU MUST LOG EXACTLY ONE VALUE PER EPOCH FOR EACH OF THE LOGGING ITEMS! DONT FUCK IT UP + """ + def __init__(self, verbose: bool = False): + self.my_fantastic_logging = { + 'mean_fg_dice': list(), + 'ema_fg_dice': list(), + 'dice_per_class_or_region': list(), + 'train_losses': list(), + 'val_losses': list(), + 'lrs': list(), + 'epoch_start_timestamps': list(), + 'epoch_end_timestamps': list() + } + self.verbose = verbose + # shut up, this logging is great + + def log(self, key, value, epoch: int): + """ + sometimes shit gets messed up. We try to catch that here + """ + assert key in self.my_fantastic_logging.keys() and isinstance(self.my_fantastic_logging[key], list), \ + 'This function is only intended to log stuff to lists and to have one entry per epoch' + + if self.verbose: print(f'logging {key}: {value} for epoch {epoch}') + + if len(self.my_fantastic_logging[key]) < (epoch + 1): + self.my_fantastic_logging[key].append(value) + else: + assert len(self.my_fantastic_logging[key]) == (epoch + 1), 'something went horribly wrong. My logging ' \ + 'lists length is off by more than 1' + print(f'maybe some logging issue!? logging {key} and {value}') + self.my_fantastic_logging[key][epoch] = value + + # handle the ema_fg_dice special case! It is automatically logged when we add a new mean_fg_dice + if key == 'mean_fg_dice': + new_ema_pseudo_dice = self.my_fantastic_logging['ema_fg_dice'][epoch - 1] * 0.9 + 0.1 * value \ + if len(self.my_fantastic_logging['ema_fg_dice']) > 0 else value + self.log('ema_fg_dice', new_ema_pseudo_dice, epoch) + + def plot_progress_png(self, output_folder): + # we infer the epoch form our internal logging + epoch = min([len(i) for i in self.my_fantastic_logging.values()]) - 1 # lists of epoch 0 have len 1 + sns.set(font_scale=2.5) + fig, ax_all = plt.subplots(3, 1, figsize=(30, 54)) + # regular progress.png as we are used to from previous nnU-Net versions + ax = ax_all[0] + ax2 = ax.twinx() + x_values = list(range(epoch + 1)) + ax.plot(x_values, self.my_fantastic_logging['train_losses'][:epoch + 1], color='b', ls='-', label="loss_tr", linewidth=4) + ax.plot(x_values, self.my_fantastic_logging['val_losses'][:epoch + 1], color='r', ls='-', label="loss_val", linewidth=4) + ax2.plot(x_values, self.my_fantastic_logging['mean_fg_dice'][:epoch + 1], color='g', ls='dotted', label="pseudo dice", + linewidth=3) + ax2.plot(x_values, self.my_fantastic_logging['ema_fg_dice'][:epoch + 1], color='g', ls='-', label="pseudo dice (mov. avg.)", + linewidth=4) + ax.set_xlabel("epoch") + ax.set_ylabel("loss") + ax2.set_ylabel("pseudo dice") + ax.legend(loc=(0, 1)) + ax2.legend(loc=(0.2, 1)) + + # epoch times to see whether the training speed is consistent (inconsistent means there are other jobs + # clogging up the system) + ax = ax_all[1] + ax.plot(x_values, [i - j for i, j in zip(self.my_fantastic_logging['epoch_end_timestamps'][:epoch + 1], + self.my_fantastic_logging['epoch_start_timestamps'])][:epoch + 1], color='b', + ls='-', label="epoch duration", linewidth=4) + ylim = [0] + [ax.get_ylim()[1]] + ax.set(ylim=ylim) + ax.set_xlabel("epoch") + ax.set_ylabel("time [s]") + ax.legend(loc=(0, 1)) + + # learning rate + ax = ax_all[2] + ax.plot(x_values, self.my_fantastic_logging['lrs'][:epoch + 1], color='b', ls='-', label="learning rate", linewidth=4) + ax.set_xlabel("epoch") + ax.set_ylabel("learning rate") + ax.legend(loc=(0, 1)) + + plt.tight_layout() + + fig.savefig(join(output_folder, "progress.png")) + plt.close() + + def get_checkpoint(self): + return self.my_fantastic_logging + + def load_checkpoint(self, checkpoint: dict): + self.my_fantastic_logging = checkpoint diff --git a/docker/template/src/nnunetv2/training/loss/__init__.py b/docker/template/src/nnunetv2/training/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/loss/compound_losses.py b/docker/template/src/nnunetv2/training/loss/compound_losses.py new file mode 100644 index 0000000..eaeb5d8 --- /dev/null +++ b/docker/template/src/nnunetv2/training/loss/compound_losses.py @@ -0,0 +1,150 @@ +import torch +from nnunetv2.training.loss.dice import SoftDiceLoss, MemoryEfficientSoftDiceLoss +from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss, TopKLoss +from nnunetv2.utilities.helpers import softmax_helper_dim1 +from torch import nn + + +class DC_and_CE_loss(nn.Module): + def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None, + dice_class=SoftDiceLoss): + """ + Weights for CE and Dice do not need to sum to one. You can set whatever you want. + :param soft_dice_kwargs: + :param ce_kwargs: + :param aggregate: + :param square_dice: + :param weight_ce: + :param weight_dice: + """ + super(DC_and_CE_loss, self).__init__() + if ignore_label is not None: + ce_kwargs['ignore_index'] = ignore_label + + self.weight_dice = weight_dice + self.weight_ce = weight_ce + self.ignore_label = ignore_label + + self.ce = RobustCrossEntropyLoss(**ce_kwargs) + self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs) + + def forward(self, net_output: torch.Tensor, target: torch.Tensor): + """ + target must be b, c, x, y(, z) with c=1 + :param net_output: + :param target: + :return: + """ + if self.ignore_label is not None: + assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ + '(DC_and_CE_loss)' + mask = target != self.ignore_label + # remove ignore label from target, replace with one of the known labels. It doesn't matter because we + # ignore gradients in those areas anyway + target_dice = torch.where(mask, target, 0) + num_fg = mask.sum() + else: + target_dice = target + mask = None + + dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ + if self.weight_dice != 0 else 0 + ce_loss = self.ce(net_output, target[:, 0]) \ + if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 + + result = self.weight_ce * ce_loss + self.weight_dice * dc_loss + return result + + +class DC_and_BCE_loss(nn.Module): + def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use_ignore_label: bool = False, + dice_class=MemoryEfficientSoftDiceLoss): + """ + DO NOT APPLY NONLINEARITY IN YOUR NETWORK! + + target mut be one hot encoded + IMPORTANT: We assume use_ignore_label is located in target[:, -1]!!! + + :param soft_dice_kwargs: + :param bce_kwargs: + :param aggregate: + """ + super(DC_and_BCE_loss, self).__init__() + if use_ignore_label: + bce_kwargs['reduction'] = 'none' + + self.weight_dice = weight_dice + self.weight_ce = weight_ce + self.use_ignore_label = use_ignore_label + + self.ce = nn.BCEWithLogitsLoss(**bce_kwargs) + self.dc = dice_class(apply_nonlin=torch.sigmoid, **soft_dice_kwargs) + + def forward(self, net_output: torch.Tensor, target: torch.Tensor): + if self.use_ignore_label: + # target is one hot encoded here. invert it so that it is True wherever we can compute the loss + mask = (1 - target[:, -1:]).bool() + # remove ignore channel now that we have the mask + target_regions = torch.clone(target[:, :-1]) + else: + target_regions = target + mask = None + + dc_loss = self.dc(net_output, target_regions, loss_mask=mask) + if mask is not None: + ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8) + else: + ce_loss = self.ce(net_output, target_regions) + result = self.weight_ce * ce_loss + self.weight_dice * dc_loss + return result + + +class DC_and_topk_loss(nn.Module): + def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None): + """ + Weights for CE and Dice do not need to sum to one. You can set whatever you want. + :param soft_dice_kwargs: + :param ce_kwargs: + :param aggregate: + :param square_dice: + :param weight_ce: + :param weight_dice: + """ + super().__init__() + if ignore_label is not None: + ce_kwargs['ignore_index'] = ignore_label + + self.weight_dice = weight_dice + self.weight_ce = weight_ce + self.ignore_label = ignore_label + + self.ce = TopKLoss(**ce_kwargs) + self.dc = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs) + + def forward(self, net_output: torch.Tensor, target: torch.Tensor): + """ + target must be b, c, x, y(, z) with c=1 + :param net_output: + :param target: + :return: + """ + if self.ignore_label is not None: + assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ + '(DC_and_CE_loss)' + mask = (target != self.ignore_label).bool() + # remove ignore label from target, replace with one of the known labels. It doesn't matter because we + # ignore gradients in those areas anyway + target_dice = torch.clone(target) + target_dice[target == self.ignore_label] = 0 + num_fg = mask.sum() + else: + target_dice = target + mask = None + + dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ + if self.weight_dice != 0 else 0 + ce_loss = self.ce(net_output, target) \ + if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 + + result = self.weight_ce * ce_loss + self.weight_dice * dc_loss + return result diff --git a/docker/template/src/nnunetv2/training/loss/deep_supervision.py b/docker/template/src/nnunetv2/training/loss/deep_supervision.py new file mode 100644 index 0000000..952e3f7 --- /dev/null +++ b/docker/template/src/nnunetv2/training/loss/deep_supervision.py @@ -0,0 +1,30 @@ +import torch +from torch import nn + + +class DeepSupervisionWrapper(nn.Module): + def __init__(self, loss, weight_factors=None): + """ + Wraps a loss function so that it can be applied to multiple outputs. Forward accepts an arbitrary number of + inputs. Each input is expected to be a tuple/list. Each tuple/list must have the same length. The loss is then + applied to each entry like this: + l = w0 * loss(input0[0], input1[0], ...) + w1 * loss(input0[1], input1[1], ...) + ... + If weights are None, all w will be 1. + """ + super(DeepSupervisionWrapper, self).__init__() + assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0" + self.weight_factors = tuple(weight_factors) + self.loss = loss + + def forward(self, *args): + assert all([isinstance(i, (tuple, list)) for i in args]), \ + f"all args must be either tuple or list, got {[type(i) for i in args]}" + # we could check for equal lengths here as well, but we really shouldn't overdo it with checks because + # this code is executed a lot of times! + + if self.weight_factors is None: + weights = (1, ) * len(args[0]) + else: + weights = self.weight_factors + + return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0]) diff --git a/docker/template/src/nnunetv2/training/loss/dice.py b/docker/template/src/nnunetv2/training/loss/dice.py new file mode 100644 index 0000000..5744357 --- /dev/null +++ b/docker/template/src/nnunetv2/training/loss/dice.py @@ -0,0 +1,192 @@ +from typing import Callable + +import torch +from nnunetv2.utilities.ddp_allgather import AllGatherGrad +from torch import nn + + +class SoftDiceLoss(nn.Module): + def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1., + ddp: bool = True, clip_tp: float = None): + """ + """ + super(SoftDiceLoss, self).__init__() + + self.do_bg = do_bg + self.batch_dice = batch_dice + self.apply_nonlin = apply_nonlin + self.smooth = smooth + self.clip_tp = clip_tp + self.ddp = ddp + + def forward(self, x, y, loss_mask=None): + shp_x = x.shape + + if self.batch_dice: + axes = [0] + list(range(2, len(shp_x))) + else: + axes = list(range(2, len(shp_x))) + + if self.apply_nonlin is not None: + x = self.apply_nonlin(x) + + tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) + + if self.ddp and self.batch_dice: + tp = AllGatherGrad.apply(tp).sum(0) + fp = AllGatherGrad.apply(fp).sum(0) + fn = AllGatherGrad.apply(fn).sum(0) + + if self.clip_tp is not None: + tp = torch.clip(tp, min=self.clip_tp , max=None) + + nominator = 2 * tp + denominator = 2 * tp + fp + fn + + dc = (nominator + self.smooth) / (torch.clip(denominator + self.smooth, 1e-8)) + + if not self.do_bg: + if self.batch_dice: + dc = dc[1:] + else: + dc = dc[:, 1:] + dc = dc.mean() + + return -dc + + +class MemoryEfficientSoftDiceLoss(nn.Module): + def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1., + ddp: bool = True): + """ + saves 1.6 GB on Dataset017 3d_lowres + """ + super(MemoryEfficientSoftDiceLoss, self).__init__() + + self.do_bg = do_bg + self.batch_dice = batch_dice + self.apply_nonlin = apply_nonlin + self.smooth = smooth + self.ddp = ddp + + def forward(self, x, y, loss_mask=None): + if self.apply_nonlin is not None: + x = self.apply_nonlin(x) + + # make everything shape (b, c) + axes = tuple(range(2, x.ndim)) + + with torch.no_grad(): + if x.ndim != y.ndim: + y = y.view((y.shape[0], 1, *y.shape[1:])) + + if x.shape == y.shape: + # if this is the case then gt is probably already a one hot encoding + y_onehot = y + else: + y_onehot = torch.zeros(x.shape, device=x.device, dtype=torch.bool) + y_onehot.scatter_(1, y.long(), 1) + + if not self.do_bg: + y_onehot = y_onehot[:, 1:] + + sum_gt = y_onehot.sum(axes) if loss_mask is None else (y_onehot * loss_mask).sum(axes) + + # this one MUST be outside the with torch.no_grad(): context. Otherwise no gradients for you + if not self.do_bg: + x = x[:, 1:] + + if loss_mask is None: + intersect = (x * y_onehot).sum(axes) + sum_pred = x.sum(axes) + else: + intersect = (x * y_onehot * loss_mask).sum(axes) + sum_pred = (x * loss_mask).sum(axes) + + if self.batch_dice: + if self.ddp: + intersect = AllGatherGrad.apply(intersect).sum(0) + sum_pred = AllGatherGrad.apply(sum_pred).sum(0) + sum_gt = AllGatherGrad.apply(sum_gt).sum(0) + + intersect = intersect.sum(0) + sum_pred = sum_pred.sum(0) + sum_gt = sum_gt.sum(0) + + dc = (2 * intersect + self.smooth) / (torch.clip(sum_gt + sum_pred + self.smooth, 1e-8)) + + dc = dc.mean() + return -dc + + +def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): + """ + net_output must be (b, c, x, y(, z))) + gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) + if mask is provided it must have shape (b, 1, x, y(, z))) + :param net_output: + :param gt: + :param axes: can be (, ) = no summation + :param mask: mask must be 1 for valid pixels and 0 for invalid pixels + :param square: if True then fp, tp and fn will be squared before summation + :return: + """ + if axes is None: + axes = tuple(range(2, net_output.ndim)) + + with torch.no_grad(): + if net_output.ndim != gt.ndim: + gt = gt.view((gt.shape[0], 1, *gt.shape[1:])) + + if net_output.shape == gt.shape: + # if this is the case then gt is probably already a one hot encoding + y_onehot = gt + else: + y_onehot = torch.zeros(net_output.shape, device=net_output.device) + y_onehot.scatter_(1, gt.long(), 1) + + tp = net_output * y_onehot + fp = net_output * (1 - y_onehot) + fn = (1 - net_output) * y_onehot + tn = (1 - net_output) * (1 - y_onehot) + + if mask is not None: + with torch.no_grad(): + mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for _ in range(2, tp.ndim)])) + tp *= mask_here + fp *= mask_here + fn *= mask_here + tn *= mask_here + # benchmark whether tiling the mask would be faster (torch.tile). It probably is for large batch sizes + # OK it barely makes a difference but the implementation above is a tiny bit faster + uses less vram + # (using nnUNetv2_train 998 3d_fullres 0) + # tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) + # fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) + # fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) + # tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1) + + if square: + tp = tp ** 2 + fp = fp ** 2 + fn = fn ** 2 + tn = tn ** 2 + + if len(axes) > 0: + tp = tp.sum(dim=axes, keepdim=False) + fp = fp.sum(dim=axes, keepdim=False) + fn = fn.sum(dim=axes, keepdim=False) + tn = tn.sum(dim=axes, keepdim=False) + + return tp, fp, fn, tn + + +if __name__ == '__main__': + from nnunetv2.utilities.helpers import softmax_helper_dim1 + pred = torch.rand((2, 3, 32, 32, 32)) + ref = torch.randint(0, 3, (2, 32, 32, 32)) + + dl_old = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False) + dl_new = MemoryEfficientSoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False) + res_old = dl_old(pred, ref) + res_new = dl_new(pred, ref) + print(res_old, res_new) diff --git a/docker/template/src/nnunetv2/training/loss/robust_ce_loss.py b/docker/template/src/nnunetv2/training/loss/robust_ce_loss.py new file mode 100644 index 0000000..3399e3a --- /dev/null +++ b/docker/template/src/nnunetv2/training/loss/robust_ce_loss.py @@ -0,0 +1,32 @@ +import torch +from torch import nn, Tensor +import numpy as np + + +class RobustCrossEntropyLoss(nn.CrossEntropyLoss): + """ + this is just a compatibility layer because my target tensor is float and has an extra dimension + + input must be logits, not probabilities! + """ + def forward(self, input: Tensor, target: Tensor) -> Tensor: + if target.ndim == input.ndim: + assert target.shape[1] == 1 + target = target[:, 0] + return super().forward(input, target.long()) + + +class TopKLoss(RobustCrossEntropyLoss): + """ + input must be logits, not probabilities! + """ + def __init__(self, weight=None, ignore_index: int = -100, k: float = 10, label_smoothing: float = 0): + self.k = k + super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False, label_smoothing=label_smoothing) + + def forward(self, inp, target): + target = target[:, 0].long() + res = super(TopKLoss, self).forward(inp, target) + num_voxels = np.prod(res.shape, dtype=np.int64) + res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) + return res.mean() diff --git a/docker/template/src/nnunetv2/training/lr_scheduler/__init__.py b/docker/template/src/nnunetv2/training/lr_scheduler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/lr_scheduler/polylr.py b/docker/template/src/nnunetv2/training/lr_scheduler/polylr.py new file mode 100644 index 0000000..44857b5 --- /dev/null +++ b/docker/template/src/nnunetv2/training/lr_scheduler/polylr.py @@ -0,0 +1,20 @@ +from torch.optim.lr_scheduler import _LRScheduler + + +class PolyLRScheduler(_LRScheduler): + def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None): + self.optimizer = optimizer + self.initial_lr = initial_lr + self.max_steps = max_steps + self.exponent = exponent + self.ctr = 0 + super().__init__(optimizer, current_step if current_step is not None else -1, False) + + def step(self, current_step=None): + if current_step is None or current_step == -1: + current_step = self.ctr + self.ctr += 1 + + new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent + for param_group in self.optimizer.param_groups: + param_group['lr'] = new_lr diff --git a/docker/template/src/nnunetv2/training/lr_scheduler/samedlr.py b/docker/template/src/nnunetv2/training/lr_scheduler/samedlr.py new file mode 100644 index 0000000..239417e --- /dev/null +++ b/docker/template/src/nnunetv2/training/lr_scheduler/samedlr.py @@ -0,0 +1,22 @@ +import torch +from torch.optim.lr_scheduler import _LRScheduler + +# Custom LR Scheduler Implementation +class CustomWarmupDecayLR(_LRScheduler): + def __init__(self, optimizer, warmup_period, max_iterations, base_lr, weight_decay, last_epoch=-1, verbose=False): + self.warmup_period = warmup_period + self.max_iterations = max_iterations + self.base_lr = base_lr + self.weight_decay = weight_decay + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if self.last_epoch < self.warmup_period: + return [self.base_lr * ((self.last_epoch + 1) / self.warmup_period) for _ in self.optimizer.param_groups] + else: + if self.warmup_period: + shift_iter = self.last_epoch - self.warmup_period + else: + shift_iter = self.last_epoch + return [self.base_lr * (1.0 - shift_iter / self.max_iterations) ** self.weight_decay for _ in self.optimizer.param_groups] + diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/__init__.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py new file mode 100644 index 0000000..821a4e0 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -0,0 +1,1270 @@ +import inspect +import multiprocessing +import os +import shutil +import sys +import warnings +from copy import deepcopy +from datetime import datetime +from time import time, sleep +from typing import Union, Tuple, List + +import numpy as np +import torch +from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter +from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose +from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ + ContrastAugmentationTransform, GammaTransform +from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform +from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform +from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform +from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor +from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p +from torch._dynamo import OptimizedModule + +from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes +from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder +from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +from nnunetv2.inference.sliding_window_prediction import compute_gaussian +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results +from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size +from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \ + ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform +from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \ + DownsampleSegForDSTransform2 +from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ + LimitedLenWrapper +from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform +from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \ + ConvertSegmentationToRegionsTransform +from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert2DTo3DTransform, \ + Convert3DTo2DTransform +from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D +from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset +from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset +from nnunetv2.training.logging.nnunet_logger import nnUNetLogger +from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from nnunetv2.utilities.collate_outputs import collate_outputs +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA +from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy +from nnunetv2.utilities.get_network_from_plans import get_network_from_plans +from nnunetv2.utilities.helpers import empty_cache, dummy_context +from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from sklearn.model_selection import KFold +from torch import autocast, nn +from torch import distributed as dist +from torch.cuda import device_count +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP + + +class nnUNetTrainer(object): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + # From https://grugbrain.dev/. Worth a read ya big brains ;-) + + # apex predator of grug is complexity + # complexity bad + # say again: + # complexity very bad + # you say now: + # complexity very, very bad + # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex + # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime + # one day code base understandable and grug can get work done, everything good! + # next day impossible: complexity demon spirit has entered code and very dangerous situation! + + # OK OK I am guilty. But I tried. + # https://www.osnews.com/images/comics/wtfm.jpg + # https://i.pinimg.com/originals/26/b2/50/26b250a738ea4abc7a5af4d42ad93af0.jpg + + self.is_ddp = dist.is_available() and dist.is_initialized() + self.local_rank = 0 if not self.is_ddp else dist.get_rank() + + self.device = device + + # print what device we are using + if self.is_ddp: # implicitly it's clear that we use cuda in this case + print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is " + f"{dist.get_world_size()}." + f"Setting device to {self.device}") + self.device = torch.device(type='cuda', index=self.local_rank) + else: + if self.device.type == 'cuda': + # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X + self.device = torch.device(type='cuda', index=0) + print(f"Using device: {self.device}") + + # loading and saving this class for continuing from checkpoint should not happen based on pickling. This + # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we + # need. So let's save the init args + self.my_init_kwargs = {} + for k in inspect.signature(self.__init__).parameters.keys(): + self.my_init_kwargs[k] = locals()[k] + + ### Saving all the init args into class variables for later access + self.plans_manager = PlansManager(plans) + self.configuration_manager = self.plans_manager.get_configuration(configuration) + self.configuration_name = configuration + self.dataset_json = dataset_json + self.fold = fold + self.unpack_dataset = unpack_dataset + + ### Setting all the folder names. We need to make sure things don't crash in case we are just running + # inference and some of the folders may not be defined! + self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \ + if nnUNet_preprocessed is not None else None + self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name, + self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \ + if nnUNet_results is not None else None + self.output_folder = join(self.output_folder_base, f'fold_{fold}') + + self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base, + self.configuration_manager.data_identifier) + # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to + # be a different configuration in the same plans + # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using + # "previous_stage" and "next_stage"). Otherwise it won't work! + self.is_cascaded = self.configuration_manager.previous_stage_name is not None + self.folder_with_segs_from_previous_stage = \ + join(nnUNet_results, self.plans_manager.dataset_name, + self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + + self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \ + if self.is_cascaded else None + + ### Some hyperparameters for you to fiddle with + self.initial_lr = 1e-2 + self.weight_decay = 3e-5 + self.oversample_foreground_percent = 0.33 + self.num_iterations_per_epoch = 250 + self.num_val_iterations_per_epoch = 50 + self.num_epochs = 1000 + self.current_epoch = 0 + self.enable_deep_supervision = True + + ### Dealing with labels/regions + self.label_manager = self.plans_manager.get_label_manager(dataset_json) + # labels can either be a list of int (regular training) or a list of tuples of int (region-based training) + # needed for predictions. We do sigmoid in case of (overlapping) regions + + self.num_input_channels = None # -> self.initialize() + self.network = None # -> self._get_network() + self.optimizer = self.lr_scheduler = None # -> self.initialize + self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None + self.loss = None # -> self.initialize + + ### Simple logging. Don't take that away from me! + # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning + # logging + timestamp = datetime.now() + maybe_mkdir_p(self.output_folder) + self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" % + (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute, + timestamp.second)) + self.logger = nnUNetLogger() + + ### placeholders + self.dataloader_train = self.dataloader_val = None # see on_train_start + + ### initializing stuff for remembering things and such + self._best_ema = None + + ### inference things + self.inference_allowed_mirroring_axes = None # this variable is set in + # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints + + ### checkpoint saving stuff + self.save_every = 50 + self.disable_checkpointing = False + + ## DDP batch size and oversampling can differ between workers and needs adaptation + # we need to change the batch size in DDP because we don't use any of those distributed samplers + self._set_batch_size_and_oversample() + + self.was_initialized = False + + self.print_to_log_file("\n#######################################################################\n" + "Please cite the following paper when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n" + "#######################################################################\n", + also_print_to_console=True, add_timestamp=False) + + def initialize(self): + if not self.was_initialized: + self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, + self.dataset_json) + + self.network = self.build_network_architecture( + self.plans_manager, + self.dataset_json, + self.configuration_manager, + self.num_input_channels, + self.enable_deep_supervision, + ).to(self.device) + # compile network for free speedup + if self._do_i_compile(): + self.print_to_log_file('Using torch.compile...') + self.network = torch.compile(self.network) + + self.optimizer, self.lr_scheduler = self.configure_optimizers() + # if ddp, wrap in DDP wrapper + if self.is_ddp: + self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network) + self.network = DDP(self.network, device_ids=[self.local_rank]) + + self.loss = self._build_loss() + self.was_initialized = True + else: + raise RuntimeError("You have called self.initialize even though the trainer was already initialized. " + "That should not happen.") + + def _do_i_compile(self): + return ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) + + def _save_debug_information(self): + # saving some debug information + if self.local_rank == 0: + dct = {} + for k in self.__dir__(): + if not k.startswith("__"): + if not callable(getattr(self, k)) or k in ['loss', ]: + dct[k] = str(getattr(self, k)) + elif k in ['network', ]: + dct[k] = str(getattr(self, k).__class__.__name__) + else: + # print(k) + pass + if k in ['dataloader_train', 'dataloader_val']: + if hasattr(getattr(self, k), 'generator'): + dct[k + '.generator'] = str(getattr(self, k).generator) + if hasattr(getattr(self, k), 'num_processes'): + dct[k + '.num_processes'] = str(getattr(self, k).num_processes) + if hasattr(getattr(self, k), 'transform'): + dct[k + '.transform'] = str(getattr(self, k).transform) + import subprocess + hostname = subprocess.getoutput(['hostname']) + dct['hostname'] = hostname + torch_version = torch.__version__ + if self.device.type == 'cuda': + gpu_name = torch.cuda.get_device_name() + dct['gpu_name'] = gpu_name + cudnn_version = torch.backends.cudnn.version() + else: + cudnn_version = 'None' + dct['device'] = str(self.device) + dct['torch_version'] = torch_version + dct['cudnn_version'] = cudnn_version + save_json(dct, join(self.output_folder, "debug.json")) + + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + """ + This is where you build the architecture according to the plans. There is no obligation to use + get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what + you want. Even ignore the plans and just return something static (as long as it can process the requested + patch size) + but don't bug us with your bugs arising from fiddling with this :-P + This is the function that is called in inference as well! This is needed so that all network architecture + variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for + training, so if you change the network architecture during training by deriving a new trainer class then + inference will know about it). + + If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet: + > label_manager = plans_manager.get_label_manager(dataset_json) + > label_manager.num_segmentation_heads + (why so complicated? -> We can have either classical training (classes) or regions. If we have regions, + the number of outputs is != the number of classes. Also there is the ignore label for which no output + should be generated. label_manager takes care of all that for you.) + + """ + return get_network_from_plans(plans_manager, dataset_json, configuration_manager, + num_input_channels, deep_supervision=enable_deep_supervision) + + def _get_deep_supervision_scales(self): + if self.enable_deep_supervision: + deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack( + self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1] + else: + deep_supervision_scales = None # for train and val_transforms + return deep_supervision_scales + + def _set_batch_size_and_oversample(self): + if not self.is_ddp: + # set batch size to what the plan says, leave oversample untouched + self.batch_size = self.configuration_manager.batch_size + else: + # batch size is distributed over DDP workers and we need to change oversample_percent for each worker + batch_sizes = [] + oversample_percents = [] + + world_size = dist.get_world_size() + my_rank = dist.get_rank() + + global_batch_size = self.configuration_manager.batch_size + assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \ + 'GPUs... Duh.' + + batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int) + + for rank in range(world_size): + if (rank + 1) * batch_size_per_GPU > global_batch_size: + batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size) + else: + batch_size = batch_size_per_GPU + + batch_sizes.append(batch_size) + + sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1]) + sample_id_high = np.sum(batch_sizes) + + if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent): + oversample_percents.append(0.0) + elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent): + oversample_percents.append(1.0) + else: + percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size + oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) - + sample_id_low / global_batch_size) / percent_covered_by_this_rank) + oversample_percents.append(oversample_percent_here) + + print("worker", my_rank, "oversample", oversample_percents[my_rank]) + print("worker", my_rank, "batch_size", batch_sizes[my_rank]) + # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank]) + # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank]) + + self.batch_size = batch_sizes[my_rank] + self.oversample_foreground_percent = oversample_percents[my_rank] + + def _build_loss(self): + if self.label_manager.has_regions: + loss = DC_and_BCE_loss({}, + {'batch_dice': self.configuration_manager.batch_dice, + 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp}, + use_ignore_label=self.label_manager.ignore_label is not None, + dice_class=MemoryEfficientSoftDiceLoss) + else: + loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, + 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, + ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + """ + This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it. + """ + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation) + if dim == 2: + do_dummy_2d_data_aug = False + # todo revisit this parametrization + if max(patch_size) / min(patch_size) > 1.5: + rotation_for_DA = { + 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + mirror_axes = (0, 1) + elif dim == 3: + # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad + # order of the axes is determined by spacing, not image size + do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD + if do_dummy_2d_data_aug: + # why do we rotate 180 deg here all the time? We should also restrict it + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + } + mirror_axes = (0, 1, 2) + else: + raise RuntimeError() + + # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the + # old nnunet for now) + initial_patch_size = get_patch_size(patch_size[-dim:], + *rotation_for_DA.values(), + (0.85, 1.25)) + if do_dummy_2d_data_aug: + initial_patch_size[0] = patch_size[0] + + self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') + self.inference_allowed_mirroring_axes = mirror_axes + + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True): + if self.local_rank == 0: + timestamp = time() + dt_object = datetime.fromtimestamp(timestamp) + + if add_timestamp: + args = (f"{dt_object}:", *args) + + successful = False + max_attempts = 5 + ctr = 0 + while not successful and ctr < max_attempts: + try: + with open(self.log_file, 'a+') as f: + for a in args: + f.write(str(a)) + f.write(" ") + f.write("\n") + successful = True + except IOError: + print(f"{datetime.fromtimestamp(timestamp)}: failed to log: ", sys.exc_info()) + sleep(0.5) + ctr += 1 + if also_print_to_console: + print(*args) + elif also_print_to_console: + print(*args) + + def print_plans(self): + if self.local_rank == 0: + dct = deepcopy(self.plans_manager.plans) + del dct['configurations'] + self.print_to_log_file(f"\nThis is the configuration used by this " + f"training:\nConfiguration name: {self.configuration_name}\n", + self.configuration_manager, '\n', add_timestamp=False) + self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) + return optimizer, lr_scheduler + + def plot_network_architecture(self): + if self._do_i_compile(): + self.print_to_log_file("Unable to plot network architecture: nnUNet_compile is enabled!") + return + + if self.local_rank == 0: + try: + # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(') + # pip install git+https://github.com/saugatkandel/hiddenlayer.git + + # from torchviz import make_dot + # # not viable. + # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels, + # *self.configuration_manager.patch_size), + # device=self.device)))).render( + # join(self.output_folder, "network_architecture.pdf"), format='pdf') + # self.optimizer.zero_grad() + + # broken. + + import hiddenlayer as hl + g = hl.build_graph(self.network, + torch.rand((1, self.num_input_channels, + *self.configuration_manager.patch_size), + device=self.device), + transforms=None) + g.save(join(self.output_folder, "network_architecture.pdf")) + del g + except Exception as e: + self.print_to_log_file("Unable to plot network architecture:") + self.print_to_log_file(e) + + # self.print_to_log_file("\nprinting the network instead:\n") + # self.print_to_log_file(self.network) + # self.print_to_log_file("\n") + finally: + empty_cache(self.device) + + def do_split(self): + """ + The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, + so always the same) and save it as splits_final.pkl file in the preprocessed data directory. + Sometimes you may want to create your own split for various reasons. For this you will need to create your own + splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in + it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) + and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to + use a random 80:20 data split. + :return: + """ + if self.fold == "all": + # if fold==all then we use all images for training and validation + case_identifiers = get_case_identifiers(self.preprocessed_dataset_folder) + tr_keys = case_identifiers + val_keys = tr_keys + else: + splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json") + dataset = nnUNetDataset(self.preprocessed_dataset_folder, case_identifiers=None, + num_images_properties_loading_threshold=0, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage) + # if the split file does not exist we need to create it + if not isfile(splits_file): + self.print_to_log_file("Creating new 5-fold cross-validation split...") + splits = [] + all_keys_sorted = np.sort(list(dataset.keys())) + kfold = KFold(n_splits=5, shuffle=True, random_state=12345) + for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): + train_keys = np.array(all_keys_sorted)[train_idx] + test_keys = np.array(all_keys_sorted)[test_idx] + splits.append({}) + splits[-1]['train'] = list(train_keys) + splits[-1]['val'] = list(test_keys) + save_json(splits, splits_file) + + else: + self.print_to_log_file("Using splits from existing split file:", splits_file) + splits = load_json(splits_file) + self.print_to_log_file(f"The split file contains {len(splits)} splits.") + + self.print_to_log_file("Desired fold for training: %d" % self.fold) + if self.fold < len(splits): + tr_keys = splits[self.fold]['train'] + val_keys = splits[self.fold]['val'] + self.print_to_log_file("This split has %d training and %d validation cases." + % (len(tr_keys), len(val_keys))) + else: + self.print_to_log_file("INFO: You requested fold %d for training but splits " + "contain only %d folds. I am now creating a " + "random (but seeded) 80:20 split!" % (self.fold, len(splits))) + # if we request a fold that is not in the split file, create a random 80:20 split + rnd = np.random.RandomState(seed=12345 + self.fold) + keys = np.sort(list(dataset.keys())) + idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) + idx_val = [i for i in range(len(keys)) if i not in idx_tr] + tr_keys = [keys[i] for i in idx_tr] + val_keys = [keys[i] for i in idx_val] + self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." + % (len(tr_keys), len(val_keys))) + if any([i in val_keys for i in tr_keys]): + self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the ' + 'splits.json or ignore if this is intentional.') + return tr_keys, val_keys + + def get_tr_and_val_datasets(self): + # create dataset split + tr_keys, val_keys = self.do_split() + + # load the datasets for training and validation. Note that we always draw random samples so we really don't + # care about distributing training cases across GPUs. + dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + return dataset_tr, dataset_val + + def get_dataloaders(self): + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + + deep_supervision_scales = self._get_deep_supervision_scales() + + ( + rotation_for_DA, + do_dummy_2d_data_aug, + initial_patch_size, + mirror_axes, + ) = self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=1, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.foreground_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms, + num_processes=allowed_num_processes, num_cached=6, seeds=None, + pin_memory=self.device.type == 'cuda', wait_time=0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val, + transform=val_transforms, num_processes=max(1, allowed_num_processes // 2), + num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda', + wait_time=0.02) + return mt_gen_train, mt_gen_val + + def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): + dataset_tr, dataset_val = self.get_tr_and_val_datasets() + + if dim == 2: + dl_tr = nnUNetDataLoader2D(dataset_tr, self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + dl_val = nnUNetDataLoader2D(dataset_val, self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + else: + dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + return dl_tr, dl_val + + @staticmethod + def get_training_transforms( + patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple, None], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 3, + order_resampling_seg: int = 1, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None, + ) -> AbstractTransform: + tr_transforms = [] + if do_dummy_2d_data_aug: + ignore_axes = (0,) + tr_transforms.append(Convert3DTo2DTransform()) + patch_size_spatial = patch_size[1:] + else: + patch_size_spatial = patch_size + ignore_axes = None + + tr_transforms.append(SpatialTransform( + patch_size_spatial, patch_center_dist_from_border=None, + do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0), + do_rotation=True, angle_x=rotation_for_DA['x'], angle_y=rotation_for_DA['y'], angle_z=rotation_for_DA['z'], + p_rot_per_axis=1, # todo experiment with this + do_scale=True, scale=(0.7, 1.4), + border_mode_data="constant", border_cval_data=0, order_data=order_resampling_data, + border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_resampling_seg, + random_crop=False, # random cropping is part of our dataloaders + p_el_per_sample=0, p_scale_per_sample=0.2, p_rot_per_sample=0.2, + independent_scale_for_each_axis=False # todo experiment with this + )) + + if do_dummy_2d_data_aug: + tr_transforms.append(Convert2DTo3DTransform()) + + tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) + tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2, + p_per_channel=0.5)) + tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15)) + tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15)) + tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, + p_per_channel=0.5, + order_downsample=0, order_upsample=3, p_per_sample=0.25, + ignore_axes=ignore_axes)) + tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1)) + tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3)) + + if mirror_axes is not None and len(mirror_axes) > 0: + tr_transforms.append(MirrorTransform(mirror_axes)) + + if use_mask_for_norm is not None and any(use_mask_for_norm): + tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], + mask_idx_in_seg=0, set_outside_to=0)) + + tr_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations' + tr_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) + tr_transforms.append(ApplyRandomBinaryOperatorTransform( + channel_idx=list(range(-len(foreground_labels), 0)), + p_per_sample=0.4, + key="data", + strel_size=(1, 8), + p_per_label=1)) + tr_transforms.append( + RemoveRandomConnectedComponentFromOneHotEncodingTransform( + channel_idx=list(range(-len(foreground_labels), 0)), + key="data", + p_per_sample=0.2, + fill_with_other_class_p=0, + dont_do_if_covers_more_than_x_percent=0.15)) + + tr_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + tr_transforms = Compose(tr_transforms) + return tr_transforms + + @staticmethod + def get_validation_transforms( + deep_supervision_scales: Union[List, Tuple, None], + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None, + ) -> AbstractTransform: + val_transforms = [] + val_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) + + val_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + + val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + val_transforms = Compose(val_transforms) + return val_transforms + + def set_deep_supervision_enabled(self, enabled: bool): + """ + This function is specific for the default architecture in nnU-Net. If you change the architecture, there are + chances you need to change this as well! + """ + if self.is_ddp: + self.network.module.decoder.deep_supervision = enabled + else: + self.network.decoder.deep_supervision = enabled + + def on_train_start(self): + if not self.was_initialized: + self.initialize() + + maybe_mkdir_p(self.output_folder) + + # make sure deep supervision is on in the network + self.set_deep_supervision_enabled(self.enable_deep_supervision) + + self.print_plans() + empty_cache(self.device) + + # maybe unpack + if self.unpack_dataset and self.local_rank == 0: + self.print_to_log_file('unpacking dataset...') + unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False, + num_processes=max(1, round(get_allowed_n_proc_DA() // 2))) + self.print_to_log_file('unpacking done...') + + if self.is_ddp: + dist.barrier() + + # dataloaders must be instantiated here because they need access to the training data which may not be present + # when doing inference + self.dataloader_train, self.dataloader_val = self.get_dataloaders() + + # copy plans and dataset.json so that they can be used for restoring everything we need for inference + save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False) + save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False) + + # we don't really need the fingerprint but its still handy to have it with the others + shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'), + join(self.output_folder_base, 'dataset_fingerprint.json')) + + # produces a pdf in output folder + self.plot_network_architecture() + + self._save_debug_information() + + # print(f"batch size: {self.batch_size}") + # print(f"oversample: {self.oversample_foreground_percent}") + + def on_train_end(self): + # dirty hack because on_epoch_end increments the epoch counter and this is executed afterwards. + # This will lead to the wrong current epoch to be stored + self.current_epoch -= 1 + self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth")) + self.current_epoch += 1 + + # now we can delete latest + if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")): + os.remove(join(self.output_folder, "checkpoint_latest.pth")) + + # shut down dataloaders + old_stdout = sys.stdout + with open(os.devnull, 'w') as f: + sys.stdout = f + if self.dataloader_train is not None: + self.dataloader_train._finish() + if self.dataloader_val is not None: + self.dataloader_val._finish() + sys.stdout = old_stdout + + empty_cache(self.device) + self.print_to_log_file("Training done.") + + def on_train_epoch_start(self): + self.network.train() + self.lr_scheduler.step(self.current_epoch) + self.print_to_log_file('') + self.print_to_log_file(f'Epoch {self.current_epoch}') + self.print_to_log_file( + f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}") + # lrs are the same for all workers so we don't need to gather them in case of DDP training + self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch) + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + output = self.network(data) + # del data + l = self.loss(output, target) + + if self.grad_scaler is not None: + self.grad_scaler.scale(l).backward() + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + l.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.optimizer.step() + return {'loss': l.detach().cpu().numpy()} + + def on_train_epoch_end(self, train_outputs: List[dict]): + outputs = collate_outputs(train_outputs) + + if self.is_ddp: + losses_tr = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(losses_tr, outputs['loss']) + loss_here = np.vstack(losses_tr).mean() + else: + loss_here = np.mean(outputs['loss']) + + self.logger.log('train_losses', loss_here, self.current_epoch) + + def on_validation_epoch_start(self): + self.network.eval() + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + output = self.network(data) + del data + l = self.loss(output, target) + + # we only need the output with the highest output resolution (if DS enabled) + if self.enable_deep_supervision: + output = output[0] + target = target[0] + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, output.ndim)) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + # no need for softmax + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + def on_validation_epoch_end(self, val_outputs: List[dict]): + outputs_collated = collate_outputs(val_outputs) + tp = np.sum(outputs_collated['tp_hard'], 0) + fp = np.sum(outputs_collated['fp_hard'], 0) + fn = np.sum(outputs_collated['fn_hard'], 0) + + if self.is_ddp: + world_size = dist.get_world_size() + + tps = [None for _ in range(world_size)] + dist.all_gather_object(tps, tp) + tp = np.vstack([i[None] for i in tps]).sum(0) + + fps = [None for _ in range(world_size)] + dist.all_gather_object(fps, fp) + fp = np.vstack([i[None] for i in fps]).sum(0) + + fns = [None for _ in range(world_size)] + dist.all_gather_object(fns, fn) + fn = np.vstack([i[None] for i in fns]).sum(0) + + losses_val = [None for _ in range(world_size)] + dist.all_gather_object(losses_val, outputs_collated['loss']) + loss_here = np.vstack(losses_val).mean() + else: + loss_here = np.mean(outputs_collated['loss']) + + global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(tp, fp, fn)]] + mean_fg_dice = np.nanmean(global_dc_per_class) + self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch) + self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch) + self.logger.log('val_losses', loss_here, self.current_epoch) + + def on_epoch_start(self): + self.logger.log('epoch_start_timestamps', time(), self.current_epoch) + + def on_epoch_end(self): + self.logger.log('epoch_end_timestamps', time(), self.current_epoch) + + self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4)) + self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4)) + self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in + self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]]) + self.print_to_log_file( + f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s") + + # handling periodic checkpointing + current_epoch = self.current_epoch + if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1): + self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth')) + + # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this + if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema: + self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1] + self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}") + self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth')) + + if self.local_rank == 0: + self.logger.plot_progress_png(self.output_folder) + + self.current_epoch += 1 + + def save_checkpoint(self, filename: str) -> None: + if self.local_rank == 0: + if not self.disable_checkpointing: + if self.is_ddp: + mod = self.network.module + else: + mod = self.network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + checkpoint = { + 'network_weights': mod.state_dict(), + 'optimizer_state': self.optimizer.state_dict(), + 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None, + 'logging': self.logger.get_checkpoint(), + '_best_ema': self._best_ema, + 'current_epoch': self.current_epoch + 1, + 'init_args': self.my_init_kwargs, + 'trainer_name': self.__class__.__name__, + 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes, + } + torch.save(checkpoint, filename) + else: + self.print_to_log_file('No checkpoint written, checkpointing is disabled') + + def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None: + if not self.was_initialized: + self.initialize() + + if isinstance(filename_or_checkpoint, str): + checkpoint = torch.load(filename_or_checkpoint, map_location=self.device) + # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not + # match. Use heuristic to make it match + new_state_dict = {} + for k, value in checkpoint['network_weights'].items(): + key = k + if key not in self.network.state_dict().keys() and key.startswith('module.'): + key = key[7:] + new_state_dict[key] = value + + self.my_init_kwargs = checkpoint['init_args'] + self.current_epoch = checkpoint['current_epoch'] + self.logger.load_checkpoint(checkpoint['logging']) + self._best_ema = checkpoint['_best_ema'] + self.inference_allowed_mirroring_axes = checkpoint[ + 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes + + # messing with state dict naming schemes. Facepalm. + if self.is_ddp: + if isinstance(self.network.module, OptimizedModule): + self.network.module._orig_mod.load_state_dict(new_state_dict) + else: + self.network.module.load_state_dict(new_state_dict) + else: + if isinstance(self.network, OptimizedModule): + self.network._orig_mod.load_state_dict(new_state_dict) + else: + self.network.load_state_dict(new_state_dict) + self.optimizer.load_state_dict(checkpoint['optimizer_state']) + if self.grad_scaler is not None: + if checkpoint['grad_scaler_state'] is not None: + self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state']) + + def perform_actual_validation(self, save_probabilities: bool = False): + self.set_deep_supervision_enabled(False) + self.network.eval() + + predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + perform_everything_on_device=True, device=self.device, verbose=False, + verbose_preprocessing=False, allow_tqdm=False) + predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None, + self.dataset_json, self.__class__.__name__, + self.inference_allowed_mirroring_axes) + + with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool: + worker_list = [i for i in segmentation_export_pool._pool] + validation_output_folder = join(self.output_folder, 'validation') + maybe_mkdir_p(validation_output_folder) + + # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute + # the validation keys across the workers. + _, val_keys = self.do_split() + if self.is_ddp: + val_keys = val_keys[self.local_rank:: dist.get_world_size()] + + dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + + next_stages = self.configuration_manager.next_stage_names + + if next_stages is not None: + _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages] + + results = [] + + for k in dataset_val.keys(): + proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, + allowed_num_queued=2) + while not proceed: + sleep(0.1) + proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, + allowed_num_queued=2) + + self.print_to_log_file(f"predicting {k}") + data, seg, properties = dataset_val.load_case(k) + + if self.is_cascaded: + data = np.vstack((data, convert_labelmap_to_one_hot(seg[-1], self.label_manager.foreground_labels, + output_dtype=data.dtype))) + with warnings.catch_warnings(): + # ignore 'The given NumPy array is not writable' warning + warnings.simplefilter("ignore") + data = torch.from_numpy(data) + + output_filename_truncated = join(validation_output_folder, k) + + try: + prediction = predictor.predict_sliding_window_return_logits(data) + except RuntimeError: + predictor.perform_everything_on_device = False + prediction = predictor.predict_sliding_window_return_logits(data) + predictor.perform_everything_on_device = True + + prediction = prediction.cpu() + + # this needs to go into background processes + results.append( + segmentation_export_pool.starmap_async( + export_prediction_from_logits, ( + (prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, output_filename_truncated, save_probabilities), + ) + ) + ) + # for debug purposes + # export_prediction(prediction_for_export, properties, self.configuration, self.plans, self.dataset_json, + # output_filename_truncated, save_probabilities) + + # if needed, export the softmax prediction for the next stage + if next_stages is not None: + for n in next_stages: + next_stage_config_manager = self.plans_manager.get_configuration(n) + expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name, + next_stage_config_manager.data_identifier) + + try: + # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented + tmp = nnUNetDataset(expected_preprocessed_folder, [k], + num_images_properties_loading_threshold=0) + d, s, p = tmp.load_case(k) + except FileNotFoundError: + self.print_to_log_file( + f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! " + f"Run the preprocessing for this configuration first!") + continue + + target_shape = d.shape[1:] + output_folder = join(self.output_folder_base, 'predicted_next_stage', n) + output_file = join(output_folder, k + '.npz') + + # resample_and_save(prediction, target_shape, output_file, self.plans_manager, self.configuration_manager, properties, + # self.dataset_json) + results.append(segmentation_export_pool.starmap_async( + resample_and_save, ( + (prediction, target_shape, output_file, self.plans_manager, + self.configuration_manager, + properties, + self.dataset_json), + ) + )) + + _ = [r.get() for r in results] + + if self.is_ddp: + dist.barrier() + + if self.local_rank == 0: + metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'), + validation_output_folder, + join(validation_output_folder, 'summary.json'), + self.plans_manager.image_reader_writer_class(), + self.dataset_json["file_ending"], + self.label_manager.foreground_regions if self.label_manager.has_regions else + self.label_manager.foreground_labels, + self.label_manager.ignore_label, chill=True) + self.print_to_log_file("Validation complete", also_print_to_console=True) + self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True) + + self.set_deep_supervision_enabled(True) + compute_gaussian.cache_clear() + + def run_training(self): + self.on_train_start() + + for epoch in range(self.current_epoch, self.num_epochs): + self.on_epoch_start() + + self.on_train_epoch_start() + train_outputs = [] + for batch_id in range(self.num_iterations_per_epoch): + train_outputs.append(self.train_step(next(self.dataloader_train))) + self.on_train_epoch_end(train_outputs) + + with torch.no_grad(): + self.on_validation_epoch_start() + val_outputs = [] + for batch_id in range(self.num_val_iterations_per_epoch): + val_outputs.append(self.validation_step(next(self.dataloader_val))) + self.on_validation_epoch_end(val_outputs) + + self.on_epoch_end() + + self.on_train_end() diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerLightMUNet.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerLightMUNet.py new file mode 100644 index 0000000..ae41d42 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerLightMUNet.py @@ -0,0 +1,141 @@ +from nnunetv2.training.nnUNetTrainer.variants.network_architecture.nnUNetTrainerNoDeepSupervision import \ + nnUNetTrainerNoDeepSupervision +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from torch import nn +import torch + +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn + +from nnunetv2.nets.LightMUNet import LightMUNet +from torch.optim import Adam + +class nnUNetTrainerLightMUNet(nnUNetTrainerNoDeepSupervision): + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device('cuda') + ): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.grad_scaler = None + self.initial_lr = 1e-4 + self.weight_decay = 1e-5 + + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = False) -> nn.Module: + + label_manager = plans_manager.get_label_manager(dataset_json) + + model = LightMUNet( + spatial_dims = len(configuration_manager.patch_size), + init_filters = 32, + in_channels=num_input_channels, + out_channels=label_manager.num_segmentation_heads, + blocks_down=[1, 2, 2, 4], + blocks_up=[1, 1, 1], + ) + + return model + + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + + output = self.network(data) + l = self.loss(output, target) + l.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.optimizer.step() + + return {'loss': l.detach().cpu().numpy()} + + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + + output = self.network(data) + del data + l = self.loss(output, target) + + axes = [0] + list(range(2, output.ndim)) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + def configure_optimizers(self): + + optimizer = Adam(self.network.parameters(), lr=self.initial_lr, weight_decay=self.weight_decay, eps=1e-5) + scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs, exponent=0.9) + + return optimizer, scheduler + + def set_deep_supervision_enabled(self, enabled: bool): + pass + + + +class nnUNetTrainerLightMUNet_100epochs(nnUNetTrainerLightMUNet): + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device('cuda') + ): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 100 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerMedNext.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerMedNext.py new file mode 100644 index 0000000..c4051a4 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerMedNext.py @@ -0,0 +1,259 @@ +from nnunetv2.training.nnUNetTrainer.variants.network_architecture.nnUNetTrainerNoDeepSupervision import \ + nnUNetTrainerNoDeepSupervision +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn +import torch +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch import nn +from nnunetv2.nets.mednextv1.MedNextV1 import MedNeXt + + +class nnUNetTrainerMedNext(nnUNetTrainerNoDeepSupervision): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + original_patch_size = self.configuration_manager.patch_size + new_patch_size = [-1] * len(original_patch_size) + for i in range(len(original_patch_size)): + if (original_patch_size[i] / 2**5) < 1 or ((original_patch_size[i] / 2**5) % 1) != 0: + new_patch_size[i] = round(original_patch_size[i] / 2**5 + 0.5) * 2**5 + else: + new_patch_size[i] = original_patch_size[i] + self.configuration_manager.configuration['patch_size'] = new_patch_size + self.print_to_log_file("Patch size changed from {} to {}".format(original_patch_size, new_patch_size)) + self.plans_manager.plans['configurations'][self.configuration_name]['patch_size'] = new_patch_size + + self.grad_scaler = None + self.initial_lr = 1e-3 + self.weight_decay = 0.01 + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + + output = self.network(data) + l = self.loss(output, target) + l.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.optimizer.step() + + return {'loss': l.detach().cpu().numpy()} + + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + output = self.network(data) + del data + l = self.loss(output, target) + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, output.ndim)) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + # no need for softmax + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + def configure_optimizers(self): + + optimizer = AdamW(self.network.parameters(), lr=self.initial_lr, weight_decay=self.weight_decay, eps=1e-5) + scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs, eta_min=1e-6) + + self.print_to_log_file(f"Using optimizer {optimizer}") + self.print_to_log_file(f"Using scheduler {scheduler}") + + return optimizer, scheduler + + def set_deep_supervision_enabled(self, enabled: bool): + pass + + +class nnUNetTrainerV2_MedNeXt_L_kernel5(nnUNetTrainerMedNext): + """ + Residual Encoder + UMmaba Bottleneck + Residual Decoder + Skip Connections + """ + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = False) -> nn.Module: + + label_manager = plans_manager.get_label_manager(dataset_json) + + model = create_mednextv1_large(num_input_channels, label_manager.num_segmentation_heads, 5, False) + + return model +class nnUNetTrainerV2_MedNeXt_L_kernel5_100epochs(nnUNetTrainerV2_MedNeXt_L_kernel5): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 100 + +class nnUNetTrainerV2_MedNeXt_B_kernel5(nnUNetTrainerMedNext): + """ + Residual Encoder + UMmaba Bottleneck + Residual Decoder + Skip Connections + """ + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = False) -> nn.Module: + + label_manager = plans_manager.get_label_manager(dataset_json) + + model = create_mednextv1_base(num_input_channels, label_manager.num_segmentation_heads, 5, False) + + return model + +class nnUNetTrainerV2_MedNeXt_B_kernel5_100epochs(nnUNetTrainerV2_MedNeXt_B_kernel5): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 100 + + + +class nnUNetTrainerV2_MedNeXt_M_kernel5(nnUNetTrainerMedNext): + """ + Residual Encoder + UMmaba Bottleneck + Residual Decoder + Skip Connections + """ + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = False) -> nn.Module: + + label_manager = plans_manager.get_label_manager(dataset_json) + + model = create_mednextv1_medium(num_input_channels, label_manager.num_segmentation_heads, 5, False) + + return model + +def create_mednextv1_small(num_input_channels, num_classes, kernel_size=3, ds=False): + return MedNeXt( + in_channels=num_input_channels, + n_channels=32, + n_classes=num_classes, + exp_r=2, + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down=True, + block_counts=[2, 2, 2, 2, 2, 2, 2, 2, 2] + ) + + +def create_mednextv1_base(num_input_channels, num_classes, kernel_size=3, ds=False): + return MedNeXt( + in_channels=num_input_channels, + n_channels=32, + n_classes=num_classes, + exp_r=[2, 3, 4, 4, 4, 4, 4, 3, 2], + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down=True, + block_counts=[2, 2, 2, 2, 2, 2, 2, 2, 2] + ) + + +def create_mednextv1_medium(num_input_channels, num_classes, kernel_size=3, ds=False): + return MedNeXt( + in_channels=num_input_channels, + n_channels=32, + n_classes=num_classes, + exp_r=[2, 3, 4, 4, 4, 4, 4, 3, 2], + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down=True, + block_counts=[3, 4, 4, 4, 4, 4, 4, 4, 3], + checkpoint_style='outside_block' + ) + + +def create_mednextv1_large(num_input_channels, num_classes, kernel_size=3, ds=False): + return MedNeXt( + in_channels=num_input_channels, + n_channels=32, + n_classes=num_classes, + exp_r=[3, 4, 8, 8, 8, 8, 8, 4, 3], + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down=True, + block_counts=[3, 4, 8, 8, 8, 8, 8, 4, 3], + checkpoint_style='outside_block' + ) + + +def create_mednext_v1(num_input_channels, num_classes, model_id, kernel_size=3, + deep_supervision=False): + model_dict = { + 'S': create_mednextv1_small, + 'B': create_mednextv1_base, + 'M': create_mednextv1_medium, + 'L': create_mednextv1_large, + } + + return model_dict[model_id]( + num_input_channels, num_classes, kernel_size, deep_supervision + ) diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerSAMed.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerSAMed.py new file mode 100644 index 0000000..e7b08f7 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerSAMed.py @@ -0,0 +1,306 @@ +from nnunetv2.training.nnUNetTrainer.variants.network_architecture.nnUNetTrainerNoDeepSupervision import \ + nnUNetTrainerNoDeepSupervision +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn +import torch +from torch.optim import AdamW +from torch import nn +from nnunetv2.nets.sam_lora_image_encoder import LoRA_Sam +from nnunetv2.nets.segment_anything.modeling.mask_decoder import MaskDecoder +from nnunetv2.nets.segment_anything import sam_model_registry +from nnunetv2.training.lr_scheduler.samedlr import CustomWarmupDecayLR +from monai.transforms import ( + Resize, + +) +from torch._dynamo import OptimizedModule + +from typing import Union + + +class nnUNetTrainerSAMed(nnUNetTrainerNoDeepSupervision): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + original_patch_size = self.configuration_manager.patch_size + new_patch_size = [-1] * len(original_patch_size) + for i in range(len(original_patch_size)): + if (original_patch_size[i] / 2 ** 5) < 1 or ((original_patch_size[i] / 2 ** 5) % 1) != 0: + new_patch_size[i] = round(original_patch_size[i] / 2 ** 5 + 0.5) * 2 ** 5 + else: + new_patch_size[i] = original_patch_size[i] + self.configuration_manager.configuration['patch_size'] = new_patch_size + self.print_to_log_file("Patch size changed from {} to {}".format(original_patch_size, new_patch_size)) + self.plans_manager.plans['configurations'][self.configuration_name]['patch_size'] = new_patch_size + self.initial_lr = 1e-3 + self.weight_decay = 0.01 + self.lr_decay=0.9 + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + low_res_label_batch = [self.resize(i.to(self.device, non_blocking=True).squeeze()) for i in target] + else: + target = target.to(self.device, non_blocking=True) + low_res_label_batch = self.resize(target.squeeze()) + + self.optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True): + outputs = self.network(data, True, self.patch_size) + # print(outputs['low_res_logits'].size(), low_res_label_batch.size(),self.label_manager.has_regions) + # print(torch.unique(low_res_label_batch),) + l = self.loss(outputs['low_res_logits'], low_res_label_batch.unsqueeze(1)) + + self.grad_scaler.scale(l).backward() + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + + + return {'loss': l.detach().cpu().numpy()} + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + + if isinstance(target, list): + low_res_label_batch = [self.resize(i.to(self.device, non_blocking=True).squeeze()) for i in target] + else: + target = target.to(self.device, non_blocking=True) + low_res_label_batch = self.resize(target.squeeze()) + + self.optimizer.zero_grad(set_to_none=True) + + # Autocast is a little ****. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + output = self.network(data,True, self.patch_size) + del data + + l = self.loss(output['low_res_logits'], low_res_label_batch.unsqueeze(1)) + output_masks = output['masks'] + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, output_masks.ndim)) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output_masks) > 0.5).long() + else: + # no need for softmax + output_seg = output_masks.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output_masks.shape, device=output_masks.device, + dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + # def calc_loss(self,outputs, low_res_label_batch, ce_loss, dice_loss, dice_weight: float = 0.8): + # low_res_logits = outputs['low_res_logits'] + # loss_ce = ce_loss(low_res_logits, low_res_label_batch.long()) + # loss_dice = dice_loss(low_res_logits, low_res_label_batch, softmax=True) + # loss = (1 - dice_weight) * loss_ce + dice_weight * loss_dice + # return loss, loss_ce, loss_dice + + # %% + + def configure_optimizers(self): + + # Custom scheduler setup + optimizer = AdamW(filter(lambda p: p.requires_grad, self.network.parameters()), lr=self.initial_lr, + betas=(0.9, 0.999), + weight_decay=0.1) + scheduler = CustomWarmupDecayLR(optimizer, warmup_period=10, max_iterations=self.num_epochs, + base_lr=self.initial_lr, weight_decay=self.lr_decay) + + self.print_to_log_file(f"Using optimizer {optimizer}") + self.print_to_log_file(f"Using scheduler {scheduler}") + + return optimizer, scheduler + + def set_deep_supervision_enabled(self, enabled: bool): + pass + + def save_checkpoint(self, filename: str) -> None: + if self.local_rank == 0: + if not self.disable_checkpointing: + if self.is_ddp: + mod = self.network.module + else: + mod = self.network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + checkpoint = { + 'network_weights': mod.get_lora_parameters(), + 'optimizer_state': self.optimizer.state_dict(), + 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None, + 'logging': self.logger.get_checkpoint(), + '_best_ema': self._best_ema, + 'current_epoch': self.current_epoch + 1, + 'init_args': self.my_init_kwargs, + 'trainer_name': self.__class__.__name__, + 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes, + } + torch.save(checkpoint, filename) + else: + self.print_to_log_file('No checkpoint written, checkpointing is disabled') + + def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None: + if not self.was_initialized: + self.initialize() + + if isinstance(filename_or_checkpoint, str): + checkpoint = torch.load(filename_or_checkpoint, map_location=self.device) + # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not + # match. Use heuristic to make it match + new_state_dict = {} + for k, value in checkpoint['network_weights'].items(): + key = k + if key not in self.network.state_dict().keys() and key.startswith('module.'): + key = key[7:] + new_state_dict[key] = value + + self.my_init_kwargs = checkpoint['init_args'] + self.current_epoch = checkpoint['current_epoch'] + self.logger.load_checkpoint(checkpoint['logging']) + self._best_ema = checkpoint['_best_ema'] + self.inference_allowed_mirroring_axes = checkpoint[ + 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes + + # messing with state dict naming schemes. Facepalm. + if self.is_ddp: + if isinstance(self.network.module, OptimizedModule): + self.network.module._orig_mod.load_lora_parameters(new_state_dict) + else: + self.network.module.load_lora_parameters(new_state_dict) + else: + if isinstance(self.network, OptimizedModule): + self.network._orig_mod.load_lora_parameters(new_state_dict) + else: + self.network.load_lora_parameters(new_state_dict) + self.optimizer.load_state_dict(checkpoint['optimizer_state']) + if self.grad_scaler is not None: + if checkpoint['grad_scaler_state'] is not None: + self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state']) + + +class nnUNetTrainerV2_SAMed_h_r_4(nnUNetTrainerSAMed): + """ + Residual Encoder + UMmaba Bottleneck + Residual Decoder + Skip Connections + """ + + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.patch_size = 512 + self.resize = Resize(spatial_size=(128, 128), mode='nearest') + # self.configuration_manager.patch_size=[self.patch_size, self.patch_size] + self.lr_decay=7 + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = False) -> nn.Module: + label_manager = plans_manager.get_label_manager(dataset_json) + + sam, img_embedding_size = sam_model_registry['vit_h'](image_size=512, + num_classes=8, # To load LoRA weights + checkpoint='checkpoints/sam_vit_h_4b8939.pth', + pixel_mean=[0, 0, 0], + pixel_std=[1, 1, 1]) + model = LoRA_Sam(sam, 4) + # net.load_lora_parameters('checkpoints/epoch_299.pth') + model.sam.mask_decoder = MaskDecoder(transformer=model.sam.mask_decoder.transformer, + transformer_dim=model.sam.mask_decoder.transformer_dim, + num_multimask_outputs=label_manager.num_segmentation_heads-1 #remove bg + ) + + return model + +class nnUNetTrainerV2_SAMed_h_r_4_100epochs(nnUNetTrainerV2_SAMed_h_r_4): + """ + Residual Encoder + UMmaba Bottleneck + Residual Decoder + Skip Connections + """ + + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + + self.num_epochs = 100 + +class nnUNetTrainerV2_SAMed_b_r_4(nnUNetTrainerSAMed): + """ + Residual Encoder + UMmaba Bottleneck + Residual Decoder + Skip Connections + """ + + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.patch_size = 256 + self.resize = Resize(spatial_size=(64, 64), mode='nearest') + + # self.configuration_manager.patch_size=[self.patch_size, self.patch_size] + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = False) -> nn.Module: + label_manager = plans_manager.get_label_manager(dataset_json) + + sam, img_embedding_size = sam_model_registry['vit_b'](image_size=256, + num_classes=8, # To load LoRA weights + checkpoint='checkpoints/sam_vit_b_01ec64.pth', + pixel_mean=[0, 0, 0], + pixel_std=[1, 1, 1]) + model = LoRA_Sam(sam, 4) + # net.load_lora_parameters('checkpoints/epoch_299.pth') + model.sam.mask_decoder = MaskDecoder(transformer=model.sam.mask_decoder.transformer, + transformer_dim=model.sam.mask_decoder.transformer_dim, + num_multimask_outputs=label_manager.num_segmentation_heads-1 #remove bg + ) + return model + +class nnUNetTrainerV2_SAMed_b_r_4_100epochs(nnUNetTrainerV2_SAMed_b_r_4): + """ + Residual Encoder + UMmaba Bottleneck + Residual Decoder + Skip Connections + """ + + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 100 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerSegResNet.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerSegResNet.py new file mode 100644 index 0000000..5f9b8f7 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerSegResNet.py @@ -0,0 +1,154 @@ +from nnunetv2.training.nnUNetTrainer.variants.network_architecture.nnUNetTrainerNoDeepSupervision import \ + nnUNetTrainerNoDeepSupervision +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from torch import nn +import torch + +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn + +from monai.networks.nets import SegResNet +from torch.optim import Adam + +class nnUNetTrainerSegResNet(nnUNetTrainerNoDeepSupervision): + + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device('cuda') + ): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.grad_scaler = None + self.initial_lr = 1e-4 + self.weight_decay = 1e-5 + + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = False) -> nn.Module: + + label_manager = plans_manager.get_label_manager(dataset_json) + + model = SegResNet( + spatial_dims = len(configuration_manager.patch_size), + init_filters = 32, + in_channels=num_input_channels, + out_channels=label_manager.num_segmentation_heads, + blocks_down=[1, 2, 2, 4], + blocks_up=[1, 1, 1], + ) + + return model + + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + + output = self.network(data) + l = self.loss(output, target) + l.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.optimizer.step() + + return {'loss': l.detach().cpu().numpy()} + + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + output = self.network(data) + del data + l = self.loss(output, target) + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, output.ndim)) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + # no need for softmax + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + def configure_optimizers(self): + + optimizer = Adam(self.network.parameters(), lr=self.initial_lr, weight_decay=self.weight_decay, eps=1e-5) + scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs, exponent=0.9) + + return optimizer, scheduler + + def set_deep_supervision_enabled(self, enabled: bool): + pass + + +class nnUNetTrainerSegResNet_100epochs(nnUNetTrainerSegResNet): + + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device('cuda') + ): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 100 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerSwinUNETR.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerSwinUNETR.py new file mode 100644 index 0000000..919a3d4 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerSwinUNETR.py @@ -0,0 +1,159 @@ +from nnunetv2.training.nnUNetTrainer.variants.network_architecture.nnUNetTrainerNoDeepSupervision import \ + nnUNetTrainerNoDeepSupervision +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn +import torch +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch import nn + +from monai.networks.nets import SwinUNETR + +class nnUNetTrainerSwinUNETR(nnUNetTrainerNoDeepSupervision): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + original_patch_size = self.configuration_manager.patch_size + new_patch_size = [-1] * len(original_patch_size) + for i in range(len(original_patch_size)): + if (original_patch_size[i] / 2**5) < 1 or ((original_patch_size[i] / 2**5) % 1) != 0: + new_patch_size[i] = round(original_patch_size[i] / 2**5 + 0.5) * 2**5 + else: + new_patch_size[i] = original_patch_size[i] + self.configuration_manager.configuration['patch_size'] = new_patch_size + self.print_to_log_file("Patch size changed from {} to {}".format(original_patch_size, new_patch_size)) + self.plans_manager.plans['configurations'][self.configuration_name]['patch_size'] = new_patch_size + + self.grad_scaler = None + self.initial_lr = 8e-4 + self.weight_decay = 0.01 + + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = False) -> nn.Module: + + label_manager = plans_manager.get_label_manager(dataset_json) + + model = SwinUNETR( + in_channels = num_input_channels, + out_channels = label_manager.num_segmentation_heads, + img_size = configuration_manager.patch_size, + depths = (2, 2, 2, 2), + num_heads = (3, 6, 12, 24), + feature_size = 48, ## + norm_name = "instance", + drop_rate = 0.0, + attn_drop_rate = 0.0, + dropout_path_rate = 0.0, + normalize = True, + use_checkpoint = False, + spatial_dims = len(configuration_manager.patch_size), + downsample = "merging", + use_v2 = False, + ) + + return model + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + + output = self.network(data) + l = self.loss(output, target) + l.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.optimizer.step() + + return {'loss': l.detach().cpu().numpy()} + + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + output = self.network(data) + del data + l = self.loss(output, target) + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, output.ndim)) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + # no need for softmax + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + def configure_optimizers(self): + + optimizer = AdamW(self.network.parameters(), lr=self.initial_lr, weight_decay=self.weight_decay, eps=1e-5) + scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs, eta_min=1e-6) + + self.print_to_log_file(f"Using optimizer {optimizer}") + self.print_to_log_file(f"Using scheduler {scheduler}") + + return optimizer, scheduler + + def set_deep_supervision_enabled(self, enabled: bool): + pass + + +class nnUNetTrainerSwinUNETR_100epochs(nnUNetTrainerSwinUNETR): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 100 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUMambaBot.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUMambaBot.py new file mode 100644 index 0000000..b8aa30a --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUMambaBot.py @@ -0,0 +1,30 @@ +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from torch import nn +from nnunetv2.nets.UMambaBot import get_umamba_bot_from_plans +import torch + +class nnUNetTrainerUMambaBot(nnUNetTrainer): + """ + Residual Encoder + UMmaba Bottleneck + Residual Decoder + Skip Connections + """ + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + + model = get_umamba_bot_from_plans(plans_manager, dataset_json, configuration_manager, + num_input_channels, deep_supervision=enable_deep_supervision) + + print("UMambaBot: {}".format(model)) + + return model + +class nnUNetTrainerUMambaBot_100epochs(nnUNetTrainerUMambaBot): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 100 + diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUMambaEnc.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUMambaEnc.py new file mode 100644 index 0000000..a0347ad --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUMambaEnc.py @@ -0,0 +1,28 @@ +import torch +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from torch import nn + +from nnunetv2.nets.UMambaEnc import get_umamba_enc_from_plans + +class nnUNetTrainerUMambaEnc(nnUNetTrainer): + """ + UMmaba Encoder + Residual Decoder + Skip Connections + """ + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + + model = get_umamba_enc_from_plans(plans_manager, dataset_json, configuration_manager, + num_input_channels, deep_supervision=enable_deep_supervision) + + print("UMambaEnc: {}".format(model)) + + return model diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUNETR.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUNETR.py new file mode 100644 index 0000000..ad82fea --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUNETR.py @@ -0,0 +1,149 @@ +from nnunetv2.training.nnUNetTrainer.variants.network_architecture.nnUNetTrainerNoDeepSupervision import \ + nnUNetTrainerNoDeepSupervision +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn + +import torch +from torch.optim import AdamW +from torch import nn + +from monai.networks.nets import UNETR + +class nnUNetTrainerUNETR(nnUNetTrainerNoDeepSupervision): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + original_patch_size = self.configuration_manager.patch_size + new_patch_size = [-1] * len(original_patch_size) + for i in range(len(original_patch_size)): + ## 16 is ViT's fixed patch size + if (original_patch_size[i] / 16) < 1 or ((original_patch_size[i] / 16) % 1) != 0: + new_patch_size[i] = round(original_patch_size[i] / 16 + 0.5) * 16 + else: + new_patch_size[i] = original_patch_size[i] + self.configuration_manager.configuration['patch_size'] = new_patch_size + self.print_to_log_file("Patch size changed from {} to {}".format(original_patch_size, new_patch_size)) + self.plans_manager.plans['configurations'][self.configuration_name]['patch_size'] = new_patch_size + + self.initial_lr = 1e-4 + self.grad_scaler = None + self.weight_decay = 0.01 + + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = False) -> nn.Module: + + label_manager = plans_manager.get_label_manager(dataset_json) + + model = UNETR( + in_channels = num_input_channels, + out_channels = label_manager.num_segmentation_heads, + img_size = configuration_manager.patch_size, + feature_size=16, + hidden_size=768, + mlp_dim = 3072, + num_heads = 12, + proj_type = "conv", + norm_name="instance", + res_block=True, + dropout_rate=0.0, + spatial_dims = len(configuration_manager.patch_size), + qkv_bias = False, + save_attn = False, + ) + + return model + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + + output = self.network(data) + l = self.loss(output, target) + l.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.optimizer.step() + + return {'loss': l.detach().cpu().numpy()} + + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad(set_to_none=True) + + output = self.network(data) + del data + l = self.loss(output, target) + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, output.ndim)) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + # no need for softmax + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + def configure_optimizers(self): + + optimizer = AdamW(self.network.parameters(), lr=self.initial_lr, weight_decay=self.weight_decay, eps=1e-5) + scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs, exponent=1.0) + + self.print_to_log_file(f"Using optimizer {optimizer}") + self.print_to_log_file(f"Using scheduler {scheduler}") + + return optimizer, scheduler + + def set_deep_supervision_enabled(self, enabled: bool): + pass \ No newline at end of file diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/__init__.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py new file mode 100644 index 0000000..fad1fff --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py @@ -0,0 +1,65 @@ +import torch +from batchgenerators.utilities.file_and_folder_operations import save_json, join, isfile, load_json + +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from torch import distributed as dist + + +class nnUNetTrainerBenchmark_5epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + assert self.fold == 0, "It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results." + self.disable_checkpointing = True + self.num_epochs = 5 + assert torch.cuda.is_available(), "This only works on GPU" + self.crashed_with_runtime_error = False + + def perform_actual_validation(self, save_probabilities: bool = False): + pass + + def save_checkpoint(self, filename: str) -> None: + # do not trust people to remember that self.disable_checkpointing must be True for this trainer + pass + + def run_training(self): + try: + super().run_training() + except RuntimeError: + self.crashed_with_runtime_error = True + + def on_train_end(self): + super().on_train_end() + + if not self.is_ddp or self.local_rank == 0: + torch_version = torch.__version__ + cudnn_version = torch.backends.cudnn.version() + gpu_name = torch.cuda.get_device_name() + if self.crashed_with_runtime_error: + fastest_epoch = 'Not enough VRAM!' + else: + epoch_times = [i - j for i, j in zip(self.logger.my_fantastic_logging['epoch_end_timestamps'], + self.logger.my_fantastic_logging['epoch_start_timestamps'])] + fastest_epoch = min(epoch_times) + + if self.is_ddp: + num_gpus = dist.get_world_size() + else: + num_gpus = 1 + + benchmark_result_file = join(self.output_folder, 'benchmark_result.json') + if isfile(benchmark_result_file): + old_results = load_json(benchmark_result_file) + else: + old_results = {} + # generate some unique key + my_key = f"{cudnn_version}__{torch_version.replace(' ', '')}__{gpu_name.replace(' ', '')}__gpus_{num_gpus}" + old_results[my_key] = { + 'torch_version': torch_version, + 'cudnn_version': cudnn_version, + 'gpu_name': gpu_name, + 'fastest_epoch': fastest_epoch, + 'num_gpus': num_gpus, + } + save_json(old_results, + join(self.output_folder, 'benchmark_result.json')) diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py new file mode 100644 index 0000000..e7de92c --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py @@ -0,0 +1,65 @@ +import torch + +from nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import ( + nnUNetTrainerBenchmark_5epochs, +) +from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels + + +class nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs): + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device("cuda"), + ): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self._set_batch_size_and_oversample() + num_input_channels = determine_num_input_channels( + self.plans_manager, self.configuration_manager, self.dataset_json + ) + patch_size = self.configuration_manager.patch_size + dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device) + if self.enable_deep_supervision: + dummy_target = [ + torch.round( + torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device) + * max(self.label_manager.all_labels) + ) + for k in self._get_deep_supervision_scales() + ] + else: + raise NotImplementedError("This trainer does not support deep supervision") + self.dummy_batch = {"data": dummy_data, "target": dummy_target} + + def get_dataloaders(self): + return None, None + + def run_training(self): + try: + self.on_train_start() + + for epoch in range(self.current_epoch, self.num_epochs): + self.on_epoch_start() + + self.on_train_epoch_start() + train_outputs = [] + for batch_id in range(self.num_iterations_per_epoch): + train_outputs.append(self.train_step(self.dummy_batch)) + self.on_train_epoch_end(train_outputs) + + with torch.no_grad(): + self.on_validation_epoch_start() + val_outputs = [] + for batch_id in range(self.num_val_iterations_per_epoch): + val_outputs.append(self.validation_step(self.dummy_batch)) + self.on_validation_epoch_end(val_outputs) + + self.on_epoch_end() + + self.on_train_end() + except RuntimeError: + self.crashed_with_runtime_error = True diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py new file mode 100644 index 0000000..7250fb8 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py @@ -0,0 +1,422 @@ +from typing import List, Union, Tuple + +import numpy as np +import torch +from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter +from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose +from batchgenerators.transforms.color_transforms import BrightnessTransform, ContrastAugmentationTransform, \ + GammaTransform +from batchgenerators.transforms.local_transforms import BrightnessGradientAdditiveTransform, LocalGammaTransform +from batchgenerators.transforms.noise_transforms import MedianFilterTransform, GaussianBlurTransform, \ + GaussianNoiseTransform, BlankRectangleTransform, SharpeningTransform +from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform +from batchgenerators.transforms.spatial_transforms import SpatialTransform, Rot90Transform, TransposeAxesTransform, \ + MirrorTransform +from batchgenerators.transforms.utility_transforms import OneOfTransform, RemoveLabelTransform, RenameTransform, \ + NumpyToTensor + +from nnunetv2.configuration import ANISO_THRESHOLD +from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size +from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \ + ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform +from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \ + DownsampleSegForDSTransform2 +from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ + LimitedLenWrapper +from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform +from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \ + ConvertSegmentationToRegionsTransform +from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \ + Convert2DTo3DTransform +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA + + +class nnUNetTrainerDA5(nnUNetTrainer): + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + """ + This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it. + """ + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation) + if dim == 2: + do_dummy_2d_data_aug = False + # todo revisit this parametrization + if max(patch_size) / min(patch_size) > 1.5: + rotation_for_DA = { + 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + mirror_axes = (0, 1) + elif dim == 3: + # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad + # order of the axes is determined by spacing, not image size + do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD + if do_dummy_2d_data_aug: + # why do we rotate 180 deg here all the time? We should also restrict it + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + } + mirror_axes = (0, 1, 2) + else: + raise RuntimeError() + + # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the + # old nnunet for now) + initial_patch_size = get_patch_size(patch_size[-dim:], + *rotation_for_DA.values(), + (0.7, 1.43)) + if do_dummy_2d_data_aug: + initial_patch_size[0] = patch_size[0] + + self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') + self.inference_allowed_mirroring_axes = mirror_axes + + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + @staticmethod + def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple, None], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 3, + order_resampling_seg: int = 1, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size]) + valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0]) + + tr_transforms = [] + + if do_dummy_2d_data_aug: + ignore_axes = (0,) + tr_transforms.append(Convert3DTo2DTransform()) + patch_size_spatial = patch_size[1:] + else: + patch_size_spatial = patch_size + ignore_axes = None + + tr_transforms.append( + SpatialTransform( + patch_size_spatial, + patch_center_dist_from_border=None, + do_elastic_deform=False, + do_rotation=True, + angle_x=rotation_for_DA['x'], + angle_y=rotation_for_DA['y'], + angle_z=rotation_for_DA['z'], + p_rot_per_axis=0.5, + do_scale=True, + scale=(0.7, 1.43), + border_mode_data="constant", + border_cval_data=0, + order_data=order_resampling_data, + border_mode_seg="constant", + border_cval_seg=-1, + order_seg=order_resampling_seg, + random_crop=False, + p_el_per_sample=0.2, + p_scale_per_sample=0.2, + p_rot_per_sample=0.4, + independent_scale_for_each_axis=True, + ) + ) + + if do_dummy_2d_data_aug: + tr_transforms.append(Convert2DTo3DTransform()) + + if np.any(matching_axes > 1): + tr_transforms.append( + Rot90Transform( + (0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5 + ), + ) + + if np.any(matching_axes > 1): + tr_transforms.append( + TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5) + ) + + tr_transforms.append(OneOfTransform([ + MedianFilterTransform( + (2, 8), + same_for_each_channel=False, + p_per_sample=0.2, + p_per_channel=0.5 + ), + GaussianBlurTransform((0.3, 1.5), + different_sigma_per_channel=True, + p_per_sample=0.2, + p_per_channel=0.5) + ])) + + tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) + + tr_transforms.append(BrightnessTransform(0, + 0.5, + per_channel=True, + p_per_sample=0.1, + p_per_channel=0.5 + ) + ) + + tr_transforms.append(OneOfTransform( + [ + ContrastAugmentationTransform( + contrast_range=(0.5, 2), + preserve_range=True, + per_channel=True, + data_key='data', + p_per_sample=0.2, + p_per_channel=0.5 + ), + ContrastAugmentationTransform( + contrast_range=(0.5, 2), + preserve_range=False, + per_channel=True, + data_key='data', + p_per_sample=0.2, + p_per_channel=0.5 + ), + ] + )) + + tr_transforms.append( + SimulateLowResolutionTransform(zoom_range=(0.25, 1), + per_channel=True, + p_per_channel=0.5, + order_downsample=0, + order_upsample=3, + p_per_sample=0.15, + ignore_axes=ignore_axes + ) + ) + + tr_transforms.append( + GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) + tr_transforms.append( + GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) + + if mirror_axes is not None and len(mirror_axes) > 0: + tr_transforms.append(MirrorTransform(mirror_axes)) + + tr_transforms.append( + BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size], + rectangle_value=np.mean, + num_rectangles=(1, 5), + force_square=False, + p_per_sample=0.4, + p_per_channel=0.5 + ) + ) + + tr_transforms.append( + BrightnessGradientAdditiveTransform( + _brightnessadditive_localgamma_transform_scale, + (-0.5, 1.5), + max_strength=_brightness_gradient_additive_max_strength, + mean_centered=False, + same_for_all_channels=False, + p_per_sample=0.3, + p_per_channel=0.5 + ) + ) + + tr_transforms.append( + LocalGammaTransform( + _brightnessadditive_localgamma_transform_scale, + (-0.5, 1.5), + _local_gamma_gamma, + same_for_all_channels=False, + p_per_sample=0.3, + p_per_channel=0.5 + ) + ) + + tr_transforms.append( + SharpeningTransform( + strength=(0.1, 1), + same_for_each_channel=False, + p_per_sample=0.2, + p_per_channel=0.5 + ) + ) + + if use_mask_for_norm is not None and any(use_mask_for_norm): + tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], + mask_idx_in_seg=0, set_outside_to=0)) + + tr_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + if ignore_label is not None: + raise NotImplementedError('ignore label not yet supported in cascade') + assert foreground_labels is not None, 'We need all_labels for cascade augmentations' + use_labels = [i for i in foreground_labels if i != 0] + tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data')) + tr_transforms.append(ApplyRandomBinaryOperatorTransform( + channel_idx=list(range(-len(use_labels), 0)), + p_per_sample=0.4, + key="data", + strel_size=(1, 8), + p_per_label=1)) + tr_transforms.append( + RemoveRandomConnectedComponentFromOneHotEncodingTransform( + channel_idx=list(range(-len(use_labels), 0)), + key="data", + p_per_sample=0.2, + fill_with_other_class_p=0, + dont_do_if_covers_more_than_x_percent=0.15)) + + tr_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + tr_transforms = Compose(tr_transforms) + return tr_transforms + + +class nnUNetTrainerDA5ord0(nnUNetTrainerDA5): + def get_dataloaders(self): + """ + changed order_resampling_data, order_resampling_seg + """ + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=0, order_resampling_seg=0, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, + allowed_num_processes, 6, None, True, 0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, + max(1, allowed_num_processes // 2), 3, None, True, 0.02) + + return mt_gen_train, mt_gen_val + + +def _brightnessadditive_localgamma_transform_scale(x, y): + return np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))) + + +def _brightness_gradient_additive_max_strength(_x, _y): + return np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5) + + +def _local_gamma_gamma(): + return np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4) + + +class nnUNetTrainerDA5Segord0(nnUNetTrainerDA5): + def get_dataloaders(self): + """ + changed order_resampling_data, order_resampling_seg + """ + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=0, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, + allowed_num_processes, 6, None, True, 0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, + max(1, allowed_num_processes // 2), 3, None, True, 0.02) + + return mt_gen_train, mt_gen_val + + +class nnUNetTrainerDA5_10epochs(nnUNetTrainerDA5): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 10 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py new file mode 100644 index 0000000..e87ff8f --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py @@ -0,0 +1,104 @@ +from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter + +from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ + LimitedLenWrapper +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA + + +class nnUNetTrainerDAOrd0(nnUNetTrainer): + def get_dataloaders(self): + """ + changed order_resampling_data, order_resampling_seg + """ + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=0, order_resampling_seg=0, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, + allowed_num_processes, 6, None, True, 0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, + max(1, allowed_num_processes // 2), 3, None, True, 0.02) + + return mt_gen_train, mt_gen_val + + +class nnUNetTrainer_DASegOrd0(nnUNetTrainer): + def get_dataloaders(self): + """ + changed order_resampling_data, order_resampling_seg + """ + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=0, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, + allowed_num_processes, 6, None, True, 0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, + max(1, allowed_num_processes // 2), 3, None, True, 0.02) + + return mt_gen_train, mt_gen_val diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py new file mode 100644 index 0000000..17f3586 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py @@ -0,0 +1,40 @@ +from typing import Union, Tuple, List + +from batchgenerators.transforms.abstract_transforms import AbstractTransform + +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +import numpy as np + + +class nnUNetTrainerNoDA(nnUNetTrainer): + @staticmethod + def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple, None], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 1, + order_resampling_seg: int = 0, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + return nnUNetTrainer.get_validation_transforms(deep_supervision_scales, is_cascaded, foreground_labels, + regions, ignore_label) + + def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): + return super().get_plain_dataloaders( + initial_patch_size=self.configuration_manager.patch_size, + dim=dim + ) + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + # we need to disable mirroring here so that no mirroring will be applied in inferene! + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py new file mode 100644 index 0000000..18ea1ea --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py @@ -0,0 +1,28 @@ +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +class nnUNetTrainerNoMirroring(nnUNetTrainer): + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + +class nnUNetTrainer_onlyMirror01(nnUNetTrainer): + """ + Only mirrors along spatial axes 0 and 1 for 3D and 0 for 2D + """ + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + if dim == 2: + mirror_axes = (0, ) + else: + mirror_axes = (0, 1) + self.inference_allowed_mirroring_axes = mirror_axes + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py new file mode 100644 index 0000000..fdc0fea --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py @@ -0,0 +1,41 @@ +import torch +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss +import numpy as np + + +class nnUNetTrainerCELoss(nnUNetTrainer): + def _build_loss(self): + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = RobustCrossEntropyLoss( + weight=None, ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100 + ) + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + +class nnUNetTrainerCELoss_5epochs(nnUNetTrainerCELoss): + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device("cuda"), + ): + """used for debugging plans etc""" + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 5 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py new file mode 100644 index 0000000..b139286 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py @@ -0,0 +1,60 @@ +import numpy as np +import torch + +from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.helpers import softmax_helper_dim1 + + +class nnUNetTrainerDiceLoss(nnUNetTrainer): + def _build_loss(self): + loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice, + 'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp}, + apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1) + + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + +class nnUNetTrainerDiceCELoss_noSmooth(nnUNetTrainer): + def _build_loss(self): + # set smooth to 0 + if self.label_manager.has_regions: + loss = DC_and_BCE_loss({}, + {'batch_dice': self.configuration_manager.batch_dice, + 'do_bg': True, 'smooth': 0, 'ddp': self.is_ddp}, + use_ignore_label=self.label_manager.ignore_label is not None, + dice_class=MemoryEfficientSoftDiceLoss) + else: + loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, + 'smooth': 0, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, + ignore_label=self.label_manager.ignore_label, + dice_class=MemoryEfficientSoftDiceLoss) + + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py new file mode 100644 index 0000000..5eff10e --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py @@ -0,0 +1,76 @@ +from nnunetv2.training.loss.compound_losses import DC_and_topk_loss +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +import numpy as np +from nnunetv2.training.loss.robust_ce_loss import TopKLoss + + +class nnUNetTrainerTopk10Loss(nnUNetTrainer): + def _build_loss(self): + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = TopKLoss( + ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, k=10 + ) + + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + +class nnUNetTrainerTopk10LossLS01(nnUNetTrainer): + def _build_loss(self): + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = TopKLoss( + ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, + k=10, + label_smoothing=0.1, + ) + + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + +class nnUNetTrainerDiceTopK10Loss(nnUNetTrainer): + def _build_loss(self): + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = DC_and_topk_loss( + {"batch_dice": self.configuration_manager.batch_dice, "smooth": 1e-5, "do_bg": False, "ddp": self.is_ddp}, + {"k": 10, "label_smoothing": 0.0}, + weight_ce=1, + weight_dice=1, + ignore_label=self.label_manager.ignore_label, + ) + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py new file mode 100644 index 0000000..60455f2 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py @@ -0,0 +1,13 @@ +import torch +from torch.optim.lr_scheduler import CosineAnnealingLR + +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +class nnUNetTrainerCosAnneal(nnUNetTrainer): + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + momentum=0.99, nesterov=True) + lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs) + return optimizer, lr_scheduler + diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py new file mode 100644 index 0000000..5f6190c --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py @@ -0,0 +1,73 @@ +from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet, PlainConvUNet +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_batchnorm +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0, InitWeights_He +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from torch import nn + + +class nnUNetTrainerBN(nnUNetTrainer): + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + num_stages = len(configuration_manager.conv_kernel_sizes) + + dim = len(configuration_manager.conv_kernel_sizes[0]) + conv_op = convert_dim_to_conv_op(dim) + + label_manager = plans_manager.get_label_manager(dataset_json) + + segmentation_network_class_name = configuration_manager.UNet_class_name + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_batchnorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_batchnorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accommodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, + 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder + } + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, + configuration_manager.unet_max_num_features) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=configuration_manager.conv_kernel_sizes, + strides=configuration_manager.pool_op_kernel_sizes, + num_classes=label_manager.num_segmentation_heads, + deep_supervision=enable_deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + return model diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py new file mode 100644 index 0000000..1152fbe --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py @@ -0,0 +1,16 @@ +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +import torch + + +class nnUNetTrainerNoDeepSupervision(nnUNetTrainer): + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device("cuda"), + ): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.enable_deep_supervision = False diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py new file mode 100644 index 0000000..be5a7f4 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py @@ -0,0 +1,58 @@ +import torch +from torch.optim import Adam, AdamW + +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +class nnUNetTrainerAdam(nnUNetTrainer): + def configure_optimizers(self): + optimizer = AdamW(self.network.parameters(), + lr=self.initial_lr, + weight_decay=self.weight_decay, + amsgrad=True) + # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + # momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) + return optimizer, lr_scheduler + + +class nnUNetTrainerVanillaAdam(nnUNetTrainer): + def configure_optimizers(self): + optimizer = Adam(self.network.parameters(), + lr=self.initial_lr, + weight_decay=self.weight_decay) + # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + # momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) + return optimizer, lr_scheduler + + +class nnUNetTrainerVanillaAdam1en3(nnUNetTrainerVanillaAdam): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-3 + + +class nnUNetTrainerVanillaAdam3en4(nnUNetTrainerVanillaAdam): + # https://twitter.com/karpathy/status/801621764144971776?lang=en + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 3e-4 + + +class nnUNetTrainerAdam1en3(nnUNetTrainerAdam): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-3 + + +class nnUNetTrainerAdam3en4(nnUNetTrainerAdam): + # https://twitter.com/karpathy/status/801621764144971776?lang=en + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 3e-4 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py new file mode 100644 index 0000000..8747f47 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py @@ -0,0 +1,66 @@ +import torch + +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from torch.optim.lr_scheduler import CosineAnnealingLR +try: + from adan_pytorch import Adan +except ImportError: + Adan = None + + +class nnUNetTrainerAdan(nnUNetTrainer): + def configure_optimizers(self): + if Adan is None: + raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"') + optimizer = Adan(self.network.parameters(), + lr=self.initial_lr, + # betas=(0.02, 0.08, 0.01), defaults + weight_decay=self.weight_decay) + # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + # momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) + return optimizer, lr_scheduler + + +class nnUNetTrainerAdan1en3(nnUNetTrainerAdan): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-3 + + +class nnUNetTrainerAdan3en4(nnUNetTrainerAdan): + # https://twitter.com/karpathy/status/801621764144971776?lang=en + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 3e-4 + + +class nnUNetTrainerAdan1en1(nnUNetTrainerAdan): + # this trainer makes no sense -> nan! + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-1 + + +class nnUNetTrainerAdanCosAnneal(nnUNetTrainerAdan): + # def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + # device: torch.device = torch.device('cuda')): + # super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + # self.num_epochs = 15 + + def configure_optimizers(self): + if Adan is None: + raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"') + optimizer = Adan(self.network.parameters(), + lr=self.initial_lr, + # betas=(0.02, 0.08, 0.01), defaults + weight_decay=self.weight_decay) + # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + # momentum=0.99, nesterov=True) + lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs) + return optimizer, lr_scheduler + diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py new file mode 100644 index 0000000..89fef48 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py @@ -0,0 +1,76 @@ +from typing import Tuple + +import torch + +from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D +from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +import numpy as np + + +class nnUNetTrainer_probabilisticOversampling(nnUNetTrainer): + """ + sampling of foreground happens randomly and not for the last 33% of samples in a batch + since most trainings happen with batch size 2 and nnunet guarantees at least one fg sample, effectively this can + be 50% + Here we compute the actual oversampling percentage used by nnUNetTrainer in order to be as consistent as possible. + If we switch to this oversampling then we can keep it at a constant 0.33 or whatever. + """ + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.oversample_foreground_percent = float(np.mean( + [not sample_idx < round(self.configuration_manager.batch_size * (1 - self.oversample_foreground_percent)) + for sample_idx in range(self.configuration_manager.batch_size)])) + self.print_to_log_file(f"self.oversample_foreground_percent {self.oversample_foreground_percent}") + + def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): + dataset_tr, dataset_val = self.get_tr_and_val_datasets() + + if dim == 2: + dl_tr = nnUNetDataLoader2D(dataset_tr, + self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True) + dl_val = nnUNetDataLoader2D(dataset_val, + self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True) + else: + dl_tr = nnUNetDataLoader3D(dataset_tr, + self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True) + dl_val = nnUNetDataLoader3D(dataset_val, + self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True) + return dl_tr, dl_val + + +class nnUNetTrainer_probabilisticOversampling_033(nnUNetTrainer_probabilisticOversampling): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.oversample_foreground_percent = 0.33 + + +class nnUNetTrainer_probabilisticOversampling_010(nnUNetTrainer_probabilisticOversampling): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.oversample_foreground_percent = 0.1 + + diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py new file mode 100644 index 0000000..990ce7e --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py @@ -0,0 +1,77 @@ +import torch + +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +class nnUNetTrainer_5epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + """used for debugging plans etc""" + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 5 + + +class nnUNetTrainer_1epoch(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + """used for debugging plans etc""" + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 1 + + +class nnUNetTrainer_10epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + """used for debugging plans etc""" + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 10 + + +class nnUNetTrainer_20epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 20 + + +class nnUNetTrainer_50epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 50 + + + + +class nnUNetTrainer_250epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 250 + + +class nnUNetTrainer_100epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 100 + +class nnUNetTrainer_2000epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 2000 + + +class nnUNetTrainer_4000epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 4000 + + +class nnUNetTrainer_8000epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 8000 diff --git a/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py new file mode 100644 index 0000000..c16b885 --- /dev/null +++ b/docker/template/src/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py @@ -0,0 +1,60 @@ +import torch + +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +class nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 250 + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + +class nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 2000 + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + +class nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 4000 + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + +class nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 8000 + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + diff --git a/docker/template/src/nnunetv2/utilities/__init__.py b/docker/template/src/nnunetv2/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/utilities/collate_outputs.py b/docker/template/src/nnunetv2/utilities/collate_outputs.py new file mode 100644 index 0000000..c9d6798 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/collate_outputs.py @@ -0,0 +1,24 @@ +from typing import List + +import numpy as np + + +def collate_outputs(outputs: List[dict]): + """ + used to collate default train_step and validation_step outputs. If you want something different then you gotta + extend this + + we expect outputs to be a list of dictionaries where each of the dict has the same set of keys + """ + collated = {} + for k in outputs[0].keys(): + if np.isscalar(outputs[0][k]): + collated[k] = [o[k] for o in outputs] + elif isinstance(outputs[0][k], np.ndarray): + collated[k] = np.vstack([o[k][None] for o in outputs]) + elif isinstance(outputs[0][k], list): + collated[k] = [item for o in outputs for item in o[k]] + else: + raise ValueError(f'Cannot collate input of type {type(outputs[0][k])}. ' + f'Modify collate_outputs to add this functionality') + return collated \ No newline at end of file diff --git a/docker/template/src/nnunetv2/utilities/dataset_name_id_conversion.py b/docker/template/src/nnunetv2/utilities/dataset_name_id_conversion.py new file mode 100644 index 0000000..29ea58a --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/dataset_name_id_conversion.py @@ -0,0 +1,74 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results +from batchgenerators.utilities.file_and_folder_operations import * +import numpy as np + + +def find_candidate_datasets(dataset_id: int): + startswith = "Dataset%03.0d" % dataset_id + if nnUNet_preprocessed is not None and isdir(nnUNet_preprocessed): + candidates_preprocessed = subdirs(nnUNet_preprocessed, prefix=startswith, join=False) + else: + candidates_preprocessed = [] + + if nnUNet_raw is not None and isdir(nnUNet_raw): + candidates_raw = subdirs(nnUNet_raw, prefix=startswith, join=False) + else: + candidates_raw = [] + + candidates_trained_models = [] + if nnUNet_results is not None and isdir(nnUNet_results): + candidates_trained_models += subdirs(nnUNet_results, prefix=startswith, join=False) + + all_candidates = candidates_preprocessed + candidates_raw + candidates_trained_models + unique_candidates = np.unique(all_candidates) + return unique_candidates + + +def convert_id_to_dataset_name(dataset_id: int): + unique_candidates = find_candidate_datasets(dataset_id) + if len(unique_candidates) > 1: + raise RuntimeError("More than one dataset name found for dataset id %d. Please correct that. (I looked in the " + "following folders:\n%s\n%s\n%s" % (dataset_id, nnUNet_raw, nnUNet_preprocessed, nnUNet_results)) + if len(unique_candidates) == 0: + raise RuntimeError(f"Could not find a dataset with the ID {dataset_id}. Make sure the requested dataset ID " + f"exists and that nnU-Net knows where raw and preprocessed data are located " + f"(see Documentation - Installation). Here are your currently defined folders:\n" + f"nnUNet_preprocessed={os.environ.get('nnUNet_preprocessed') if os.environ.get('nnUNet_preprocessed') is not None else 'None'}\n" + f"nnUNet_results={os.environ.get('nnUNet_results') if os.environ.get('nnUNet_results') is not None else 'None'}\n" + f"nnUNet_raw={os.environ.get('nnUNet_raw') if os.environ.get('nnUNet_raw') is not None else 'None'}\n" + f"If something is not right, adapt your environment variables.") + return unique_candidates[0] + + +def convert_dataset_name_to_id(dataset_name: str): + assert dataset_name.startswith("Dataset") + dataset_id = int(dataset_name[7:10]) + return dataset_id + + +def maybe_convert_to_dataset_name(dataset_name_or_id: Union[int, str]) -> str: + if isinstance(dataset_name_or_id, str) and dataset_name_or_id.startswith("Dataset"): + return dataset_name_or_id + if isinstance(dataset_name_or_id, str): + try: + dataset_name_or_id = int(dataset_name_or_id) + except ValueError: + raise ValueError("dataset_name_or_id was a string and did not start with 'Dataset' so we tried to " + "convert it to a dataset ID (int). That failed, however. Please give an integer number " + "('1', '2', etc) or a correct dataset name. Your input: %s" % dataset_name_or_id) + return convert_id_to_dataset_name(dataset_name_or_id) diff --git a/docker/template/src/nnunetv2/utilities/ddp_allgather.py b/docker/template/src/nnunetv2/utilities/ddp_allgather.py new file mode 100644 index 0000000..c42b3ef --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/ddp_allgather.py @@ -0,0 +1,49 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Tuple + +import torch +from torch import distributed + + +def print_if_rank0(*args): + if distributed.get_rank() == 0: + print(*args) + + +class AllGatherGrad(torch.autograd.Function): + # stolen from pytorch lightning + @staticmethod + def forward( + ctx: Any, + tensor: torch.Tensor, + group: Optional["torch.distributed.ProcessGroup"] = None, + ) -> torch.Tensor: + ctx.group = group + + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + + torch.distributed.all_gather(gathered_tensor, tensor, group=group) + gathered_tensor = torch.stack(gathered_tensor, dim=0) + + return gathered_tensor + + @staticmethod + def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: + grad_output = torch.cat(grad_output) + + torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) + + return grad_output[torch.distributed.get_rank()], None + diff --git a/docker/template/src/nnunetv2/utilities/default_n_proc_DA.py b/docker/template/src/nnunetv2/utilities/default_n_proc_DA.py new file mode 100644 index 0000000..3ecc922 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/default_n_proc_DA.py @@ -0,0 +1,44 @@ +import subprocess +import os + + +def get_allowed_n_proc_DA(): + """ + This function is used to set the number of processes used on different Systems. It is specific to our cluster + infrastructure at DKFZ. You can modify it to suit your needs. Everything is allowed. + + IMPORTANT: if the environment variable nnUNet_n_proc_DA is set it will overwrite anything in this script + (see first line). + + Interpret the output as the number of processes used for data augmentation PER GPU. + + The way it is implemented here is simply a look up table. We know the hostnames, CPU and GPU configurations of our + systems and set the numbers accordingly. For example, a system with 4 GPUs and 48 threads can use 12 threads per + GPU without overloading the CPU (technically 11 because we have a main process as well), so that's what we use. + """ + + if 'nnUNet_n_proc_DA' in os.environ.keys(): + use_this = int(os.environ['nnUNet_n_proc_DA']) + else: + hostname = subprocess.getoutput(['hostname']) + if hostname in ['Fabian', ]: + use_this = 12 + elif hostname in ['hdf19-gpu16', 'hdf19-gpu17', 'hdf19-gpu18', 'hdf19-gpu19', 'e230-AMDworkstation']: + use_this = 16 + elif hostname.startswith('e230-dgx1'): + use_this = 10 + elif hostname.startswith('hdf18-gpu') or hostname.startswith('e132-comp'): + use_this = 16 + elif hostname.startswith('e230-dgx2'): + use_this = 6 + elif hostname.startswith('e230-dgxa100-'): + use_this = 28 + elif hostname.startswith('lsf22-gpu'): + use_this = 28 + elif hostname.startswith('hdf19-gpu') or hostname.startswith('e071-gpu'): + use_this = 12 + else: + use_this = 12 # default value + + use_this = min(use_this, os.cpu_count()) + return use_this diff --git a/docker/template/src/nnunetv2/utilities/file_path_utilities.py b/docker/template/src/nnunetv2/utilities/file_path_utilities.py new file mode 100644 index 0000000..a1c9622 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/file_path_utilities.py @@ -0,0 +1,123 @@ +from multiprocessing import Pool +from typing import Union, Tuple +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.configuration import default_num_processes +from nnunetv2.paths import nnUNet_results +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + +def convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration): + return f'{trainer_name}__{plans_identifier}__{configuration}' + + +def convert_identifier_to_trainer_plans_config(identifier: str): + return os.path.basename(identifier).split('__') + + +def get_output_folder(dataset_name_or_id: Union[str, int], trainer_name: str = 'nnUNetTrainer', + plans_identifier: str = 'nnUNetPlans', configuration: str = '3d_fullres', + fold: Union[str, int] = None) -> str: + tmp = join(nnUNet_results, maybe_convert_to_dataset_name(dataset_name_or_id), + convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration)) + if fold is not None: + tmp = join(tmp, f'fold_{fold}') + return tmp + + +def parse_dataset_trainer_plans_configuration_from_path(path: str): + folders = split_path(path) + # this here can be a little tricky because we are making assumptions. Let's hope this never fails lol + + # safer to make this depend on two conditions, the fold_x and the DatasetXXX + # first let's see if some fold_X is present + fold_x_present = [i.startswith('fold_') for i in folders] + if any(fold_x_present): + idx = fold_x_present.index(True) + # OK now two entries before that there should be DatasetXXX + assert len(folders[:idx]) >= 2, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ + 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' + if folders[idx - 2].startswith('Dataset'): + split = folders[idx - 1].split('__') + assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ + 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' + return folders[idx - 2], *split + else: + # we can only check for dataset followed by a string that is separable into three strings by splitting with '__' + # look for DatasetXXX + dataset_folder = [i.startswith('Dataset') for i in folders] + if any(dataset_folder): + idx = dataset_folder.index(True) + assert len(folders) >= (idx + 1), 'Bad path, cannot extract what I need. Your path needs to be at least ' \ + 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' + split = folders[idx + 1].split('__') + assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ + 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' + return folders[idx], *split + + +def get_ensemble_name(model1_folder, model2_folder, folds: Tuple[int, ...]): + identifier = 'ensemble___' + os.path.basename(model1_folder) + '___' + \ + os.path.basename(model2_folder) + '___' + folds_tuple_to_string(folds) + return identifier + + +def get_ensemble_name_from_d_tr_c(dataset, tr1, p1, c1, tr2, p2, c2, folds: Tuple[int, ...]): + model1_folder = get_output_folder(dataset, tr1, p1, c1) + model2_folder = get_output_folder(dataset, tr2, p2, c2) + + get_ensemble_name(model1_folder, model2_folder, folds) + + +def convert_ensemble_folder_to_model_identifiers_and_folds(ensemble_folder: str): + prefix, *models, folds = os.path.basename(ensemble_folder).split('___') + return models, folds + + +def folds_tuple_to_string(folds: Union[List[int], Tuple[int, ...]]): + s = str(folds[0]) + for f in folds[1:]: + s += f"_{f}" + return s + + +def folds_string_to_tuple(folds_string: str): + folds = folds_string.split('_') + res = [] + for f in folds: + try: + res.append(int(f)) + except ValueError: + res.append(f) + return res + + +def check_workers_alive_and_busy(export_pool: Pool, worker_list: List, results_list: List, allowed_num_queued: int = 0): + """ + + returns True if the number of results that are not ready is greater than the number of available workers + allowed_num_queued + """ + alive = [i.is_alive() for i in worker_list] + if not all(alive): + raise RuntimeError('Some background workers are no longer alive') + + not_ready = [not i.ready() for i in results_list] + if sum(not_ready) >= (len(export_pool._pool) + allowed_num_queued): + return True + return False + + +if __name__ == '__main__': + ### well at this point I could just write tests... + path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres' + print(parse_dataset_trainer_plans_configuration_from_path(path)) + path = 'Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres' + print(parse_dataset_trainer_plans_configuration_from_path(path)) + path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres/fold_all' + print(parse_dataset_trainer_plans_configuration_from_path(path)) + try: + path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/' + print(parse_dataset_trainer_plans_configuration_from_path(path)) + except AssertionError: + print('yayy, assertion works') diff --git a/docker/template/src/nnunetv2/utilities/find_class_by_name.py b/docker/template/src/nnunetv2/utilities/find_class_by_name.py new file mode 100644 index 0000000..a345d99 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/find_class_by_name.py @@ -0,0 +1,24 @@ +import importlib +import pkgutil + +from batchgenerators.utilities.file_and_folder_operations import * + + +def recursive_find_python_class(folder: str, class_name: str, current_module: str): + tr = None + for importer, modname, ispkg in pkgutil.iter_modules([folder]): + # print(modname, ispkg) + if not ispkg: + m = importlib.import_module(current_module + "." + modname) + if hasattr(m, class_name): + tr = getattr(m, class_name) + break + + if tr is None: + for importer, modname, ispkg in pkgutil.iter_modules([folder]): + if ispkg: + next_current_module = current_module + "." + modname + tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module) + if tr is not None: + break + return tr \ No newline at end of file diff --git a/docker/template/src/nnunetv2/utilities/get_network_from_plans.py b/docker/template/src/nnunetv2/utilities/get_network_from_plans.py new file mode 100644 index 0000000..1dd1dd2 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/get_network_from_plans.py @@ -0,0 +1,77 @@ +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 +from nnunetv2.utilities.network_initialization import InitWeights_He +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from torch import nn + + +def get_network_from_plans(plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + num_input_channels: int, + deep_supervision: bool = True): + """ + we may have to change this in the future to accommodate other plans -> network mappings + + num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the + trainer rather than inferring it again from the plans here. + """ + num_stages = len(configuration_manager.conv_kernel_sizes) + + dim = len(configuration_manager.conv_kernel_sizes[0]) + conv_op = convert_dim_to_conv_op(dim) + + label_manager = plans_manager.get_label_manager(dataset_json) + + segmentation_network_class_name = configuration_manager.UNet_class_name + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accommodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, + 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder + } + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, + configuration_manager.unet_max_num_features) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=configuration_manager.conv_kernel_sizes, + strides=configuration_manager.pool_op_kernel_sizes, + num_classes=label_manager.num_segmentation_heads, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + return model diff --git a/docker/template/src/nnunetv2/utilities/helpers.py b/docker/template/src/nnunetv2/utilities/helpers.py new file mode 100644 index 0000000..42448e3 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/helpers.py @@ -0,0 +1,27 @@ +import torch + + +def softmax_helper_dim0(x: torch.Tensor) -> torch.Tensor: + return torch.softmax(x, 0) + + +def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor: + return torch.softmax(x, 1) + + +def empty_cache(device: torch.device): + if device.type == 'cuda': + torch.cuda.empty_cache() + elif device.type == 'mps': + from torch import mps + mps.empty_cache() + else: + pass + + +class dummy_context(object): + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass diff --git a/docker/template/src/nnunetv2/utilities/json_export.py b/docker/template/src/nnunetv2/utilities/json_export.py new file mode 100644 index 0000000..5ea463c --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/json_export.py @@ -0,0 +1,59 @@ +from collections.abc import Iterable + +import numpy as np +import torch + + +def recursive_fix_for_json_export(my_dict: dict): + # json is stupid. 'cannot serialize object of type bool_/int64/float64'. Come on bro. + keys = list(my_dict.keys()) # cannot iterate over keys() if we change keys.... + for k in keys: + if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)): + tmp = my_dict[k] + del my_dict[k] + my_dict[int(k)] = tmp + del tmp + k = int(k) + + if isinstance(my_dict[k], dict): + recursive_fix_for_json_export(my_dict[k]) + elif isinstance(my_dict[k], np.ndarray): + assert my_dict[k].ndim == 1, 'only 1d arrays are supported' + my_dict[k] = fix_types_iterable(my_dict[k], output_type=list) + elif isinstance(my_dict[k], (np.bool_,)): + my_dict[k] = bool(my_dict[k]) + elif isinstance(my_dict[k], (np.int64, np.int32, np.int8, np.uint8)): + my_dict[k] = int(my_dict[k]) + elif isinstance(my_dict[k], (np.float32, np.float64, np.float16)): + my_dict[k] = float(my_dict[k]) + elif isinstance(my_dict[k], list): + my_dict[k] = fix_types_iterable(my_dict[k], output_type=type(my_dict[k])) + elif isinstance(my_dict[k], tuple): + my_dict[k] = fix_types_iterable(my_dict[k], output_type=tuple) + elif isinstance(my_dict[k], torch.device): + my_dict[k] = str(my_dict[k]) + else: + pass # pray it can be serialized + + +def fix_types_iterable(iterable, output_type): + # this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep you hands off of this. + out = [] + for i in iterable: + if type(i) in (np.int64, np.int32, np.int8, np.uint8): + out.append(int(i)) + elif isinstance(i, dict): + recursive_fix_for_json_export(i) + out.append(i) + elif type(i) in (np.float32, np.float64, np.float16): + out.append(float(i)) + elif type(i) in (np.bool_,): + out.append(bool(i)) + elif isinstance(i, str): + out.append(i) + elif isinstance(i, Iterable): + # print('recursive call on', i, type(i)) + out.append(fix_types_iterable(i, type(i))) + else: + out.append(i) + return output_type(out) diff --git a/docker/template/src/nnunetv2/utilities/label_handling/__init__.py b/docker/template/src/nnunetv2/utilities/label_handling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/utilities/label_handling/label_handling.py b/docker/template/src/nnunetv2/utilities/label_handling/label_handling.py new file mode 100644 index 0000000..58b2513 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/label_handling/label_handling.py @@ -0,0 +1,322 @@ +from __future__ import annotations +from time import time +from typing import Union, List, Tuple, Type + +import numpy as np +import torch +from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice +from batchgenerators.utilities.file_and_folder_operations import join + +import nnunetv2 +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.helpers import softmax_helper_dim0 + +from typing import TYPE_CHECKING + +# see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/ +if TYPE_CHECKING: + from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager + + +class LabelManager(object): + def __init__(self, label_dict: dict, regions_class_order: Union[List[int], None], force_use_labels: bool = False, + inference_nonlin=None): + self._sanity_check(label_dict) + self.label_dict = label_dict + self.regions_class_order = regions_class_order + self._force_use_labels = force_use_labels + + if force_use_labels: + self._has_regions = False + else: + self._has_regions: bool = any( + [isinstance(i, (tuple, list)) and len(i) > 1 for i in self.label_dict.values()]) + + self._ignore_label: Union[None, int] = self._determine_ignore_label() + self._all_labels: List[int] = self._get_all_labels() + + self._regions: Union[None, List[Union[int, Tuple[int, ...]]]] = self._get_regions() + + if self.has_ignore_label: + assert self.ignore_label == max( + self.all_labels) + 1, 'If you use the ignore label it must have the highest ' \ + 'label value! It cannot be 0 or in between other labels. ' \ + 'Sorry bro.' + + if inference_nonlin is None: + self.inference_nonlin = torch.sigmoid if self.has_regions else softmax_helper_dim0 + else: + self.inference_nonlin = inference_nonlin + + def _sanity_check(self, label_dict: dict): + if not 'background' in label_dict.keys(): + raise RuntimeError('Background label not declared (remember that this should be label 0!)') + bg_label = label_dict['background'] + if isinstance(bg_label, (tuple, list)): + raise RuntimeError(f"Background label must be 0. Not a list. Not a tuple. Your background label: {bg_label}") + assert int(bg_label) == 0, f"Background label must be 0. Your background label: {bg_label}" + # not sure if we want to allow regions that contain background. I don't immediately see how this could cause + # problems so we allow it for now. That doesn't mean that this is explicitly supported. It could be that this + # just crashes. + + def _get_all_labels(self) -> List[int]: + all_labels = [] + for k, r in self.label_dict.items(): + # ignore label is not going to be used, hence the name. Duh. + if k == 'ignore': + continue + if isinstance(r, (tuple, list)): + for ri in r: + all_labels.append(int(ri)) + else: + all_labels.append(int(r)) + all_labels = list(np.unique(all_labels)) + all_labels.sort() + return all_labels + + def _get_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]: + if not self._has_regions or self._force_use_labels: + return None + else: + assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \ + 'define regions_class_order!' + regions = [] + for k, r in self.label_dict.items(): + # ignore ignore label + if k == 'ignore': + continue + # ignore regions that are background + if (np.isscalar(r) and r == 0) \ + or \ + (isinstance(r, (tuple, list)) and len(np.unique(r)) == 1 and np.unique(r)[0] == 0): + continue + if isinstance(r, list): + r = tuple(r) + regions.append(r) + assert len(self.regions_class_order) == len(regions), 'regions_class_order must have as ' \ + 'many entries as there are ' \ + 'regions' + return regions + + def _determine_ignore_label(self) -> Union[None, int]: + ignore_label = self.label_dict.get('ignore') + if ignore_label is not None: + assert isinstance(ignore_label, int), f'Ignore label has to be an integer. It cannot be a region ' \ + f'(list/tuple). Got {type(ignore_label)}.' + return ignore_label + + @property + def has_regions(self) -> bool: + return self._has_regions + + @property + def has_ignore_label(self) -> bool: + return self.ignore_label is not None + + @property + def all_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]: + return self._regions + + @property + def all_labels(self) -> List[int]: + return self._all_labels + + @property + def ignore_label(self) -> Union[None, int]: + return self._ignore_label + + def apply_inference_nonlin(self, logits: Union[np.ndarray, torch.Tensor]) -> \ + Union[np.ndarray, torch.Tensor]: + """ + logits has to have shape (c, x, y(, z)) where c is the number of classes/regions + """ + if isinstance(logits, np.ndarray): + logits = torch.from_numpy(logits) + + with torch.no_grad(): + # softmax etc is not implemented for half + logits = logits.float() + probabilities = self.inference_nonlin(logits) + + return probabilities + + def convert_probabilities_to_segmentation(self, predicted_probabilities: Union[np.ndarray, torch.Tensor]) -> \ + Union[np.ndarray, torch.Tensor]: + """ + assumes that inference_nonlinearity was already applied! + + predicted_probabilities has to have shape (c, x, y(, z)) where c is the number of classes/regions + """ + if not isinstance(predicted_probabilities, (np.ndarray, torch.Tensor)): + raise RuntimeError(f"Unexpected input type. Expected np.ndarray or torch.Tensor," + f" got {type(predicted_probabilities)}") + + if self.has_regions: + assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \ + 'define regions_class_order!' + # check correct number of outputs + assert predicted_probabilities.shape[0] == self.num_segmentation_heads, \ + f'unexpected number of channels in predicted_probabilities. Expected {self.num_segmentation_heads}, ' \ + f'got {predicted_probabilities.shape[0]}. Remember that predicted_probabilities should have shape ' \ + f'(c, x, y(, z)).' + + if self.has_regions: + if isinstance(predicted_probabilities, np.ndarray): + segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.uint16) + else: + # no uint16 in torch + segmentation = torch.zeros(predicted_probabilities.shape[1:], dtype=torch.int16, + device=predicted_probabilities.device) + for i, c in enumerate(self.regions_class_order): + segmentation[predicted_probabilities[i] > 0.5] = c + else: + segmentation = predicted_probabilities.argmax(0) + + return segmentation + + def convert_logits_to_segmentation(self, predicted_logits: Union[np.ndarray, torch.Tensor]) -> \ + Union[np.ndarray, torch.Tensor]: + input_is_numpy = isinstance(predicted_logits, np.ndarray) + probabilities = self.apply_inference_nonlin(predicted_logits) + if input_is_numpy and isinstance(probabilities, torch.Tensor): + probabilities = probabilities.cpu().numpy() + return self.convert_probabilities_to_segmentation(probabilities) + + def revert_cropping_on_probabilities(self, predicted_probabilities: Union[torch.Tensor, np.ndarray], + bbox: List[List[int]], + original_shape: Union[List[int], Tuple[int, ...]]): + """ + ONLY USE THIS WITH PROBABILITIES, DO NOT USE LOGITS AND DO NOT USE FOR SEGMENTATION MAPS!!! + + predicted_probabilities must be (c, x, y(, z)) + + Why do we do this here? Well if we pad probabilities we need to make sure that convert_logits_to_segmentation + correctly returns background in the padded areas. Also we want to ba able to look at the padded probabilities + and not have strange artifacts. + Only LabelManager knows how this needs to be done. So let's let him/her do it, ok? + """ + # revert cropping + probs_reverted_cropping = np.zeros((predicted_probabilities.shape[0], *original_shape), + dtype=predicted_probabilities.dtype) \ + if isinstance(predicted_probabilities, np.ndarray) else \ + torch.zeros((predicted_probabilities.shape[0], *original_shape), dtype=predicted_probabilities.dtype) + + if not self.has_regions: + probs_reverted_cropping[0] = 1 + + slicer = bounding_box_to_slice(bbox) + probs_reverted_cropping[tuple([slice(None)] + list(slicer))] = predicted_probabilities + return probs_reverted_cropping + + @staticmethod + def filter_background(classes_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]]): + # heck yeah + # This is definitely taking list comprehension too far. Enjoy. + return [i for i in classes_or_regions if + ((not isinstance(i, (tuple, list))) and i != 0) + or + (isinstance(i, (tuple, list)) and not ( + len(np.unique(i)) == 1 and np.unique(i)[0] == 0))] + + @property + def foreground_regions(self): + return self.filter_background(self.all_regions) + + @property + def foreground_labels(self): + return self.filter_background(self.all_labels) + + @property + def num_segmentation_heads(self): + if self.has_regions: + return len(self.foreground_regions) + else: + return len(self.all_labels) + + +def get_labelmanager_class_from_plans(plans: dict) -> Type[LabelManager]: + if 'label_manager' not in plans.keys(): + print('No label manager specified in plans. Using default: LabelManager') + return LabelManager + else: + labelmanager_class = recursive_find_python_class(join(nnunetv2.__path__[0], "utilities", "label_handling"), + plans['label_manager'], + current_module="nnunetv2.utilities.label_handling") + return labelmanager_class + + +def convert_labelmap_to_one_hot(segmentation: Union[np.ndarray, torch.Tensor], + all_labels: Union[List, torch.Tensor, np.ndarray, tuple], + output_dtype=None) -> Union[np.ndarray, torch.Tensor]: + """ + if output_dtype is None then we use np.uint8/torch.uint8 + if input is torch.Tensor then output will be on the same device + + np.ndarray is faster than torch.Tensor + + if segmentation is torch.Tensor, this function will be faster if it is LongTensor. If it is somethine else we have + to cast which takes time. + + IMPORTANT: This function only works properly if your labels are consecutive integers, so something like 0, 1, 2, 3, ... + DO NOT use it with 0, 32, 123, 255, ... or whatever (fix your labels, yo) + """ + if isinstance(segmentation, torch.Tensor): + result = torch.zeros((len(all_labels), *segmentation.shape), + dtype=output_dtype if output_dtype is not None else torch.uint8, + device=segmentation.device) + # variant 1, 2x faster than 2 + result.scatter_(0, segmentation[None].long(), 1) # why does this have to be long!? + # variant 2, slower than 1 + # for i, l in enumerate(all_labels): + # result[i] = segmentation == l + else: + result = np.zeros((len(all_labels), *segmentation.shape), + dtype=output_dtype if output_dtype is not None else np.uint8) + # variant 1, fastest in my testing + for i, l in enumerate(all_labels): + result[i] = segmentation == l + # variant 2. Takes about twice as long so nah + # result = np.eye(len(all_labels))[segmentation].transpose((3, 0, 1, 2)) + return result + + +def determine_num_input_channels(plans_manager: PlansManager, + configuration_or_config_manager: Union[str, ConfigurationManager], + dataset_json: dict) -> int: + if isinstance(configuration_or_config_manager, str): + config_manager = plans_manager.get_configuration(configuration_or_config_manager) + else: + config_manager = configuration_or_config_manager + + label_manager = plans_manager.get_label_manager(dataset_json) + num_modalities = len(dataset_json['modality']) if 'modality' in dataset_json.keys() else len(dataset_json['channel_names']) + + # cascade has different number of input channels + if config_manager.previous_stage_name is not None: + num_label_inputs = len(label_manager.foreground_labels) + num_input_channels = num_modalities + num_label_inputs + else: + num_input_channels = num_modalities + return num_input_channels + + +if __name__ == '__main__': + # this code used to be able to differentiate variant 1 and 2 to measure time. + num_labels = 7 + seg = np.random.randint(0, num_labels, size=(256, 256, 256), dtype=np.uint8) + seg_torch = torch.from_numpy(seg) + st = time() + onehot_npy = convert_labelmap_to_one_hot(seg, np.arange(num_labels)) + time_1 = time() + onehot_npy2 = convert_labelmap_to_one_hot(seg, np.arange(num_labels)) + time_2 = time() + onehot_torch = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels)) + time_torch = time() + onehot_torch2 = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels)) + time_torch2 = time() + print( + f'np: {time_1 - st}, np2: {time_2 - time_1}, torch: {time_torch - time_2}, torch2: {time_torch2 - time_torch}') + onehot_torch = onehot_torch.numpy() + onehot_torch2 = onehot_torch2.numpy() + print(np.all(onehot_torch == onehot_npy)) + print(np.all(onehot_torch2 == onehot_npy)) diff --git a/docker/template/src/nnunetv2/utilities/network_initialization.py b/docker/template/src/nnunetv2/utilities/network_initialization.py new file mode 100644 index 0000000..1ead271 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/network_initialization.py @@ -0,0 +1,12 @@ +from torch import nn + + +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) diff --git a/docker/template/src/nnunetv2/utilities/overlay_plots.py b/docker/template/src/nnunetv2/utilities/overlay_plots.py new file mode 100644 index 0000000..66a3b67 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/overlay_plots.py @@ -0,0 +1,273 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +from typing import Tuple, Union + +import numpy as np +import pandas as pd +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.configuration import default_num_processes +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + +color_cycle = ( + "000000", + "4363d8", + "f58231", + "3cb44b", + "e6194B", + "911eb4", + "ffe119", + "bfef45", + "42d4f4", + "f032e6", + "000075", + "9A6324", + "808000", + "800000", + "469990", +) + + +def hex_to_rgb(hex: str): + assert len(hex) == 6 + return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4)) + + +def generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None, + color_cycle: Tuple[str, ...] = color_cycle, + overlay_intensity: float = 0.6): + """ + image can be 2d greyscale or 2d RGB (color channel in last dimension!) + + Segmentation must be label map of same shape as image (w/o color channels) + + mapping can be label_id -> idx_in_cycle or None + + returned image is scaled to [0, 255] (uint8)!!! + """ + # create a copy of image + image = np.copy(input_image) + + if image.ndim == 2: + image = np.tile(image[:, :, None], (1, 1, 3)) + elif image.ndim == 3: + if image.shape[2] == 1: + image = np.tile(image, (1, 1, 3)) + else: + raise RuntimeError(f'if 3d image is given the last dimension must be the color channels (3 channels). ' + f'Only 2D images are supported. Your image shape: {image.shape}') + else: + raise RuntimeError("unexpected image shape. only 2D images and 2D images with color channels (color in " + "last dimension) are supported") + + # rescale image to [0, 255] + image = image - image.min() + image = image / image.max() * 255 + + # create output + if mapping is None: + uniques = np.sort(pd.unique(segmentation.ravel())) # np.unique(segmentation) + mapping = {i: c for c, i in enumerate(uniques)} + + for l in mapping.keys(): + image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]])) + + # rescale result to [0, 255] + image = image / image.max() * 255 + return image.astype(np.uint8) + + +def select_slice_to_plot(image: np.ndarray, segmentation: np.ndarray) -> int: + """ + image and segmentation are expected to be 3D + + selects the slice with the largest amount of fg (regardless of label) + + we give image so that we can easily replace this function if needed + """ + fg_mask = segmentation != 0 + fg_per_slice = fg_mask.sum((1, 2)) + selected_slice = int(np.argmax(fg_per_slice)) + return selected_slice + + +def select_slice_to_plot2(image: np.ndarray, segmentation: np.ndarray) -> int: + """ + image and segmentation are expected to be 3D (or 1, x, y) + + selects the slice with the largest amount of fg (how much percent of each class are in each slice? pick slice + with highest avg percent) + + we give image so that we can easily replace this function if needed + """ + classes = [i for i in np.sort(pd.unique(segmentation.ravel())) if i != 0] + fg_per_slice = np.zeros((image.shape[0], len(classes))) + for i, c in enumerate(classes): + fg_mask = segmentation == c + fg_per_slice[:, i] = fg_mask.sum((1, 2)) + fg_per_slice[:, i] /= fg_per_slice.sum() + fg_per_slice = fg_per_slice.mean(1) + return int(np.argmax(fg_per_slice)) + + +def plot_overlay(image_file: str, segmentation_file: str, image_reader_writer: BaseReaderWriter, output_file: str, + overlay_intensity: float = 0.6): + import matplotlib.pyplot as plt + + image, props = image_reader_writer.read_images((image_file, )) + image = image[0] + seg, props_seg = image_reader_writer.read_seg(segmentation_file) + seg = seg[0] + + assert image.shape == seg.shape, "image and seg do not have the same shape: %s, %s" % ( + image_file, segmentation_file) + + assert image.ndim == 3, 'only 3D images/segs are supported' + + selected_slice = select_slice_to_plot2(image, seg) + # print(image.shape, selected_slice) + + overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity) + + plt.imsave(output_file, overlay) + + +def plot_overlay_preprocessed(case_file: str, output_file: str, overlay_intensity: float = 0.6, channel_idx=0): + import matplotlib.pyplot as plt + data = np.load(case_file)['data'] + seg = np.load(case_file)['seg'][0] + + assert channel_idx < (data.shape[0]), 'This dataset only supports channel index up to %d' % (data.shape[0] - 1) + + image = data[channel_idx] + seg[seg < 0] = 0 + + selected_slice = select_slice_to_plot2(image, seg) + + overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity) + + plt.imsave(output_file, overlay) + + +def multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, image_reader_writer, + list_of_output_files, overlay_intensity, + num_processes=8): + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + r = p.starmap_async(plot_overlay, zip( + list_of_image_files, list_of_seg_files, [image_reader_writer] * len(list_of_output_files), + list_of_output_files, [overlay_intensity] * len(list_of_output_files) + )) + r.get() + + +def multiprocessing_plot_overlay_preprocessed(list_of_case_files, list_of_output_files, overlay_intensity, + num_processes=8, channel_idx=0): + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + r = p.starmap_async(plot_overlay_preprocessed, zip( + list_of_case_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files), + [channel_idx] * len(list_of_output_files) + )) + r.get() + + +def generate_overlays_from_raw(dataset_name_or_id: Union[int, str], output_folder: str, + num_processes: int = 8, channel_idx: int = 0, overlay_intensity: float = 0.6): + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + folder = join(nnUNet_raw, dataset_name) + dataset_json = load_json(join(folder, 'dataset.json')) + dataset = get_filenames_of_train_images_and_targets(folder, dataset_json) + + image_files = [v['images'][channel_idx] for v in dataset.values()] + seg_files = [v['label'] for v in dataset.values()] + + assert all([isfile(i) for i in image_files]) + assert all([isfile(i) for i in seg_files]) + + maybe_mkdir_p(output_folder) + output_files = [join(output_folder, i + '.png') for i in dataset.keys()] + + image_reader_writer = determine_reader_writer_from_dataset_json(dataset_json, image_files[0])() + multiprocessing_plot_overlay(image_files, seg_files, image_reader_writer, output_files, overlay_intensity, num_processes) + + +def generate_overlays_from_preprocessed(dataset_name_or_id: Union[int, str], output_folder: str, + num_processes: int = 8, channel_idx: int = 0, + configuration: str = None, + plans_identifier: str = 'nnUNetPlans', + overlay_intensity: float = 0.6): + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + folder = join(nnUNet_preprocessed, dataset_name) + if not isdir(folder): raise RuntimeError("run preprocessing for that task first") + + plans = load_json(join(folder, plans_identifier + '.json')) + if configuration is None: + if '3d_fullres' in plans['configurations'].keys(): + configuration = '3d_fullres' + else: + configuration = '2d' + data_identifier = plans['configurations'][configuration]["data_identifier"] + preprocessed_folder = join(folder, data_identifier) + + if not isdir(preprocessed_folder): + raise RuntimeError(f"Preprocessed data folder for configuration {configuration} of plans identifier " + f"{plans_identifier} ({dataset_name}) does not exist. Run preprocessing for this " + f"configuration first!") + + identifiers = [i[:-4] for i in subfiles(preprocessed_folder, suffix='.npz', join=False)] + + output_files = [join(output_folder, i + '.png') for i in identifiers] + image_files = [join(preprocessed_folder, i + ".npz") for i in identifiers] + + maybe_mkdir_p(output_folder) + multiprocessing_plot_overlay_preprocessed(image_files, output_files, overlay_intensity=overlay_intensity, + num_processes=num_processes, channel_idx=channel_idx) + + +def entry_point_generate_overlay(): + import argparse + parser = argparse.ArgumentParser("Plots png overlays of the slice with the most foreground. Note that this " + "disregards spacing information!") + parser.add_argument('-d', type=str, help="Dataset name or id", required=True) + parser.add_argument('-o', type=str, help="output folder", required=True) + parser.add_argument('-np', type=int, default=default_num_processes, required=False, + help=f"number of processes used. Default: {default_num_processes}") + parser.add_argument('-channel_idx', type=int, default=0, required=False, + help="channel index used (0 = _0000). Default: 0") + parser.add_argument('--use_raw', action='store_true', required=False, help="if set then we use raw data. else " + "we use preprocessed") + parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', + help='plans identifier. Only used if --use_raw is not set! Default: nnUNetPlans') + parser.add_argument('-c', type=str, required=False, default=None, + help='configuration name. Only used if --use_raw is not set! Default: None = ' + '3d_fullres if available, else 2d') + parser.add_argument('-overlay_intensity', type=float, required=False, default=0.6, + help='overlay intensity. Higher = brighter/less transparent') + + + args = parser.parse_args() + + if args.use_raw: + generate_overlays_from_raw(args.d, args.o, args.np, args.channel_idx, + overlay_intensity=args.overlay_intensity) + else: + generate_overlays_from_preprocessed(args.d, args.o, args.np, args.channel_idx, args.c, args.p, + overlay_intensity=args.overlay_intensity) + + +if __name__ == '__main__': + entry_point_generate_overlay() \ No newline at end of file diff --git a/docker/template/src/nnunetv2/utilities/plans_handling/__init__.py b/docker/template/src/nnunetv2/utilities/plans_handling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker/template/src/nnunetv2/utilities/plans_handling/plans_handler.py b/docker/template/src/nnunetv2/utilities/plans_handling/plans_handler.py new file mode 100644 index 0000000..6c39fd1 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/plans_handling/plans_handler.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +import dynamic_network_architectures +from copy import deepcopy +from functools import lru_cache, partial +from typing import Union, Tuple, List, Type, Callable + +import numpy as np +import torch + +from nnunetv2.preprocessing.resampling.utils import recursive_find_resampling_fn_by_name +from torch import nn + +import nnunetv2 +from batchgenerators.utilities.file_and_folder_operations import load_json, join + +from nnunetv2.imageio.reader_writer_registry import recursive_find_reader_writer_by_name +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.label_handling.label_handling import get_labelmanager_class_from_plans + + +# see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from nnunetv2.utilities.label_handling.label_handling import LabelManager + from nnunetv2.imageio.base_reader_writer import BaseReaderWriter + from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor + from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner + + +class ConfigurationManager(object): + def __init__(self, configuration_dict: dict): + self.configuration = configuration_dict + + def __repr__(self): + return self.configuration.__repr__() + + @property + def data_identifier(self) -> str: + return self.configuration['data_identifier'] + + @property + def preprocessor_name(self) -> str: + return self.configuration['preprocessor_name'] + + @property + @lru_cache(maxsize=1) + def preprocessor_class(self) -> Type[DefaultPreprocessor]: + preprocessor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing"), + self.preprocessor_name, + current_module="nnunetv2.preprocessing") + return preprocessor_class + + @property + def batch_size(self) -> int: + return self.configuration['batch_size'] + + @property + def patch_size(self) -> List[int]: + return self.configuration['patch_size'] + + @property + def median_image_size_in_voxels(self) -> List[int]: + return self.configuration['median_image_size_in_voxels'] + + @property + def spacing(self) -> List[float]: + return self.configuration['spacing'] + + @property + def normalization_schemes(self) -> List[str]: + return self.configuration['normalization_schemes'] + + @property + def use_mask_for_norm(self) -> List[bool]: + return self.configuration['use_mask_for_norm'] + + @property + def UNet_class_name(self) -> str: + return self.configuration['UNet_class_name'] + + @property + @lru_cache(maxsize=1) + def UNet_class(self) -> Type[nn.Module]: + unet_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"), + self.UNet_class_name, + current_module="dynamic_network_architectures.architectures") + if unet_class is None: + raise RuntimeError('The network architecture specified by the plans file ' + 'is non-standard (maybe your own?). Fix this by not using ' + 'ConfigurationManager.UNet_class to instantiate ' + 'it (probably just overwrite build_network_architecture of your trainer.') + return unet_class + + @property + def UNet_base_num_features(self) -> int: + return self.configuration['UNet_base_num_features'] + + @property + def n_conv_per_stage_encoder(self) -> List[int]: + return self.configuration['n_conv_per_stage_encoder'] + + @property + def n_conv_per_stage_decoder(self) -> List[int]: + return self.configuration['n_conv_per_stage_decoder'] + + @property + def num_pool_per_axis(self) -> List[int]: + return self.configuration['num_pool_per_axis'] + + @property + def pool_op_kernel_sizes(self) -> List[List[int]]: + return self.configuration['pool_op_kernel_sizes'] + + @property + def conv_kernel_sizes(self) -> List[List[int]]: + return self.configuration['conv_kernel_sizes'] + + @property + def unet_max_num_features(self) -> int: + return self.configuration['unet_max_num_features'] + + @property + @lru_cache(maxsize=1) + def resampling_fn_data(self) -> Callable[ + [Union[torch.Tensor, np.ndarray], + Union[Tuple[int, ...], List[int], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray] + ], + Union[torch.Tensor, np.ndarray]]: + fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_data']) + fn = partial(fn, **self.configuration['resampling_fn_data_kwargs']) + return fn + + @property + @lru_cache(maxsize=1) + def resampling_fn_probabilities(self) -> Callable[ + [Union[torch.Tensor, np.ndarray], + Union[Tuple[int, ...], List[int], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray] + ], + Union[torch.Tensor, np.ndarray]]: + fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_probabilities']) + fn = partial(fn, **self.configuration['resampling_fn_probabilities_kwargs']) + return fn + + @property + @lru_cache(maxsize=1) + def resampling_fn_seg(self) -> Callable[ + [Union[torch.Tensor, np.ndarray], + Union[Tuple[int, ...], List[int], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray] + ], + Union[torch.Tensor, np.ndarray]]: + fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_seg']) + fn = partial(fn, **self.configuration['resampling_fn_seg_kwargs']) + return fn + + @property + def batch_dice(self) -> bool: + return self.configuration['batch_dice'] + + @property + def next_stage_names(self) -> Union[List[str], None]: + ret = self.configuration.get('next_stage') + if ret is not None: + if isinstance(ret, str): + ret = [ret] + return ret + + @property + def previous_stage_name(self) -> Union[str, None]: + return self.configuration.get('previous_stage') + + +class PlansManager(object): + def __init__(self, plans_file_or_dict: Union[str, dict]): + """ + Why do we need this? + 1) resolve inheritance in configurations + 2) expose otherwise annoying stuff like getting the label manager or IO class from a string + 3) clearly expose the things that are in the plans instead of hiding them in a dict + 4) cache shit + + This class does not prevent you from going wild. You can still use the plans directly if you prefer + (PlansHandler.plans['key']) + """ + self.plans = plans_file_or_dict if isinstance(plans_file_or_dict, dict) else load_json(plans_file_or_dict) + + def __repr__(self): + return self.plans.__repr__() + + def _internal_resolve_configuration_inheritance(self, configuration_name: str, + visited: Tuple[str, ...] = None) -> dict: + if configuration_name not in self.plans['configurations'].keys(): + raise ValueError(f'The configuration {configuration_name} does not exist in the plans I have. Valid ' + f'configuration names are {list(self.plans["configurations"].keys())}.') + configuration = deepcopy(self.plans['configurations'][configuration_name]) + if 'inherits_from' in configuration: + parent_config_name = configuration['inherits_from'] + + if visited is None: + visited = (configuration_name,) + else: + if parent_config_name in visited: + raise RuntimeError(f"Circular dependency detected. The following configurations were visited " + f"while solving inheritance (in that order!): {visited}. " + f"Current configuration: {configuration_name}. Its parent configuration " + f"is {parent_config_name}.") + visited = (*visited, configuration_name) + + base_config = self._internal_resolve_configuration_inheritance(parent_config_name, visited) + base_config.update(configuration) + configuration = base_config + return configuration + + @lru_cache(maxsize=10) + def get_configuration(self, configuration_name: str): + if configuration_name not in self.plans['configurations'].keys(): + raise RuntimeError(f"Requested configuration {configuration_name} not found in plans. " + f"Available configurations: {list(self.plans['configurations'].keys())}") + + configuration_dict = self._internal_resolve_configuration_inheritance(configuration_name) + return ConfigurationManager(configuration_dict) + + @property + def dataset_name(self) -> str: + return self.plans['dataset_name'] + + @property + def plans_name(self) -> str: + return self.plans['plans_name'] + + @property + def original_median_spacing_after_transp(self) -> List[float]: + return self.plans['original_median_spacing_after_transp'] + + @property + def original_median_shape_after_transp(self) -> List[float]: + return self.plans['original_median_shape_after_transp'] + + @property + @lru_cache(maxsize=1) + def image_reader_writer_class(self) -> Type[BaseReaderWriter]: + return recursive_find_reader_writer_by_name(self.plans['image_reader_writer']) + + @property + def transpose_forward(self) -> List[int]: + return self.plans['transpose_forward'] + + @property + def transpose_backward(self) -> List[int]: + return self.plans['transpose_backward'] + + @property + def available_configurations(self) -> List[str]: + return list(self.plans['configurations'].keys()) + + @property + @lru_cache(maxsize=1) + def experiment_planner_class(self) -> Type[ExperimentPlanner]: + planner_name = self.experiment_planner_name + experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), + planner_name, + current_module="nnunetv2.experiment_planning") + return experiment_planner + + @property + def experiment_planner_name(self) -> str: + return self.plans['experiment_planner_used'] + + @property + @lru_cache(maxsize=1) + def label_manager_class(self) -> Type[LabelManager]: + return get_labelmanager_class_from_plans(self.plans) + + def get_label_manager(self, dataset_json: dict, **kwargs) -> LabelManager: + return self.label_manager_class(label_dict=dataset_json['labels'], + regions_class_order=dataset_json.get('regions_class_order'), + **kwargs) + + @property + def foreground_intensity_properties_per_channel(self) -> dict: + if 'foreground_intensity_properties_per_channel' not in self.plans.keys(): + if 'foreground_intensity_properties_by_modality' in self.plans.keys(): + return self.plans['foreground_intensity_properties_by_modality'] + return self.plans['foreground_intensity_properties_per_channel'] + + +if __name__ == '__main__': + from nnunetv2.paths import nnUNet_preprocessed + from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + plans = load_json(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(3), 'nnUNetPlans.json')) + # build new configuration that inherits from 3d_fullres + plans['configurations']['3d_fullres_bs4'] = { + 'batch_size': 4, + 'inherits_from': '3d_fullres' + } + # now get plans and configuration managers + plans_manager = PlansManager(plans) + configuration_manager = plans_manager.get_configuration('3d_fullres_bs4') + print(configuration_manager) # look for batch size 4 diff --git a/docker/template/src/nnunetv2/utilities/utils.py b/docker/template/src/nnunetv2/utilities/utils.py new file mode 100644 index 0000000..b0c16a2 --- /dev/null +++ b/docker/template/src/nnunetv2/utilities/utils.py @@ -0,0 +1,69 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path +from functools import lru_cache +from typing import Union + +from batchgenerators.utilities.file_and_folder_operations import * +import numpy as np +import re + +from nnunetv2.paths import nnUNet_raw + + +def get_identifiers_from_splitted_dataset_folder(folder: str, file_ending: str): + files = subfiles(folder, suffix=file_ending, join=False) + # all files have a 4 digit channel index (_XXXX) + crop = len(file_ending) + 5 + files = [i[:-crop] for i in files] + # only unique image ids + files = np.unique(files) + return files + + +def create_lists_from_splitted_dataset_folder(folder: str, file_ending: str, identifiers: List[str] = None) -> List[ + List[str]]: + """ + does not rely on dataset.json + """ + if identifiers is None: + identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending) + files = subfiles(folder, suffix=file_ending, join=False, sort=True) + list_of_lists = [] + for f in identifiers: + p = re.compile(re.escape(f) + r"_\d\d\d\d" + re.escape(file_ending)) + list_of_lists.append([join(folder, i) for i in files if p.fullmatch(i)]) + return list_of_lists + + +def get_filenames_of_train_images_and_targets(raw_dataset_folder: str, dataset_json: dict = None): + if dataset_json is None: + dataset_json = load_json(join(raw_dataset_folder, 'dataset.json')) + + if 'dataset' in dataset_json.keys(): + dataset = dataset_json['dataset'] + for k in dataset.keys(): + dataset[k]['label'] = os.path.abspath(join(raw_dataset_folder, dataset[k]['label'])) if not os.path.isabs(dataset[k]['label']) else dataset[k]['label'] + dataset[k]['images'] = [os.path.abspath(join(raw_dataset_folder, i)) if not os.path.isabs(i) else i for i in dataset[k]['images']] + else: + identifiers = get_identifiers_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending']) + images = create_lists_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'], identifiers) + segs = [join(raw_dataset_folder, 'labelsTr', i + dataset_json['file_ending']) for i in identifiers] + dataset = {i: {'images': im, 'label': se} for i, im, se in zip(identifiers, images, segs)} + return dataset + + +if __name__ == '__main__': + print(get_filenames_of_train_images_and_targets(join(nnUNet_raw, 'Dataset002_Heart'))) diff --git a/docker/template/src/run.sh b/docker/template/src/run.sh new file mode 100644 index 0000000..2ec089a --- /dev/null +++ b/docker/template/src/run.sh @@ -0,0 +1,5 @@ +#!/bin/bash +# $1 is the csv file containing sample identifiers of test images +# $2 is the input path where test images are located +# $3 is the output path where predicted masks will be stored as images. +python main.py $1 $2 $3 \ No newline at end of file diff --git a/docker/template/src/setup.py b/docker/template/src/setup.py new file mode 100644 index 0000000..d5af773 --- /dev/null +++ b/docker/template/src/setup.py @@ -0,0 +1,67 @@ +from setuptools import setup, find_namespace_packages + +setup(name='nnunetv2', + packages=find_namespace_packages(include=["nnunetv2", "nnunetv2.*"]), + version='2.1.1', + description='nnU-Net. Framework for out-of-the box biomedical image segmentation.', + url='https://github.com/MIC-DKFZ/nnUNet', + author='Helmholtz Imaging Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center', + author_email='f.isensee@dkfz-heidelberg.de', + license='Apache License Version 2.0, January 2004', + python_requires=">=3.10", + install_requires=[ + "torch>=2.0.0", + "acvl-utils>=0.2", + "dynamic-network-architectures>=0.2", + "tqdm", + "mamba-ssm==1.2.0.post1", + "dicom2nifti", + "gdown", + "scikit-image>=0.14", + "medpy", + "scipy", + "batchgenerators>=0.25", + "numpy", + "scikit-learn", + "scikit-image>=0.19.3", + "SimpleITK>=2.2.1", + "pandas", + "graphviz", + 'tifffile', + 'requests', + "nibabel", + "matplotlib", + "seaborn", + "imagecodecs", + "yacs", + "monai==1.3.0", + "opencv-python" + ], + entry_points={ + 'console_scripts': [ + 'nnUNetv2_plan_and_preprocess = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_and_preprocess_entry', # api available + 'nnUNetv2_extract_fingerprint = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:extract_fingerprint_entry', # api available + 'nnUNetv2_plan_experiment = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_experiment_entry', # api available + 'nnUNetv2_preprocess = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:preprocess_entry', # api available + 'nnUNetv2_train = nnunetv2.run.run_training:run_training_entry', # api available + 'nnUNetv2_predict_from_modelfolder = nnunetv2.inference.predict_from_raw_data:predict_entry_point_modelfolder', # api available + 'nnUNetv2_predict = nnunetv2.inference.predict_from_raw_data:predict_entry_point', # api available + 'nnUNetv2_convert_old_nnUNet_dataset = nnunetv2.dataset_conversion.convert_raw_dataset_from_old_nnunet_format:convert_entry_point', # api available + 'nnUNetv2_find_best_configuration = nnunetv2.evaluation.find_best_configuration:find_best_configuration_entry_point', # api available + 'nnUNetv2_determine_postprocessing = nnunetv2.postprocessing.remove_connected_components:entry_point_determine_postprocessing_folder', # api available + 'nnUNetv2_apply_postprocessing = nnunetv2.postprocessing.remove_connected_components:entry_point_apply_postprocessing', # api available + 'nnUNetv2_ensemble = nnunetv2.ensembling.ensemble:entry_point_ensemble_folders', # api available + 'nnUNetv2_accumulate_crossval_results = nnunetv2.evaluation.find_best_configuration:accumulate_crossval_results_entry_point', # api available + 'nnUNetv2_plot_overlay_pngs = nnunetv2.utilities.overlay_plots:entry_point_generate_overlay', # api available + 'nnUNetv2_download_pretrained_model_by_url = nnunetv2.model_sharing.entry_points:download_by_url', # api available + 'nnUNetv2_install_pretrained_model_from_zip = nnunetv2.model_sharing.entry_points:install_from_zip_entry_point', # api available + 'nnUNetv2_export_model_to_zip = nnunetv2.model_sharing.entry_points:export_pretrained_model_entry', # api available + 'nnUNetv2_move_plans_between_datasets = nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets', # api available + 'nnUNetv2_evaluate_folder = nnunetv2.evaluation.evaluate_predictions:evaluate_folder_entry_point', # api available + 'nnUNetv2_evaluate_simple = nnunetv2.evaluation.evaluate_predictions:evaluate_simple_entry_point', # api available + 'nnUNetv2_convert_MSD_dataset = nnunetv2.dataset_conversion.convert_MSD_dataset:entry_point' # api available + ], + }, + keywords=['deep learning', 'image segmentation', 'medical image analysis', + 'medical image segmentation', 'nnU-Net', 'nnunet'] + )