| import os |
| import cv2 |
| import numpy as np |
| import torch |
| import yaml |
| from typing import Optional, Tuple, Union |
| from io import BytesIO |
| from PIL import Image |
| import logging |
| import traceback |
|
|
| from basicsr.models import create_model |
| from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite |
| from basicsr.utils.options import parse |
|
|
| |
| def setup_logger(name, log_level=logging.INFO): |
| """Set up logger.""" |
| logger = logging.getLogger(name) |
| logger.setLevel(log_level) |
| |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| |
| |
| console_handler = logging.StreamHandler() |
| console_handler.setFormatter(formatter) |
| logger.addHandler(console_handler) |
| |
| return logger |
|
|
| logger = setup_logger(__name__) |
|
|
| class NAFNetDeblur: |
| def __init__(self, config_path: str = 'options/test/REDS/NAFNet-width64.yml'): |
| """ |
| Initialize the NAFNet deblurring model. |
| |
| Args: |
| config_path: Path to the model configuration YAML file |
| """ |
| try: |
| logger.info(f"Initializing NAFNet with config: {config_path}") |
| |
| module_dir = os.path.dirname(os.path.abspath(__file__)) |
| if not os.path.isabs(config_path): |
| config_path = os.path.join(module_dir, config_path) |
| |
| |
| if not os.path.exists(config_path): |
| error_msg = f"Config file not found: {config_path}" |
| logger.error(error_msg) |
| raise FileNotFoundError(error_msg) |
| |
| |
| opt = parse(config_path, is_train=False) |
| opt["dist"] = False |
| |
| |
| logger.info("Creating model") |
| self.model = create_model(opt) |
| |
| |
| try: |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| logger.info(f"Using device: {self.device}") |
| except Exception as e: |
| logger.warning(f"Failed to set device. Error: {str(e)}") |
| logger.warning("Using CPU mode") |
| self.device = torch.device('cpu') |
| |
| |
| self.inputs_dir = os.path.join(module_dir, 'inputs') |
| self.outputs_dir = os.path.join(module_dir, 'outputs') |
| |
| |
| os.makedirs(self.inputs_dir, exist_ok=True) |
| os.makedirs(self.outputs_dir, exist_ok=True) |
| |
| logger.info("Model initialized successfully") |
| except Exception as e: |
| logger.error(f"Failed to initialize model: {str(e)}") |
| logger.error(traceback.format_exc()) |
| raise |
|
|
| def imread(self, img_path): |
| """Read an image from file.""" |
| img = cv2.imread(img_path) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| return img |
|
|
| def img2tensor(self, img, bgr2rgb=False, float32=True): |
| """Convert image to tensor.""" |
| img = img.astype(np.float32) / 255.0 |
| return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32) |
|
|
| def deblur_image(self, image: Union[str, np.ndarray, bytes]) -> np.ndarray: |
| """ |
| Deblur an image. |
| |
| Args: |
| image: Input image as a file path, numpy array, or bytes |
| |
| Returns: |
| Deblurred image as a numpy array |
| """ |
| try: |
| |
| if isinstance(image, str): |
| |
| logger.info(f"Loading image from path: {image}") |
| img = self.imread(image) |
| if img is None: |
| raise ValueError(f"Failed to read image from {image}") |
| elif isinstance(image, bytes): |
| |
| logger.info("Loading image from bytes") |
| nparr = np.frombuffer(image, np.uint8) |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| if img is None: |
| |
| pil_img = Image.open(BytesIO(image)) |
| img = np.array(pil_img.convert('RGB')) |
| else: |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| elif isinstance(image, np.ndarray): |
| |
| logger.info("Processing image from numpy array") |
| img = image.copy() |
| if img.shape[2] == 3 and img.dtype == np.uint8: |
| if img[0,0,0] > img[0,0,2]: |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| else: |
| raise ValueError(f"Unsupported image type: {type(image)}") |
|
|
| |
| if img is None or img.size == 0: |
| raise ValueError("Image is empty or invalid") |
| |
| logger.info(f"Image shape: {img.shape}, dtype: {img.dtype}") |
| |
| |
| if len(img.shape) != 3 or img.shape[2] != 3: |
| raise ValueError(f"Image must have 3 channels, got shape {img.shape}") |
| |
| |
| max_dim = max(img.shape[0], img.shape[1]) |
| if max_dim > 2000: |
| scale_factor = 2000 / max_dim |
| new_h = int(img.shape[0] * scale_factor) |
| new_w = int(img.shape[1] * scale_factor) |
| logger.warning(f"Image too large, resizing from {img.shape[:2]} to {(new_h, new_w)}") |
| img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
|
| |
| logger.info("Converting image to tensor") |
| img_tensor = self.img2tensor(img) |
| |
| |
| logger.info("Running inference with model") |
| with torch.no_grad(): |
| try: |
| self.model.feed_data(data={'lq': img_tensor.unsqueeze(dim=0)}) |
| |
| if self.model.opt['val'].get('grids', False): |
| self.model.grids() |
| |
| self.model.test() |
| |
| if self.model.opt['val'].get('grids', False): |
| self.model.grids_inverse() |
| |
| visuals = self.model.get_current_visuals() |
| result = tensor2img([visuals['result']]) |
| except Exception as e: |
| logger.error(f"Error during model inference: {str(e)}") |
| logger.error(traceback.format_exc()) |
| raise |
| |
| logger.info("Image deblurred successfully") |
| return result |
| except Exception as e: |
| logger.error(f"Error in deblur_image: {str(e)}") |
| logger.error(traceback.format_exc()) |
| raise |
|
|
| def save_image(self, image: np.ndarray, output_path: str) -> str: |
| """Save an image to the given path.""" |
| try: |
| |
| save_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
| |
| |
| if not os.path.isabs(output_path): |
| |
| output_path = os.path.join(self.outputs_dir, output_path) |
| |
| |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| |
| cv2.imwrite(output_path, save_img) |
| logger.info(f"Image saved to {output_path}") |
| return output_path |
| except Exception as e: |
| logger.error(f"Error saving image: {str(e)}") |
| logger.error(traceback.format_exc()) |
| raise |
|
|
| def main(): |
| """ |
| Main function to test the NAFNet deblurring model. |
| Processes all images in the inputs directory and saves results to outputs directory. |
| """ |
| try: |
| |
| deblur_model = NAFNetDeblur() |
| |
| |
| inputs_dir = deblur_model.inputs_dir |
| outputs_dir = deblur_model.outputs_dir |
| |
| |
| input_files = [f for f in os.listdir(inputs_dir) if os.path.isfile(os.path.join(inputs_dir, f)) |
| and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] |
| |
| if not input_files: |
| logger.warning(f"No image files found in {inputs_dir}") |
| print(f"No image files found in {inputs_dir}. Please add some images and try again.") |
| return |
| |
| logger.info(f"Found {len(input_files)} images to process") |
| |
| |
| for input_file in input_files: |
| try: |
| input_path = os.path.join(inputs_dir, input_file) |
| output_file = f"deblurred_{input_file}" |
| output_path = os.path.join(outputs_dir, output_file) |
| |
| logger.info(f"Processing {input_file}...") |
| |
| |
| deblurred_img = deblur_model.deblur_image(input_path) |
| |
| |
| deblur_model.save_image(deblurred_img, output_path) |
| |
| logger.info(f"Saved result to {output_path}") |
| |
| except Exception as e: |
| logger.error(f"Error processing {input_file}: {str(e)}") |
| logger.error(traceback.format_exc()) |
| |
| logger.info("Processing complete!") |
| |
| except Exception as e: |
| logger.error(f"Error in main function: {str(e)}") |
| logger.error(traceback.format_exc()) |
| |
| if __name__ == "__main__": |
| main() |