Skip to content

Fix: make module cloning efficient for CPU devices#4703

Merged
antimora merged 6 commits intotracel-ai:mainfrom
antimora:fix/efficient-module-clone-cpu
Apr 14, 2026
Merged

Fix: make module cloning efficient for CPU devices#4703
antimora merged 6 commits intotracel-ai:mainfrom
antimora:fix/efficient-module-clone-cpu

Conversation

@antimora
Copy link
Copy Markdown
Collaborator

Summary

  • Fix Param::clone() to preserve lazy (uninitialized) state instead of eagerly triggering initialization, which allocated tensor memory unnecessarily on CPU backends
  • Change Uninitialized.init from Box<dyn FnOnce> to Arc<dyn Fn> so the init function can be cloned without consuming it
  • Add Clone impl for Uninitialized (just Arc refcount bumps, no memory allocation)
  • Update init_mapper() to use Arc::clone instead of the mem::swap + panic-placeholder hack

Benchmark results (312 MB model, NdArray CPU backend) confirm max allocation is ~344 MB with no memory doubling.

Test plan

  • Existing burn-core, burn-nn, burn-store tests pass (3 existing tests updated for new clone semantics)
  • 7 new lazy_clone tests covering:
    • Cloning uninitialized param does not trigger init
    • Cloning initialized param produces equal values
    • Lazy clone produces valid tensor with correct shape on access
    • Lazy clones produce independent values (assert_ne)
    • Deref on lazy clone triggers init independently
    • init_mapper on lazy clone does not affect original
    • Loading weights into a fresh module works correctly
  • cargo bench --bench unified_loading confirms no memory doubling

Closes #3754

Cloning an uninitialized Param previously triggered eager initialization
via val(), allocating tensor memory unnecessarily. This was especially
costly on CPU backends (NdArray) where tensor cloning involves real
memory allocation rather than lightweight GPU handle copies.

Changes:
- Change Uninitialized.init from Box<dyn FnOnce> to Arc<dyn Fn> so the
  init function can be cloned without consuming it
- Implement Clone for Uninitialized (just Arc refcount bumps)
- Param::clone() now preserves lazy state for uninitialized params
  instead of forcing initialization
- Update Param::uninitialized() bound from FnOnce to Fn + Sync
- Update init_mapper() to use Arc::clone instead of mem::swap hack
- Update Initializer::init_with closure to clone captures for Fn compat
- Fix tests that relied on clone triggering initialization

Closes tracel-ai#3754
Verify that:
- Cloning an uninitialized param does not trigger initialization
- Cloning an initialized param shares the same tensor values
- A lazy clone produces a valid tensor with the correct shape on access
- Loading weights into a cloned module does not initialize the original
Comments:
- Clarify that clones produce independent values, not shared ones
- Document initialize() single-call contract via take()
- Note that clone gets its own OnceCell/RwLock
- Fix "Already initialized" to also cover consumed-init case
- Use "have the same values" instead of "share" in test comments
- Clarify forward pass triggers init via tensor access

Tests:
- Add test: lazy clones produce independent values (assert_ne)
- Add test: Deref on lazy clone triggers init independently
- Add test: init_mapper on lazy clone does not affect original
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Improves module cloning on CPU devices by preserving Param lazy (uninitialized) state during clone, avoiding eager initialization and unnecessary tensor allocations. This aligns with Burn’s lazy-init parameter system and addresses the reported CPU memory-doubling behavior when cloning modules before loading weights.

Changes:

  • Make Param::clone() preserve lazy state by cloning the uninitialized initializer (Uninitialized) rather than calling val().
  • Change lazy initializer storage from FnOnce to clonable Arc<dyn Fn> and simplify init_mapper() accordingly.
  • Update/extend tests to reflect the new clone semantics and cover lazy-clone behaviors.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
crates/burn-nn/src/activation/activation_wrapper.rs Updates SwiGlu activation test to trigger lazy init before cloning/loading weights.
crates/burn-core/tests/test_derive_module.rs Updates existing record-load tests for new clone semantics and adds a lazy_clone test module.
crates/burn-core/src/module/param/base.rs Core implementation: initializer made clonable via Arc<dyn Fn>, init_mapper refactor, and lazy-preserving Clone for Param.
crates/burn-core/src/module/initializer.rs Adjusts initializer closure captures to satisfy the new Fn + Send + Sync requirements.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 1 to 5
use super::ParamId;
use alloc::{boxed::Box, format};
use alloc::format;
use burn_std::stub::RwLock;
use burn_tensor::Shape;
use core::cell::OnceCell;
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of alloc::boxed::Box from the imports breaks compilation for the #[cfg(not(target_has_atomic = "ptr"))] branch below, where Mapper<T> is defined as Arc<Box<dyn Fn...>>. Consider reintroducing the Box import (possibly cfg-gated) or qualifying it as alloc::boxed::Box in the type alias.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already addressed in commit a6a1104 - added cfg-gated use alloc::boxed::Box import along with InitFn<P> / new_init_fn() type alias and constructor that mirror the existing Mapper / new_mapper pattern.

