| 2023-08-17 |
Flux.1 Direct link: /blog/fluxschnell/
Flux is a generative diffusion model from Blackforest Labs. Diffusion is a process by which particles spread out from a high concentration to a low concentration. The model uses a series of transformations to generate images, starting from a simple noise vector and gradually generating an image over some number of steps. Here is some output of the model:
Diffusion-kit provides a simple interface to run the model on a macbook,
diffusionkit-cli --prompt "prompt" \ --steps 3 --output-path "featured.png"
And in python:
from diffusionkit.mlx import FluxPipeline pipeline = FluxPipeline( shift=1.0, model_version="argmaxinc/mlx-FLUX.1-schnell", low_memory_mode=True, a16=True, w16=True, )
HEIGHT = 512 WIDTH = 512 NUM_STEPS = 4 # 4 for FLUX.1-schnell, 50 for SD3 CFG_WEIGHT = 0. # for FLUX.1-schnell, 5. for SD3
image, _ = pipeline.generate_image( "a photo of a cat", cfg_weight=CFG_WEIGHT, num_steps=NUM_STEPS, latent_size=(HEIGHT // 8, WIDTH // 8), )
image.save("image.png")
Generated Flux output image omitted in the migrated site because the original PNG asset was not present in source control.
Quite impressive. I thought diffusion would require more steps. In thermodynamics, diffusion is a slow process, the long-time limit of how much a system can change. Certainly, something exciting is going on under the hood. The model is based on an architecture published on “Scaling Rectified Flow Transformers for High-Resolution Image Synthesis” by Stability AI (paper).
Let’s look a bit deeper in the code:
The text is transformed into an embedding, which is used to condition the image generation. The function ‘encode_text’ is used to encode the text into an embedding.
conditioning, pooled_conditioning = self.encode_text( text, cfg_weight, negative_text ) mx.eval(conditioning) mx.eval(pooled_conditioning)
The diffusion steps take place inside the FluxPipeline class, in the generate_image method. The specific function inside this method is:
def denoise_latents( self, conditioning, pooled_conditioning, num_steps: int = 2, cfg_weight: float = 0.0, latent_size: Tuple[int] = (64, 64), seed=None, image_path: Optional[str] = None, denoise: float = 1.0, ): # Set the PRNG state seed = int(time.time()) if seed is None else seed logger.info(f"Seed: {seed}") mx.random.seed(seed) x_T = self.get_empty_latent(*latent_size) if image_path is None: denoise = 1.0 else: x_T = self.encode_image_to_latents(image_path, seed=seed) x_T = self.latent_format.process_in(x_T) noise = self.get_noise(seed, x_T) sigmas = self.get_sigmas(self.sampler, num_steps) sigmas = sigmas[int(num_steps * (1 - denoise)) :] extra_args = { "conditioning": conditioning, "cfg_weight": cfg_weight, "pooled_conditioning": pooled_conditioning, } noise_scaled = self.sampler.noise_scaling( sigmas[0], noise, x_T, self.max_denoise(sigmas) ) latent, iter_time = sample_euler( CFGDenoiser(self), noise_scaled, sigmas, extra_args=extra_args )
latent = self.latent_format.process_out(latent)
return latent, iter_time
The function ‘decode_latents_to_image’ is used to decode the latent representation to an image. It calls the decoder and then the clip module to get the final image. The code:
def decode_latents_to_image(self, x_t): x = self.decoder(x_t) x = mx.clip(x / 2 + 0.5, 0, 1) return x
The architecture
Architecture figure omitted in the migrated site because the original image asset was not present in source control.
class MMDiT(nn.Module): def __init__(self, config: MMDiTConfig): super().__init__() self.config = config
# Input adapters and embeddings self.x_embedder = LatentImageAdapter(config)
if config.pos_embed_type == PositionalEncoding.LearnedInputEmbedding: self.x_pos_embedder = LatentImagePositionalEmbedding(config) self.pre_sdpa_rope = nn.Identity() elif config.pos_embed_type == PositionalEncoding.PreSDPARope: self.pre_sdpa_rope = RoPE( theta=10000, axes_dim=config.rope_axes_dim, ) else: raise ValueError( f"Unsupported positional encoding type: {config.pos_embed_type}" )
self.y_embedder = PooledTextEmbeddingAdapter(config) self.t_embedder = TimestepAdapter(config) self.context_embedder = nn.Linear( config.token_level_text_embed_dim, config.hidden_size, )
self.multimodal_transformer_blocks = [ MultiModalTransformerBlock( config, skip_text_post_sdpa=(i == config.depth_multimodal - 1) and (config.depth_unified < 1), ) for i in range(config.depth_multimodal) ]
if config.depth_unified > 0: self.unified_transformer_blocks = [ UnifiedTransformerBlock(config) for _ in range(config.depth_unified) ]
self.final_layer = FinalLayer(config)
The forward pass:
def __call__( self, latent_image_embeddings: mx.array, token_level_text_embeddings: mx.array, timestep: mx.array, ) -> mx.array: batch, latent_height, latent_width, _ = latent_image_embeddings.shape token_level_text_embeddings = self.context_embedder(token_level_text_embeddings)
if hasattr(self, "x_pos_embedder"): latent_image_embeddings = self.x_embedder( latent_image_embeddings ) + self.x_pos_embedder(latent_image_embeddings) else: latent_image_embeddings = self.x_embedder(latent_image_embeddings)
latent_image_embeddings = latent_image_embeddings.reshape( batch, -1, 1, self.config.hidden_size )
if self.config.pos_embed_type == PositionalEncoding.PreSDPARope: positional_encodings = self.pre_sdpa_rope( text_sequence_length=token_level_text_embeddings.shape[1], latent_image_resolution=( latent_height // self.config.patch_size, latent_width // self.config.patch_size, ), ) else: positional_encodings = None
# MultiModalTransformer layers if self.config.depth_multimodal > 0: for bidx, block in enumerate(self.multimodal_transformer_blocks): latent_image_embeddings, token_level_text_embeddings = block( latent_image_embeddings, token_level_text_embeddings, timestep, positional_encodings=positional_encodings, )
# UnifiedTransformerBlock layers if self.config.depth_unified > 0: latent_unified_embeddings = mx.concatenate( (token_level_text_embeddings, latent_image_embeddings), axis=1 )
for bidx, block in enumerate(self.unified_transformer_blocks): latent_unified_embeddings = block( latent_unified_embeddings, timestep, positional_encodings=positional_encodings, )
latent_image_embeddings = latent_unified_embeddings[ :, token_level_text_embeddings.shape[1] :, ... ]
# Final layer latent_image_embeddings = self.final_layer( latent_image_embeddings, timestep, )
if self.config.patchify_via_reshape: latent_image_embeddings = self.x_embedder.unpack( latent_image_embeddings, (latent_height, latent_width) ) else: latent_image_embeddings = unpatchify( latent_image_embeddings, patch_size=self.config.patch_size, target_height=latent_height, target_width=latent_width, vae_latent_dim=self.config.vae_latent_dim, ) return latent_image_embeddings
The model is composed of several blocks. The first block is the input adapter and embeddings. The first loop is the multi-modal transformer blocks, and the second loop is the unified transformer blocks, before the final layer.
The multi-modal transformer block
class MultiModalTransformerBlock(nn.Module): def __init__(self, config: MMDiTConfig, skip_text_post_sdpa: bool = False): super().__init__() self.image_transformer_block = TransformerBlock(config) self.text_transformer_block = TransformerBlock( config, skip_post_sdpa=skip_text_post_sdpa )
sdpa_impl = mx.fast.scaled_dot_product_attention self.sdpa = partial(sdpa_impl)
self.config = config self.per_head_dim = config.hidden_size // config.num_heads
def __call__( self, latent_image_embeddings: mx.array, # latent image embeddings token_level_text_embeddings: mx.array, # token-level text embeddings timestep: mx.array, # pooled text embeddings + timestep embeddings positional_encodings: mx.array = None, # positional encodings for rope ): # Prepare multi-modal SDPA inputs image_intermediates = self.image_transformer_block.pre_sdpa( latent_image_embeddings, timestep=timestep, )
text_intermediates = self.text_transformer_block.pre_sdpa( token_level_text_embeddings, timestep=timestep, )
batch = latent_image_embeddings.shape[0]
def rearrange_for_sdpa(t): # Target data layout: (batch, head, seq_len, channel) return t.reshape( batch, -1, self.config.num_heads, self.per_head_dim ).transpose(0, 2, 1, 3)
if self.config.depth_unified > 0: multimodal_sdpa_inputs = { "q": mx.concatenate( [text_intermediates["q"], image_intermediates["q"]], axis=2 ), "k": mx.concatenate( [text_intermediates["k"], image_intermediates["k"]], axis=2 ), "v": mx.concatenate( [text_intermediates["v"], image_intermediates["v"]], axis=2 ), "scale": 1.0 / np.sqrt(self.per_head_dim), } else: multimodal_sdpa_inputs = { "q": rearrange_for_sdpa( mx.concatenate( [image_intermediates["q"], text_intermediates["q"]], axis=1 ) ), "k": rearrange_for_sdpa( mx.concatenate( [image_intermediates["k"], text_intermediates["k"]], axis=1 ) ), "v": rearrange_for_sdpa( mx.concatenate( [image_intermediates["v"], text_intermediates["v"]], axis=1 ) ), "scale": 1.0 / np.sqrt(self.per_head_dim), }
if self.config.pos_embed_type == PositionalEncoding.PreSDPARope: assert positional_encodings is not None multimodal_sdpa_inputs["q"] = RoPE.apply( multimodal_sdpa_inputs["q"], positional_encodings ) multimodal_sdpa_inputs["k"] = RoPE.apply( multimodal_sdpa_inputs["k"], positional_encodings )
if self.config.low_memory_mode: multimodal_sdpa_inputs[ "memory_efficient_threshold" ] = SDPA_FLASH_ATTN_THRESHOLD
# Compute multi-modal SDPA sdpa_outputs = ( self.sdpa(**multimodal_sdpa_inputs) .transpose(0, 2, 1, 3) .reshape(batch, -1, 1, self.config.hidden_size) )
# Split into image-text sequences for post-SDPA layers img_seq_len = latent_image_embeddings.shape[1] txt_seq_len = token_level_text_embeddings.shape[1]
if self.config.depth_unified > 0: text_sdpa_output = sdpa_outputs[:, :txt_seq_len, :, :] image_sdpa_output = sdpa_outputs[:, txt_seq_len:, :, :] else: image_sdpa_output = sdpa_outputs[:, :img_seq_len, :, :] text_sdpa_output = sdpa_outputs[:, -txt_seq_len:, :, :]
# Post-SDPA layers latent_image_embeddings = self.image_transformer_block.post_sdpa( residual=latent_image_embeddings, sdpa_output=image_sdpa_output, **image_intermediates, ) if self.text_transformer_block.skip_post_sdpa: # Text token related outputs from the final layer do not impact the model output token_level_text_embeddings = None else: token_level_text_embeddings = self.text_transformer_block.post_sdpa( residual=token_level_text_embeddings, sdpa_output=text_sdpa_output, **text_intermediates, )
return latent_image_embeddings, token_level_text_embeddings
The unified transformer block
class UnifiedTransformerBlock(nn.Module): def __init__(self, config: MMDiTConfig): super().__init__() self.transformer_block = TransformerBlock( config, num_modulation_params=3 if config.parallel_mlp_for_unified_blocks else 6, parallel_mlp=config.parallel_mlp_for_unified_blocks, )
sdpa_impl = mx.fast.scaled_dot_product_attention self.sdpa = partial(sdpa_impl)
self.config = config self.per_head_dim = config.hidden_size // config.num_heads
def __call__( self, latent_unified_embeddings: mx.array, # latent image embeddings timestep: mx.array, # pooled text embeddings + timestep embeddings positional_encodings: mx.array = None, # positional encodings for rope ): # Prepare multi-modal SDPA inputs intermediates = self.transformer_block.pre_sdpa( latent_unified_embeddings, timestep=timestep, )
batch = latent_unified_embeddings.shape[0]
def rearrange_for_sdpa(t): # Target data layout: (batch, head, seq_len, channel) return t.reshape( batch, -1, self.config.num_heads, self.per_head_dim ).transpose(0, 2, 1, 3)
multimodal_sdpa_inputs = { "q": intermediates["q"], "k": intermediates["k"], "v": intermediates["v"], "scale": 1.0 / np.sqrt(self.per_head_dim), }
if self.config.pos_embed_type == PositionalEncoding.PreSDPARope: assert positional_encodings is not None multimodal_sdpa_inputs["q"] = RoPE.apply( multimodal_sdpa_inputs["q"], positional_encodings ) multimodal_sdpa_inputs["k"] = RoPE.apply( multimodal_sdpa_inputs["k"], positional_encodings )
if self.config.low_memory_mode: multimodal_sdpa_inputs[ "memory_efficient_threshold" ] = SDPA_FLASH_ATTN_THRESHOLD
# Compute multi-modal SDPA sdpa_outputs = ( self.sdpa(**multimodal_sdpa_inputs) .transpose(0, 2, 1, 3) .reshape(batch, -1, 1, self.config.hidden_size) )
# o_proj and mlp.fc2 uses the same bias, remove mlp.fc2 bias self.transformer_block.mlp.fc2.bias = self.transformer_block.mlp.fc2.bias * 0.0
# Post-SDPA layers latent_unified_embeddings = self.transformer_block.post_sdpa( residual=latent_unified_embeddings, sdpa_output=sdpa_outputs, **intermediates, )
return latent_unified_embeddings
The final layer
The final transformation of the latent image embeddings.
class FinalLayer(nn.Module): def __init__(self, config: MMDiTConfig): super().__init__() self.norm_final = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.linear = nn.Linear( config.hidden_size, (config.patch_size**2) * config.vae_latent_dim, ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size), )
def __call__( self, latent_image_embeddings: mx.array, timestep: mx.array, ) -> mx.array: if timestep.size > 1: timestep = timestep[0] modulation_params = self._modulation_params[timestep.item()]
shift, residual_scale = mx.split(modulation_params, 2, axis=-1) latent_image_embeddings = affine_transform( latent_image_embeddings, shift=shift, residual_scale=residual_scale, norm_module=self.norm_final, ) return self.linear(latent_image_embeddings)
|