Spaces:
Runtime error
Runtime error
| from diffusers import StableDiffusionPipeline, UNet2DConditionModel | |
| import torch | |
| import copy | |
| import time | |
| ORIGINAL_CHECKPOINT_ID = "CompVis/stable-diffusion-v1-4" | |
| COMPRESSED_UNET_ID = "nota-ai/bk-sdm-small" | |
| DEVICE='cuda' | |
| # DEVICE='cpu' | |
| class SdmCompressionDemo: | |
| def __init__(self, device) -> None: | |
| self.device = device | |
| self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32 | |
| self.pipe_original = StableDiffusionPipeline.from_pretrained(ORIGINAL_CHECKPOINT_ID, | |
| torch_dtype=self.torch_dtype) | |
| self.pipe_compressed = copy.deepcopy(self.pipe_original) | |
| self.pipe_compressed.unet = UNet2DConditionModel.from_pretrained(COMPRESSED_UNET_ID, | |
| subfolder="unet", | |
| torch_dtype=self.torch_dtype) | |
| if 'cuda' in self.device: | |
| self.pipe_original = self.pipe_original.to(self.device) | |
| self.pipe_compressed = self.pipe_compressed.to(self.device) | |
| self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.' | |
| def _count_params(self, model): | |
| return sum(p.numel() for p in model.parameters()) | |
| def get_sdm_params(self, pipe): | |
| params_unet = self._count_params(pipe.unet) | |
| params_text_enc = self._count_params(pipe.text_encoder) | |
| params_image_dec = self._count_params(pipe.vae.decoder) | |
| params_total = params_unet + params_text_enc + params_image_dec | |
| return f"Total {(params_total/1e6):.1f}M (U-Net {(params_unet/1e6):.1f}M)" | |
| def generate_image(self, pipe, text, negative, guidance_scale, steps, seed): | |
| generator = torch.Generator(self.device).manual_seed(seed) | |
| start = time.time() | |
| result = pipe(text, negative_prompt = negative, generator = generator, | |
| guidance_scale = guidance_scale, num_inference_steps = steps) | |
| test_time = time.time() - start | |
| image = result.images[0] | |
| nsfw_detected = result.nsfw_content_detected[0] | |
| print(f"text {text} | Processed time: {test_time} sec | nsfw_flag {nsfw_detected}") | |
| print(f"negative {negative} | guidance_scale {guidance_scale} | steps {steps} ") | |
| print("===========") | |
| return image, nsfw_detected, format(test_time, ".2f") | |
| def error_msg(self, nsfw_detected): | |
| if nsfw_detected: | |
| return self.device_msg+" Black images are returned when potential harmful content is detected. Try different prompts or seeds." | |
| else: | |
| return self.device_msg | |
| def check_invalid_input(self, text): | |
| if text == '': | |
| return True | |
| def infer_original_model(self, text, negative, guidance_scale, steps, seed): | |
| print(f"=== ORIG model --- seed {seed}") | |
| if self.check_invalid_input(text): | |
| return None, "Please enter the input prompt.", None | |
| output_image, nsfw_detected, test_time = self.generate_image(self.pipe_original, | |
| text, negative, guidance_scale, steps, seed) | |
| return output_image, self.error_msg(nsfw_detected), test_time | |
| def infer_compressed_model(self, text, negative, guidance_scale, steps, seed): | |
| print(f"=== COMPRESSED model --- seed {seed}") | |
| if self.check_invalid_input(text): | |
| return None, "Please enter the input prompt.", None | |
| output_image, nsfw_detected, test_time = self.generate_image(self.pipe_compressed, | |
| text, negative, guidance_scale, steps, seed) | |
| return output_image, self.error_msg(nsfw_detected), test_time | |
| def get_example_list(self): | |
| return [ | |
| 'a tropical bird sitting on a branch of a tree', | |
| 'many decorative umbrellas hanging up', | |
| 'an orange cat staring off with pretty eyes', | |
| 'beautiful woman face with fancy makeup', | |
| 'a decorated living room with a stylish feel', | |
| 'a black vase holding a bouquet of roses', | |
| 'very elegant bedroom featuring natural wood', | |
| 'buffet-style food including cake and cheese', | |
| 'a tall castle sitting under a cloudy sky', | |
| 'closeup of a brown bear sitting in a grassy area', | |
| 'a large basket with many fresh vegetables', | |
| 'house being built with lots of wood', | |
| 'a close up of a pizza with several toppings', | |
| 'a golden vase with many different flows', | |
| 'a statue of a lion face attached to brick wall', | |
| 'something that looks particularly interesting', | |
| 'table filled with a variety of different dishes', | |
| 'a cinematic view of a large snowy peak', | |
| 'a grand city in the year 2100, hyper realistic', | |
| 'a blue eyed baby girl looking at the camera', | |
| ] | |