@@ -170,22 +170,51 @@ def read_image(image_path: str) -> np.ndarray:
170170 return image
171171
172172
173- def read_image_as_pil (image : Image .Image | str | np .ndarray , exif_fix : bool = True ) -> Image .Image :
174- """Loads an image as PIL.Image.Image.
173+ def _to_hwc (arr : np .ndarray ) -> np .ndarray :
174+ """Return a HWC array, transposing CHW to HWC when necessary.
175+
176+ Uses channel-count heuristic (1, 3, or 4) instead of a size threshold so
177+ small images (height < 5 px) are handled correctly. Channel order is NOT
178+ changed — callers are responsible for BGR/RGB correctness before passing in.
179+
180+ Args:
181+ arr (numpy.ndarray): The input array to be converted to HWC format.
182+
183+ Returns:
184+ numpy.ndarray: The input array converted to HWC format if necessary, otherwise the original array
185+
186+ """
187+ a = np .asarray (arr )
188+ if a .ndim == 3 and a .shape [0 ] in (1 , 3 , 4 ) and a .shape [- 1 ] not in (1 , 3 , 4 ):
189+ return np .transpose (a , (1 , 2 , 0 ))
190+ return a
191+
192+
193+ def read_image_as_pil (
194+ image : Image .Image | str | np .ndarray ,
195+ exif_fix : bool = True ,
196+ return_arr : bool = False ,
197+ ) -> Image .Image | np .ndarray :
198+ """Loads an image as PIL.Image.Image (or np.ndarray when return_arr=True).
175199
176200 Args:
177201 image (Union[Image.Image, str, np.ndarray]): The image to be loaded. It can be an image path or URL (str),
178202 a numpy image (np.ndarray), or a PIL.Image object.
179- exif_fix (bool, optional): Whether to apply an EXIF fix to the image. Defaults to False.
203+ exif_fix (bool): Whether to apply an EXIF fix to the image. Defaults to False.
204+ return_arr (bool): When True and the input is already a numpy array, skip the
205+ costly PIL conversion and return an HWC RGB ndarray directly. For PIL/str inputs the
206+ PIL image is converted to ndarray before returning. Defaults to False.
180207
181208 Returns:
182- PIL.Image.Image: The loaded image as a PIL.Image object .
209+ PIL.Image.Image | np.ndarray : The loaded image.
183210 """
184211 # https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil
185212 Image .MAX_IMAGE_PIXELS = None
186213
187214 if isinstance (image , Image .Image ):
188- image_pil = image
215+ if return_arr :
216+ return np .asarray (image )
217+ return image
189218 elif isinstance (image , str ):
190219 # read image if str image path is provided
191220 try :
@@ -211,17 +240,16 @@ def read_image_as_pil(image: Image.Image | str | np.ndarray, exif_fix: bool = Tr
211240 image_pil = Image .fromarray (image_sk , mode = "RGB" )
212241 else :
213242 raise TypeError (f"image with shape: { image_sk .shape [3 ]} is not supported." )
243+ if return_arr :
244+ return np .asarray (image_pil )
245+ return image_pil
214246 elif isinstance (image , np .ndarray ):
215- # check if image is in CHW format (Channels, Height, Width)
216- # heuristic: 3 dimensions, first dim (channels) < 5, last dim (width) > 4
217- if image .ndim == 3 and image .shape [0 ] < 5 : # image in CHW
218- if image .shape [2 ] > 4 :
219- # convert CHW to HWC (Height, Width, Channels)
220- image = np .transpose (image , (1 , 2 , 0 ))
221- image_pil = Image .fromarray (image )
247+ arr = _to_hwc (image )
248+ if return_arr :
249+ return arr
250+ return Image .fromarray (arr )
222251 else :
223252 raise TypeError ("read image with 'pillow' using 'Image.open()'" )
224- return image_pil
225253
226254
227255def select_random_color () -> list [int ]:
0 commit comments