Skip to content

Commit 9310bec

Browse files
authored
perf: ⚡️ enhance read_image_as_pil read speed for better slice speed (#1353)
Signed-off-by: Onuralp SEZER <thunderbirdtr@gmail.com>
1 parent b832c2c commit 9310bec

2 files changed

Lines changed: 48 additions & 16 deletions

File tree

sahi/predict.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,19 @@ def get_prediction(
133133

134134
try:
135135
durations_in_seconds = dict()
136-
image_as_pil = read_image_as_pil(image)
136+
image_as_arr = read_image_as_pil(image, return_arr=True)
137137

138138
if shift_amount is None:
139139
shift_amount = [0, 0]
140140
if full_shape is None:
141-
full_shape = [image_as_pil.height, image_as_pil.width]
141+
if image_as_arr.ndim == 2: # type: ignore[union-attr]
142+
h, w = image_as_arr.shape # type: ignore[misc]
143+
else:
144+
h, w = image_as_arr.shape[:2] # type: ignore[union-attr]
145+
full_shape = [h, w]
142146

143147
time_start = time.perf_counter()
144-
detection_model.perform_inference(np.ascontiguousarray(image_as_pil))
148+
detection_model.perform_inference(np.ascontiguousarray(image_as_arr))
145149
durations_in_seconds["prediction"] = time.perf_counter() - time_start
146150

147151
time_start = time.perf_counter()

sahi/utils/cv.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,22 +170,51 @@ def read_image(image_path: str) -> np.ndarray:
170170
return image
171171

172172

173-
def read_image_as_pil(image: Image.Image | str | np.ndarray, exif_fix: bool = True) -> Image.Image:
174-
"""Loads an image as PIL.Image.Image.
173+
def _to_hwc(arr: np.ndarray) -> np.ndarray:
174+
"""Return a HWC array, transposing CHW to HWC when necessary.
175+
176+
Uses channel-count heuristic (1, 3, or 4) instead of a size threshold so
177+
small images (height < 5 px) are handled correctly. Channel order is NOT
178+
changed — callers are responsible for BGR/RGB correctness before passing in.
179+
180+
Args:
181+
arr (numpy.ndarray): The input array to be converted to HWC format.
182+
183+
Returns:
184+
numpy.ndarray: The input array converted to HWC format if necessary, otherwise the original array
185+
186+
"""
187+
a = np.asarray(arr)
188+
if a.ndim == 3 and a.shape[0] in (1, 3, 4) and a.shape[-1] not in (1, 3, 4):
189+
return np.transpose(a, (1, 2, 0))
190+
return a
191+
192+
193+
def read_image_as_pil(
194+
image: Image.Image | str | np.ndarray,
195+
exif_fix: bool = True,
196+
return_arr: bool = False,
197+
) -> Image.Image | np.ndarray:
198+
"""Loads an image as PIL.Image.Image (or np.ndarray when return_arr=True).
175199
176200
Args:
177201
image (Union[Image.Image, str, np.ndarray]): The image to be loaded. It can be an image path or URL (str),
178202
a numpy image (np.ndarray), or a PIL.Image object.
179-
exif_fix (bool, optional): Whether to apply an EXIF fix to the image. Defaults to False.
203+
exif_fix (bool): Whether to apply an EXIF fix to the image. Defaults to False.
204+
return_arr (bool): When True and the input is already a numpy array, skip the
205+
costly PIL conversion and return an HWC RGB ndarray directly. For PIL/str inputs the
206+
PIL image is converted to ndarray before returning. Defaults to False.
180207
181208
Returns:
182-
PIL.Image.Image: The loaded image as a PIL.Image object.
209+
PIL.Image.Image | np.ndarray: The loaded image.
183210
"""
184211
# https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil
185212
Image.MAX_IMAGE_PIXELS = None
186213

187214
if isinstance(image, Image.Image):
188-
image_pil = image
215+
if return_arr:
216+
return np.asarray(image)
217+
return image
189218
elif isinstance(image, str):
190219
# read image if str image path is provided
191220
try:
@@ -211,17 +240,16 @@ def read_image_as_pil(image: Image.Image | str | np.ndarray, exif_fix: bool = Tr
211240
image_pil = Image.fromarray(image_sk, mode="RGB")
212241
else:
213242
raise TypeError(f"image with shape: {image_sk.shape[3]} is not supported.")
243+
if return_arr:
244+
return np.asarray(image_pil)
245+
return image_pil
214246
elif isinstance(image, np.ndarray):
215-
# check if image is in CHW format (Channels, Height, Width)
216-
# heuristic: 3 dimensions, first dim (channels) < 5, last dim (width) > 4
217-
if image.ndim == 3 and image.shape[0] < 5: # image in CHW
218-
if image.shape[2] > 4:
219-
# convert CHW to HWC (Height, Width, Channels)
220-
image = np.transpose(image, (1, 2, 0))
221-
image_pil = Image.fromarray(image)
247+
arr = _to_hwc(image)
248+
if return_arr:
249+
return arr
250+
return Image.fromarray(arr)
222251
else:
223252
raise TypeError("read image with 'pillow' using 'Image.open()'")
224-
return image_pil
225253

226254

227255
def select_random_color() -> list[int]:

0 commit comments

Comments
 (0)