1414import warnings
1515import threading
1616
17+
1718def _detect_image (
1819 image ,
1920 model_path ,
2021 scale ,
21- segment_fn ,
22+ segment_fn ,
2223 segment_args : dict ,
2324 num_thread = None ,
2425 bands = None ,
@@ -31,7 +32,7 @@ def _detect_image(
3132 min_area = 0 ,
3233 max_tile_size = None ,
3334 max_tile_dim = None ,
34- ** kwargs
35+ ** kwargs ,
3536):
3637 """
3738 Perform object detection on an Earth Engine image using a deep learning model.
@@ -46,7 +47,7 @@ def _detect_image(
4647 all are resampled to this scale. Defaults to the minimum scale of image bands.
4748 segment_fn (Callable): A segmentation post-processing function that processes the model's raw output.
4849 segment_args (dict): A dictionary of keyword arguments to pass to the segmentation function.
49- num_threads (int, optional): Number of tiles to download concurrently. Defaults to a sensible auto value.
50+ num_threads (int, optional): Number of tiles to download concurrently. Defaults to a sensible auto value.
5051 bands (list, optional): A list of band names to use in the timelapse.
5152 window_size (int): Size of sliding window for inference.
5253 overlap (int): Overlap between adjacent windows.
@@ -63,20 +64,22 @@ def _detect_image(
6364
6465 Returns:
6566 ee.FeatureCollection: A FeatureCollection containing vectorized object detection results, reprojected to WGS84.
66-
67+
6768 Note:
6869 This function relies on the `geedim`, `rasterio`, and `geoai` packages for tiling, downloading, inference,
6970 and raster-to-vector conversion. Ensure the Earth Engine image is processed and scaled appropriately before passing to this function.
7071 """
71- logging .basicConfig (level = logging .INFO , format = ' [%(levelname)s] %(message)s' )
72+ logging .basicConfig (level = logging .INFO , format = " [%(levelname)s] %(message)s" )
7273 warnings .filterwarnings ("ignore" , category = UserWarning , module = "rasterio._env" )
73- logging .getLogger (' rasterio' ).setLevel (logging .ERROR )
74+ logging .getLogger (" rasterio" ).setLevel (logging .ERROR )
7475
7576 out_lock = threading .Lock ()
7677
77- with tempfile .NamedTemporaryFile (suffix = ".tif" , delete = False ) as temp_merged , \
78- tempfile .NamedTemporaryFile (suffix = ".tif" , delete = False ) as temp_mask , \
79- tempfile .NamedTemporaryFile (suffix = ".geojson" , delete = False ) as temp_geojson :
78+ with (
79+ tempfile .NamedTemporaryFile (suffix = ".tif" , delete = False ) as temp_merged ,
80+ tempfile .NamedTemporaryFile (suffix = ".tif" , delete = False ) as temp_mask ,
81+ tempfile .NamedTemporaryFile (suffix = ".geojson" , delete = False ) as temp_geojson ,
82+ ):
8083
8184 merged_path = temp_merged .name
8285 masks_path = temp_mask .name
@@ -92,31 +95,31 @@ def _detect_image(
9295
9396 baseImg = gd .download .BaseImage (image )
9497
95- if kwargs .get ('set_nodata' ) is None :
96- kwargs ['set_nodata' ] = True
97- if kwargs .get ('crs' ) is None :
98- kwargs ['crs' ] = "EPSG:4326"
99- if kwargs .get ('region' ) is None :
100- kwargs ['region' ] = image .geometry ().bounds ()
101- if kwargs .get ('epsilon' ) is None :
102- kwargs ['epsilon' ] = 0.2
103- if kwargs .get ('min_segments' ) is None :
104- kwargs ['min_segments' ] = 4
105- if kwargs .get ('area_tolerance' ) is None :
106- kwargs ['area_tolerance' ] = 0.7
107- if kwargs .get ('detect_triangles' ) is None :
108- kwargs ['detect_triangles' ] = True
109-
110-
98+ if kwargs .get ("set_nodata" ) is None :
99+ kwargs ["set_nodata" ] = True
100+ if kwargs .get ("crs" ) is None :
101+ kwargs ["crs" ] = "EPSG:4326"
102+ if kwargs .get ("region" ) is None :
103+ kwargs ["region" ] = image .geometry ().bounds ()
104+ if kwargs .get ("epsilon" ) is None :
105+ kwargs ["epsilon" ] = 0.2
106+ if kwargs .get ("min_segments" ) is None :
107+ kwargs ["min_segments" ] = 4
108+ if kwargs .get ("area_tolerance" ) is None :
109+ kwargs ["area_tolerance" ] = 0.7
110+ if kwargs .get ("detect_triangles" ) is None :
111+ kwargs ["detect_triangles" ] = True
111112
112113 exp_img , profile = baseImg ._prepare_for_download (
113- set_nodata = kwargs [' set_nodata' ],
114- crs = kwargs [' crs' ],
114+ set_nodata = kwargs [" set_nodata" ],
115+ crs = kwargs [" crs" ],
115116 scale = scale ,
116- region = kwargs [' region' ]
117+ region = kwargs [" region" ],
117118 )
118119 print (exp_img .shape )
119- tile_shape , num_tiles = exp_img ._get_tile_shape (max_tile_size = max_tile_size , max_tile_dim = max_tile_dim )
120+ tile_shape , num_tiles = exp_img ._get_tile_shape (
121+ max_tile_size = max_tile_size , max_tile_dim = max_tile_dim
122+ )
120123 tiles_list = list (exp_img ._tiles (tile_shape ))
121124
122125 dtype_size = np .dtype (exp_img .dtype ).itemsize
@@ -125,25 +128,32 @@ def _detect_image(
125128 raw_download_size = tile_pixel_size * len (tiles_list )
126129
127130 bar = tqdm (
128- desc = "🧩 Processing tiles" , total = raw_download_size ,
129- bar_format = '{desc}: |{bar}| {n_fmt}/{total_fmt} (raw) [{percentage:5.1f}%] in {elapsed:>5s} (eta: {remaining:>5s})' ,
130- dynamic_ncols = True , unit_scale = True , unit = 'B'
131+ desc = "🧩 Processing tiles" ,
132+ total = raw_download_size ,
133+ bar_format = "{desc}: |{bar}| {n_fmt}/{total_fmt} (raw) [{percentage:5.1f}%] in {elapsed:>5s} (eta: {remaining:>5s})" ,
134+ dynamic_ncols = True ,
135+ unit_scale = True ,
136+ unit = "B" ,
131137 )
132138
133- with rasterio .Env (GDAL_NUM_THREADS = ' ALL_CPUs' , GTIFF_FORCE_RGBA = False ):
139+ with rasterio .Env (GDAL_NUM_THREADS = " ALL_CPUs" , GTIFF_FORCE_RGBA = False ):
134140 with rasterio .open (merged_path , "w" , ** profile ) as out_ds :
135141
136142 def process_tile (tile ):
137143 try :
138- tile_array = tile .download (session = gd .utils .retry_session (),bar = bar )
144+ tile_array = tile .download (
145+ session = gd .utils .retry_session (), bar = bar
146+ )
139147 with out_lock :
140148 out_ds .write (tile_array , window = tile .window )
141149 return True
142150 except Exception as e :
143151 return False
144152
145153 with ThreadPoolExecutor (max_workers = max_threads ) as executor :
146- futures = [executor .submit (process_tile , tile ) for tile in tiles_list ]
154+ futures = [
155+ executor .submit (process_tile , tile ) for tile in tiles_list
156+ ]
147157 for future in as_completed (futures ):
148158 future .result ()
149159 bar .close ()
@@ -157,16 +167,24 @@ def process_tile(tile):
157167 batch_size = batch_size ,
158168 num_channels = num_channels ,
159169 num_classes = num_classes ,
160- ** segment_args
170+ ** segment_args ,
161171 )
162172
163173 def silent (func , * args , ** kwargs ):
164174 f = io .StringIO ()
165175 with redirect_stdout (f ), redirect_stderr (f ):
166176 return func (* args , ** kwargs )
167-
168- gdf = silent (geoai .orthogonalize ,masks_path ,output_path ,
169- min_area = min_area ,epsilon = kwargs ['epsilon' ],min_segments = kwargs ['min_segments' ],area_tolerance = kwargs ['area_tolerance' ],detect_triangles = kwargs ['detect_triangles' ])
177+
178+ gdf = silent (
179+ geoai .orthogonalize ,
180+ masks_path ,
181+ output_path ,
182+ min_area = min_area ,
183+ epsilon = kwargs ["epsilon" ],
184+ min_segments = kwargs ["min_segments" ],
185+ area_tolerance = kwargs ["area_tolerance" ],
186+ detect_triangles = kwargs ["detect_triangles" ],
187+ )
170188 gdf = geoai .add_geometric_properties (gdf )
171189 gdf_wgs84 = gdf .to_crs (epsg = 4326 )
172190 geojson_dict = gdf_wgs84 .__geo_interface__
@@ -178,7 +196,7 @@ def silent(func, *args, **kwargs):
178196 logging .warning (f"Failed to remove { path } : { e } " )
179197
180198 return geojson_to_ee (geojson_dict )
181-
199+
182200
183201def detect_instance_segmentation (
184202 image ,
@@ -195,7 +213,7 @@ def detect_instance_segmentation(
195213 min_area = 0 ,
196214 max_tile_size = None ,
197215 max_tile_dim = None ,
198- ** kwargs
216+ ** kwargs ,
199217):
200218 """
201219 Perform instance segmentation on an Earth Engine image using a pre-trained Mask R-CNN model.
@@ -204,7 +222,7 @@ def detect_instance_segmentation(
204222 image: ee.Image object.
205223 model_path: Path to the trained model.
206224 scale: Resolution in meters.
207- num_threads: Number of tiles to download concurrently. Defaults to a sensible auto value.
225+ num_threads: Number of tiles to download concurrently. Defaults to a sensible auto value.
208226 bands: List of image bands to use.
209227 window_size: Size of the sliding window in pixels .
210228 overlap: Overlap between windows in pixels.
@@ -221,7 +239,7 @@ def detect_instance_segmentation(
221239 model_path = model_path ,
222240 scale = scale ,
223241 segment_fn = geoai .instance_segmentation ,
224- segment_args = {},
242+ segment_args = {},
225243 num_thread = num_thread ,
226244 bands = bands ,
227245 window_size = window_size ,
@@ -233,9 +251,10 @@ def detect_instance_segmentation(
233251 min_area = min_area ,
234252 max_tile_size = max_tile_size ,
235253 max_tile_dim = max_tile_dim ,
236- ** kwargs
254+ ** kwargs ,
237255 )
238256
257+
239258def detect_semantic_segmentation (
240259 image ,
241260 model_path ,
@@ -255,7 +274,7 @@ def detect_semantic_segmentation(
255274 max_tile_dim = None ,
256275 device = None ,
257276 quiet = False ,
258- ** kwargs
277+ ** kwargs ,
259278):
260279 """
261280 Perform semantic segmentation on an Earth Engine image.
@@ -266,7 +285,7 @@ def detect_semantic_segmentation(
266285 scale: Resolution in meters.
267286 architecture: Model architecture used for training.
268287 encoder_name: Encoder backbone name used for training.
269- num_threads: Number of tiles to download concurrently. Defaults to a sensible auto value.
288+ num_threads: Number of tiles to download concurrently. Defaults to a sensible auto value.
270289 bands: List of image bands to use.
271290 window_size: Size of the sliding window in pixels.
272291 overlap: Overlap between windows in pixels.
@@ -290,7 +309,7 @@ def detect_semantic_segmentation(
290309 "encoder_name" : encoder_name ,
291310 "device" : device ,
292311 "quiet" : quiet ,
293- ** kwargs
312+ ** kwargs ,
294313 },
295314 num_thread = num_thread ,
296315 bands = bands ,
@@ -303,5 +322,5 @@ def detect_semantic_segmentation(
303322 min_area = min_area ,
304323 max_tile_size = max_tile_size ,
305324 max_tile_dim = max_tile_dim ,
306- ** kwargs
307- )
325+ ** kwargs ,
326+ )
0 commit comments