Skip to content

Commit ff4e106

Browse files
authored
[HMA] Add explicit connection close for signal ops and clear temp refs (#1968)
1 parent c85e723 commit ff4e106

2 files changed

Lines changed: 60 additions & 42 deletions

File tree

hasher-matcher-actioner/src/OpenMediaMatch/background_tasks/build_index.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def build_index(
8484
index_store.store_signal_type_index(for_signal_type, built_index, checkpoint)
8585
finally:
8686
# Force garbage collection to reclaim memory and attempt to free pages
87+
# explicitly free the built index before reclaiming memory
88+
built_index = None
8789
trim_process_memory(logger, "Indexer")
8890

8991
logger.info(
@@ -115,6 +117,9 @@ def _prepare_index(
115117
index_cls = for_signal_type.get_index_cls()
116118
built_index = index_cls.build(signal_list)
117119

120+
# explicitly free the signal list before returning
121+
signal_list.clear()
122+
118123
# Create checkpoint
119124
checkpoint = SignalTypeIndexBuildCheckpoint.get_empty()
120125
if last_cs is not None:

hasher-matcher-actioner/src/OpenMediaMatch/storage/postgres/database.py

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -495,28 +495,34 @@ def commit_signal_index(
495495
store_start_time = time.time()
496496
# Deep dark magic - direct access postgres large object API
497497
raw_conn = db.engine.raw_connection()
498-
l_obj = raw_conn.lobject(0, "wb", 0, tmpfile.name)
499-
self._log(
500-
"imported tmpfile as lobject oid %d took %s",
501-
l_obj.oid,
502-
duration_to_human_str(time.time() - store_start_time),
503-
)
504-
if self.serialized_index_large_object_oid is not None:
505-
if self.index_lobj_exists():
506-
old_obj = raw_conn.lobject(self.serialized_index_large_object_oid, "n")
507-
self._log("deallocating old lobject %d", old_obj.oid)
508-
old_obj.unlink()
509-
else:
510-
self._log(
511-
"old lobject %d doesn't exist? "
512-
+ "This might be a previous partial failure",
513-
self.serialized_index_large_object_oid,
514-
level=logging.WARNING,
515-
)
516-
517-
self.serialized_index_large_object_oid = l_obj.oid
518-
db.session.add(self)
519-
raw_conn.commit()
498+
try:
499+
l_obj = raw_conn.lobject(0, "wb", 0, tmpfile.name)
500+
self._log(
501+
"imported tmpfile as lobject oid %d took %s",
502+
l_obj.oid,
503+
duration_to_human_str(time.time() - store_start_time),
504+
)
505+
if self.serialized_index_large_object_oid is not None:
506+
if self.index_lobj_exists():
507+
old_obj = raw_conn.lobject(
508+
self.serialized_index_large_object_oid, "n"
509+
)
510+
self._log("deallocating old lobject %d", old_obj.oid)
511+
old_obj.unlink()
512+
else:
513+
self._log(
514+
"old lobject %d doesn't exist? "
515+
+ "This might be a previous partial failure",
516+
self.serialized_index_large_object_oid,
517+
level=logging.WARNING,
518+
)
519+
520+
self.serialized_index_large_object_oid = l_obj.oid
521+
db.session.add(self)
522+
raw_conn.commit()
523+
finally:
524+
# explicitly close the raw connection to free memory
525+
raw_conn.close()
520526

521527
self._log(
522528
"commited new index, %d signals %s took %s",
@@ -546,28 +552,35 @@ def load_signal_index(self) -> SignalTypeIndex[int]:
546552
# I'm sorry future debugger finding this comment.
547553
load_start_time = time.time()
548554
raw_conn = db.engine.raw_connection()
549-
l_obj = raw_conn.lobject(oid, "rb")
555+
try:
556+
l_obj = raw_conn.lobject(oid, "rb")
550557

551-
with tempfile.NamedTemporaryFile("rb") as tmpfile:
552-
self._log("importing lobject oid %d to tmpfile %s", l_obj.oid, tmpfile.name)
553-
l_obj.export(tmpfile.name)
554-
tmpfile.seek(0, io.SEEK_END)
555-
self._log(
556-
"downloading %s to tmpfile took %s",
557-
_human_friendly_bytesize(tmpfile.tell()),
558-
duration_to_human_str(time.time() - load_start_time),
559-
)
560-
tmpfile.seek(0)
558+
with tempfile.NamedTemporaryFile("rb") as tmpfile:
559+
self._log(
560+
"importing lobject oid %d to tmpfile %s", l_obj.oid, tmpfile.name
561+
)
562+
l_obj.export(tmpfile.name)
563+
tmpfile.seek(0, io.SEEK_END)
564+
self._log(
565+
"downloading %s to tmpfile took %s",
566+
_human_friendly_bytesize(tmpfile.tell()),
567+
duration_to_human_str(time.time() - load_start_time),
568+
)
569+
tmpfile.seek(0)
570+
571+
deserialize_start = time.time()
572+
index = t.cast(
573+
SignalTypeIndex[int],
574+
SignalTypeIndex.deserialize(t.cast(t.BinaryIO, tmpfile.file)),
575+
)
576+
self._log(
577+
"deserialize took %s",
578+
duration_to_human_str(time.time() - deserialize_start),
579+
)
580+
finally:
581+
# explicitly close the raw connection to free memory
582+
raw_conn.close()
561583

562-
deserialize_start = time.time()
563-
index = t.cast(
564-
SignalTypeIndex[int],
565-
SignalTypeIndex.deserialize(t.cast(t.BinaryIO, tmpfile.file)),
566-
)
567-
self._log(
568-
"deserialize took %s",
569-
duration_to_human_str(time.time() - deserialize_start),
570-
)
571584
self._log(
572585
"loading signal index took %s",
573586
duration_to_human_str(time.time() - load_start_time),

0 commit comments

Comments
 (0)