[Breaking] Remove Tensor backend generic and add high-level Device struct#4717
Draft
[Breaking] Remove Tensor backend generic and add high-level Device struct#4717
Tensor backend generic and add high-level Device struct#4717Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Pull Request Template
Checklist
cargo run-checkscommand has been executed.Related Issues/PRs
<B: Backend>generic (core motivation, by @nathanielsimard)Dispatch#4508 (introducesburn-dispatch/Dispatchbackend),Dispatchautodiff checkpointing strategy support #4629 (autodiff checkpointing as device property), [Breaking] Use device settings to provide output dtype #4653 (dtype from device settings), Refactor backend tests to set device settings at initialization + useDispatch#4666 (tests migrated toDispatch)Changes
Problem: Every piece of user code that uses tensors must carry a
B: Backendtype parameter, which propagates through every struct, function, and trait impl in a project:This creates three concrete problems:
<B: Backend>generics or lock users into one backend.TensorDatamanipulations (which is really meant to be a data representation, not a device-transfer mechanism).Autodiff<B>wrapper type, making the backend type even more complex.Solution: Two changes working together:
burn-dispatchcrate (landed in [Feat] Global backendDispatch#4508) provides a single concreteDispatchbackend that implements theBackendtrait via compile-time enum dispatch over all enabled backends. Backends are still behind feature flags, so users enable only what they need.DispatchDeviceandDispatchTensorare enums over per-backend device/tensor types, so the actual backend is selected at runtime from the enum variant while the type system sees onlyDispatch.Remove
BfromTensor: SinceDispatchis the one backend for user-facing code,Tensor<B, D, K>becomesTensor<D, K>. Autodiff is now a property of theDevicerather than a type parameter — call.autodiff()on any device to opt into gradient tracking.The
DispatchDeviceenum dispatches based on which Cargo feature flags are enabled (cuda,ndarray,vulkan, etc.). When only one backend feature is enabled the compiler optimizes thematchaway entirely; with multiple backends enabled the overhead is minimal enum dispatch rather than vtable dispatch.Key benefits this unlocks:
TensorDataround-trips.Testing
Backend tests were migrated to use
Dispatchin #4666 and validate correctness across all backends.