| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import copy |
| | import json |
| | import math |
| | import os |
| | import zipfile |
| | from argparse import Namespace |
| | from datetime import timedelta |
| | from typing import Any, Sequence |
| |
|
| | import numpy as np |
| | import skimage |
| | import torch |
| | import torch.distributed as dist |
| | from monai.bundle import ConfigParser |
| | from monai.config import DtypeLike, NdarrayOrTensor |
| | from monai.data import CacheDataset, DataLoader, partition_dataset |
| | from monai.transforms import Compose, EnsureTyped, Lambdad, LoadImaged, Orientationd |
| | from monai.transforms.utils_morphological_ops import dilate, erode |
| | from monai.utils import TransformBackends, convert_data_type, convert_to_dst_type, get_equivalent_dtype |
| | from scipy import stats |
| | from torch import Tensor |
| |
|
| |
|
| | def unzip_dataset(dataset_dir): |
| | if dist.is_available() and dist.is_initialized(): |
| | rank = dist.get_rank() |
| | else: |
| | rank = 0 |
| |
|
| | if rank == 0: |
| | if not os.path.exists(dataset_dir): |
| | zip_file_path = dataset_dir + ".zip" |
| | if not os.path.isfile(zip_file_path): |
| | raise ValueError(f"Please download {zip_file_path}.") |
| | with zipfile.ZipFile(zip_file_path, "r") as zip_ref: |
| | zip_ref.extractall(path=os.path.dirname(dataset_dir)) |
| | print(f"Unzipped {zip_file_path} to {dataset_dir}.") |
| |
|
| | if dist.is_available() and dist.is_initialized(): |
| | dist.barrier() |
| |
|
| | return |
| |
|
| |
|
| | def add_data_dir2path(list_files: list, data_dir: str, fold: int = None) -> tuple[list, list]: |
| | """ |
| | Read a list of data dictionary. |
| | |
| | Args: |
| | list_files (list): input data to load and transform to generate dataset for model. |
| | data_dir (str): directory of files. |
| | fold (int, optional): fold index for cross validation. Defaults to None. |
| | |
| | Returns: |
| | tuple[list, list]: A tuple of two arrays (training, validation). |
| | """ |
| | new_list_files = copy.deepcopy(list_files) |
| | if fold is not None: |
| | new_list_files_train = [] |
| | new_list_files_val = [] |
| | for d in new_list_files: |
| | d["image"] = os.path.join(data_dir, d["image"]) |
| |
|
| | if "label" in d: |
| | d["label"] = os.path.join(data_dir, d["label"]) |
| |
|
| | if fold is not None: |
| | if d["fold"] == fold: |
| | new_list_files_val.append(copy.deepcopy(d)) |
| | else: |
| | new_list_files_train.append(copy.deepcopy(d)) |
| |
|
| | if fold is not None: |
| | return new_list_files_train, new_list_files_val |
| | else: |
| | return new_list_files, [] |
| |
|
| |
|
| | def maisi_datafold_read(json_list, data_base_dir, fold=None): |
| | with open(json_list, "r") as f: |
| | filenames_train = json.load(f)["training"] |
| | |
| | train_files, val_files = add_data_dir2path(filenames_train, data_base_dir, fold=fold) |
| | print(f"dataset: {data_base_dir}, num_training_files: {len(train_files)}, num_val_files: {len(val_files)}") |
| | return train_files, val_files |
| |
|
| |
|
| | def remap_labels(mask, label_dict_remap_json): |
| | """ |
| | Remap labels in the mask according to the provided label dictionary. |
| | |
| | This function reads a JSON file containing label mapping information and applies |
| | the mapping to the input mask. |
| | |
| | Args: |
| | mask (Tensor): The input mask tensor to be remapped. |
| | label_dict_remap_json (str): Path to the JSON file containing the label mapping dictionary. |
| | |
| | Returns: |
| | Tensor: The remapped mask tensor. |
| | """ |
| | with open(label_dict_remap_json, "r") as f: |
| | mapping_dict = json.load(f) |
| | mapper = MapLabelValue( |
| | orig_labels=[pair[0] for pair in mapping_dict.values()], |
| | target_labels=[pair[1] for pair in mapping_dict.values()], |
| | dtype=torch.uint8, |
| | ) |
| | return mapper(mask[0, ...])[None, ...].to(mask.device) |
| |
|
| |
|
| | def get_index_arr(img): |
| | """ |
| | Generate an index array for the given image. |
| | |
| | This function creates a 3D array of indices corresponding to the dimensions of the input image. |
| | |
| | Args: |
| | img (ndarray): The input image array. |
| | |
| | Returns: |
| | ndarray: A 3D array containing the indices for each dimension of the input image. |
| | """ |
| | return np.moveaxis( |
| | np.moveaxis( |
| | np.stack(np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), np.arange(img.shape[2]))), 0, 3 |
| | ), |
| | 0, |
| | 1, |
| | ) |
| |
|
| |
|
| | def supress_non_largest_components(img, target_label, default_val=0): |
| | """ |
| | Suppress all components except the largest one(s) for specified target labels. |
| | |
| | This function identifies the largest component(s) for each target label and |
| | suppresses all other smaller components. |
| | |
| | Args: |
| | img (ndarray): The input image array. |
| | target_label (list): List of label values to process. |
| | default_val (int, optional): Value to assign to suppressed voxels. Defaults to 0. |
| | |
| | Returns: |
| | tuple: A tuple containing: |
| | - ndarray: Modified image with non-largest components suppressed. |
| | - int: Number of voxels that were changed. |
| | """ |
| | index_arr = get_index_arr(img) |
| | img_mod = copy.deepcopy(img) |
| | new_background = np.zeros(img.shape, dtype=np.bool_) |
| | for label in target_label: |
| | label_cc = skimage.measure.label(img == label, connectivity=3) |
| | uv, uc = np.unique(label_cc, return_counts=True) |
| | dominant_vals = uv[np.argsort(uc)[::-1][:2]] |
| | if len(dominant_vals) >= 2: |
| | new_background = np.logical_or( |
| | new_background, |
| | np.logical_not(np.logical_or(label_cc == dominant_vals[0], label_cc == dominant_vals[1])), |
| | ) |
| |
|
| | for voxel in index_arr[new_background]: |
| | img_mod[tuple(voxel)] = default_val |
| | diff = np.sum((img - img_mod) > 0) |
| |
|
| | return img_mod, diff |
| |
|
| |
|
| | def erode_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor: |
| | """ |
| | Erode 2D/3D binary mask with data type as torch tensor. |
| | |
| | Args: |
| | mask_t: input 2D/3D binary mask, [M,N] or [M,N,P] torch tensor. |
| | filter_size: erosion filter size, has to be odd numbers, default to be 3. |
| | pad_value: the filled value for padding. We need to pad the input before filtering |
| | to keep the output with the same size as input. Usually use default value |
| | and not changed. |
| | |
| | Return: |
| | Tensor: eroded mask, same shape as input. |
| | """ |
| | return erode(mask_t.float().unsqueeze(0).unsqueeze(0), filter_size, pad_value=pad_value).squeeze(0).squeeze(0) |
| |
|
| |
|
| | def dilate_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor: |
| | """ |
| | Dilate 2D/3D binary mask with data type as torch tensor. |
| | |
| | Args: |
| | mask_t: input 2D/3D binary mask, [M,N] or [M,N,P] torch tensor. |
| | filter_size: dilation filter size, has to be odd numbers, default to be 3. |
| | pad_value: the filled value for padding. We need to pad the input before filtering |
| | to keep the output with the same size as input. Usually use default value |
| | and not changed. |
| | |
| | Return: |
| | Tensor: dilated mask, same shape as input. |
| | """ |
| | return dilate(mask_t.float().unsqueeze(0).unsqueeze(0), filter_size, pad_value=pad_value).squeeze(0).squeeze(0) |
| |
|
| |
|
| | def binarize_labels(x: Tensor, bits: int = 8) -> Tensor: |
| | """ |
| | Convert input tensor to binary representation. |
| | |
| | This function takes an input tensor and converts it to a binary representation |
| | using the specified number of bits. |
| | |
| | Args: |
| | x (Tensor): Input tensor with shape (B, 1, H, W, D). |
| | bits (int, optional): Number of bits to use for binary representation. Defaults to 8. |
| | |
| | Returns: |
| | Tensor: Binary representation of the input tensor with shape (B, bits, H, W, D). |
| | """ |
| | mask = 2 ** torch.arange(bits).to(x.device, x.dtype) |
| | return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte().squeeze(1).permute(0, 4, 1, 2, 3) |
| |
|
| |
|
| | def setup_ddp(rank: int, world_size: int) -> torch.device: |
| | """ |
| | Initialize the distributed process group. |
| | |
| | Args: |
| | rank (int): rank of the current process. |
| | world_size (int): number of processes participating in the job. |
| | |
| | Returns: |
| | torch.device: device of the current process. |
| | """ |
| | dist.init_process_group( |
| | backend="nccl", init_method="env://", timeout=timedelta(seconds=36000), rank=rank, world_size=world_size |
| | ) |
| | dist.barrier() |
| | device = torch.device(f"cuda:{rank}") |
| | return device |
| |
|
| |
|
| | def define_instance(args: Namespace, instance_def_key: str) -> Any: |
| | """ |
| | Define and instantiate an object based on the provided arguments and instance definition key. |
| | |
| | This function uses a ConfigParser to parse the arguments and instantiate an object |
| | defined by the instance_def_key. |
| | |
| | Args: |
| | args: An object containing the arguments to be parsed. |
| | instance_def_key (str): The key used to retrieve the instance definition from the parsed content. |
| | |
| | Returns: |
| | The instantiated object as defined by the instance_def_key in the parsed configuration. |
| | """ |
| | parser = ConfigParser(vars(args)) |
| | parser.parse(True) |
| | return parser.get_parsed_content(instance_def_key, instantiate=True) |
| |
|
| |
|
| | def prepare_maisi_controlnet_json_dataloader( |
| | json_data_list: list | str, |
| | data_base_dir: list | str, |
| | batch_size: int = 1, |
| | fold: int = 0, |
| | cache_rate: float = 0.0, |
| | rank: int = 0, |
| | world_size: int = 1, |
| | ) -> tuple[DataLoader, DataLoader]: |
| | """ |
| | Prepare dataloaders for training and validation. |
| | |
| | Args: |
| | json_data_list (list | str): the name of JSON files listing the data. |
| | data_base_dir (list | str): directory of files. |
| | batch_size (int, optional): how many samples per batch to load . Defaults to 1. |
| | fold (int, optional): fold index for cross validation. Defaults to 0. |
| | cache_rate (float, optional): percentage of cached data in total. Defaults to 0.0. |
| | rank (int, optional): rank of the current process. Defaults to 0. |
| | world_size (int, optional): number of processes participating in the job. Defaults to 1. |
| | |
| | Returns: |
| | tuple[DataLoader, DataLoader]: A tuple of two dataloaders (training, validation). |
| | """ |
| | use_ddp = world_size > 1 |
| | if isinstance(json_data_list, list): |
| | assert isinstance(data_base_dir, list) |
| | list_train = [] |
| | list_valid = [] |
| | for data_list, data_root in zip(json_data_list, data_base_dir): |
| | with open(data_list, "r") as f: |
| | json_data = json.load(f)["training"] |
| | train, val = add_data_dir2path(json_data, data_root, fold) |
| | list_train += train |
| | list_valid += val |
| | else: |
| | with open(json_data_list, "r") as f: |
| | json_data = json.load(f)["training"] |
| | list_train, list_valid = add_data_dir2path(json_data, data_base_dir, fold) |
| |
|
| | common_transform = [ |
| | LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True), |
| | Orientationd(keys=["label"], axcodes="RAS"), |
| | EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True), |
| | Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)), |
| | Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)), |
| | Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), |
| | Lambdad(keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2), |
| | ] |
| | train_transforms, val_transforms = Compose(common_transform), Compose(common_transform) |
| |
|
| | train_loader = None |
| |
|
| | if use_ddp: |
| | list_train = partition_dataset(data=list_train, shuffle=True, num_partitions=world_size, even_divisible=True)[ |
| | rank |
| | ] |
| | train_ds = CacheDataset(data=list_train, transform=train_transforms, cache_rate=cache_rate, num_workers=8) |
| | train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) |
| | if use_ddp: |
| | list_valid = partition_dataset(data=list_valid, shuffle=True, num_partitions=world_size, even_divisible=False)[ |
| | rank |
| | ] |
| | val_ds = CacheDataset(data=list_valid, transform=val_transforms, cache_rate=cache_rate, num_workers=8) |
| | val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False) |
| | return train_loader, val_loader |
| |
|
| |
|
| | def organ_fill_by_closing(data, target_label, device, close_times=2, filter_size=3, pad_value=0.0): |
| | """ |
| | Fill holes in an organ mask using morphological closing operations. |
| | |
| | This function performs a series of dilation and erosion operations to fill holes |
| | in the organ mask identified by the target label. |
| | |
| | Args: |
| | data (ndarray): The input data containing organ labels. |
| | target_label (int): The label of the organ to be processed. |
| | device (str): The device to perform the operations on (e.g., 'cuda:0'). |
| | close_times (int, optional): Number of times to perform the closing operation. Defaults to 2. |
| | filter_size (int, optional): Size of the filter for dilation and erosion. Defaults to 3. |
| | pad_value (float, optional): Value used for padding in dilation and erosion. Defaults to 0.0. |
| | |
| | Returns: |
| | ndarray: Boolean mask of the filled organ. |
| | """ |
| | mask = (data == target_label).astype(np.uint8) |
| | mask = torch.from_numpy(mask).to(device) |
| | for _ in range(close_times): |
| | mask = dilate_one_img(mask, filter_size=filter_size, pad_value=pad_value) |
| | mask = erode_one_img(mask, filter_size=filter_size, pad_value=pad_value) |
| | return mask.cpu().numpy().astype(np.bool_) |
| |
|
| |
|
| | def organ_fill_by_removed_mask(data, target_label, remove_mask, device): |
| | """ |
| | Fill an organ mask in regions where it was previously removed. |
| | |
| | Args: |
| | data (ndarray): The input data containing organ labels. |
| | target_label (int): The label of the organ to be processed. |
| | remove_mask (ndarray): Boolean mask indicating regions where the organ was removed. |
| | device (str): The device to perform the operations on (e.g., 'cuda:0'). |
| | |
| | Returns: |
| | ndarray: Boolean mask of the filled organ in previously removed regions. |
| | """ |
| | mask = (data == target_label).astype(np.uint8) |
| | mask = dilate_one_img(torch.from_numpy(mask).to(device), filter_size=3, pad_value=0.0) |
| | mask = dilate_one_img(mask, filter_size=3, pad_value=0.0) |
| | roi_oragn_mask = dilate_one_img(mask, filter_size=3, pad_value=0.0).cpu().numpy() |
| | return (roi_oragn_mask * remove_mask).astype(np.bool_) |
| |
|
| |
|
| | def get_body_region_index_from_mask(input_mask): |
| | """ |
| | Determine the top and bottom body region indices from an input mask. |
| | |
| | Args: |
| | input_mask (Tensor): Input mask tensor containing body region labels. |
| | |
| | Returns: |
| | tuple: Two lists representing the top and bottom region indices. |
| | """ |
| | region_indices = {} |
| | |
| | region_indices["region_0"] = [22, 120] |
| | |
| | region_indices["region_1"] = [28, 29, 30, 31, 32] |
| | |
| | region_indices["region_2"] = [1, 2, 3, 4, 5, 14] |
| | |
| | region_indices["region_3"] = [93, 94] |
| |
|
| | nda = input_mask.cpu().numpy().squeeze() |
| | unique_elements = np.lib.arraysetops.unique(nda) |
| | unique_elements = list(unique_elements) |
| | |
| | overlap_array = np.zeros(len(region_indices), dtype=np.uint8) |
| | for _j in range(len(region_indices)): |
| | overlap = any(element in region_indices[f"region_{_j}"] for element in unique_elements) |
| | overlap_array[_j] = np.uint8(overlap) |
| | overlap_array_indices = np.nonzero(overlap_array)[0] |
| | top_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amin(overlap_array_indices), ...] |
| | top_region_index = list(top_region_index) |
| | top_region_index = [int(_k) for _k in top_region_index] |
| | bottom_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amax(overlap_array_indices), ...] |
| | bottom_region_index = list(bottom_region_index) |
| | bottom_region_index = [int(_k) for _k in bottom_region_index] |
| | |
| | return top_region_index, bottom_region_index |
| |
|
| |
|
| | def general_mask_generation_post_process(volume_t, target_tumor_label=None, device="cuda:0"): |
| | """ |
| | Perform post-processing on a generated mask volume. |
| | |
| | This function applies various refinement steps to improve the quality of the generated mask, |
| | including body mask refinement, tumor prediction refinement, and organ-specific processing. |
| | |
| | Args: |
| | volume_t (ndarray): Input volume containing organ and tumor labels. |
| | target_tumor_label (int, optional): Label of the target tumor. Defaults to None. |
| | device (str, optional): Device to perform operations on. Defaults to "cuda:0". |
| | |
| | Returns: |
| | ndarray: Post-processed volume with refined organ and tumor labels. |
| | """ |
| | |
| | hepatic_vessel = volume_t == 25 |
| | airway = volume_t == 132 |
| |
|
| | |
| | body_region_mask = ( |
| | erode_one_img(torch.from_numpy((volume_t > 0)).to(device), filter_size=3, pad_value=0.0).cpu().numpy() |
| | ) |
| | body_region_mask, _ = supress_non_largest_components(body_region_mask, [1]) |
| | body_region_mask = ( |
| | dilate_one_img(torch.from_numpy(body_region_mask).to(device), filter_size=3, pad_value=0.0) |
| | .cpu() |
| | .numpy() |
| | .astype(np.uint8) |
| | ) |
| | volume_t = volume_t * body_region_mask |
| |
|
| | |
| | tumor_organ_dict = {23: 28, 24: 4, 26: 1, 27: 62, 128: 200} |
| | for t in [23, 24, 26, 27, 128]: |
| | if t != target_tumor_label: |
| | volume_t[volume_t == t] = tumor_organ_dict[t] |
| | else: |
| | volume_t[organ_fill_by_closing(volume_t, target_label=t, device=device)] = t |
| | volume_t[organ_fill_by_closing(volume_t, target_label=t, device=device)] = t |
| | |
| | if target_tumor_label != 26 and target_tumor_label != 128: |
| | volume_t, _ = supress_non_largest_components(volume_t, [target_tumor_label], default_val=200) |
| | target_tumor = volume_t == target_tumor_label |
| |
|
| | |
| | |
| | |
| | oran_list = [1, 4, 10, 12, 3, 28, 29, 30, 31, 32, 5, 14, 13, 6, 7, 8, 9, 10] |
| | if target_tumor_label != 128: |
| | oran_list += list(range(33, 60)) |
| | data, _ = supress_non_largest_components(volume_t, oran_list, default_val=200) |
| | organ_remove_mask = (volume_t - data).astype(np.bool_) |
| | |
| | intestinal_mask_ = ( |
| | (data == 12).astype(np.uint8) |
| | + (data == 13).astype(np.uint8) |
| | + (data == 19).astype(np.uint8) |
| | + (data == 62).astype(np.uint8) |
| | ) |
| | intestinal_mask, _ = supress_non_largest_components(intestinal_mask_, [1], default_val=0) |
| | |
| | small_bowel_remove_mask = (data == 19).astype(np.uint8) - (data == 19).astype(np.uint8) * intestinal_mask |
| | |
| | colon_remove_mask = (data == 62).astype(np.uint8) - (data == 62).astype(np.uint8) * intestinal_mask |
| | intestinal_remove_mask = (small_bowel_remove_mask + colon_remove_mask).astype(np.bool_) |
| | data[intestinal_remove_mask] = 200 |
| |
|
| | |
| | for organ_label in oran_list: |
| | data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
| |
|
| | if target_tumor_label == 23 and np.sum(target_tumor) > 0: |
| | |
| | dia_lung_tumor_mask = ( |
| | dilate_one_img(torch.from_numpy((data == 23)).to(device), filter_size=3, pad_value=0.0).cpu().numpy() |
| | ) |
| | tmp = ( |
| | (data * (dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8))).astype(np.float32).flatten() |
| | ) |
| | tmp[tmp == 0] = float("nan") |
| | mode = int(stats.mode(tmp.flatten(), nan_policy="omit")[0]) |
| | if mode in [28, 29, 30, 31, 32]: |
| | dia_lung_tumor_mask = ( |
| | dilate_one_img(torch.from_numpy(dia_lung_tumor_mask).to(device), filter_size=3, pad_value=0.0) |
| | .cpu() |
| | .numpy() |
| | ) |
| | lung_remove_mask = dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8).astype(np.uint8) |
| | data[organ_fill_by_removed_mask(data, target_label=mode, remove_mask=lung_remove_mask, device=device)] = ( |
| | mode |
| | ) |
| | dia_lung_tumor_mask = ( |
| | dilate_one_img(torch.from_numpy(dia_lung_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() |
| | ) |
| | data[ |
| | organ_fill_by_removed_mask( |
| | data, target_label=23, remove_mask=dia_lung_tumor_mask * organ_remove_mask, device=device |
| | ) |
| | ] = 23 |
| | for organ_label in [28, 29, 30, 31, 32]: |
| | data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
| | data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
| | data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
| |
|
| | if target_tumor_label == 26 and np.sum(target_tumor) > 0: |
| | |
| | |
| | data[organ_fill_by_removed_mask(data, target_label=1, remove_mask=intestinal_remove_mask, device=device)] = 1 |
| | data[organ_fill_by_removed_mask(data, target_label=1, remove_mask=intestinal_remove_mask, device=device)] = 1 |
| | |
| | data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3 |
| | data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3 |
| | dia_tumor_mask = ( |
| | dilate_one_img(torch.from_numpy((data == target_tumor_label)).to(device), filter_size=3, pad_value=0.0) |
| | .cpu() |
| | .numpy() |
| | ) |
| | dia_tumor_mask = ( |
| | dilate_one_img(torch.from_numpy(dia_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() |
| | ) |
| | data[ |
| | organ_fill_by_removed_mask( |
| | data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device |
| | ) |
| | ] = target_tumor_label |
| | |
| | hepatic_tumor_vessel_liver_mask_ = ( |
| | (data == 26).astype(np.uint8) + (data == 25).astype(np.uint8) + (data == 1).astype(np.uint8) |
| | ) |
| | hepatic_tumor_vessel_liver_mask_ = (hepatic_tumor_vessel_liver_mask_ > 1).astype(np.uint8) |
| | hepatic_tumor_vessel_liver_mask, _ = supress_non_largest_components( |
| | hepatic_tumor_vessel_liver_mask_, [1], default_val=0 |
| | ) |
| | removed_region = (hepatic_tumor_vessel_liver_mask_ - hepatic_tumor_vessel_liver_mask).astype(np.bool_) |
| | data[removed_region] = 200 |
| | target_tumor = (target_tumor * hepatic_tumor_vessel_liver_mask).astype(np.bool_) |
| | |
| | data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 |
| | data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 |
| | data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 |
| |
|
| | if target_tumor_label == 27 and np.sum(target_tumor) > 0: |
| | |
| | dia_tumor_mask = ( |
| | dilate_one_img(torch.from_numpy((data == target_tumor_label)).to(device), filter_size=3, pad_value=0.0) |
| | .cpu() |
| | .numpy() |
| | ) |
| | dia_tumor_mask = ( |
| | dilate_one_img(torch.from_numpy(dia_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() |
| | ) |
| | data[ |
| | organ_fill_by_removed_mask( |
| | data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device |
| | ) |
| | ] = target_tumor_label |
| |
|
| | if target_tumor_label == 129 and np.sum(target_tumor) > 0: |
| | |
| | for organ_label in [5, 14]: |
| | data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
| | data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
| | data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
| | |
| | |
| | print( |
| | "Current model does not support hepatic vessel by size control, " |
| | "so we treat generated hepatic vessel as part of liver for better visiaulization." |
| | ) |
| | data[hepatic_vessel] = 1 |
| | data[airway] = 132 |
| | if target_tumor_label is not None: |
| | data[target_tumor] = target_tumor_label |
| |
|
| | return data |
| |
|
| |
|
| | class MapLabelValue: |
| | """ |
| | Utility to map label values to another set of values. |
| | For example, map [3, 2, 1] to [0, 1, 2], [1, 2, 3] -> [0.5, 1.5, 2.5], ["label3", "label2", "label1"] -> [0, 1, 2], |
| | [3.5, 2.5, 1.5] -> ["label0", "label1", "label2"], etc. |
| | The label data must be numpy array or array-like data and the output data will be numpy array. |
| | |
| | """ |
| |
|
| | backend = [TransformBackends.NUMPY, TransformBackends.TORCH] |
| |
|
| | def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None: |
| | """ |
| | Args: |
| | orig_labels: original labels that map to others. |
| | target_labels: expected label values, 1: 1 map to the `orig_labels`. |
| | dtype: convert the output data to dtype, default to float32. |
| | if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend. |
| | |
| | """ |
| | if len(orig_labels) != len(target_labels): |
| | raise ValueError("orig_labels and target_labels must have the same length.") |
| |
|
| | self.orig_labels = orig_labels |
| | self.target_labels = target_labels |
| | self.pair = tuple((o, t) for o, t in zip(self.orig_labels, self.target_labels) if o != t) |
| | type_dtype = type(dtype) |
| | if getattr(type_dtype, "__module__", "") == "torch": |
| | self.use_numpy = False |
| | self.dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor) |
| | else: |
| | self.use_numpy = True |
| | self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray) |
| |
|
| | def __call__(self, img: NdarrayOrTensor): |
| | """ |
| | Apply the label mapping to the input image. |
| | |
| | Args: |
| | img (NdarrayOrTensor): Input image to be remapped. |
| | |
| | Returns: |
| | NdarrayOrTensor: Remapped image. |
| | """ |
| | if self.use_numpy: |
| | img_np, *_ = convert_data_type(img, np.ndarray) |
| | _out_shape = img_np.shape |
| | img_flat = img_np.flatten() |
| | try: |
| | out_flat = img_flat.astype(self.dtype) |
| | except ValueError: |
| | |
| | out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) |
| | for o, t in self.pair: |
| | out_flat[img_flat == o] = t |
| | out_t = out_flat.reshape(_out_shape) |
| | else: |
| | img_t, *_ = convert_data_type(img, torch.Tensor) |
| | out_t = img_t.detach().clone().to(self.dtype) |
| | for o, t in self.pair: |
| | out_t[img_t == o] = t |
| | out, *_ = convert_to_dst_type(src=out_t, dst=img, dtype=self.dtype) |
| | return out |
| |
|
| |
|
| | def dynamic_infer(inferer, model, images): |
| | """ |
| | Perform dynamic inference using a model and an inferer, typically a monai SlidingWindowInferer. |
| | |
| | This function determines whether to use the model directly or to use the provided inferer |
| | (such as a sliding window inferer) based on the size of the input images. |
| | |
| | Args: |
| | inferer: An inference object, typically a monai SlidingWindowInferer, which handles patch-based inference. |
| | model (torch.nn.Module): The model used for inference. |
| | images (torch.Tensor): The input images for inference, shape [N,C,H,W,D] or [N,C,H,W]. |
| | |
| | Returns: |
| | torch.Tensor: The output from the model or the inferer, depending on the input size. |
| | """ |
| | if torch.numel(images[0:1, 0:1, ...]) < math.prod(inferer.roi_size): |
| | return model(images) |
| | else: |
| | |
| | spatial_dims = images.shape[2:] |
| | orig_roi = inferer.roi_size |
| |
|
| | |
| | if len(orig_roi) != len(spatial_dims): |
| | raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).") |
| |
|
| | |
| | adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)] |
| | inferer.roi_size = adjusted_roi |
| | output = inferer(network=model, inputs=images) |
| | inferer.roi_size = orig_roi |
| | return output |
| |
|