| | import sys |
| | from model.trainer import Trainer |
| |
|
| | sys.path.insert(0, '.') |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import torch.backends.cudnn as cudnn |
| | from torch.nn.parallel import gather |
| | import torch.optim.lr_scheduler |
| |
|
| | import dataset.dataset as myDataLoader |
| | import dataset.Transforms as myTransforms |
| | from model.metric_tool import ConfuseMatrixMeter |
| | from model.utils import BCEDiceLoss, init_seed |
| | from PIL import Image |
| | import os |
| | import time |
| | import numpy as np |
| | from argparse import ArgumentParser |
| | from tqdm import tqdm |
| |
|
| |
|
| | @torch.no_grad() |
| | def validate(args, val_loader, model, save_masks=False): |
| | model.eval() |
| |
|
| | |
| | for m in model.modules(): |
| | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): |
| | m.track_running_stats = True |
| | m.eval() |
| |
|
| | salEvalVal = ConfuseMatrixMeter(n_class=2) |
| | epoch_loss = [] |
| |
|
| | if save_masks: |
| | mask_dir = f"{args.savedir}/pred_masks" |
| | os.makedirs(mask_dir, exist_ok=True) |
| | print(f"Saving prediction masks to: {mask_dir}") |
| |
|
| | pbar = tqdm(enumerate(val_loader), total=len(val_loader), desc="Validating") |
| |
|
| | for batch_idx, batched_inputs in pbar: |
| | img, target = batched_inputs |
| | |
| | batch_file_names = val_loader.sampler.data_source.file_list[ |
| | batch_idx * args.batch_size : (batch_idx + 1) * args.batch_size |
| | ] |
| | |
| | pre_img = img[:, 0:3] |
| | post_img = img[:, 3:6] |
| |
|
| | if args.onGPU: |
| | pre_img = pre_img.cuda() |
| | post_img = post_img.cuda() |
| | target = target.cuda() |
| |
|
| | target = target.float() |
| | output = model(pre_img, post_img) |
| | loss = BCEDiceLoss(output, target) |
| | pred = (output > 0.5).long() |
| |
|
| | if save_masks: |
| | pred_np = pred.cpu().numpy().astype(np.uint8) |
| | |
| | print(f"\nDebug - Batch {batch_idx}: {len(batch_file_names)} files, Mask shape: {pred_np.shape}") |
| | |
| | try: |
| | for i in range(pred_np.shape[0]): |
| | if i >= len(batch_file_names): |
| | print(f"Warning: Missing filename for mask {i}, using default") |
| | base_name = f"batch_{batch_idx}_mask_{i}" |
| | else: |
| | base_name = os.path.splitext(os.path.basename(batch_file_names[i]))[0] |
| | |
| | single_mask = pred_np[i, 0] |
| | |
| | if single_mask.ndim != 2: |
| | raise ValueError(f"Invalid mask shape: {single_mask.shape}") |
| | |
| | mask_path = f"{mask_dir}/{base_name}_pred.png" |
| | Image.fromarray(single_mask * 255).save(mask_path) |
| | print(f"Saved: {mask_path}") |
| |
|
| | except Exception as e: |
| | print(f"\nError saving batch {batch_idx}: {str(e)}") |
| | print(f"Current mask shape: {single_mask.shape if 'single_mask' in locals() else 'N/A'}") |
| | print(f"Current file: {base_name if 'base_name' in locals() else 'N/A'}") |
| |
|
| | if args.onGPU and torch.cuda.device_count() > 1: |
| | pred = gather(pred, 0, dim=0) |
| |
|
| | f1 = salEvalVal.update_cm(pr=pred.cpu().numpy(), gt=target.cpu().numpy()) |
| | epoch_loss.append(loss.item()) |
| |
|
| | pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'F1': f"{f1:.4f}"}) |
| |
|
| | average_loss = sum(epoch_loss) / len(epoch_loss) |
| | scores = salEvalVal.get_scores() |
| | return average_loss, scores |
| | |
| | def ValidateSegmentation(args): |
| | """完整的验证流程主函数""" |
| | |
| | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| | torch.backends.cudnn.benchmark = True |
| | init_seed(args.seed) |
| |
|
| | |
| | args.savedir = os.path.join(args.savedir, |
| | f"{args.file_root}_iter_{args.max_steps}_lr_{args.lr}") |
| | os.makedirs(args.savedir, exist_ok=True) |
| |
|
| | |
| | dataset_mapping = { |
| | 'LEVIR': './levir_cd_256', |
| | 'WHU': './whu_cd_256', |
| | 'CLCD': './clcd_256', |
| | 'SYSU': './sysu_256', |
| | 'OSCD': './oscd_256' |
| | } |
| | args.file_root = dataset_mapping.get(args.file_root, args.file_root) |
| |
|
| | |
| | model = Trainer(args.model_type).float() |
| | if args.onGPU: |
| | model = model.cuda() |
| |
|
| | |
| | mean = [0.406, 0.456, 0.485, 0.406, 0.456, 0.485] |
| | std = [0.225, 0.224, 0.229, 0.225, 0.224, 0.229] |
| |
|
| | valDataset = myTransforms.Compose([ |
| | myTransforms.Normalize(mean=mean, std=std), |
| | myTransforms.Scale(args.inWidth, args.inHeight), |
| | myTransforms.ToTensor() |
| | ]) |
| |
|
| | |
| | test_data = myDataLoader.Dataset(file_root=args.file_root, mode="test", transform=valDataset) |
| | testLoader = torch.utils.data.DataLoader( |
| | test_data, |
| | batch_size=args.batch_size, |
| | shuffle=False, |
| | num_workers=args.num_workers, |
| | pin_memory=True |
| | ) |
| |
|
| | |
| | logFileLoc = os.path.join(args.savedir, args.logFile) |
| | logger = open(logFileLoc, 'a' if os.path.exists(logFileLoc) else 'w') |
| | if not os.path.exists(logFileLoc): |
| | logger.write("\n%s\t%s\t%s\t%s\t%s\t%s\t%s" % |
| | ('Epoch', 'Kappa', 'IoU', 'F1', 'Recall', 'Precision', 'OA')) |
| | logger.flush() |
| |
|
| | |
| | model_file_name = os.path.join(args.savedir, 'best_model.pth') |
| | if not os.path.exists(model_file_name): |
| | raise FileNotFoundError(f"Model file not found: {model_file_name}") |
| |
|
| | state_dict = torch.load(model_file_name) |
| | model.load_state_dict(state_dict) |
| | print(f"Loaded model from {model_file_name}") |
| |
|
| | |
| | loss_test, score_test = validate(args, testLoader, model, save_masks=args.save_masks) |
| |
|
| | |
| | print("\nTest Results:") |
| | print(f"Loss: {loss_test:.4f}") |
| | print(f"Kappa: {score_test['Kappa']:.4f}") |
| | print(f"IoU: {score_test['IoU']:.4f}") |
| | print(f"F1: {score_test['F1']:.4f}") |
| | print(f"Recall: {score_test['recall']:.4f}") |
| | print(f"Precision: {score_test['precision']:.4f}") |
| | print(f"OA: {score_test['OA']:.4f}") |
| |
|
| | |
| | logger.write("\n%s\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" % |
| | ('Test', score_test['Kappa'], score_test['IoU'], score_test['F1'], |
| | score_test['recall'], score_test['precision'], score_test['OA'])) |
| | logger.close() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = ArgumentParser() |
| | parser.add_argument('--file_root', default="LEVIR", |
| | help='Data directory | LEVIR | WHU | CLCD | SYSU | OSCD') |
| | parser.add_argument('--inWidth', type=int, default=256, help='Width of input image') |
| | parser.add_argument('--inHeight', type=int, default=256, help='Height of input image') |
| | parser.add_argument('--max_steps', type=int, default=80000, |
| | help='Max. number of iterations (for path naming)') |
| | parser.add_argument('--num_workers', type=int, default=4, |
| | help='Number of data loading workers') |
| | parser.add_argument('--model_type', type=str, default='small', |
| | help='Model type | tiny | small') |
| | parser.add_argument('--batch_size', type=int, default=16, |
| | help='Batch size for validation') |
| | parser.add_argument('--lr', type=float, default=2e-4, |
| | help='Learning rate (for path naming)') |
| | parser.add_argument('--seed', type=int, default=16, |
| | help='Random seed for reproducibility') |
| | parser.add_argument('--savedir', default='./results', |
| | help='Base directory to save results') |
| | parser.add_argument('--logFile', default='testLog.txt', |
| | help='File to save validation logs') |
| | parser.add_argument('--onGPU', default=True, |
| | type=lambda x: (str(x).lower() == 'true'), |
| | help='Run on GPU if True') |
| | parser.add_argument('--gpu_id', type=int, default=0, |
| | help='GPU device id') |
| | parser.add_argument('--save_masks', action='store_true', |
| | help='Save predicted masks to disk') |
| |
|
| | args = parser.parse_args() |
| | print('Validation with args:') |
| | print(args) |
| |
|
| | ValidateSegmentation(args) |
| |
|
| |
|