Skip to content

Commit e009504

Browse files
tamirmsclaude
andcommitted
Go generator: emit output size tracking and capped pre-allocation
Update the Go code generator to emit TrackOutputBytesOf calls before each heap allocation site (union arms, optional fields, array elements) and cap initial array allocation at 256 elements with append-based growth. This works with the new MaxOutputBytes option in go-xdr to allow callers to limit cumulative decoded output size. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1636575 commit e009504

5 files changed

Lines changed: 131 additions & 14 deletions

File tree

lib/xdrgen/generators/go.rb

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,9 @@ def render_union_decode_from_interface(out, union)
571571
else
572572
mn = name(arm)
573573
type = arm.type
574+
out2.puts " if err = xdr.TrackOutputBytesOf[#{reference arm.type}](d); err != nil {"
575+
out2.puts " return n, fmt.Errorf(\"decoding #{reference arm.type}: %w\", err)"
576+
out2.puts " }"
574577
out2.puts " u.#{mn} = new(#{reference arm.type})"
575578
render_decode_from_body(out2, "(*u.#{mn})",type, declared_variables: [], self_encode: false)
576579
end
@@ -662,6 +665,9 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
662665
out.puts tail
663666
out.puts " #{var} = nil"
664667
out.puts " if b {"
668+
out.puts " if err = xdr.TrackOutputBytesOf[#{name type}](d); err != nil {"
669+
out.puts " return n, fmt.Errorf(\"decoding #{name type}: %w\", err)"
670+
out.puts " }"
665671
out.puts " #{var} = new(#{name type})"
666672
end
667673
case type
@@ -704,6 +710,9 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
704710
out.puts tail
705711
out.puts " #{var} = nil"
706712
out.puts " if b {"
713+
out.puts " if err = xdr.TrackOutputBytesOf[#{name type.resolved_type.declaration.type}](d); err != nil {"
714+
out.puts " return n, fmt.Errorf(\"decoding #{name type.resolved_type.declaration.type}: %w\", err)"
715+
out.puts " }"
707716
out.puts " #{var} = new(#{name type.resolved_type.declaration.type})"
708717
end
709718
var = "(*#{name type})(#{var})" if self_encode
@@ -744,16 +753,35 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
744753
out.puts " if il, ok := d.InputLen(); ok && uint(il) < uint(l) {"
745754
out.puts " return n, fmt.Errorf(\"decoding #{name type}: length (%d) exceeds remaining input length (%d)\", l, il)"
746755
out.puts " }"
747-
out.puts " #{var} = make([]#{name type}, l)"
756+
# Cap pre-allocation to avoid memory amplification from untrusted inputs.
757+
# The InputLen check above compares element count against remaining
758+
# input bytes, but each element may be much larger in memory than on
759+
# the wire. Capping initial allocation and growing via append ensures
760+
# memory usage stays proportional to data actually decoded.
761+
slice_var = var # save before optional handling may reassign var
762+
out.puts " {"
763+
out.puts " initialCap := l"
764+
out.puts " if initialCap > xdr.MaxPrealloc {"
765+
out.puts " initialCap = xdr.MaxPrealloc"
766+
out.puts " }"
767+
out.puts " #{slice_var} = make([]#{name type}, 0, initialCap)"
768+
out.puts " var empty #{name type}"
748769
out.puts " for i := uint32(0); i < l; i++ {"
749-
element_var = "#{var}[i]"
770+
out.puts " if err = xdr.TrackOutputBytesOf[#{name type}](d); err != nil {"
771+
out.puts " return n, fmt.Errorf(\"decoding #{name type}: %w\", err)"
772+
out.puts " }"
773+
out.puts " #{slice_var} = append(#{slice_var}, empty)"
774+
element_var = "#{slice_var}[i]"
750775
optional_within = type.is_a?(AST::Identifier) && type.resolved_type.sub_type == :optional
751776
if optional_within
752777
out.puts " var eb bool"
753778
out.puts " eb, nTmp, err = d.DecodeBool()"
754779
out.puts tail
755780
out.puts " #{element_var} = nil"
756781
out.puts " if eb {"
782+
out.puts " if err = xdr.TrackOutputBytesOf[#{name type.resolved_type.declaration.type}](d); err != nil {"
783+
out.puts " return n, fmt.Errorf(\"decoding #{name type.resolved_type.declaration.type}: %w\", err)"
784+
out.puts " }"
757785
out.puts " #{element_var} = new(#{name type.resolved_type.declaration.type})"
758786
var = "(*#{element_var})"
759787
end
@@ -763,6 +791,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
763791
out.puts " }"
764792
end
765793
out.puts " }"
794+
out.puts " }"
766795
out.puts " }"
767796
else
768797
raise "Unknown sub_type: #{type.sub_type}"

