Skip to content

[Breaking] Remove Tensor backend generic and add high-level Device struct#4717

Draft
laggui wants to merge 3 commits intomainfrom
refactor/backend/tensor
Draft

[Breaking] Remove Tensor backend generic and add high-level Device struct#4717
laggui wants to merge 3 commits intomainfrom
refactor/backend/tensor

Conversation

@laggui
Copy link
Copy Markdown
Member

@laggui laggui commented Apr 2, 2026

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Changes

Problem: Every piece of user code that uses tensors must carry a B: Backend type parameter, which propagates through every struct, function, and trait impl in a project:

// Before: B infects the entire call stack
pub struct Model<B: Backend> { layer: nn::Linear<B>, ... }
impl<B: Backend> Model<B> {
    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> { ... }
}
fn run<B: AutodiffBackend>(device: B::Device) { ... }

This creates three concrete problems:

  1. Boilerplate: Every library/app must expose <B: Backend> generics or lock users into one backend.
  2. No runtime dispatch: Backend is compile-time only; can't fall back from GPU→CPU when hardware is unavailable, and switching devices requires going through TensorData manipulations (which is really meant to be a data representation, not a device-transfer mechanism).
  3. Autodiff coupling: Training requires a separate Autodiff<B> wrapper type, making the backend type even more complex.

Solution: Two changes working together:

  • burn-dispatch crate (landed in [Feat] Global backend Dispatch #4508) provides a single concrete Dispatch backend that implements the Backend trait via compile-time enum dispatch over all enabled backends. Backends are still behind feature flags, so users enable only what they need. DispatchDevice and DispatchTensor are enums over per-backend device/tensor types, so the actual backend is selected at runtime from the enum variant while the type system sees only Dispatch.

  • Remove B from Tensor: Since Dispatch is the one backend for user-facing code, Tensor<B, D, K> becomes Tensor<D, K>. Autodiff is now a property of the Device rather than a type parameter — call .autodiff() on any device to opt into gradient tracking.

// After: no backend generic anywhere in user code
let device = Device::default();            // auto-selects best available backend
let device = Device::default().autodiff(); // enables gradient tracking

let x = Tensor::<2>::zeros([3, 4], &device);

The DispatchDevice enum dispatches based on which Cargo feature flags are enabled (cuda, ndarray, vulkan, etc.). When only one backend feature is enabled the compiler optimizes the match away entirely; with multiple backends enabled the overhead is minimal enum dispatch rather than vtable dispatch.

Key benefits this unlocks:

  • Easy runtime switching between backend devices (e.g. CPU ↔ GPU) without TensorData round-trips.
  • Simpler development cycles — feature-gate the primitive to keep compile times fast while iterating.
  • A path toward making the primitive opaque, further improving compile times.
  • Docs and book will be updated separately to guide existing users through the migration.

Testing

Backend tests were migrated to use Dispatch in #4666 and validate correctness across all backends.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Default backend implementation

1 participant