Comment on lines 201 to 211
/// Create a new parameter that is not already initialized.
pub fn uninitialized<F>(
id: ParamId,
init: F,
device: T::Device,
is_require_grad: bool,
shape: Shape,
) -> Self
where
F: FnOnce(&T::Device, bool) -> T + Send + 'static,
F: Fn(&T::Device, bool) -> T + Send + Sync + 'static,
{
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Param::uninitialized changed from accepting an FnOnce to requiring Fn + Send + Sync. Since this is a public API, the new bounds are a breaking change for downstream users (e.g., init closures that capture non-Sync state or need FnOnce). If backward compatibility is important, consider keeping the old signature (e.g., via an overload/new method) or providing a compatibility wrapper internally (such as storing an Arc<Mutex<Option<F>>> to support FnOnce while still enabling cloning).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an intentional change. Param::uninitialized is primarily used internally (by Initializer::init_with and test helpers). The Fn + Sync bound is inherent to the design: for the init function to be cloneable via Arc, it must be callable multiple times. All existing callers already use Fn-compatible closures (they only read their captures). The version is pre-release (0.21.0-pre.2) so this is acceptable as a breaking change.

Comment on lines 311 to 315
/// Execute the given function on the inner value.
pub fn init_mapper<F: FnOnce(T) -> T + Send + 'static>(self, func: F) -> Self
pub fn init_mapper<F: Fn(T) -> T + Send + Sync + 'static>(self, func: F) -> Self
where
T: 'static,
{
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_mapper now requires F: Fn(T) -> T + Send + Sync, which is also a tightening vs the previous FnOnce + Send and can be a breaking change even when the parameter is already initialized (where thread-safety isn’t necessarily required). If possible, consider splitting into two methods (one for lazy init mapping requiring Fn + Sync, and one for eager mapping that can accept FnOnce) to keep the more flexible API for the initialized path.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reasoning as the uninitialized change above. init_mapper has a single callsite (linear.rs) which already uses a Fn-compatible closure. Splitting into two methods adds API surface complexity for no practical benefit since all init closures are inherently Fn (they just transform tensors without consuming unique state). The map() fallback path (for already-initialized params) still accepts FnOnce since it consumes self.

Comment on lines +471 to +491
#[test]
fn lazy_clone_then_load_should_not_init_original() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleBasic::<TestBackend>::new(&device);

// Initialize module_1 so we have a record to load.
let _ = module_1.weight_basic.to_data();
let record = module_1.clone().into_record();

// Create a fresh uninitialized module and load weights into it.
let module_2 = ModuleBasic::<TestBackend>::new(&device);
assert!(!module_2.weight_basic.is_initialized());

let module_2 = module_2.load_record(record);

// After loading, the param should be initialized with the loaded values.
assert_eq!(
module_1.weight_basic.to_data(),
module_2.weight_basic.to_data()
);
}
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test name lazy_clone_then_load_should_not_init_original doesn’t match what the test currently asserts (it initializes module_1, creates a fresh module_2, loads the record, then checks equality). Consider renaming it to reflect the behavior under test (e.g., that loading into a fresh module initializes it correctly), or adjust the assertions to actually verify that cloning + loading does not initialize the original lazy module.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Renamed the test to load_record_into_uninitialized_module_should_work which accurately describes what it verifies. Fixed in 4fff91a.

…rection

On targets without atomics, portable_atomic_util::Arc cannot directly
hold unsized types like dyn Fn(...). Add InitFn<P> type alias and
new_init_fn() constructor mirroring the existing Mapper/new_mapper
pattern, with Box indirection on no_std targets.
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 31, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 63.26%. Comparing base (e779b63) to head (4f307d2).
⚠️ Report is 11 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #4703      +/-   ##
==========================================
+ Coverage   63.24%   63.26%   +0.02%     
==========================================
  Files        1108     1108              
  Lines      146074   146167      +93     
==========================================
+ Hits        92381    92474      +93     
  Misses      53693    53693              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@antimora antimora self-assigned this Apr 1, 2026
Copy link
Copy Markdown
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

This enforces the lazy initiation contract even more.

@antimora antimora merged commit 8cc356e into tracel-ai:main Apr 14, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Module cloning is inefficient for CPU devices

3 participants