spec/output/generator_spec_go/nesting.x/MyXDR_generated.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,15 +505,21 @@ func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
505505
}
506506
switch UnionKey(u.Type) {
507507
case UnionKeyOne:
508-
u.One = new(MyUnionOne)
508+
if err = xdr.TrackOutputBytesOf[MyUnionOne](d); err != nil {
509+
return n, fmt.Errorf("decoding MyUnionOne: %w", err)
510+
}
511+
u.One = new(MyUnionOne)
509512
nTmp, err = (*u.One).DecodeFrom(d, maxDepth)
510513
n += nTmp
511514
if err != nil {
512515
return n, fmt.Errorf("decoding MyUnionOne: %w", err)
513516
}
514517
return n, nil
515518
case UnionKeyTwo:
516-
u.Two = new(MyUnionTwo)
519+
if err = xdr.TrackOutputBytesOf[MyUnionTwo](d); err != nil {
520+
return n, fmt.Errorf("decoding MyUnionTwo: %w", err)
521+
}
522+
u.Two = new(MyUnionTwo)
517523
nTmp, err = (*u.Two).DecodeFrom(d, maxDepth)
518524
n += nTmp
519525
if err != nil {

spec/output/generator_spec_go/optional.x/MyXDR_generated.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
186186
}
187187
s.FirstOption = nil
188188
if b {
189+
if err = xdr.TrackOutputBytesOf[Int](d); err != nil {
190+
return n, fmt.Errorf("decoding Int: %w", err)
191+
}
189192
s.FirstOption = new(Int)
190193
s.FirstOption, nTmp, err = d.DecodeInt()
191194
n += nTmp
@@ -200,6 +203,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
200203
}
201204
s.SecondOption = nil
202205
if b {
206+
if err = xdr.TrackOutputBytesOf[Int](d); err != nil {
207+
return n, fmt.Errorf("decoding Int: %w", err)
208+
}
203209
s.SecondOption = new(Int)
204210
s.SecondOption, nTmp, err = d.DecodeInt()
205211
n += nTmp
@@ -214,6 +220,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
214220
}
215221
s.ThirdOption = nil
216222
if b {
223+
if err = xdr.TrackOutputBytesOf[Arr](d); err != nil {
224+
return n, fmt.Errorf("decoding Arr: %w", err)
225+
}
217226
s.ThirdOption = new(Arr)
218227
nTmp, err = s.ThirdOption.DecodeFrom(d, maxDepth)
219228
n += nTmp

spec/output/generator_spec_go/test.x/MyXDR_generated.go

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -552,14 +552,25 @@ func (s *Hashes2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
552552
if il, ok := d.InputLen(); ok && uint(il) < uint(l) {
553553
return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, il)
554554
}
555-
(*s) = make([]Hash, l)
555+
{
556+
initialCap := l
557+
if initialCap > xdr.MaxPrealloc {
558+
initialCap = xdr.MaxPrealloc
559+
}
560+
(*s) = make([]Hash, 0, initialCap)
561+
var empty Hash
556562
for i := uint32(0); i < l; i++ {
563+
if err = xdr.TrackOutputBytesOf[Hash](d); err != nil {
564+
return n, fmt.Errorf("decoding Hash: %w", err)
565+
}
566+
(*s) = append((*s), empty)
557567
nTmp, err = (*s)[i].DecodeFrom(d, maxDepth)
558568
n += nTmp
559569
if err != nil {
560570
return n, fmt.Errorf("decoding Hash: %w", err)
561571
}
562572
}
573+
}
563574
}
564575
return n, nil
565576
}
@@ -631,14 +642,25 @@ func (s *Hashes3) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
631642
if il, ok := d.InputLen(); ok && uint(il) < uint(l) {
632643
return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, il)
633644
}
634-
(*s) = make([]Hash, l)
645+
{
646+
initialCap := l
647+
if initialCap > xdr.MaxPrealloc {
648+
initialCap = xdr.MaxPrealloc
649+
}
650+
(*s) = make([]Hash, 0, initialCap)
651+
var empty Hash
635652
for i := uint32(0); i < l; i++ {
653+
if err = xdr.TrackOutputBytesOf[Hash](d); err != nil {
654+
return n, fmt.Errorf("decoding Hash: %w", err)
655+
}
656+
(*s) = append((*s), empty)
636657
nTmp, err = (*s)[i].DecodeFrom(d, maxDepth)
637658
n += nTmp
638659
if err != nil {
639660
return n, fmt.Errorf("decoding Hash: %w", err)
640661
}
641662
}
663+
}
642664
}
643665
return n, nil
644666
}
@@ -1006,6 +1028,9 @@ func (s *MyStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
10061028
}
10071029
s.Field2 = nil
10081030
if b {
1031+
if err = xdr.TrackOutputBytesOf[Hash](d); err != nil {
1032+
return n, fmt.Errorf("decoding Hash: %w", err)
1033+
}
10091034
s.Field2 = new(Hash)
10101035
nTmp, err = s.Field2.DecodeFrom(d, maxDepth)
10111036
n += nTmp
@@ -1114,14 +1139,25 @@ func (s *LotsOfMyStructs) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error)
11141139
if il, ok := d.InputLen(); ok && uint(il) < uint(l) {
11151140
return n, fmt.Errorf("decoding MyStruct: length (%d) exceeds remaining input length (%d)", l, il)
11161141
}
1117-
s.Members = make([]MyStruct, l)
1142+
{
1143+
initialCap := l
1144+
if initialCap > xdr.MaxPrealloc {
1145+
initialCap = xdr.MaxPrealloc
1146+
}
1147+
s.Members = make([]MyStruct, 0, initialCap)
1148+
var empty MyStruct
11181149
for i := uint32(0); i < l; i++ {
1150+
if err = xdr.TrackOutputBytesOf[MyStruct](d); err != nil {
1151+
return n, fmt.Errorf("decoding MyStruct: %w", err)
1152+
}
1153+
s.Members = append(s.Members, empty)
11191154
nTmp, err = s.Members[i].DecodeFrom(d, maxDepth)
11201155
n += nTmp
11211156
if err != nil {
11221157
return n, fmt.Errorf("decoding MyStruct: %w", err)
11231158
}
11241159
}
1160+
}
11251161
}
11261162
return n, nil
11271163
}
@@ -1571,7 +1607,10 @@ switch Color(u.Color) {
15711607
// Void
15721608
return n, nil
15731609
default:
1574-
u.Blah2 = new(int32)
1610+
if err = xdr.TrackOutputBytesOf[int32](d); err != nil {
1611+
return n, fmt.Errorf("decoding int32: %w", err)
1612+
}
1613+
u.Blah2 = new(int32)
15751614
(*u.Blah2), nTmp, err = d.DecodeInt()
15761615
n += nTmp
15771616
if err != nil {

spec/output/generator_spec_go/union.x/MyXDR_generated.go

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -417,15 +417,21 @@ func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
417417
}
418418
switch UnionKey(u.Type) {
419419
case UnionKeyError:
420-
u.Error = new(Error)
420+
if err = xdr.TrackOutputBytesOf[Error](d); err != nil {
421+
return n, fmt.Errorf("decoding Error: %w", err)
422+
}
423+
u.Error = new(Error)
421424
nTmp, err = (*u.Error).DecodeFrom(d, maxDepth)
422425
n += nTmp
423426
if err != nil {
424427
return n, fmt.Errorf("decoding Error: %w", err)
425428
}
426429
return n, nil
427430
case UnionKeyMulti:
428-
u.Things = new([]Multi)
431+
if err = xdr.TrackOutputBytesOf[[]Multi](d); err != nil {
432+
return n, fmt.Errorf("decoding []Multi: %w", err)
433+
}
434+
u.Things = new([]Multi)
429435
var l uint32
430436
l, nTmp, err = d.DecodeUint()
431437
n += nTmp
@@ -437,14 +443,25 @@ switch UnionKey(u.Type) {
437443
if il, ok := d.InputLen(); ok && uint(il) < uint(l) {
438444
return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, il)
439445
}
440-
(*u.Things) = make([]Multi, l)
446+
{
447+
initialCap := l
448+
if initialCap > xdr.MaxPrealloc {
449+
initialCap = xdr.MaxPrealloc
450+
}
451+
(*u.Things) = make([]Multi, 0, initialCap)
452+
var empty Multi
441453
for i := uint32(0); i < l; i++ {
454+
if err = xdr.TrackOutputBytesOf[Multi](d); err != nil {
455+
return n, fmt.Errorf("decoding Multi: %w", err)
456+
}
457+
(*u.Things) = append((*u.Things), empty)
442458
nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth)
443459
n += nTmp
444460
if err != nil {
445461
return n, fmt.Errorf("decoding Multi: %w", err)
446462
}
447463
}
464+
}
448465
}
449466
return n, nil
450467
}
@@ -626,15 +643,21 @@ func (u *IntUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
626643
}
627644
switch int32(u.Type) {
628645
case 0:
629-
u.Error = new(Error)
646+
if err = xdr.TrackOutputBytesOf[Error](d); err != nil {
647+
return n, fmt.Errorf("decoding Error: %w", err)
648+
}
649+
u.Error = new(Error)
630650
nTmp, err = (*u.Error).DecodeFrom(d, maxDepth)
631651
n += nTmp
632652
if err != nil {
633653
return n, fmt.Errorf("decoding Error: %w", err)
634654
}
635655
return n, nil
636656
case 1:
637-
u.Things = new([]Multi)
657+
if err = xdr.TrackOutputBytesOf[[]Multi](d); err != nil {
658+
return n, fmt.Errorf("decoding []Multi: %w", err)
659+
}
660+
u.Things = new([]Multi)
638661
var l uint32
639662
l, nTmp, err = d.DecodeUint()
640663
n += nTmp
@@ -646,14 +669,25 @@ switch int32(u.Type) {
646669
if il, ok := d.InputLen(); ok && uint(il) < uint(l) {
647670
return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, il)
648671
}
649-
(*u.Things) = make([]Multi, l)
672+
{
673+
initialCap := l
674+
if initialCap > xdr.MaxPrealloc {
675+
initialCap = xdr.MaxPrealloc
676+
}
677+
(*u.Things) = make([]Multi, 0, initialCap)
678+
var empty Multi
650679
for i := uint32(0); i < l; i++ {
680+
if err = xdr.TrackOutputBytesOf[Multi](d); err != nil {
681+
return n, fmt.Errorf("decoding Multi: %w", err)
682+
}
683+
(*u.Things) = append((*u.Things), empty)
651684
nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth)
652685
n += nTmp
653686
if err != nil {
654687
return n, fmt.Errorf("decoding Multi: %w", err)
655688
}
656689
}
690+
}
657691
}
658692
return n, nil
659693
}

0 commit comments

Comments
 (0)