Skip to content

Commit ee6d075

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 040656e commit ee6d075

2 files changed

Lines changed: 69 additions & 50 deletions

File tree

geemap/dl.py

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
import warnings
1515
import threading
1616

17+
1718
def _detect_image(
1819
image,
1920
model_path,
2021
scale,
21-
segment_fn,
22+
segment_fn,
2223
segment_args: dict,
2324
num_thread=None,
2425
bands=None,
@@ -31,7 +32,7 @@ def _detect_image(
3132
min_area=0,
3233
max_tile_size=None,
3334
max_tile_dim=None,
34-
**kwargs
35+
**kwargs,
3536
):
3637
"""
3738
Perform object detection on an Earth Engine image using a deep learning model.
@@ -46,7 +47,7 @@ def _detect_image(
4647
all are resampled to this scale. Defaults to the minimum scale of image bands.
4748
segment_fn (Callable): A segmentation post-processing function that processes the model's raw output.
4849
segment_args (dict): A dictionary of keyword arguments to pass to the segmentation function.
49-
num_threads (int, optional): Number of tiles to download concurrently. Defaults to a sensible auto value.
50+
num_threads (int, optional): Number of tiles to download concurrently. Defaults to a sensible auto value.
5051
bands (list, optional): A list of band names to use in the timelapse.
5152
window_size (int): Size of sliding window for inference.
5253
overlap (int): Overlap between adjacent windows.
@@ -63,20 +64,22 @@ def _detect_image(
6364
6465
Returns:
6566
ee.FeatureCollection: A FeatureCollection containing vectorized object detection results, reprojected to WGS84.
66-
67+
6768
Note:
6869
This function relies on the `geedim`, `rasterio`, and `geoai` packages for tiling, downloading, inference,
6970
and raster-to-vector conversion. Ensure the Earth Engine image is processed and scaled appropriately before passing to this function.
7071
"""
71-
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
72+
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
7273
warnings.filterwarnings("ignore", category=UserWarning, module="rasterio._env")
73-
logging.getLogger('rasterio').setLevel(logging.ERROR)
74+
logging.getLogger("rasterio").setLevel(logging.ERROR)
7475

7576
out_lock = threading.Lock()
7677

77-
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as temp_merged, \
78-
tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as temp_mask, \
79-
tempfile.NamedTemporaryFile(suffix=".geojson", delete=False) as temp_geojson:
78+
with (
79+
tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as temp_merged,
80+
tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as temp_mask,
81+
tempfile.NamedTemporaryFile(suffix=".geojson", delete=False) as temp_geojson,
82+
):
8083

8184
merged_path = temp_merged.name
8285
masks_path = temp_mask.name
@@ -92,31 +95,31 @@ def _detect_image(
9295

9396
baseImg = gd.download.BaseImage(image)
9497

95-
if kwargs.get('set_nodata') is None:
96-
kwargs['set_nodata'] = True
97-
if kwargs.get('crs') is None:
98-
kwargs['crs'] = "EPSG:4326"
99-
if kwargs.get('region') is None:
100-
kwargs['region'] = image.geometry().bounds()
101-
if kwargs.get('epsilon') is None:
102-
kwargs['epsilon'] = 0.2
103-
if kwargs.get('min_segments') is None:
104-
kwargs['min_segments'] = 4
105-
if kwargs.get('area_tolerance') is None:
106-
kwargs['area_tolerance'] = 0.7
107-
if kwargs.get('detect_triangles') is None:
108-
kwargs['detect_triangles'] = True
109-
110-
98+
if kwargs.get("set_nodata") is None:
99+
kwargs["set_nodata"] = True
100+
if kwargs.get("crs") is None:
101+
kwargs["crs"] = "EPSG:4326"
102+
if kwargs.get("region") is None:
103+
kwargs["region"] = image.geometry().bounds()
104+
if kwargs.get("epsilon") is None:
105+
kwargs["epsilon"] = 0.2
106+
if kwargs.get("min_segments") is None:
107+
kwargs["min_segments"] = 4
108+
if kwargs.get("area_tolerance") is None:
109+
kwargs["area_tolerance"] = 0.7
110+
if kwargs.get("detect_triangles") is None:
111+
kwargs["detect_triangles"] = True
111112

112113
exp_img, profile = baseImg._prepare_for_download(
113-
set_nodata=kwargs['set_nodata'],
114-
crs=kwargs['crs'],
114+
set_nodata=kwargs["set_nodata"],
115+
crs=kwargs["crs"],
115116
scale=scale,
116-
region=kwargs['region']
117+
region=kwargs["region"],
117118
)
118119
print(exp_img.shape)
119-
tile_shape, num_tiles = exp_img._get_tile_shape(max_tile_size=max_tile_size, max_tile_dim=max_tile_dim)
120+
tile_shape, num_tiles = exp_img._get_tile_shape(
121+
max_tile_size=max_tile_size, max_tile_dim=max_tile_dim
122+
)
120123
tiles_list = list(exp_img._tiles(tile_shape))
121124

122125
dtype_size = np.dtype(exp_img.dtype).itemsize
@@ -125,25 +128,32 @@ def _detect_image(
125128
raw_download_size = tile_pixel_size * len(tiles_list)
126129

127130
bar = tqdm(
128-
desc="🧩 Processing tiles", total=raw_download_size,
129-
bar_format='{desc}: |{bar}| {n_fmt}/{total_fmt} (raw) [{percentage:5.1f}%] in {elapsed:>5s} (eta: {remaining:>5s})',
130-
dynamic_ncols=True, unit_scale=True, unit='B'
131+
desc="🧩 Processing tiles",
132+
total=raw_download_size,
133+
bar_format="{desc}: |{bar}| {n_fmt}/{total_fmt} (raw) [{percentage:5.1f}%] in {elapsed:>5s} (eta: {remaining:>5s})",
134+
dynamic_ncols=True,
135+
unit_scale=True,
136+
unit="B",
131137
)
132138

133-
with rasterio.Env(GDAL_NUM_THREADS='ALL_CPUs', GTIFF_FORCE_RGBA=False):
139+
with rasterio.Env(GDAL_NUM_THREADS="ALL_CPUs", GTIFF_FORCE_RGBA=False):
134140
with rasterio.open(merged_path, "w", **profile) as out_ds:
135141

136142
def process_tile(tile):
137143
try:
138-
tile_array = tile.download(session=gd.utils.retry_session(),bar=bar)
144+
tile_array = tile.download(
145+
session=gd.utils.retry_session(), bar=bar
146+
)
139147
with out_lock:
140148
out_ds.write(tile_array, window=tile.window)
141149
return True
142150
except Exception as e:
143151
return False
144152

145153
with ThreadPoolExecutor(max_workers=max_threads) as executor:
146-
futures = [executor.submit(process_tile, tile) for tile in tiles_list]
154+
futures = [
155+
executor.submit(process_tile, tile) for tile in tiles_list
156+
]
147157
for future in as_completed(futures):
148158
future.result()
149159
bar.close()
@@ -157,16 +167,24 @@ def process_tile(tile):
157167
batch_size=batch_size,
158168
num_channels=num_channels,
159169
num_classes=num_classes,
160-
**segment_args
170+
**segment_args,
161171
)
162172

163173
def silent(func, *args, **kwargs):
164174
f = io.StringIO()
165175
with redirect_stdout(f), redirect_stderr(f):
166176
return func(*args, **kwargs)
167-
168-
gdf = silent(geoai.orthogonalize,masks_path,output_path,
169-
min_area=min_area,epsilon=kwargs['epsilon'],min_segments=kwargs['min_segments'],area_tolerance=kwargs['area_tolerance'],detect_triangles=kwargs['detect_triangles'])
177+
178+
gdf = silent(
179+
geoai.orthogonalize,
180+
masks_path,
181+
output_path,
182+
min_area=min_area,
183+
epsilon=kwargs["epsilon"],
184+
min_segments=kwargs["min_segments"],
185+
area_tolerance=kwargs["area_tolerance"],
186+
detect_triangles=kwargs["detect_triangles"],
187+
)
170188
gdf = geoai.add_geometric_properties(gdf)
171189
gdf_wgs84 = gdf.to_crs(epsg=4326)
172190
geojson_dict = gdf_wgs84.__geo_interface__
@@ -178,7 +196,7 @@ def silent(func, *args, **kwargs):
178196
logging.warning(f"Failed to remove {path}: {e}")
179197

180198
return geojson_to_ee(geojson_dict)
181-
199+
182200

183201
def detect_instance_segmentation(
184202
image,
@@ -195,7 +213,7 @@ def detect_instance_segmentation(
195213
min_area=0,
196214
max_tile_size=None,
197215
max_tile_dim=None,
198-
**kwargs
216+
**kwargs,
199217
):
200218
"""
201219
Perform instance segmentation on an Earth Engine image using a pre-trained Mask R-CNN model.
@@ -204,7 +222,7 @@ def detect_instance_segmentation(
204222
image: ee.Image object.
205223
model_path: Path to the trained model.
206224
scale: Resolution in meters.
207-
num_threads: Number of tiles to download concurrently. Defaults to a sensible auto value.
225+
num_threads: Number of tiles to download concurrently. Defaults to a sensible auto value.
208226
bands: List of image bands to use.
209227
window_size: Size of the sliding window in pixels .
210228
overlap: Overlap between windows in pixels.
@@ -221,7 +239,7 @@ def detect_instance_segmentation(
221239
model_path=model_path,
222240
scale=scale,
223241
segment_fn=geoai.instance_segmentation,
224-
segment_args={},
242+
segment_args={},
225243
num_thread=num_thread,
226244
bands=bands,
227245
window_size=window_size,
@@ -233,9 +251,10 @@ def detect_instance_segmentation(
233251
min_area=min_area,
234252
max_tile_size=max_tile_size,
235253
max_tile_dim=max_tile_dim,
236-
**kwargs
254+
**kwargs,
237255
)
238256

257+
239258
def detect_semantic_segmentation(
240259
image,
241260
model_path,
@@ -255,7 +274,7 @@ def detect_semantic_segmentation(
255274
max_tile_dim=None,
256275
device=None,
257276
quiet=False,
258-
**kwargs
277+
**kwargs,
259278
):
260279
"""
261280
Perform semantic segmentation on an Earth Engine image.
@@ -266,7 +285,7 @@ def detect_semantic_segmentation(
266285
scale: Resolution in meters.
267286
architecture: Model architecture used for training.
268287
encoder_name: Encoder backbone name used for training.
269-
num_threads: Number of tiles to download concurrently. Defaults to a sensible auto value.
288+
num_threads: Number of tiles to download concurrently. Defaults to a sensible auto value.
270289
bands: List of image bands to use.
271290
window_size: Size of the sliding window in pixels.
272291
overlap: Overlap between windows in pixels.
@@ -290,7 +309,7 @@ def detect_semantic_segmentation(
290309
"encoder_name": encoder_name,
291310
"device": device,
292311
"quiet": quiet,
293-
**kwargs
312+
**kwargs,
294313
},
295314
num_thread=num_thread,
296315
bands=bands,
@@ -303,5 +322,5 @@ def detect_semantic_segmentation(
303322
min_area=min_area,
304323
max_tile_size=max_tile_size,
305324
max_tile_dim=max_tile_dim,
306-
**kwargs
307-
)
325+
**kwargs,
326+
)

geemap/geemap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from .legends import builtin_legends
3636
from . import examples
3737

38-
from .dl import detect_instance_segmentation,detect_semantic_segmentation
38+
from .dl import detect_instance_segmentation, detect_semantic_segmentation
3939

4040

4141
basemaps = Box(xyz_to_leaflet(), frozen_box=True)

0 commit comments

Comments
 (0)