-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SDXL] Add SDXL pipeline to SHARK #1941
Conversation
adf077d
to
c94081c
Compare
Cool, can't wait to try this. Let me run it quick and review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried this out and had a few notes. PTAL
Also, we will eventually need this on top of @aviator19941 's SD implementation in Turbine. The implementation may be slightly cleaner that way. Either way I've suggested a few changes that can make this a little more user-friendly.
Try it out yourself with a custom model, switch between XL models and SD1.5 models, stress test it a bit... and put cushions on the rough edges and this should be ready to go.
Happy to chat and set up meetings about my comments and/or the impending Turbine integration
apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py
Outdated
Show resolved
Hide resolved
c94081c
to
d5823a6
Compare
Hi @monorimet As discussed yesterday, I have addressed the following in the latest push to this revision :-
It is now ready to be re-reviewed. For CLI : |
Awesome. Thank you @Abhishek-Varma. I'll give it a try later today. |
Cool, seems to work. What needs done to have support for custom SDXL models/loras? We can merge once lint is fixed |
d548683
to
ee2e213
Compare
Hi @monorimet - I've addressed your review comments. Regarding the custom SDXL/Loras - someone implementing that can follow the same approach used for non-SDXL models with the assumption that the HF repo contains all the sub-models and the base model being used is stabilityai/stable-diffusion-xl-base-1.0 (since that's what we use in |
ee2e213
to
560350a
Compare
# We need to scale down the height/width by vae_scale_factor, which | ||
# happens to be 8 in this case. | ||
self.height = height // 8 | ||
self.width = width // 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems unintentional. The original SD needs height // 8
and width // 8
and this changes the default while assigning the old behavior to SDXL.
If I try SD1.4 from scratch on this commit, I get a 4TB CPU allocation error. If I add not
to if is_sdxl
on line 170, the problem seems fixed for SD1.5/2.1 but now the SDXL VAE issue starts to make more sense...
I'll update after a bit more debugging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, it looks like the vae_scale_factor thing is baked in already for SD1.X / 2.1 so no need to change it for SDXL.. unless i'm missing something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps following would help clear why // 8
(vae_scale_factor
) needs to be done for SDXL :-
- User only mentions the final image size to be generated, be it SD1.X/2.1/XL.
- In case of SDXL we currently only deal with 1024x1024 image generation.
- But the inputs to UNET/VAE for a 1024x1024 image needs to be 128x128. And this is nothing but
1024 // 8
. - This is something I observed in the normal Python execution by printing the input sizes expected by individual models right before their "forward" method is invoked.
- Since the inputs to each model is worked out currently via
base_model.json
and for SD1.X/2.1 it substitutes the final height/width - but for SDXL it needs to be the scaled down version (as seen in point 3 above). - So, I added that and guarded it with
if is_sdxl
flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before this commit, the SD1.X / 2.1 models already do height // 8
and width // 8
here. This commit changes the behavior for those models at this point to height
and width
(introduces the CPUAllocator issue) and then defines a special case for XL with the old default behavior. Unless we are reworking the pipeline for SD1.X/2.1 let's keep the non-SDXL path as it was beforehand, to reduce breakages
-- This commit adds SDXL pipeline to SHARK. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
560350a
to
7a6d096
Compare
* [SDXL] Add SDXL pipeline to SHARK -- This commit adds SDXL pipeline to SHARK. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com> * (SDXL) Fix --ondemand and vae scale factor use, and fix VAE flags. --------- Signed-off-by: Abhishek Varma <abhishek@nod-labs.com> Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
-- This commit adds SDXL pipeline to SHARK.
Signed-off-by: Abhishek Varma abhishek@nod-labs.com