Skip to content

Commit 0eb82b5

Browse files
authored
Feat/implement fusion for irfft (#4736)
* ignore my tags file in .gitignore * add fusion support for rfft * remove warnings for missing documentation * update router runner, add rfft call * fix typos in doc for rfft * implement fusion for irfft * update router runner, add irfft call * add fusion support for rfft * remove warnings for missing documentation * implement fusion for irfft * update router runner, add irfft call
1 parent ab00b38 commit 0eb82b5

4 files changed

Lines changed: 81 additions & 4 deletions

File tree

crates/burn-fusion/src/ops/module.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,10 +1593,36 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
15931593
}
15941594

15951595
fn irfft(
1596-
_spectrum_re: FloatTensor<Fusion<B>>,
1597-
_spectrum_im: FloatTensor<Fusion<B>>,
1598-
_dim: usize,
1596+
spectrum_re: FloatTensor<Fusion<B>>,
1597+
spectrum_im: FloatTensor<Fusion<B>>,
1598+
dim: usize,
15991599
) -> FloatTensor<Fusion<B>> {
1600-
todo!("irfft is not yet supported for fusion")
1600+
make_ops!(IRfftOps, IRfftOpIr, |desc: &IRfftOpIr,
1601+
handles: &mut HandleContainer<
1602+
B::Handle,
1603+
>| {
1604+
let input_re = handles.get_float_tensor::<B>(&desc.input_re);
1605+
let input_im = handles.get_float_tensor::<B>(&desc.input_im);
1606+
1607+
let signal = B::irfft(input_re, input_im, desc.dim);
1608+
handles.register_float_tensor::<B>(&desc.out_signal.id, signal);
1609+
});
1610+
1611+
let streams = OperationStreams::with_inputs([&spectrum_re, &spectrum_im]);
1612+
let client = spectrum_re.client.clone();
1613+
1614+
let desc = IRfftOpIr::create(spectrum_re.into_ir(), spectrum_im.into_ir(), dim, || {
1615+
client.create_empty_handle()
1616+
});
1617+
1618+
let mut outputs = client
1619+
.register(
1620+
streams,
1621+
OperationIr::Module(ModuleOperationIr::IRfft(desc.clone())),
1622+
IRfftOps::<B>::new(desc),
1623+
)
1624+
.into_iter();
1625+
1626+
outputs.next().unwrap()
16011627
}
16021628
}

