| from constants import ( |
| TAESD_MODEL, |
| TAESDXL_MODEL, |
| TAESD_MODEL_OPENVINO, |
| TAESDXL_MODEL_OPENVINO, |
| ) |
|
|
|
|
| def get_tiny_decoder_vae_model(pipeline_class) -> str: |
| print(f"Pipeline class : {pipeline_class}") |
| if ( |
| pipeline_class == "LatentConsistencyModelPipeline" |
| or pipeline_class == "StableDiffusionPipeline" |
| or pipeline_class == "StableDiffusionImg2ImgPipeline" |
| or pipeline_class == "StableDiffusionControlNetPipeline" |
| or pipeline_class == "StableDiffusionControlNetImg2ImgPipeline" |
| ): |
| return TAESD_MODEL |
| elif ( |
| pipeline_class == "StableDiffusionXLPipeline" |
| or pipeline_class == "StableDiffusionXLImg2ImgPipeline" |
| ): |
| return TAESDXL_MODEL |
| elif ( |
| pipeline_class == "OVStableDiffusionPipeline" |
| or pipeline_class == "OVStableDiffusionImg2ImgPipeline" |
| ): |
| return TAESD_MODEL_OPENVINO |
| elif pipeline_class == "OVStableDiffusionXLPipeline": |
| return TAESDXL_MODEL_OPENVINO |
| else: |
| raise Exception("No valid pipeline class found!") |
|
|