Skip to content

Commit 7e0c618

Browse files
committed
SDPose: resize input always
1 parent e6be419 commit 7e0c618

3 files changed

Lines changed: 42 additions & 36 deletions

File tree

comfy/ldm/modules/sdpose.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def forward(self, x): # Decode heatmaps to keypoints
9090
origin_max = np.max(hm[k])
9191
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
9292
dr[border:-border, border:-border] = hm[k].copy()
93-
dr = gaussian_filter(dr, sigma=2.0)
93+
dr = gaussian_filter(dr, sigma=2.0, truncate=2.5)
9494
hm[k] = dr[border:-border, border:-border].copy()
9595
cur_max = np.max(hm[k])
9696
if cur_max > 0:

comfy_extras/nodes_rtdetr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ def define_schema(cls):
3232
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
3333
B, H, W, C = image.shape
3434

35-
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
36-
3735
comfy.model_management.load_model_gpu(model)
38-
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
36+
results = []
37+
for i in range(0, B, 32):
38+
batch = image[i:i + 32]
39+
image_in = comfy.utils.common_upscale(batch.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
40+
results.extend(model.model.diffusion_model(image_in, (W, H)))
3941

4042
all_bbox_dicts = []
4143

comfy_extras/nodes_sdpose.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import comfy.utils
3+
import comfy.model_management
34
import numpy as np
45
import math
56
import colorsys
@@ -410,7 +411,9 @@ def _zeros(n):
410411
pose_outputs.append(canvas)
411412

412413
pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0)
413-
final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0
414+
final_pose_output = torch.from_numpy(pose_outputs_np).to(
415+
device=comfy.model_management.intermediate_device(),
416+
dtype=comfy.model_management.intermediate_dtype()) / 255.0
414417
return io.NodeOutput(final_pose_output)
415418

416419
class SDPoseKeypointExtractor(io.ComfyNode):
@@ -459,6 +462,27 @@ def output_patch(h, hsp, transformer_options):
459462
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
460463
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
461464

465+
def _resize_to_model(imgs):
466+
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
467+
h, w = imgs.shape[-3], imgs.shape[-2]
468+
scale = min(model_h / h, model_w / w)
469+
sh, sw = int(round(h * scale)), int(round(w * scale))
470+
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
471+
chw = imgs.permute(0, 3, 1, 2).float()
472+
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
473+
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
474+
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
475+
return padded.permute(0, 2, 3, 1), scale, pt, pl
476+
477+
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
478+
"""Remap keypoints from model space back to original image space."""
479+
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
480+
invalid = kp[..., 0] < 0
481+
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
482+
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
483+
kp[invalid] = -1
484+
return kp
485+
462486
def _run_on_latent(latent_batch):
463487
"""Run one forward pass and return (keypoints_list, scores_list) for the batch."""
464488
nonlocal captured_feat
@@ -504,56 +528,36 @@ def _run_on_latent(latent_batch):
504528
if x2 <= x1 or y2 <= y1:
505529
continue
506530

507-
crop_h_px, crop_w_px = y2 - y1, x2 - x1
508531
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
509-
510-
# scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size.
511-
scale = min(model_h / crop_h_px, model_w / crop_w_px)
512-
scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale))
513-
pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2
514-
515-
crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW
516-
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
517-
padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
518-
padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
519-
crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC
532+
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
520533

521534
latent_crop = vae.encode(crop_resized)
522535
kp_batch, sc_batch = _run_on_latent(latent_crop)
523-
kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space
524-
525-
# remove padding offset, undo scale, offset to full-image coordinates.
526-
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
527-
kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1
528-
kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1
529-
536+
kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
530537
img_keypoints.append(kp)
531-
img_scores.append(sc)
538+
img_scores.append(sc_batch[0])
532539
else:
533-
# No bboxes for this image – run on the full image
534-
latent_img = vae.encode(img)
540+
img_resized, scale, pad_top, pad_left = _resize_to_model(img)
541+
latent_img = vae.encode(img_resized)
535542
kp_batch, sc_batch = _run_on_latent(latent_img)
536-
img_keypoints.append(kp_batch[0])
543+
img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
537544
img_scores.append(sc_batch[0])
538545

539546
all_keypoints.append(img_keypoints)
540547
all_scores.append(img_scores)
541548
pbar.update(1)
542549

543550
else: # full-image mode, batched
544-
tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints")
545-
for batch_start in range(0, total_images, batch_size):
546-
batch_end = min(batch_start + batch_size, total_images)
547-
latent_batch = vae.encode(image[batch_start:batch_end])
548-
551+
for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
552+
batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
553+
latent_batch = vae.encode(batch_resized)
549554
kp_batch, sc_batch = _run_on_latent(latent_batch)
550555

551556
for kp, sc in zip(kp_batch, sc_batch):
552-
all_keypoints.append([kp])
557+
all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
553558
all_scores.append([sc])
554-
tqdm_pbar.update(1)
555559

556-
pbar.update(batch_end - batch_start)
560+
pbar.update(len(kp_batch))
557561

558562
openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width)
559563
return io.NodeOutput(openpose_frames)

0 commit comments

Comments
 (0)