crates/burn-fusion/src/stream/context.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,12 @@ impl RelativeOps for ModuleOperationIr {
491491
out_re: desc.out_re.to_relative(converter),
492492
out_im: desc.out_re.to_relative(converter),
493493
}),
494+
ModuleOperationIr::IRfft(desc) => ModuleOperationIr::IRfft(IRfftOpIr {
495+
input_re: desc.input_re.to_relative(converter),
496+
input_im: desc.input_im.to_relative(converter),
497+
dim: desc.dim,
498+
out_signal: desc.out_signal.to_relative(converter),
499+
}),
494500
ModuleOperationIr::Attention(desc) => ModuleOperationIr::Attention(AttentionOpIr {
495501
query: desc.query.to_relative(converter),
496502
key: desc.key.to_relative(converter),

crates/burn-ir/src/operation.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ pub enum ModuleOperationIr {
252252
InterpolateBackward(InterpolateBackwardOpIr),
253253
/// Operation corresponding to [Rfft](burn_backend::ops::ModuleOps::rfft)
254254
Rfft(RfftOpIr),
255+
/// Operation corresponding to [IRfft](burn_backend::ops::ModuleOps::irfft)
256+
IRfft(IRfftOpIr),
255257
/// Operation corresponding to [attention](burn_backend::ops::ModuleOps::attention).
256258
Attention(AttentionOpIr),
257259
}
@@ -1599,6 +1601,15 @@ pub struct RfftOpIr {
15991601
pub out_im: TensorIr,
16001602
}
16011603

1604+
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1605+
#[allow(missing_docs)]
1606+
pub struct IRfftOpIr {
1607+
pub input_re: TensorIr,
1608+
pub input_im: TensorIr,
1609+
pub dim: usize,
1610+
pub out_signal: TensorIr,
1611+
}
1612+
16021613
#[allow(missing_docs)]
16031614
impl RfftOpIr {
16041615
pub fn create<F>(signal: TensorIr, dim: usize, mut new_id: F) -> Self
@@ -1618,6 +1629,25 @@ impl RfftOpIr {
16181629
}
16191630
}
16201631

1632+
#[allow(missing_docs)]
1633+
impl IRfftOpIr {
1634+
pub fn create<F>(input_re: TensorIr, input_im: TensorIr, dim: usize, mut new_id: F) -> Self
1635+
where
1636+
F: FnMut() -> crate::TensorId,
1637+
{
1638+
let mut shape = input_re.shape.clone();
1639+
shape[dim] = (shape[dim] - 1) * 2;
1640+
let dtype = input_re.dtype;
1641+
1642+
Self {
1643+
input_re,
1644+
input_im,
1645+
dim,
1646+
out_signal: TensorIr::uninit(new_id(), shape, dtype),
1647+
}
1648+
}
1649+
}
1650+
16211651
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
16221652
#[allow(missing_docs)]
16231653
pub struct AttentionOptionsIr {
@@ -2703,6 +2733,9 @@ impl ModuleOperationIr {
27032733
Box::new([&repr.x, &repr.grad].into_iter())
27042734
}
27052735
ModuleOperationIr::Rfft(repr) => Box::new([&repr.signal].into_iter()),
2736+
ModuleOperationIr::IRfft(repr) => {
2737+
Box::new([&repr.input_re, &repr.input_im].into_iter())
2738+
}
27062739
ModuleOperationIr::Attention(repr) => {
27072740
if let Some(mask) = &repr.mask {
27082741
if let Some(attn_bias) = &repr.attn_bias {
@@ -2798,6 +2831,7 @@ impl ModuleOperationIr {
27982831
ModuleOperationIr::Interpolate(repr) => Box::new([&repr.out].into_iter()),
27992832
ModuleOperationIr::InterpolateBackward(repr) => Box::new([&repr.out].into_iter()),
28002833
ModuleOperationIr::Rfft(repr) => Box::new([&repr.out_re, &repr.out_im].into_iter()),
2834+
ModuleOperationIr::IRfft(repr) => Box::new([&repr.out_signal].into_iter()),
28012835
ModuleOperationIr::Attention(repr) => Box::new([&repr.out].into_iter()),
28022836
}
28032837
}
@@ -2998,6 +3032,10 @@ impl ModuleOperationIr {
29983032
ModuleOperationIr::Rfft(repr) => {
29993033
repr.signal.mark_read_only(nodes, &mut output);
30003034
}
3035+
ModuleOperationIr::IRfft(repr) => {
3036+
repr.input_re.mark_read_only(nodes, &mut output);
3037+
repr.input_im.mark_read_only(nodes, &mut output);
3038+
}
30013039
ModuleOperationIr::Attention(repr) => {
30023040
repr.query.mark_read_only(nodes, &mut output);
30033041
repr.key.mark_read_only(nodes, &mut output);

crates/burn-router/src/runner.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,13 @@ impl<B: BackendIr> RunnerClient for Runner<B> {
15241524
handles.register_float_tensor::<B>(&desc.out_re.id, out_re);
15251525
handles.register_float_tensor::<B>(&desc.out_im.id, out_im);
15261526
}
1527+
ModuleOperationIr::IRfft(desc) => {
1528+
let spectrum_re = handles.get_float_tensor::<B>(&desc.input_re);
1529+
let spectrum_im = handles.get_float_tensor::<B>(&desc.input_im);
1530+
let signal = B::irfft(spectrum_re, spectrum_im, desc.dim);
1531+
1532+
handles.register_float_tensor::<B>(&desc.out_signal.id, signal);
1533+
}
15271534
ModuleOperationIr::Attention(desc) => {
15281535
let query = handles.get_float_tensor::<B>(&desc.query);
15291536
let key = handles.get_float_tensor::<B>(&desc.key);

0 commit comments

Comments
 (0)