| import os |
| import uuid |
| import gc |
| from typing import Optional |
| from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks |
| from fastapi.responses import FileResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| import uvicorn |
| import logging |
| import traceback |
| import shutil |
| import torch |
|
|
| from deblur_module import NAFNetDeblur, setup_logger |
|
|
| |
| logger = setup_logger(__name__) |
|
|
| |
| API_URL = "http://localhost:8001" |
|
|
| |
| app = FastAPI( |
| title="NAFNet Debluring API", |
| description="API for deblurring images using NAFNet deep learning model", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| model = None |
| processing_lock = False |
|
|
| def get_model(): |
| global model |
| if model is None: |
| logger.info("Initializing NAFNet deblurring model...") |
| try: |
| model = NAFNetDeblur() |
| logger.info("Model initialized successfully") |
| except Exception as e: |
| logger.error(f"Failed to initialize model: {str(e)}") |
| logger.error(traceback.format_exc()) |
| raise RuntimeError(f"Could not initialize NAFNet model: {str(e)}") |
| return model |
|
|
| def cleanup_resources(input_path=None): |
| """Clean up resources after processing""" |
| |
| gc.collect() |
| |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| |
| if input_path and os.path.exists(input_path): |
| try: |
| os.remove(input_path) |
| logger.info(f"Removed temporary input file: {input_path}") |
| except Exception as e: |
| logger.warning(f"Could not remove temporary input file: {str(e)}") |
| |
| |
| global processing_lock |
| processing_lock = False |
| logger.info("Resources cleaned up") |
| |
| @app.get("/") |
| async def root(): |
| return {"message": "NAFNet Debluring API is running"} |
|
|
| @app.post("/deblur/", response_class=FileResponse) |
| async def deblur_image(background_tasks: BackgroundTasks, file: UploadFile = File(...)): |
| """ |
| Deblur an uploaded image and return the processed image file. |
| """ |
| global processing_lock |
| input_path = None |
| |
| |
| if processing_lock: |
| logger.warning("Another deblurring request is already in progress") |
| raise HTTPException( |
| status_code=429, |
| detail="Server is busy processing another image. Please try again shortly." |
| ) |
| |
| |
| processing_lock = True |
| |
| try: |
| |
| if not file.content_type.startswith("image/"): |
| logger.warning(f"Invalid file type: {file.content_type}") |
| processing_lock = False |
| raise HTTPException(status_code=400, detail="File must be an image") |
| |
| logger.info(f"Processing image: {file.filename}, size: {file.size} bytes, type: {file.content_type}") |
| |
| |
| module_dir = os.path.dirname(os.path.abspath(__file__)) |
| input_dir = os.path.join(module_dir, 'inputs') |
| output_dir = os.path.join(module_dir, 'outputs') |
| os.makedirs(input_dir, exist_ok=True) |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| unique_id = uuid.uuid4().hex |
| input_filename = f"input_{unique_id}.png" |
| input_path = os.path.join(input_dir, input_filename) |
| output_filename = f"deblurred_{unique_id}.png" |
| output_path = os.path.join(output_dir, output_filename) |
| |
| try: |
| |
| file_contents = await file.read() |
| |
| |
| with open(input_path, "wb") as buffer: |
| buffer.write(file_contents) |
| |
| logger.info(f"Input image saved to: {input_path}") |
| |
| |
| await file.close() |
| |
| |
| deblur_model = get_model() |
| |
| |
| logger.info("Starting deblurring process...") |
| deblurred_img = deblur_model.deblur_image(input_path) |
| |
| |
| deblur_model.save_image(deblurred_img, output_filename) |
| |
| logger.info(f"Image deblurred successfully, saved to: {output_path}") |
| |
| |
| background_tasks.add_task(cleanup_resources, input_path) |
| |
| |
| return FileResponse( |
| output_path, |
| media_type="image/png", |
| filename=f"deblurred_{file.filename}" |
| ) |
| except Exception as e: |
| logger.error(f"Error in deblurring process: {str(e)}") |
| logger.error(traceback.format_exc()) |
| |
| cleanup_resources(input_path) |
| raise HTTPException(status_code=500, detail=f"Deblurring failed: {str(e)}") |
| |
| except HTTPException: |
| |
| raise |
| except Exception as e: |
| error_msg = f"Error processing image: {str(e)}" |
| logger.error(error_msg) |
| logger.error(traceback.format_exc()) |
| |
| processing_lock = False |
| raise HTTPException(status_code=500, detail=error_msg) |
|
|
| @app.get("/status/") |
| async def status(): |
| """Check API status and model availability.""" |
| try: |
| logger.info("Checking model status") |
| |
| |
| if processing_lock: |
| return { |
| "status": "busy", |
| "model_loaded": True, |
| "message": "Currently processing an image" |
| } |
| |
| |
| deblur_model = get_model() |
| |
| |
| memory_info = {} |
| if torch.cuda.is_available(): |
| memory_info["cuda_memory_allocated"] = f"{torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
| memory_info["cuda_memory_reserved"] = f"{torch.cuda.memory_reserved() / 1024**2:.2f} MB" |
| memory_info["cuda_max_memory"] = f"{torch.cuda.max_memory_allocated() / 1024**2:.2f} MB" |
| |
| logger.info("Model is loaded and ready") |
| return { |
| "status": "ok", |
| "model_loaded": True, |
| "processing": processing_lock, |
| "memory": memory_info |
| } |
| except Exception as e: |
| error_msg = f"Error checking model status: {str(e)}" |
| logger.error(error_msg) |
| return {"status": "error", "model_loaded": False, "error": str(e)} |
|
|
| @app.get("/clear-memory/") |
| async def clear_memory(): |
| """Force clear memory and release resources.""" |
| try: |
| |
| gc.collect() |
| |
| |
| if torch.cuda.is_available(): |
| before = torch.cuda.memory_allocated() / 1024**2 |
| torch.cuda.empty_cache() |
| after = torch.cuda.memory_allocated() / 1024**2 |
| logger.info(f"CUDA memory cleared: {before:.2f} MB → {after:.2f} MB") |
| |
| |
| global processing_lock |
| was_locked = processing_lock |
| processing_lock = False |
| |
| return { |
| "status": "ok", |
| "message": "Memory cleared successfully", |
| "lock_released": was_locked |
| } |
| except Exception as e: |
| error_msg = f"Error clearing memory: {str(e)}" |
| logger.error(error_msg) |
| return {"status": "error", "error": str(e)} |
|
|
| @app.get("/diagnostics/") |
| async def diagnostics(): |
| """Get diagnostic information about the system.""" |
| try: |
| |
| import platform |
| import sys |
| import torch |
| import cv2 |
| import numpy as np |
| |
| |
| system_info = { |
| "platform": platform.platform(), |
| "python_version": sys.version, |
| "torch_version": torch.__version__, |
| "cuda_available": torch.cuda.is_available(), |
| "opencv_version": cv2.__version__, |
| "numpy_version": np.__version__, |
| } |
| |
| |
| if torch.cuda.is_available(): |
| system_info["cuda_version"] = torch.version.cuda |
| system_info["cuda_device_count"] = torch.cuda.device_count() |
| system_info["cuda_current_device"] = torch.cuda.current_device() |
| system_info["cuda_device_name"] = torch.cuda.get_device_name(0) |
| system_info["cuda_memory_allocated"] = f"{torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
| system_info["cuda_memory_reserved"] = f"{torch.cuda.memory_reserved() / 1024**2:.2f} MB" |
| |
| |
| if model is not None: |
| system_info["model_loaded"] = True |
| else: |
| system_info["model_loaded"] = False |
| |
| |
| if os.name == 'posix': |
| import shutil |
| total, used, free = shutil.disk_usage("/") |
| system_info["disk_total"] = f"{total // (2**30)} GB" |
| system_info["disk_used"] = f"{used // (2**30)} GB" |
| system_info["disk_free"] = f"{free // (2**30)} GB" |
| |
| return { |
| "status": "ok", |
| "system_info": system_info |
| } |
| except Exception as e: |
| error_msg = f"Error gathering diagnostics: {str(e)}" |
| logger.error(error_msg) |
| logger.error(traceback.format_exc()) |
| return {"status": "error", "error": str(e)} |
|
|
| def run_server(host="0.0.0.0", port=8001): |
| """Run the FastAPI server""" |
| uvicorn.run(app, host=host, port=port) |
|
|
| if __name__ == "__main__": |
| run_server() |