|
1 | 1 | import torch |
2 | 2 | import comfy.utils |
| 3 | +import comfy.model_management |
3 | 4 | import numpy as np |
4 | 5 | import math |
5 | 6 | import colorsys |
@@ -410,7 +411,9 @@ def _zeros(n): |
410 | 411 | pose_outputs.append(canvas) |
411 | 412 |
|
412 | 413 | pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0) |
413 | | - final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0 |
| 414 | + final_pose_output = torch.from_numpy(pose_outputs_np).to( |
| 415 | + device=comfy.model_management.intermediate_device(), |
| 416 | + dtype=comfy.model_management.intermediate_dtype()) / 255.0 |
414 | 417 | return io.NodeOutput(final_pose_output) |
415 | 418 |
|
416 | 419 | class SDPoseKeypointExtractor(io.ComfyNode): |
@@ -459,6 +462,27 @@ def output_patch(h, hsp, transformer_options): |
459 | 462 | model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768 |
460 | 463 | model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024 |
461 | 464 |
|
| 465 | + def _resize_to_model(imgs): |
| 466 | + """Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left).""" |
| 467 | + h, w = imgs.shape[-3], imgs.shape[-2] |
| 468 | + scale = min(model_h / h, model_w / w) |
| 469 | + sh, sw = int(round(h * scale)), int(round(w * scale)) |
| 470 | + pt, pl = (model_h - sh) // 2, (model_w - sw) // 2 |
| 471 | + chw = imgs.permute(0, 3, 1, 2).float() |
| 472 | + scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled") |
| 473 | + padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) |
| 474 | + padded[:, :, pt:pt + sh, pl:pl + sw] = scaled |
| 475 | + return padded.permute(0, 2, 3, 1), scale, pt, pl |
| 476 | + |
| 477 | + def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0): |
| 478 | + """Remap keypoints from model space back to original image space.""" |
| 479 | + kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) |
| 480 | + invalid = kp[..., 0] < 0 |
| 481 | + kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x |
| 482 | + kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y |
| 483 | + kp[invalid] = -1 |
| 484 | + return kp |
| 485 | + |
462 | 486 | def _run_on_latent(latent_batch): |
463 | 487 | """Run one forward pass and return (keypoints_list, scores_list) for the batch.""" |
464 | 488 | nonlocal captured_feat |
@@ -504,56 +528,36 @@ def _run_on_latent(latent_batch): |
504 | 528 | if x2 <= x1 or y2 <= y1: |
505 | 529 | continue |
506 | 530 |
|
507 | | - crop_h_px, crop_w_px = y2 - y1, x2 - x1 |
508 | 531 | crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C) |
509 | | - |
510 | | - # scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size. |
511 | | - scale = min(model_h / crop_h_px, model_w / crop_w_px) |
512 | | - scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale)) |
513 | | - pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2 |
514 | | - |
515 | | - crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW |
516 | | - scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled") |
517 | | - padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) |
518 | | - padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled |
519 | | - crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC |
| 532 | + crop_resized, scale, pad_top, pad_left = _resize_to_model(crop) |
520 | 533 |
|
521 | 534 | latent_crop = vae.encode(crop_resized) |
522 | 535 | kp_batch, sc_batch = _run_on_latent(latent_crop) |
523 | | - kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space |
524 | | - |
525 | | - # remove padding offset, undo scale, offset to full-image coordinates. |
526 | | - kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) |
527 | | - kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1 |
528 | | - kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1 |
529 | | - |
| 536 | + kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1) |
530 | 537 | img_keypoints.append(kp) |
531 | | - img_scores.append(sc) |
| 538 | + img_scores.append(sc_batch[0]) |
532 | 539 | else: |
533 | | - # No bboxes for this image – run on the full image |
534 | | - latent_img = vae.encode(img) |
| 540 | + img_resized, scale, pad_top, pad_left = _resize_to_model(img) |
| 541 | + latent_img = vae.encode(img_resized) |
535 | 542 | kp_batch, sc_batch = _run_on_latent(latent_img) |
536 | | - img_keypoints.append(kp_batch[0]) |
| 543 | + img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left)) |
537 | 544 | img_scores.append(sc_batch[0]) |
538 | 545 |
|
539 | 546 | all_keypoints.append(img_keypoints) |
540 | 547 | all_scores.append(img_scores) |
541 | 548 | pbar.update(1) |
542 | 549 |
|
543 | 550 | else: # full-image mode, batched |
544 | | - tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints") |
545 | | - for batch_start in range(0, total_images, batch_size): |
546 | | - batch_end = min(batch_start + batch_size, total_images) |
547 | | - latent_batch = vae.encode(image[batch_start:batch_end]) |
548 | | - |
| 551 | + for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"): |
| 552 | + batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size]) |
| 553 | + latent_batch = vae.encode(batch_resized) |
549 | 554 | kp_batch, sc_batch = _run_on_latent(latent_batch) |
550 | 555 |
|
551 | 556 | for kp, sc in zip(kp_batch, sc_batch): |
552 | | - all_keypoints.append([kp]) |
| 557 | + all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)]) |
553 | 558 | all_scores.append([sc]) |
554 | | - tqdm_pbar.update(1) |
555 | 559 |
|
556 | | - pbar.update(batch_end - batch_start) |
| 560 | + pbar.update(len(kp_batch)) |
557 | 561 |
|
558 | 562 | openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width) |
559 | 563 | return io.NodeOutput(openpose_frames) |
|
0 commit comments