Fix: make module cloning efficient for CPU devices#4703
Fix: make module cloning efficient for CPU devices#4703antimora merged 6 commits intotracel-ai:mainfrom
Conversation
Cloning an uninitialized Param previously triggered eager initialization via val(), allocating tensor memory unnecessarily. This was especially costly on CPU backends (NdArray) where tensor cloning involves real memory allocation rather than lightweight GPU handle copies. Changes: - Change Uninitialized.init from Box<dyn FnOnce> to Arc<dyn Fn> so the init function can be cloned without consuming it - Implement Clone for Uninitialized (just Arc refcount bumps) - Param::clone() now preserves lazy state for uninitialized params instead of forcing initialization - Update Param::uninitialized() bound from FnOnce to Fn + Sync - Update init_mapper() to use Arc::clone instead of mem::swap hack - Update Initializer::init_with closure to clone captures for Fn compat - Fix tests that relied on clone triggering initialization Closes tracel-ai#3754
Verify that: - Cloning an uninitialized param does not trigger initialization - Cloning an initialized param shares the same tensor values - A lazy clone produces a valid tensor with the correct shape on access - Loading weights into a cloned module does not initialize the original
Comments: - Clarify that clones produce independent values, not shared ones - Document initialize() single-call contract via take() - Note that clone gets its own OnceCell/RwLock - Fix "Already initialized" to also cover consumed-init case - Use "have the same values" instead of "share" in test comments - Clarify forward pass triggers init via tensor access Tests: - Add test: lazy clones produce independent values (assert_ne) - Add test: Deref on lazy clone triggers init independently - Add test: init_mapper on lazy clone does not affect original
There was a problem hiding this comment.
Pull request overview
Improves module cloning on CPU devices by preserving Param lazy (uninitialized) state during clone, avoiding eager initialization and unnecessary tensor allocations. This aligns with Burn’s lazy-init parameter system and addresses the reported CPU memory-doubling behavior when cloning modules before loading weights.
Changes:
- Make
Param::clone()preserve lazy state by cloning the uninitialized initializer (Uninitialized) rather than callingval(). - Change lazy initializer storage from
FnOnceto clonableArc<dyn Fn>and simplifyinit_mapper()accordingly. - Update/extend tests to reflect the new clone semantics and cover lazy-clone behaviors.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| crates/burn-nn/src/activation/activation_wrapper.rs | Updates SwiGlu activation test to trigger lazy init before cloning/loading weights. |
| crates/burn-core/tests/test_derive_module.rs | Updates existing record-load tests for new clone semantics and adds a lazy_clone test module. |
| crates/burn-core/src/module/param/base.rs | Core implementation: initializer made clonable via Arc<dyn Fn>, init_mapper refactor, and lazy-preserving Clone for Param. |
| crates/burn-core/src/module/initializer.rs | Adjusts initializer closure captures to satisfy the new Fn + Send + Sync requirements. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| use super::ParamId; | ||
| use alloc::{boxed::Box, format}; | ||
| use alloc::format; | ||
| use burn_std::stub::RwLock; | ||
| use burn_tensor::Shape; | ||
| use core::cell::OnceCell; |
There was a problem hiding this comment.
The removal of alloc::boxed::Box from the imports breaks compilation for the #[cfg(not(target_has_atomic = "ptr"))] branch below, where Mapper<T> is defined as Arc<Box<dyn Fn...>>. Consider reintroducing the Box import (possibly cfg-gated) or qualifying it as alloc::boxed::Box in the type alias.
There was a problem hiding this comment.
Already addressed in commit a6a1104 - added cfg-gated use alloc::boxed::Box import along with InitFn<P> / new_init_fn() type alias and constructor that mirror the existing Mapper / new_mapper pattern.
| /// Create a new parameter that is not already initialized. | ||
| pub fn uninitialized<F>( | ||
| id: ParamId, | ||
| init: F, | ||
| device: T::Device, | ||
| is_require_grad: bool, | ||
| shape: Shape, | ||
| ) -> Self | ||
| where | ||
| F: FnOnce(&T::Device, bool) -> T + Send + 'static, | ||
| F: Fn(&T::Device, bool) -> T + Send + Sync + 'static, | ||
| { |
There was a problem hiding this comment.
Param::uninitialized changed from accepting an FnOnce to requiring Fn + Send + Sync. Since this is a public API, the new bounds are a breaking change for downstream users (e.g., init closures that capture non-Sync state or need FnOnce). If backward compatibility is important, consider keeping the old signature (e.g., via an overload/new method) or providing a compatibility wrapper internally (such as storing an Arc<Mutex<Option<F>>> to support FnOnce while still enabling cloning).
There was a problem hiding this comment.
This is an intentional change. Param::uninitialized is primarily used internally (by Initializer::init_with and test helpers). The Fn + Sync bound is inherent to the design: for the init function to be cloneable via Arc, it must be callable multiple times. All existing callers already use Fn-compatible closures (they only read their captures). The version is pre-release (0.21.0-pre.2) so this is acceptable as a breaking change.
| /// Execute the given function on the inner value. | ||
| pub fn init_mapper<F: FnOnce(T) -> T + Send + 'static>(self, func: F) -> Self | ||
| pub fn init_mapper<F: Fn(T) -> T + Send + Sync + 'static>(self, func: F) -> Self | ||
| where | ||
| T: 'static, | ||
| { |
There was a problem hiding this comment.
init_mapper now requires F: Fn(T) -> T + Send + Sync, which is also a tightening vs the previous FnOnce + Send and can be a breaking change even when the parameter is already initialized (where thread-safety isn’t necessarily required). If possible, consider splitting into two methods (one for lazy init mapping requiring Fn + Sync, and one for eager mapping that can accept FnOnce) to keep the more flexible API for the initialized path.
There was a problem hiding this comment.
Same reasoning as the uninitialized change above. init_mapper has a single callsite (linear.rs) which already uses a Fn-compatible closure. Splitting into two methods adds API surface complexity for no practical benefit since all init closures are inherently Fn (they just transform tensors without consuming unique state). The map() fallback path (for already-initialized params) still accepts FnOnce since it consumes self.
| #[test] | ||
| fn lazy_clone_then_load_should_not_init_original() { | ||
| let device = <TestBackend as Backend>::Device::default(); | ||
| let module_1 = ModuleBasic::<TestBackend>::new(&device); | ||
|
|
||
| // Initialize module_1 so we have a record to load. | ||
| let _ = module_1.weight_basic.to_data(); | ||
| let record = module_1.clone().into_record(); | ||
|
|
||
| // Create a fresh uninitialized module and load weights into it. | ||
| let module_2 = ModuleBasic::<TestBackend>::new(&device); | ||
| assert!(!module_2.weight_basic.is_initialized()); | ||
|
|
||
| let module_2 = module_2.load_record(record); | ||
|
|
||
| // After loading, the param should be initialized with the loaded values. | ||
| assert_eq!( | ||
| module_1.weight_basic.to_data(), | ||
| module_2.weight_basic.to_data() | ||
| ); | ||
| } |
There was a problem hiding this comment.
The test name lazy_clone_then_load_should_not_init_original doesn’t match what the test currently asserts (it initializes module_1, creates a fresh module_2, loads the record, then checks equality). Consider renaming it to reflect the behavior under test (e.g., that loading into a fresh module initializes it correctly), or adjust the assertions to actually verify that cloning + loading does not initialize the original lazy module.
There was a problem hiding this comment.
Good catch. Renamed the test to load_record_into_uninitialized_module_should_work which accurately describes what it verifies. Fixed in 4fff91a.
…rection On targets without atomics, portable_atomic_util::Arc cannot directly hold unsized types like dyn Fn(...). Add InitFn<P> type alias and new_init_fn() constructor mirroring the existing Mapper/new_mapper pattern, with Box indirection on no_std targets.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #4703 +/- ##
==========================================
+ Coverage 63.24% 63.26% +0.02%
==========================================
Files 1108 1108
Lines 146074 146167 +93
==========================================
+ Hits 92381 92474 +93
Misses 53693 53693 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
laggui
left a comment
There was a problem hiding this comment.
Makes sense!
This enforces the lazy initiation contract even more.
Summary
Param::clone()to preserve lazy (uninitialized) state instead of eagerly triggering initialization, which allocated tensor memory unnecessarily on CPU backendsUninitialized.initfromBox<dyn FnOnce>toArc<dyn Fn>so the init function can be cloned without consuming itCloneimpl forUninitialized(just Arc refcount bumps, no memory allocation)init_mapper()to useArc::cloneinstead of themem::swap+ panic-placeholder hackBenchmark results (312 MB model, NdArray CPU backend) confirm max allocation is ~344 MB with no memory doubling.
Test plan
burn-core,burn-nn,burn-storetests pass (3 existing tests updated for new clone semantics)lazy_clonetests covering:cargo bench --bench unified_loadingconfirms no memory doublingCloses #3754