HATEBIN
>
import asyncio from io import BytesIO from PIL import Image, ImageFilter from diffusers import (LMSDiscreteScheduler, UNet2DConditionModel, AutoencoderKL) import torch from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer class Foo: def __init__(self): self.default_height = 512 self.default_width = 512 self.default_inference_steps = 50 self.default_guidance_scale = 7.5 self.negative_prompts = ["watermark"] self.model_name = "stabilityai/stable-diffusion-xl-base-0.9" self.tokenizer = CLIPTokenizer.from_pretrained(self.model_name, subfolder="tokenizer") self.vae = AutoencoderKL.from_pretrained( self.model_name, subfolder="vae").to("cuda") self.vae.enable_xformers_memory_efficient_attention() self.vae.enable_tiling() self.tokenizer = CLIPTokenizer.from_pretrained( self.model_name, subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained( self.model_name, subfolder="text_encoder").to("cuda") self.unet = UNet2DConditionModel.from_pretrained( self.model_name, subfolder="unet").to("cuda") self.unet.enable_xformers_memory_efficient_attention() self.scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) self.seed = torch.seed() self.generator = torch.manual_seed(self.seed) def latents_to_pil(self, latents): latents = (1 / 0.18215) * latents latents = latents.half() # convert to half precision with torch.no_grad(): image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images def convert_image_bytes64(self, image, compress_level): buffered = BytesIO() image.save(buffered, format="PNG", compress_level=compress_level) buffered.seek(0) return buffered def async_blocking_function_runner(self, func, *args, **kwargs): # creates another event loop in the other thread and runs func in it res = self.loop.create_task(func(*args, **kwargs)) return res def blocking_code(self): self.batch_size = 4 self.default_height = int(512) self.default_width = int(512) text_input = self.tokenizer(["Capybara holding a sword whilst wearing a knights costume, photo"] * self.batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") with torch.autocast("cuda"): text_embeddings = self.text_encoder( text_input.input_ids.to("cuda"))[0] max_length = text_input.input_ids.shape[-1] uncond_input = self.tokenizer( self.negative_prompts * self.batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) with torch.no_grad(): uncond_embeddings = self.text_encoder( uncond_input.input_ids.to("cuda"))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # Prep Scheduler self.scheduler.set_timesteps(self.default_inference_steps) # Prep latents latents = torch.randn( (self.batch_size, self.unet.in_channels, self.default_height // 8, self.default_width // 8), generator=self.generator, ) latents = latents.to("cuda") latents = latents * self.scheduler.init_noise_sigma num_updates = self.default_inference_steps // 10 + \ (self.default_inference_steps % 10 != 0) update_interval = self.default_inference_steps / num_updates for i, t in tqdm(enumerate(self.scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t) with torch.no_grad(): noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.default_guidance_scale * \ (noise_pred_text - noise_pred_uncond) latents_x0 = self.scheduler.step( noise_pred, t, latents).pred_original_sample latents = self.scheduler.step( noise_pred, t, latents).prev_sample if i % update_interval == 0 or i == self.default_inference_steps - 1: if (i < 10): continue im_t0 = self.latents_to_pil(latents_x0) im = self.image_grid(im_t0, 2, 2) blur_level = 10 - (i / self.default_inference_steps) * 10 if blur_level <= 1: blur_level = 0 im = im.filter(ImageFilter.GaussianBlur(radius=blur_level)) image = self.convert_image_bytes64(im, compress_level=0) self.loop.run_in_executor( None, self.async_blocking_function_runner, self.send_iteration_image, (image, i + 1)) return im def image_grid(self, imgs, rows, cols): w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols*w, i//cols*h)) return grid async def send_iteration_image(self, data): print("Not implemented yet...") def generate_image(self): self.loop = asyncio.get_event_loop() foo = self.blocking_code() foo.save("grid.png") foo = Foo() foo.generate_image()