From e33a18f33c8940297520c793fae28c9e2dfa7bfb Mon Sep 17 00:00:00 2001 From: "Juan S. Mrad" Date: Mon, 13 Apr 2026 21:22:05 -0500 Subject: [PATCH] [HMA] Add explicit connection close for signal ops and clear temp refs --- .../background_tasks/build_index.py | 5 + .../storage/postgres/database.py | 97 +++++++++++-------- 2 files changed, 60 insertions(+), 42 deletions(-) diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/background_tasks/build_index.py b/hasher-matcher-actioner/src/OpenMediaMatch/background_tasks/build_index.py index 20c18aa9a..47fb59358 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/background_tasks/build_index.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/background_tasks/build_index.py @@ -84,6 +84,8 @@ def build_index( index_store.store_signal_type_index(for_signal_type, built_index, checkpoint) finally: # Force garbage collection to reclaim memory and attempt to free pages + # explicitly free the built index before reclaiming memory + built_index = None trim_process_memory(logger, "Indexer") logger.info( @@ -115,6 +117,9 @@ def _prepare_index( index_cls = for_signal_type.get_index_cls() built_index = index_cls.build(signal_list) + # explicitly free the signal list before returning + signal_list.clear() + # Create checkpoint checkpoint = SignalTypeIndexBuildCheckpoint.get_empty() if last_cs is not None: diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/storage/postgres/database.py b/hasher-matcher-actioner/src/OpenMediaMatch/storage/postgres/database.py index 7e6c933ac..ebccc46ee 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/storage/postgres/database.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/storage/postgres/database.py @@ -495,28 +495,34 @@ def commit_signal_index( store_start_time = time.time() # Deep dark magic - direct access postgres large object API raw_conn = db.engine.raw_connection() - l_obj = raw_conn.lobject(0, "wb", 0, tmpfile.name) - self._log( - "imported tmpfile as lobject oid %d took %s", - l_obj.oid, - duration_to_human_str(time.time() - store_start_time), - ) - if self.serialized_index_large_object_oid is not None: - if self.index_lobj_exists(): - old_obj = raw_conn.lobject(self.serialized_index_large_object_oid, "n") - self._log("deallocating old lobject %d", old_obj.oid) - old_obj.unlink() - else: - self._log( - "old lobject %d doesn't exist? " - + "This might be a previous partial failure", - self.serialized_index_large_object_oid, - level=logging.WARNING, - ) - - self.serialized_index_large_object_oid = l_obj.oid - db.session.add(self) - raw_conn.commit() + try: + l_obj = raw_conn.lobject(0, "wb", 0, tmpfile.name) + self._log( + "imported tmpfile as lobject oid %d took %s", + l_obj.oid, + duration_to_human_str(time.time() - store_start_time), + ) + if self.serialized_index_large_object_oid is not None: + if self.index_lobj_exists(): + old_obj = raw_conn.lobject( + self.serialized_index_large_object_oid, "n" + ) + self._log("deallocating old lobject %d", old_obj.oid) + old_obj.unlink() + else: + self._log( + "old lobject %d doesn't exist? " + + "This might be a previous partial failure", + self.serialized_index_large_object_oid, + level=logging.WARNING, + ) + + self.serialized_index_large_object_oid = l_obj.oid + db.session.add(self) + raw_conn.commit() + finally: + # explicitly close the raw connection to free memory + raw_conn.close() self._log( "commited new index, %d signals %s took %s", @@ -546,28 +552,35 @@ def load_signal_index(self) -> SignalTypeIndex[int]: # I'm sorry future debugger finding this comment. load_start_time = time.time() raw_conn = db.engine.raw_connection() - l_obj = raw_conn.lobject(oid, "rb") + try: + l_obj = raw_conn.lobject(oid, "rb") - with tempfile.NamedTemporaryFile("rb") as tmpfile: - self._log("importing lobject oid %d to tmpfile %s", l_obj.oid, tmpfile.name) - l_obj.export(tmpfile.name) - tmpfile.seek(0, io.SEEK_END) - self._log( - "downloading %s to tmpfile took %s", - _human_friendly_bytesize(tmpfile.tell()), - duration_to_human_str(time.time() - load_start_time), - ) - tmpfile.seek(0) + with tempfile.NamedTemporaryFile("rb") as tmpfile: + self._log( + "importing lobject oid %d to tmpfile %s", l_obj.oid, tmpfile.name + ) + l_obj.export(tmpfile.name) + tmpfile.seek(0, io.SEEK_END) + self._log( + "downloading %s to tmpfile took %s", + _human_friendly_bytesize(tmpfile.tell()), + duration_to_human_str(time.time() - load_start_time), + ) + tmpfile.seek(0) + + deserialize_start = time.time() + index = t.cast( + SignalTypeIndex[int], + SignalTypeIndex.deserialize(t.cast(t.BinaryIO, tmpfile.file)), + ) + self._log( + "deserialize took %s", + duration_to_human_str(time.time() - deserialize_start), + ) + finally: + # explicitly close the raw connection to free memory + raw_conn.close() - deserialize_start = time.time() - index = t.cast( - SignalTypeIndex[int], - SignalTypeIndex.deserialize(t.cast(t.BinaryIO, tmpfile.file)), - ) - self._log( - "deserialize took %s", - duration_to_human_str(time.time() - deserialize_start), - ) self._log( "loading signal index took %s", duration_to_human_str(time.time() - load_start_time),