-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable Cascade? #95
Comments
This generates a recognizable image, though given the quality of the image, there is definitely something missing from the equation somewhere. Though it seems somewhat possible. import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
device = 'cuda'
prompt = "an image of a (shiba inu)1.5 donning a spacesuit++"
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant='bf16',
torch_dtype=torch.bfloat16).to(device)
prior_compel = Compel(tokenizer=prior.tokenizer,
text_encoder=prior.text_encoder,
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=True, device=device)
conditioning, pooled = prior_compel(prompt)
prior_output = prior(
num_inference_steps=20,
guidance_scale=4,
prompt_embeds=conditioning,
prompt_embeds_pooled=pooled.unsqueeze(1))
prior.to('cpu')
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant='bf16',
torch_dtype=torch.float16).to(device)
decoder_compel = Compel(tokenizer=decoder.tokenizer,
text_encoder=decoder.text_encoder,
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=True,
device=device)
conditioning, pooled = decoder_compel(prompt)
decoder(num_inference_steps=10,
guidance_scale=0.0,
prompt_embeds=conditioning,
prompt_embeds_pooled=pooled.unsqueeze(1),
image_embeddings=prior_output.image_embeddings.half()).images[0].save('test.png') |
Hi @Teriks, have you resolved the issue using the prior + decoder setup in the snippet? |
you might want to confirm if this |
The embeddings provider probably needs some alternate logic to handle Stable Cascade. I decided to sit down and mess with it a little, I think it needs something like this. You would probably need to implement a new Here is a monkey patch demo that produces a decent quality image. I might have time for a PR next week, though it would be very simple to add. from typing import *
import torch
import compel
from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
class SCascadeEmbeddingsProvider(compel.EmbeddingsProvider):
def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
text_encoder_output = self.text_encoder(token_ids,
attention_mask,
output_hidden_states=True,
return_dict=True)
return text_encoder_output.hidden_states[-1]
def get_pooled_embeddings(self, texts: List[str], attention_mask: Optional[torch.Tensor] = None,
device: Optional[str] = None) -> Optional[torch.Tensor]:
device = device or self.device
token_ids = self.get_token_ids(texts, padding="max_length", truncation_override=True)
token_ids = torch.tensor(token_ids, dtype=torch.long).to(device)
text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True)
pooled = text_encoder_output.text_embeds
return pooled.unsqueeze(1)
# monkey patch in the correct behavior for this example
def patch_compel(compel_obj: compel.Compel):
compel_obj.conditioning_provider.__class__ = SCascadeEmbeddingsProvider
# Do generation
device = 'cuda'
prompt = "an image of a shiba inu with (blue eyes)1.4, donning a green+ spacesuit"
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant='bf16',
torch_dtype=torch.bfloat16).to(device)
prior_compel = compel.Compel(tokenizer=prior.tokenizer,
text_encoder=prior.text_encoder,
requires_pooled=True, device=device)
# patch prior
patch_compel(prior_compel)
conditioning, pooled = prior_compel(prompt)
prior_output = prior(
num_inference_steps=20,
guidance_scale=4,
prompt_embeds=conditioning,
prompt_embeds_pooled=pooled)
prior.to('cpu')
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant='bf16',
torch_dtype=torch.float16).to(device)
decoder_compel = compel.Compel(tokenizer=decoder.tokenizer,
text_encoder=decoder.text_encoder,
requires_pooled=True, device=device)
# patch decoder
patch_compel(decoder_compel)
conditioning, pooled = decoder_compel(prompt)
image = decoder(
num_inference_steps=10,
guidance_scale=0.0,
prompt_embeds=conditioning,
prompt_embeds_pooled=pooled,
image_embeddings=prior_output.image_embeddings.half()).images[0]
image.save('test.png')
decoder.to('cpu') Example Result: Prompt: |
Stable Cascade support, new ReturnedEmbeddingsType #104 |
It seems like it might be possible for this to work with stable cascade?
I am wondering if there is a working snippet for prior + decoder or if it is incompatible at the moment.
The text was updated successfully, but these errors were encountered: