-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
195 lines (146 loc) · 5.66 KB
/
utils.py
File metadata and controls
195 lines (146 loc) · 5.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Utility functions for Qwen Image Edit ComfyUI plugin
"""
import torch
from PIL import Image
import numpy as np
from typing import Union, List, Tuple, Optional
def tensor_to_pil(tensor: torch.Tensor, batch_idx: int = 0) -> Image.Image:
"""Convert ComfyUI tensor to PIL Image.
Args:
tensor: ComfyUI IMAGE tensor (BCHW format, float32 [0,1])
batch_idx: Which image in batch to convert
Returns:
PIL Image in RGB format
"""
if tensor.dim() == 4:
tensor = tensor[batch_idx]
elif tensor.dim() == 3:
pass # already single image CHW
else:
raise ValueError(f"Expected 3D or 4D tensor, got {tensor.dim()}D")
# Convert from CHW to HWC and scale to [0,255]
image_np = (tensor.clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)
return Image.fromarray(image_np)
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
"""Convert PIL Image to ComfyUI tensor format.
Args:
image: PIL Image in RGB format
Returns:
ComfyUI IMAGE tensor (1CHW format, float32 [0,1])
"""
# Convert to numpy array and normalize to [0,1]
arr = np.array(image).astype(np.float32) / 255.0
# Handle grayscale by converting to RGB
if arr.ndim == 2:
arr = np.stack([arr, arr, arr], axis=-1)
# Remove alpha channel if present
if arr.shape[-1] == 4:
arr = arr[..., :3]
# Convert from HWC to 1CHW
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).contiguous()
return tensor
def batch_tensor_to_pil(tensor: torch.Tensor) -> List[Image.Image]:
"""Convert batch tensor to list of PIL Images.
Args:
tensor: ComfyUI IMAGE tensor (BCHW format)
Returns:
List of PIL Images
"""
if tensor.dim() != 4:
raise ValueError(f"Expected 4D batch tensor, got {tensor.dim()}D")
return [tensor_to_pil(tensor, i) for i in range(tensor.shape[0])]
def pil_list_to_tensor(images: List[Image.Image]) -> torch.Tensor:
"""Convert list of PIL Images to batch tensor.
Args:
images: List of PIL Images
Returns:
ComfyUI IMAGE tensor (BCHW format)
"""
if not images:
raise ValueError("Empty image list")
tensors = [pil_to_tensor(img) for img in images]
return torch.cat(tensors, dim=0)
def validate_gpu_memory() -> Tuple[bool, str]:
"""Check if sufficient GPU memory is available.
Returns:
Tuple of (is_sufficient, message)
"""
if not torch.cuda.is_available():
return False, "CUDA not available. CPU mode will be used but may be very slow."
try:
memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
if memory_gb < 8:
return False, f"Only {memory_gb:.1f}GB GPU memory available. 8GB+ recommended for optimal performance."
return True, f"GPU memory: {memory_gb:.1f}GB - sufficient for Qwen Image Edit"
except Exception as e:
return False, f"Could not check GPU memory: {e}"
def safe_model_load(model_path: str, device: str = "auto") -> Tuple[bool, str]:
"""Safely attempt to load model with appropriate device selection.
Args:
model_path: Path or name of the model
device: Target device ("auto", "cuda", "cpu")
Returns:
Tuple of (success, message/error)
"""
if device == "auto":
if torch.cuda.is_available():
gpu_ok, gpu_msg = validate_gpu_memory()
device = "cuda" if gpu_ok else "cpu"
if not gpu_ok:
return False, f"Falling back to CPU: {gpu_msg}"
else:
device = "cpu"
try:
# This is just a validation function - actual loading happens in the node
return True, f"Model loading prepared for device: {device}"
except Exception as e:
return False, f"Model loading failed: {e}"
def prepare_generation_params(params: Optional[dict], defaults: Optional[dict] = None) -> dict:
"""Prepare and validate generation parameters.
Args:
params: User provided parameters
defaults: Default parameter values
Returns:
Merged and validated parameters
"""
if defaults is None:
defaults = {
"true_cfg_scale": 4.0,
"negative_prompt": " ",
"num_inference_steps": 50,
"seed": 0,
}
if params is None:
params = {}
# Merge with defaults
result = defaults.copy()
result.update(params)
# Validate ranges
result["true_cfg_scale"] = max(0.0, min(20.0, float(result["true_cfg_scale"])))
result["num_inference_steps"] = max(1, min(200, int(result["num_inference_steps"])))
result["seed"] = max(0, min(2**31-1, int(result["seed"])))
return result
def split_prompts(prompt_text: str) -> List[str]:
"""Split multi-line prompt text into individual prompts.
Args:
prompt_text: Multi-line string with prompts
Returns:
List of individual prompt strings
"""
lines = [line.strip() for line in prompt_text.splitlines() if line.strip()]
return lines if lines else [""]
def format_error_message(error: Exception, context: str = "") -> str:
"""Format error messages for user-friendly display.
Args:
error: The exception that occurred
context: Additional context about where the error occurred
Returns:
Formatted error message
"""
error_type = type(error).__name__
error_msg = str(error)
if context:
return f"[{context}] {error_type}: {error_msg}"
else:
return f"{error_type}: {error_msg}"