forked from kandinskylab/kandinsky-5
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwan2_generate_video.py
More file actions
5319 lines (4484 loc) · 254 KB
/
wan2_generate_video.py
File metadata and controls
5319 lines (4484 loc) · 254 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import argparse
from datetime import datetime
import gc
import random
import os
import re
import time
import math
from typing import Tuple, Optional, List, Union, Any
from pathlib import Path # Added for glob_images in V2V
# Set PyTorch CUDA allocator to reduce memory fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import torch
import accelerate
from accelerate import Accelerator
from functools import partial
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from PIL import Image
import cv2 # Added for V2V video loading/resizing
import numpy as np # Added for V2V video processing
import torchvision.transforms.functional as TF
import torchvision
from tqdm import tqdm
from networks import lora_wan
from utils.safetensors_utils import mem_eff_save_file, load_safetensors
from utils.lora_utils import filter_lora_state_dict
from Wan2_2.wan.configs import WAN_CONFIGS, SUPPORTED_SIZES
import wan
from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype
from wan.modules.vae import WanVAE
from Wan2_2.wan.modules.vae2_2 import Wan2_2_VAE
from wan.modules.t5 import T5EncoderModel
from wan.modules.clip import CLIPModel
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.utils.fm_solvers_euler import EulerScheduler
from wan.utils.step_distill_scheduler import StepDistillScheduler
from blissful_tuner.latent_preview import LatentPreviewer
try:
from lycoris.kohya import create_network_from_weights
except:
pass
from utils.model_utils import str_to_dtype
from utils.device_utils import clean_memory_on_device
# Context Windows imports
try:
from Wan2_2.context_windows import (
WanContextWindowsHandler,
IndexListContextHandler,
ContextSchedules,
ContextFuseMethods,
)
CONTEXT_WINDOWS_AVAILABLE = True
except ImportError:
CONTEXT_WINDOWS_AVAILABLE = False
logger = logging.getLogger(__name__)
logger.info("Context windows module not available. Install required dependencies to enable context windows.")
# Local implementations to avoid xformers/flash-attention dependency
import av
import cv2
import glob
from PIL import Image
from einops import rearrange
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
def synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
def glob_images(directory, base="*"):
img_paths = []
for ext in IMAGE_EXTENSIONS:
if base == "*":
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
img_paths = list(set(img_paths)) # remove duplicates
img_paths.sort()
return img_paths
def resize_image_to_bucket(image, bucket_reso):
is_pil_image = isinstance(image, Image.Image)
if is_pil_image:
image_width, image_height = image.size
else:
image_height, image_width = image.shape[:2]
if bucket_reso == (image_width, image_height):
return np.array(image) if is_pil_image else image
bucket_width, bucket_height = bucket_reso
if bucket_width == image_width or bucket_height == image_height:
image = np.array(image) if is_pil_image else image
else:
# resize the image to the bucket resolution to match the short side
scale_width = bucket_width / image_width
scale_height = bucket_height / image_height
scale = max(scale_width, scale_height)
image_width = int(image_width * scale + 0.5)
image_height = int(image_height * scale + 0.5)
if scale > 1:
image = Image.fromarray(image) if not is_pil_image else image
image = image.resize((image_width, image_height), Image.LANCZOS)
image = np.array(image)
else:
image = np.array(image) if is_pil_image else image
image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
# crop the image to the bucket resolution
crop_left = (image_width - bucket_width) // 2
crop_top = (image_height - bucket_height) // 2
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
return image
def hv_load_images(image_dir, video_length, bucket_reso):
image_files = glob_images(image_dir)
if len(image_files) == 0:
raise ValueError(f"No image files found in {image_dir}")
if len(image_files) < video_length:
raise ValueError(f"Number of images in {image_dir} is less than {video_length}")
image_files.sort()
images = []
for image_file in image_files[:video_length]:
image = Image.open(image_file)
image = resize_image_to_bucket(image, bucket_reso) # returns a numpy array
images.append(image)
return images
def hv_load_video(video_path, start_frame, end_frame, bucket_reso):
container = av.open(video_path)
video = []
for i, frame in enumerate(container.decode(video=0)):
if start_frame is not None and i < start_frame:
continue
if end_frame is not None and i >= end_frame:
break
frame = frame.to_image()
if bucket_reso is not None:
frame = resize_image_to_bucket(frame, bucket_reso)
else:
frame = np.array(frame)
video.append(frame)
container.close()
return video
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
from einops import rearrange # Local import to avoid scope issues
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = torch.clamp(x, 0, 1)
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
height, width, _ = outputs[0].shape
# create output container
container = av.open(path, mode="w")
# create video stream
codec = "libx264"
pixel_format = "yuv420p"
stream = container.add_stream(codec, rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = pixel_format
stream.bit_rate = 4000000 # 4Mbit/s
for frame_array in outputs:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
packets = stream.encode(frame)
for packet in packets:
container.mux(packet)
for packet in stream.encode():
container.mux(packet)
container.close()
def save_images_grid(videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, save_individually=True):
from einops import rearrange # Local import to avoid scope issues
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = torch.clamp(x, 0, 1)
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
if save_individually:
output_dir = os.path.join(parent_dir, image_name)
else:
output_dir = parent_dir
os.makedirs(output_dir, exist_ok=True)
for i, x in enumerate(outputs):
image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png")
image = Image.fromarray(x)
image.save(image_path)
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def parse_args() -> argparse.Namespace:
"""parse command line arguments"""
parser = argparse.ArgumentParser(description="Wan 2.2 inference script with new model architecture support")
# WAN arguments
parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).")
parser.add_argument("--task", type=str, default="t2v-A14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.")
parser.add_argument(
"--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla", "euler", "step_distill"], help="The solver used to sample."
)
parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path")
parser.add_argument("--dit_low_noise", type=str, default=None, help="DiT low noise checkpoint path (for dual-dit models)")
parser.add_argument("--dit_high_noise", type=str, default=None, help="DiT high noise checkpoint path (for dual-dit models)")
parser.add_argument("--dual_dit_boundary", type=float, default=None, help="Override boundary for dual-dit models (0.0-1.0). Low noise model used above after threshold. Default: 0.875 for t2v-A14B, 0.900 for i2v-A14B")
parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path")
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16")
parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path")
parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path")
# LoRA
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
# LoRA for high noise model (dual-dit models only)
parser.add_argument("--lora_weight_high", type=str, nargs="*", required=False, default=None,
help="LoRA weight path for high noise model (dual-dit models only)")
parser.add_argument("--lora_multiplier_high", type=float, nargs="*", default=1.0,
help="LoRA multiplier for high noise model")
parser.add_argument("--include_patterns_high", type=str, nargs="*", default=None, help="LoRA module include patterns for high noise model")
parser.add_argument("--exclude_patterns_high", type=str, nargs="*", default=None, help="LoRA module exclude patterns for high noise model")
parser.add_argument(
"--save_merged_model",
type=str,
default=None,
help="Save merged model to path. If specified, no inference will be performed.",
)
# inference
parser.add_argument("--prompt", type=str, required=True, help="prompt for generation")
parser.add_argument(
"--negative_prompt",
type=str,
default=None,
help="negative prompt for generation, use default negative prompt if not specified",
)
parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width")
parser.add_argument("--video_length", type=int, default=None, help="video length, Default depends on task")
parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16")
parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps")
parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
parser.add_argument(
"--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with trash). Default is False."
)
parser.add_argument(
"--guidance_scale",
type=float,
default=5.0,
help="Guidance scale for classifier free guidance. Default is 5.0.",
)
# V2V arguments
parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference (standard Wan V2V)")
parser.add_argument("--strength", type=float, default=0.75, help="Strength for video2video inference (0.0-1.0)")
parser.add_argument("--v2v_low_noise_only", action="store_true", help="For V2V with dual-dit models, use only the low noise model")
parser.add_argument(
"--v2v_use_i2v", action="store_true",
help="Use i2v model for V2V (extracts first frame for CLIP conditioning). Recommended for i2v-A14B."
)
# I2V arguments
parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference")
parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference")
# Fun-Control arguments (NEW/MODIFIED)
parser.add_argument(
"--control_path", # Keep this argument name
type=str,
default=None,
help="path to control video for inference with Fun-Control model. video file or directory with images",
)
parser.add_argument(
"--control_start",
type=float,
default=0.0,
help="Start point (0.0-1.0) in the timeline where control influence is full (after fade-in)",
)
parser.add_argument(
"--control_end",
type=float,
default=1.0,
help="End point (0.0-1.0) in the timeline where control influence starts to fade out",
)
parser.add_argument(
"--control_falloff_percentage", # NEW name
type=float,
default=0.3,
help="Falloff percentage (0.0-0.49) for smooth transitions at start/end of control influence region",
)
parser.add_argument(
"--control_weight", # NEW name
type=float,
default=1.0,
help="Overall weight/strength of control video influence for Fun-Control (0.0 to high values)",
)
parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving")
parser.add_argument(
"--cfg_skip_mode",
type=str,
default="none",
choices=["early", "late", "middle", "early_late", "alternate", "none"],
help="CFG skip mode. each mode skips different parts of the CFG. "
" early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)",
)
parser.add_argument(
"--cfg_apply_ratio",
type=float,
default=None,
help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).",
)
parser.add_argument(
"--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated"
)
parser.add_argument(
"--slg_scale",
type=float,
default=3.0,
help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond",
)
parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.")
parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.")
parser.add_argument(
"--slg_mode",
type=str,
default=None,
choices=["original", "uncond"],
help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred",
)
# Flow Matching
parser.add_argument(
"--flow_shift",
type=float,
default=None,
help="Shift factor for flow matching schedulers. Default depends on task.",
)
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
parser.add_argument("--mixed_dtype", action="store_true", help="use model with mixed weight dtypes (preserves original dtypes, e.g. mixed fp16/fp32)")
parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
parser.add_argument(
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
)
parser.add_argument(
"--attn_mode",
type=str,
default="torch",
choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"],
help="attention mode",
)
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
parser.add_argument(
"--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
parser.add_argument(
"--compile_args",
nargs=4,
metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
help="Torch.compile settings",
)
parser.add_argument("--preview", type=int, default=None, metavar="N",
help="Enable latent preview every N steps. Generates previews in 'previews' subdirectory.",
)
parser.add_argument("--preview_suffix", type=str, default=None,
help="Unique suffix for preview files to avoid conflicts in concurrent runs.",
)
# Video extension arguments (multitalk-style)
parser.add_argument("--extend_video", type=str, default=None, help="Path to video to extend using multitalk-style iterative generation")
parser.add_argument("--extend_frames", type=int, default=200, help="Total number of frames to generate when extending video")
parser.add_argument("--frames_to_check", type=int, default=30, help="Number of frames from the end to analyze for best transition point (clean i2v-based extension)")
parser.add_argument("--motion_frames", type=int, default=25, help="Number of frames to use for motion conditioning in each chunk")
# Model selection for extension
parser.add_argument("--force_low_noise", action="store_true", help="Force use of low noise model for video extension")
parser.add_argument("--force_high_noise", action="store_true", help="Force use of high noise model for video extension")
parser.add_argument("--extension_dual_dit_boundary", type=float, default=None, help="Custom dual-dit boundary for video extension (0.0-1.0). Overrides force_low_noise/force_high_noise")
# Latent injection timing controls
parser.add_argument("--inject_motion_timesteps", type=str, default="all", choices=["all", "high_only", "low_only", "none"],
help="When to inject motion frames: 'all'=every timestep, 'high_only'=high noise timesteps only, 'low_only'=low noise timesteps only, 'none'=no injection")
parser.add_argument("--injection_strength", type=float, default=1.0, help="Strength of motion frame injection (0.0-1.0, 1.0=full replacement)")
parser.add_argument("--motion_noise_ratio", type=float, default=0.3,
help="Noise ratio for motion frames in extension (0.0-1.0, lower=less noise/more preservation)")
parser.add_argument("--color_match", type=str, default="hm",
choices=["disabled", "hm", "mkl", "reinhard", "mvgd", "hm-mvgd-hm", "hm-mkl-hm"],
help="Color matching method for video extension (default: histogram matching)")
# Context Windows Arguments
parser.add_argument("--use_context_windows", action="store_true",
help="Enable sliding context windows for long video generation")
parser.add_argument("--context_length", type=int, default=81,
help="Length of context window in frames (default: 81)")
parser.add_argument("--context_overlap", type=int, default=30,
help="Overlap between context windows in frames (default: 30)")
parser.add_argument("--context_schedule", type=str, default="standard_static",
choices=["standard_static", "standard_uniform", "looped_uniform", "batched"],
help="Context window scheduling method (default: standard_static)")
parser.add_argument("--context_stride", type=int, default=1,
help="Stride for uniform context schedules (default: 1)")
parser.add_argument("--context_closed_loop", action="store_true",
help="Enable closed loop for cyclic videos")
parser.add_argument("--context_fuse_method", type=str, default="pyramid",
choices=["pyramid", "flat", "overlap-linear", "relative"],
help="Method for fusing context window results (default: pyramid)")
parser.add_argument("--context_dim", type=int, default=2,
help="Dimension to apply context windows (2=temporal for video, default: 2)")
args = parser.parse_args()
assert (args.latent_path is None or len(args.latent_path) == 0) or (
args.output_type == "images" or args.output_type == "video"
), "latent_path is only supported for images or video output"
# Add checks for mutually exclusive arguments
if args.video_path is not None and args.image_path is not None and not args.v2v_use_i2v:
raise ValueError("--video_path and --image_path cannot be used together unless --v2v_use_i2v is specified.")
if args.v2v_use_i2v and args.video_path is None:
raise ValueError("--v2v_use_i2v requires --video_path to be specified.")
if args.v2v_use_i2v and "i2v" not in args.task:
logger.warning("--v2v_use_i2v is recommended for i2v models. Current task: %s", args.task)
if args.video_path is not None and args.control_path is not None:
raise ValueError("--video_path (standard V2V) and --control_path (Fun-Control) cannot be used together.")
if args.image_path is not None and "t2v" in args.task:
logger.warning("--image_path is provided, but task is set to t2v. Task type does not directly affect I2V mode.")
if args.control_path is not None and not WAN_CONFIGS[args.task].is_fun_control:
raise ValueError("--control_path is provided, but the selected task does not support Fun-Control.")
if not (0.0 <= args.control_falloff_percentage <= 0.49):
raise ValueError("--control_falloff_percentage must be between 0.0 and 0.49")
if args.mixed_dtype and args.fp8:
raise ValueError("--mixed_dtype and --fp8 cannot be used together")
if args.mixed_dtype and args.fp8_scaled:
raise ValueError("--mixed_dtype and --fp8_scaled cannot be used together")
if args.mixed_dtype and args.lora_weight:
logger.warning("--mixed_dtype with LoRA: LoRA weights will be merged at the model's original precision")
if args.task == "i2v-14B-FC-1.1" and args.image_path is None:
logger.warning(f"Task '{args.task}' typically uses --image_path as the reference image for ref_conv. Proceeding without it.")
return args
class DynamicModelManager:
"""Manages dynamic loading and unloading of models during inference."""
def __init__(self, config, device, dit_dtype, dit_weight_dtype, args):
self.config = config
self.device = device
self.dit_dtype = dit_dtype
self.dit_weight_dtype = dit_weight_dtype
self.args = args
self.current_model = None
self.current_model_type = None # 'low' or 'high'
self.model_paths = {}
self.lora_weights_list_low = None
self.lora_multipliers_low = None
self.lora_weights_list_high = None
self.lora_multipliers_high = None
def has_model_loaded(self):
"""Check if any model is currently loaded."""
return self.current_model is not None
def set_model_paths(self, low_path: str, high_path: str):
"""Set the paths for low and high noise models."""
self.model_paths['low'] = low_path
self.model_paths['high'] = high_path
def set_lora_weights(self, lora_weights_list_low, lora_multipliers_low,
lora_weights_list_high, lora_multipliers_high):
"""Save LoRA weights to apply to dynamically loaded models."""
self.lora_weights_list_low = lora_weights_list_low
self.lora_multipliers_low = lora_multipliers_low
self.lora_weights_list_high = lora_weights_list_high
self.lora_multipliers_high = lora_multipliers_high
def get_model(self, model_type: str) -> WanModel:
"""Load the requested model if not already loaded."""
if self.current_model_type == model_type:
return self.current_model
# Unload current model if exists
if self.current_model is not None:
logger.info(f"Unloading {self.current_model_type} noise model (GPU + CPU blocks)...")
# Get memory usage before unloading (GPU + CPU monitoring)
if torch.cuda.is_available():
memory_before = torch.cuda.memory_allocated(self.device) / 1024**3
logger.info(f"GPU memory before unload: {memory_before:.2f} GB")
# Monitor CPU memory for debugging model deletion
import psutil
cpu_memory_before = psutil.Process().memory_info().rss / 1024**3
logger.info(f"CPU memory before model unload: {cpu_memory_before:.2f} GB")
# Handle block swapping cleanup if enabled
if hasattr(self.current_model, 'blocks_to_swap') and self.current_model.blocks_to_swap is not None:
if self.current_model.blocks_to_swap > 0:
logger.info(f"Cleaning up block swapping for {self.current_model_type} model...")
# First, ensure all blocks are moved back from swap
if hasattr(self.current_model, 'offloader') and self.current_model.offloader is not None:
# Wait for any pending operations on all blocks
for idx in range(len(self.current_model.blocks)):
try:
self.current_model.offloader.wait_for_block(idx)
except Exception as e:
logger.warning(f"Error waiting for block {idx}: {e}")
# Move all blocks back to CPU to free GPU memory
for idx in range(len(self.current_model.blocks)):
try:
# Move block to CPU if it's not already there
self.current_model.blocks[idx] = self.current_model.blocks[idx].cpu()
except Exception as e:
logger.warning(f"Error moving block {idx} to CPU: {e}")
# Clean up the offloader properly - FIX BACKWARD HOOK CLOSURE LEAK
try:
# 1. Clear backward hook handles (prevents closure reference leak)
if hasattr(self.current_model.offloader, 'remove_handles'):
for handle in self.current_model.offloader.remove_handles:
handle.remove()
self.current_model.offloader.remove_handles.clear()
# 2. Shutdown ThreadPoolExecutor and clear futures to prevent memory leaks
if hasattr(self.current_model.offloader, 'thread_pool'):
self.current_model.offloader.thread_pool.shutdown(wait=True)
if hasattr(self.current_model.offloader, 'futures'):
self.current_model.offloader.futures.clear()
del self.current_model.offloader
self.current_model.offloader = None
except Exception as e:
logger.warning(f"Error cleaning up offloader: {e}")
# Enhanced block reference clearing before deletion
try:
# 1. Clear torch.compile cache if model was compiled
if self.args.compile:
logger.info("Clearing torch.compile cache for model deletion")
torch._dynamo.reset()
# 2. Clear individual block references (prevents compiled block retention)
if hasattr(self.current_model, 'blocks') and self.current_model.blocks is not None:
for i in range(len(self.current_model.blocks)):
self.current_model.blocks[i] = None
self.current_model.blocks.clear()
self.current_model.blocks = None
# 3. Move any remaining parameters to CPU
self.current_model = self.current_model.cpu()
except Exception as e:
logger.warning(f"Error during enhanced cleanup: {e}")
# 4. Force model deletion with verification
del self.current_model
self.current_model = None
self.current_model_type = None
# Aggressive cleanup for both GPU and CPU memory
torch.cuda.empty_cache() # Clear GPU cache
torch.cuda.synchronize() # Wait for all GPU operations to complete
gc.collect() # Force Python garbage collection (clears CPU memory)
torch.cuda.empty_cache() # Second GPU cache clear
clean_memory_on_device(self.device)
# Additional cleanup for fragmented memory
if torch.cuda.is_available():
torch.cuda.ipc_collect()
# Log memory usage after unloading (GPU + CPU verification)
if torch.cuda.is_available():
memory_after = torch.cuda.memory_allocated(self.device) / 1024**3
logger.info(f"GPU memory after unload: {memory_after:.2f} GB (freed: {memory_before - memory_after:.2f} GB)")
# Verify CPU memory cleanup
cpu_memory_after = psutil.Process().memory_info().rss / 1024**3
cpu_freed = cpu_memory_before - cpu_memory_after
logger.info(f"CPU memory after model unload: {cpu_memory_after:.2f} GB (freed: {cpu_freed:.2f} GB)")
# Load new model
logger.info(f"Loading {model_type} noise model...")
loading_device = "cpu"
if self.args.blocks_to_swap == 0 and self.lora_weights_list_low is None and not self.args.fp8_scaled:
loading_device = self.device
loading_weight_dtype = self.dit_weight_dtype
if self.args.mixed_dtype:
# For mixed dtype, load weights as-is without conversion
loading_weight_dtype = None
elif self.args.fp8_scaled or self.args.lora_weight is not None:
loading_weight_dtype = self.dit_dtype
# Select appropriate LoRA weights for this model type
lora_weights_list = None
lora_multipliers = None
if model_type == 'low':
lora_weights_list = self.lora_weights_list_low
lora_multipliers = self.lora_multipliers_low
else: # 'high'
lora_weights_list = self.lora_weights_list_high
lora_multipliers = self.lora_multipliers_high
# DEBUG: Print full LoRA list and weights being applied to this DiT model
if lora_weights_list is not None:
logger.info(f"DEBUG: Loading {model_type} noise DiT model with {len(lora_weights_list)} LoRA(s)")
for i, lora_sd in enumerate(lora_weights_list):
multiplier = lora_multipliers[i] if lora_multipliers and i < len(lora_multipliers) else 1.0
lora_keys = list(lora_sd.keys())[:5] # Show first 5 keys
logger.info(f"DEBUG: LoRA {i+1}/{len(lora_weights_list)} for {model_type} noise model - Multiplier: {multiplier}, Keys sample: {lora_keys}")
else:
logger.info(f"DEBUG: Loading {model_type} noise DiT model with NO LoRA weights")
# Load model with LoRA weights if available
model = load_wan_model(
self.config, self.device, self.model_paths[model_type],
self.args.attn_mode, False, loading_device, loading_weight_dtype, False,
lora_weights_list=lora_weights_list, lora_multipliers=lora_multipliers
)
# Optimize model
optimize_model(model, self.args, self.device, self.dit_dtype, self.dit_weight_dtype)
self.current_model = model
self.current_model_type = model_type
return model
def cleanup(self):
"""Clean up any loaded models."""
if self.current_model is not None:
logger.info(f"Final cleanup of {self.current_model_type} noise model...")
# Handle block swapping cleanup
if hasattr(self.current_model, 'blocks_to_swap') and self.current_model.blocks_to_swap is not None:
if self.current_model.blocks_to_swap > 0 and hasattr(self.current_model, 'offloader'):
if self.current_model.offloader is not None:
# Wait for all blocks using the correct method
for idx in range(len(self.current_model.blocks)):
try:
self.current_model.offloader.wait_for_block(idx)
except Exception as e:
logger.warning(f"Error waiting for block {idx}: {e}")
# Move all blocks to CPU
for idx in range(len(self.current_model.blocks)):
try:
self.current_model.blocks[idx] = self.current_model.blocks[idx].cpu()
except Exception as e:
logger.warning(f"Error moving block {idx} to CPU: {e}")
try:
# 1. Clear backward hook handles (prevents closure reference leak)
if hasattr(self.current_model.offloader, 'remove_handles'):
for handle in self.current_model.offloader.remove_handles:
handle.remove()
self.current_model.offloader.remove_handles.clear()
# 2. Shutdown ThreadPoolExecutor and clear futures to prevent memory leaks
if hasattr(self.current_model.offloader, 'thread_pool'):
self.current_model.offloader.thread_pool.shutdown(wait=True)
if hasattr(self.current_model.offloader, 'futures'):
self.current_model.offloader.futures.clear()
del self.current_model.offloader
self.current_model.offloader = None
except Exception as e:
logger.warning(f"Error cleaning up offloader: {e}")
# Enhanced block reference clearing before final deletion
try:
# 1. Clear torch.compile cache if model was compiled
if self.args.compile:
logger.info("Final cleanup: Clearing torch.compile cache")
torch._dynamo.reset()
# 2. Clear individual block references (prevents compiled block retention)
if hasattr(self.current_model, 'blocks') and self.current_model.blocks is not None:
for i in range(len(self.current_model.blocks)):
self.current_model.blocks[i] = None
self.current_model.blocks.clear()
self.current_model.blocks = None
# 3. Move model to CPU before deletion
self.current_model = self.current_model.cpu()
except Exception as e:
logger.warning(f"Error during final enhanced cleanup: {e}")
# 4. Force final model deletion
del self.current_model
self.current_model = None
self.current_model_type = None
# Aggressive cleanup for both GPU and CPU
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
clean_memory_on_device(self.device)
def unload_all(self):
"""Alias for cleanup method to ensure compatibility."""
self.cleanup()
def create_funcontrol_conditioning_latent(
args: argparse.Namespace,
config,
vae: WanVAE,
device: torch.device,
lat_f: int,
lat_h: int,
lat_w: int,
pixel_height: int, # Actual pixel height for resizing
pixel_width: int # Actual pixel width for resizing
) -> Optional[torch.Tensor]:
"""
Creates the conditioning latent tensor 'y' for FunControl models,
replicating the logic from WanWeightedControlToVideo node.
Args:
args: Command line arguments.
config: Model configuration.
vae: Loaded VAE model instance.
device: Target computation device.
lat_f: Number of latent frames.
lat_h: Latent height.
lat_w: Latent width.
pixel_height: Target pixel height for image/video processing.
pixel_width: Target pixel width for image/video processing.
Returns:
torch.Tensor: The final conditioning latent 'y' [1, 32, lat_f, lat_h, lat_w],
or None if VAE is missing when required.
"""
logger.info("Creating FunControl conditioning latent 'y'...")
if vae is None:
# Should not happen if called correctly, but check anyway
logger.error("VAE is required to create FunControl conditioning latent but was not provided.")
return None
batch_size = 1 # Hardcoded for script execution
total_latent_frames = lat_f
vae_dtype = vae.dtype # Use VAE's dtype for encoding
# Initialize the two parts of the concat latent
# Control part (first 16 channels) - will be filled later
control_latent_part = torch.zeros([batch_size, 16, total_latent_frames, lat_h, lat_w],
device=device, dtype=vae_dtype).contiguous()
# Image guidance part (last 16 channels)
image_guidance_latent = torch.zeros([batch_size, 16, total_latent_frames, lat_h, lat_w],
device=device, dtype=vae_dtype).contiguous()
# --- Image Guidance Processing (Start/End Images) ---
timeline_mask = torch.zeros([1, 1, total_latent_frames], device=device, dtype=torch.float32).contiguous()
has_start_image = args.image_path is not None
has_end_image = args.end_image_path is not None
# Process start image if provided
start_latent = None
if has_start_image:
logger.info(f"Processing start image: {args.image_path}")
try:
img = Image.open(args.image_path).convert("RGB")
img_np = np.array(img)
# Resize to target pixel dimensions
interpolation = cv2.INTER_AREA if pixel_height < img_np.shape[0] else cv2.INTER_CUBIC
img_resized_np = cv2.resize(img_np, (pixel_width, pixel_height), interpolation=interpolation)
# Convert to tensor CFHW, range [-1, 1]
img_tensor = TF.to_tensor(img_resized_np).sub_(0.5).div_(0.5).to(device)
img_tensor = img_tensor.unsqueeze(1) # Add frame dim: C,F,H,W
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=vae_dtype):
# vae.encode expects a list, returns a list. Take first element.
# Result shape [C', F', H', W'] - needs batch dim for processing here
start_latent = vae.encode([img_tensor])[0].unsqueeze(0).to(device).contiguous() # [1, 16, 1, lat_h, lat_w]
# Calculate influence and falloff
start_frames_influence = min(start_latent.shape[2], total_latent_frames) # Usually 1
if start_frames_influence > 0:
# Use falloff_percentage for smooth transition *away* from start image
falloff_len_frames = max(1, int(total_latent_frames * args.control_falloff_percentage))
start_influence_mask = torch.ones([1, 1, total_latent_frames], device=device, dtype=torch.float32).contiguous()
# Apply falloff starting *after* the first frame
if total_latent_frames > 1 + falloff_len_frames:
# Falloff from frame 1 to 1+falloff_len_frames
t = torch.linspace(0, 1, falloff_len_frames, device=device)
falloff = 0.5 + 0.5 * torch.cos(t * math.pi) # 1 -> 0
start_influence_mask[0, 0, 1:1+falloff_len_frames] = falloff
# Set influence to 0 after falloff
start_influence_mask[0, 0, 1+falloff_len_frames:] = 0.0
elif total_latent_frames > 1:
# Shorter falloff if video is too short
t = torch.linspace(0, 1, total_latent_frames - 1, device=device)
falloff = 0.5 + 0.5 * torch.cos(t * math.pi) # 1 -> 0
start_influence_mask[0, 0, 1:] = falloff
# Place start latent in the image guidance part, weighted by mask
# Since start_latent is only frame 0, we just place it there.
# The mask influences how other elements (like end image) blend *in*.
image_guidance_latent[:, :, 0:1, :, :] = start_latent[:, :, 0:1, :, :] # Take first frame
# Update the main timeline mask
timeline_mask = torch.max(timeline_mask, start_influence_mask) # Start image dominates beginning
logger.info(f"Start image processed. Latent shape: {start_latent.shape}")
except Exception as e:
logger.error(f"Error processing start image: {e}")
# Continue without start image guidance
# Process end image if provided
end_latent = None
if has_end_image:
logger.info(f"Processing end image: {args.end_image_path}")
try:
img = Image.open(args.end_image_path).convert("RGB")
img_np = np.array(img)
# Resize to target pixel dimensions
interpolation = cv2.INTER_AREA if pixel_height < img_np.shape[0] else cv2.INTER_CUBIC
img_resized_np = cv2.resize(img_np, (pixel_width, pixel_height), interpolation=interpolation)
# Convert to tensor CFHW, range [-1, 1]
img_tensor = TF.to_tensor(img_resized_np).sub_(0.5).div_(0.5).to(device)
img_tensor = img_tensor.unsqueeze(1) # Add frame dim: C,F,H,W
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=vae_dtype):
# vae.encode expects a list, returns a list. Take first element.
# Result shape [C', F', H', W'] - needs batch dim for processing here
end_latent = vae.encode([img_tensor])[0].unsqueeze(0).to(device).contiguous() # [1, 16, 1, lat_h, lat_w]
# Determine if using looped schedule
using_looped = (hasattr(args, 'use_context_windows') and args.use_context_windows and
hasattr(args, 'context_schedule') and args.context_schedule == "looped_uniform")
# Calculate end image influence transition (S-curve / cubic)
end_influence_mask = torch.zeros([1, 1, total_latent_frames], device=device, dtype=torch.float32).contiguous()
falloff_len_frames = max(1, int(total_latent_frames * args.control_falloff_percentage))
# Determine when the end image influence should start ramping up
# More sophisticated start point based on control_end if control video exists
if args.control_path and args.control_end < 1.0:
# Start fade-in just before control video fades out significantly
influence_start_frame = max(0, int(total_latent_frames * args.control_end) - falloff_len_frames // 2)
else:
# Default: start influence around 60% mark if no control or control runs full length
# For looped mode: start earlier for stronger end image conditioning
if using_looped:
influence_start_frame = max(0, int(total_latent_frames * 0.5))
logger.info("Looped mode: Starting end image influence earlier (50% mark)")
else:
influence_start_frame = max(0, int(total_latent_frames * 0.6))
# Ensure start frame isn't too close to the beginning if start image exists
if has_start_image:
influence_start_frame = max(influence_start_frame, 1 + falloff_len_frames) # Ensure it starts after start_img falloff
transition_length = total_latent_frames - influence_start_frame
if transition_length > 0:
logger.info(f"End image influence transition: frames {influence_start_frame} to {total_latent_frames-1}")
curve_positions = torch.linspace(0, 1, transition_length, device=device)
for i, pos in enumerate(curve_positions):
idx = influence_start_frame + i
if idx < total_latent_frames:
# Cubic ease-in-out curve (smoother than cosine)
if pos < 0.5: influence = 4 * pos * pos * pos
else: p = pos - 1; influence = 1 + 4 * p * p * p
# Ensure full influence near the end - stronger for looped mode
frames_from_end = total_latent_frames - 1 - idx
if using_looped and frames_from_end < 5:
influence = 1.0 # Force full influence in last 5 frames for looped
elif frames_from_end < 3:
influence = 1.0 # Force full influence in last 3 frames for non-looped
end_influence_mask[0, 0, idx] = influence
# Blending logic (similar to base_nodes)
blend_start_frame = influence_start_frame
blend_length = total_latent_frames - blend_start_frame
if blend_length > 0:
# Create reference end latent (just the single frame repeated conceptually)
# Blend existing content with end latent based on influence weight
for i in range(blend_length):
idx = blend_start_frame + i
if idx < total_latent_frames:
weight = end_influence_mask[0, 0, idx].item()
if weight > 0:
# Blend: (1-w)*current + w*end_latent
image_guidance_latent[:, :, idx] = (
(1.0 - weight) * image_guidance_latent[:, :, idx] +
weight * end_latent[:, :, 0] # Use the single frame end_latent
)
# Ensure final frames are exactly the end image latent
# Use more frames for looped mode to ensure clean loop
if using_looped:
last_frames_exact = min(5, total_latent_frames) # Last 5 frames for looped
logger.info("Looped mode: Forcing last 5 frames to exact end image")
else:
last_frames_exact = min(3, total_latent_frames) # Last 3 frames for non-looped
if last_frames_exact > 0:
end_offset = total_latent_frames - last_frames_exact
if end_offset >= 0:
image_guidance_latent[:, :, end_offset:] = end_latent[:, :, 0:1].repeat(1, 1, last_frames_exact, 1, 1)
# Update the main timeline mask
timeline_mask = torch.max(timeline_mask, end_influence_mask)
logger.info(f"End image processed. Latent shape: {end_latent.shape}")
except Exception as e:
logger.error(f"Error processing end image: {e}")
# Continue without end image guidance
# --- Control Video Processing ---
control_video_latent = None
if args.control_path:
logger.info(f"Processing control video: {args.control_path}")
try:
# Load control video frames (use helper from hv_generate_video for consistency)
# Use args.video_length for the number of frames
if os.path.isfile(args.control_path):
video_frames_np = hv_load_video(args.control_path, 0, args.video_length, bucket_reso=(pixel_width, pixel_height))
elif os.path.isdir(args.control_path):
video_frames_np = hv_load_images(args.control_path, args.video_length, bucket_reso=(pixel_width, pixel_height))
else:
raise FileNotFoundError(f"Control path not found: {args.control_path}")
if not video_frames_np:
raise ValueError("No frames loaded from control path.")
num_control_frames_loaded = len(video_frames_np)
if num_control_frames_loaded < args.video_length:
logger.warning(f"Control video loaded {num_control_frames_loaded} frames, less than target {args.video_length}. Padding with last frame.")
# Pad with the last frame
last_frame = video_frames_np[-1]
padding = [last_frame] * (args.video_length - num_control_frames_loaded)
video_frames_np.extend(padding)
# Stack and convert to tensor: F, H, W, C -> B, C, F, H, W, range [-1, 1]
video_frames_np = np.stack(video_frames_np[:args.video_length], axis=0) # Ensure correct length
control_tensor = torch.from_numpy(video_frames_np).permute(0, 3, 1, 2).float() / 127.5 - 1.0 # F,C,H,W
control_tensor = control_tensor.permute(1, 0, 2, 3) # C,F,H,W
control_tensor = control_tensor.unsqueeze(0).to(device) # B,C,F,H,W
# Encode control video
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=vae_dtype):
# vae.encode expects list of [C, F, H, W], returns list of [C', F', H', W']
control_video_latent = vae.encode([control_tensor[0]])[0].unsqueeze(0).to(device).contiguous() # [1, 16, lat_f, lat_h, lat_w]
# Calculate weighted control mask (replicating base_nodes logic)
control_frames_latent = control_video_latent.shape[2] # Should match total_latent_frames
control_mask = torch.zeros([1, 1, control_frames_latent], device=device, dtype=torch.float32).contiguous()
start_frame_idx = max(0, min(control_frames_latent - 1, int(control_frames_latent * args.control_start)))
end_frame_idx = max(start_frame_idx + 1, min(control_frames_latent, int(control_frames_latent * args.control_end)))
falloff_len_frames = max(2, int(control_frames_latent * args.control_falloff_percentage))
# Main active region