Skip to content

Commit 323a7b9

Browse files
Sbozzolohughcars
authored andcommitted
Reduce memory usage by avoid eager allocation of H
With the tools I added in PR #708, I am now in the position where I can look at optimizing how memory is used in Palace. One of the first places I found is with the allocation of H. Currently, we allocate the full Hessenberg matrix with size max_dim * max_dim, even when only fewer iterations are required. For cases where max_dim (= MaxIts) is large, this can be very significant. Here, I allocate H incrementally with the GMRES. In my test case, incremental allocation reduced the total memory from 600 GB to 200 GB.
1 parent dd86424 commit 323a7b9

1 file changed

Lines changed: 20 additions & 5 deletions

File tree

palace/linalg/iterative.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -507,22 +507,37 @@ void GmresSolver<OperType>::Initialize() const
507507
V[j].SetSize(A->Height());
508508
V[j].UseDevice(true);
509509
}
510-
H.resize((max_dim + 1) * max_dim);
511-
s.resize(max_dim + 1);
512-
cs.resize(max_dim + 1);
513-
sn.resize(max_dim + 1);
510+
// H, s, cs, sn are allocated incrementally in Update() to avoid O(max_dim^2) upfront
511+
// cost when max_dim is large (e.g., defaulting to max_it).
512+
s.resize(std::min(init_size + 1, max_dim + 1));
513+
cs.resize(std::min(init_size + 1, max_dim + 1));
514+
sn.resize(std::min(init_size + 1, max_dim + 1));
515+
H.resize(static_cast<std::size_t>(max_dim + 1) * std::min(init_size, max_dim));
514516
}
515517

516518
template <typename OperType>
517519
void GmresSolver<OperType>::Update(int j) const
518520
{
519-
// Add storage for basis vectors in increments.
521+
// Add storage for basis vectors, Hessenberg columns, and rotations in increments.
520522
constexpr int add_size = 10;
521523
for (int k = j + 1; k < std::min(j + 1 + add_size, max_dim + 1); k++)
522524
{
523525
V[k].SetSize(A->Height());
524526
V[k].UseDevice(true);
525527
}
528+
int needed_cols = std::min(j + 1 + add_size, max_dim);
529+
int current_cols = static_cast<int>(H.size()) / (max_dim + 1);
530+
if (needed_cols > current_cols)
531+
{
532+
H.resize(static_cast<std::size_t>(max_dim + 1) * needed_cols);
533+
}
534+
auto needed_size = static_cast<std::size_t>(std::min(j + 2 + add_size, max_dim + 1));
535+
if (needed_size > s.size())
536+
{
537+
s.resize(needed_size);
538+
cs.resize(needed_size);
539+
sn.resize(needed_size);
540+
}
526541
}
527542

528543
template <typename OperType>

0 commit comments

Comments
 (0)