-
Notifications
You must be signed in to change notification settings - Fork 6.7k
LTX2 distilled checkpoint support #12934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Running results output.mp4Converted ckpt ltx2_sample.mp4 |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for shipping this so quickly! Left some comments, LMK if they make sense.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Left some comments about the distilled sigmas schedule.
If I print out the timesteps for the Stage 1 distilled pipeline, I get (for commit faeccc5):
Distilled timesteps: tensor([1000.0000, 999.6502, 999.2961, 998.9380, 998.5754, 994.4882,
979.3755, 929.6974, 100.0000], device='cuda:0')Here the sigmas (and thus the timesteps) are shifted toward a terminal value of 0.1, and use_dynamic_shifting is applied as well. However, I believe the distilled sigmas are used as-is in the original LTX 2.0 code:
So I think when creating the distilled scheduler we need to disable use_dynamic_shifting and shift_terminal so that the distilled sigmas are used without changes.
Can you check whether the final distilled sigmas match up with those of the original implementation?
|
The original test script didn't work for me, but I was able to get a working version as follows: import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video
device = "cuda:0"
width = 768
height = 512
pipe = LTX2Pipeline.from_pretrained(
"rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload(device=device)
prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
frame_rate = 24.0
video_latent, audio_latent = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=8,
sigmas=DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
output_type="latent",
return_dict=False,
)
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
"rootonchair/LTX-2-19b-distilled",
subfolder="upsample_pipeline/latent_upsampler",
torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
latents=video_latent,
output_type="latent",
return_dict=False,
)[0]
video, audio = pipe(
latents=upscaled_video_latent,
audio_latents=audio_latent,
prompt=prompt,
negative_prompt=negative_prompt,
width=width * 2,
height=height * 2,
num_inference_steps=3,
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_distilled_sample.mp4",
)The necessary changes were to create |
Co-authored-by: dg845 <[email protected]>
|
Not jeopardizing this PR at all but while we're at the two-stage pipeline stuff, it could also be cool to verify it with the distilled LoRA that we have in place (PR already merged: #12933). So, what we would do is:
Once we're close to merging the PR, we could document all of these to inform the community. |
|
@dg845 thank you for your detail reviews. Let's me take a closer look on that |
That sounds interesting. Let's have a quick test on two stage distilled LoRA too |
|
For two-stage inference with the Stage 2 distilled LoRA, I think this script should work: import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video
device = "cuda:0"
width = 768
height = 512
pipe = LTX2Pipeline.from_pretrained(
"rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16
)
# This scheduler should use distilled sigmas without any changes
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
)
pipe.scheduler = new_scheduler
pipe.enable_model_cpu_offload(device=device)
prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
frame_rate = 24.0
video_latent, audio_latent = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=8,
sigmas=DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
output_type="latent",
return_dict=False,
)
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
"rootonchair/LTX-2-19b-distilled",
subfolder="upsample_pipeline/latent_upsampler",
torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
latents=video_latent,
output_type="latent",
return_dict=False,
)[0]
# Load Stage 2 distilled LoRA
pipe.load_lora_weights(
"Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
)
pipe.set_adapters("stage_2_distilled", 1.0)
# VAE tiling seems necessary to avoid OOM error when VAE decoding
pipe.vae.enable_tiling()
video, audio = pipe(
latents=upscaled_video_latent,
audio_latents=audio_latent,
prompt=prompt,
negative_prompt=negative_prompt,
width=width * 2,
height=height * 2,
num_inference_steps=3,
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_distilled_sample.mp4",
)Sample with LoRA and scheduler fix: ltx2_distilled_sample_lora_fix.mp4 |
I think we should run with the original LTX2 weight and not the distilled checkpoint. WDYT? |
|
Yes, the first stage, in this case, should use the non-distilled ckpt. |
|
Fixed script (I think): import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video
device = "cuda:0"
width = 768
height = 512
pipe = LTX2Pipeline.from_pretrained(
"Lightricks/LTX-2", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload(device=device)
prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
# Stage 1 default (non-distilled) inference
frame_rate = 24.0
video_latent, audio_latent = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=40,
sigmas=None,
guidance_scale=4.0,
output_type="latent",
return_dict=False,
)
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
"Lightricks/LTX-2",
subfolder="latent_upsampler",
torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
latents=video_latent,
output_type="latent",
return_dict=False,
)[0]
# Load Stage 2 distilled LoRA
pipe.load_lora_weights(
"Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
)
pipe.set_adapters("stage_2_distilled", 1.0)
# VAE tiling seems necessary to avoid OOM error when VAE decoding
pipe.vae.enable_tiling()
# Change scheduler to use Stage 2 distilled sigmas as is
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
)
pipe.scheduler = new_scheduler
# Stage 2 inference with distilled LoRA and sigmas
video, audio = pipe(
latents=upscaled_video_latent,
audio_latents=audio_latent,
prompt=prompt,
negative_prompt=negative_prompt,
width=width * 2,
height=height * 2,
num_inference_steps=3,
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_distilled_sample.mp4",
) |
|
If I test the distilled pipeline with the prompt ltx2_distilled_sample_dog_edm.mp4I would expect the audio to be music for this prompt, but instead the audio is only noise, so I think there might be something wrong with the way audio is currently being handled in the distilled pipeline. (The video also doesn't follow the prompt closely; I'm not sure if this is a symptom of the audio being messed up or if there are also bugs for video processing.) |
|
@dg845 should the second stage inference with LoRA be run with 4 width=width * 2,
height=height * 2, |
I believe it should be run with 3, as
It is currently necessary as otherwise we'd get a shape error, but we could modify the code to infer the diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py Lines 922 to 925 in 3c70440
|
…into feat/distill-ltx2
@tin2tin you can use it with any LTX2 weight and not limited to distilled weight |
|
I have updated the document and the test. I'm not sure if the test is enough or we need to add more for all the new params. If all good, let's check for two-stages lora generation result (still downloading the original repo) before merging |
Co-authored-by: dg845 <[email protected]>
Co-authored-by: dg845 <[email protected]>
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this! Only a few nits to go before merging.
| pipe = LTX2Pipeline.from_pretrained( | ||
| "Lightricks/LTX-2", torch_dtype=torch.bfloat16 | ||
| ) | ||
| pipe.enable_sequential_cpu_offload(device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| pipe.enable_sequential_cpu_offload(device=device) | |
| pipe.enable_model_cpu_offload(device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to interfere, but I've tested this quite a lot (rtx 4090), and using enable_sequential_cpu_offload is actually a very good default, considering how big this model is, since it makes it possible to run locally.
| pipe = LTX2Pipeline.from_pretrained( | ||
| model_path, torch_dtype=torch.bfloat16 | ||
| ) | ||
| pipe.enable_sequential_cpu_offload(device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| pipe.enable_sequential_cpu_offload(device=device) | |
| pipe.enable_model_cpu_offload(device=device) |
| height = 512 | ||
| random_seed = 42 | ||
| generator = torch.Generator(device).manual_seed(random_seed) | ||
| model_path = "rootonchair/LTX-2-19b-distilled" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dg845 let's get these transferred to the Lightricks org. Would you be able to check internally?
Removed one-stage generation example code and added comments for noise scale in two-stage generation.
|
Merging as the CI failures are unrelated to the PR. |
|
@rootonchair thank you so much for your contributions! You are now officially an MVP. Please use this link to create your certificate. Please also let us know your HF profile ID so that we can provide some credits to it. Looking forward to collaborating more with you! |
|
Thank you @sayakpaul. I really appreciate having the chance to work on Diffusers MVP. I'm eager to contribute more so hope that we will have more collaboration soon 😄 |
|
@rootonchair you should have access to HF Pro membership for 6 months and some credits within your account :) |
|
@sayakpaul awesome ❤️ |
What does this PR do?
Fixes #12925
Test script t2i
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.