Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions lib/xdrgen/generators/go.rb
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,9 @@ def render_union_decode_from_interface(out, union)
else
mn = name(arm)
type = arm.type
out2.puts " if err = xdr.TrackOutputBytesOf[#{reference arm.type}](d); err != nil {"
out2.puts " return n, fmt.Errorf(\"decoding #{reference arm.type}: %w\", err)"
out2.puts " }"
Comment thread
tamirms marked this conversation as resolved.
out2.puts " u.#{mn} = new(#{reference arm.type})"
render_decode_from_body(out2, "(*u.#{mn})",type, declared_variables: [], self_encode: false)
end
Expand Down Expand Up @@ -662,6 +665,9 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts tail
out.puts " #{var} = nil"
out.puts " if b {"
out.puts " if err = xdr.TrackOutputBytesOf[#{name type}](d); err != nil {"
out.puts " return n, fmt.Errorf(\"decoding #{name type}: %w\", err)"
out.puts " }"
out.puts " #{var} = new(#{name type})"
Comment thread
tamirms marked this conversation as resolved.
end
case type
Expand Down Expand Up @@ -704,6 +710,9 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts tail
out.puts " #{var} = nil"
out.puts " if b {"
out.puts " if err = xdr.TrackOutputBytesOf[#{name type.resolved_type.declaration.type}](d); err != nil {"
out.puts " return n, fmt.Errorf(\"decoding #{name type.resolved_type.declaration.type}: %w\", err)"
out.puts " }"
out.puts " #{var} = new(#{name type.resolved_type.declaration.type})"
end
var = "(*#{name type})(#{var})" if self_encode
Expand Down Expand Up @@ -744,16 +753,35 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " if il, ok := d.InputLen(); ok && uint(il) < uint(l) {"
out.puts " return n, fmt.Errorf(\"decoding #{name type}: length (%d) exceeds remaining input length (%d)\", l, il)"
out.puts " }"
out.puts " #{var} = make([]#{name type}, l)"
# Cap pre-allocation to avoid memory amplification from untrusted inputs.
# The InputLen check above compares element count against remaining
# input bytes, but each element may be much larger in memory than on
# the wire. Capping initial allocation and growing via append ensures
# memory usage stays proportional to data actually decoded.
slice_var = var # save before optional handling may reassign var
out.puts " {"
out.puts " initialCap := l"
out.puts " if initialCap > xdr.MaxPrealloc {"
out.puts " initialCap = xdr.MaxPrealloc"
out.puts " }"
out.puts " #{slice_var} = make([]#{name type}, 0, initialCap)"
out.puts " var empty #{name type}"
out.puts " for i := uint32(0); i < l; i++ {"
element_var = "#{var}[i]"
out.puts " if err = xdr.TrackOutputBytesOf[#{name type}](d); err != nil {"
out.puts " return n, fmt.Errorf(\"decoding #{name type}: %w\", err)"
out.puts " }"
out.puts " #{slice_var} = append(#{slice_var}, empty)"
element_var = "#{slice_var}[i]"
optional_within = type.is_a?(AST::Identifier) && type.resolved_type.sub_type == :optional
if optional_within
out.puts " var eb bool"
out.puts " eb, nTmp, err = d.DecodeBool()"
out.puts tail
out.puts " #{element_var} = nil"
out.puts " if eb {"
out.puts " if err = xdr.TrackOutputBytesOf[#{name type.resolved_type.declaration.type}](d); err != nil {"
out.puts " return n, fmt.Errorf(\"decoding #{name type.resolved_type.declaration.type}: %w\", err)"
out.puts " }"
out.puts " #{element_var} = new(#{name type.resolved_type.declaration.type})"
Comment thread
tamirms marked this conversation as resolved.
var = "(*#{element_var})"
end
Expand All @@ -763,6 +791,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " }"
end
out.puts " }"
out.puts " }"
out.puts " }"
else
raise "Unknown sub_type: #{type.sub_type}"
Expand Down
10 changes: 8 additions & 2 deletions spec/output/generator_spec_go/nesting.x/MyXDR_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,15 +505,21 @@ func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
}
switch UnionKey(u.Type) {
case UnionKeyOne:
u.One = new(MyUnionOne)
if err = xdr.TrackOutputBytesOf[MyUnionOne](d); err != nil {
return n, fmt.Errorf("decoding MyUnionOne: %w", err)
}
u.One = new(MyUnionOne)
nTmp, err = (*u.One).DecodeFrom(d, maxDepth)
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding MyUnionOne: %w", err)
}
return n, nil
case UnionKeyTwo:
u.Two = new(MyUnionTwo)
if err = xdr.TrackOutputBytesOf[MyUnionTwo](d); err != nil {
return n, fmt.Errorf("decoding MyUnionTwo: %w", err)
}
u.Two = new(MyUnionTwo)
nTmp, err = (*u.Two).DecodeFrom(d, maxDepth)
n += nTmp
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions spec/output/generator_spec_go/optional.x/MyXDR_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
}
s.FirstOption = nil
if b {
if err = xdr.TrackOutputBytesOf[Int](d); err != nil {
return n, fmt.Errorf("decoding Int: %w", err)
}
s.FirstOption = new(Int)
s.FirstOption, nTmp, err = d.DecodeInt()
Comment thread
tamirms marked this conversation as resolved.
n += nTmp
Expand All @@ -200,6 +203,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
}
s.SecondOption = nil
if b {
if err = xdr.TrackOutputBytesOf[Int](d); err != nil {
return n, fmt.Errorf("decoding Int: %w", err)
}
s.SecondOption = new(Int)
s.SecondOption, nTmp, err = d.DecodeInt()
Comment thread
tamirms marked this conversation as resolved.
n += nTmp
Expand All @@ -214,6 +220,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
}
s.ThirdOption = nil
if b {
if err = xdr.TrackOutputBytesOf[Arr](d); err != nil {
return n, fmt.Errorf("decoding Arr: %w", err)
}
s.ThirdOption = new(Arr)
nTmp, err = s.ThirdOption.DecodeFrom(d, maxDepth)
n += nTmp
Expand Down
47 changes: 43 additions & 4 deletions spec/output/generator_spec_go/test.x/MyXDR_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,14 +552,25 @@ func (s *Hashes2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
if il, ok := d.InputLen(); ok && uint(il) < uint(l) {
return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, il)
}
(*s) = make([]Hash, l)
{
initialCap := l
if initialCap > xdr.MaxPrealloc {
initialCap = xdr.MaxPrealloc
}
(*s) = make([]Hash, 0, initialCap)
var empty Hash
for i := uint32(0); i < l; i++ {
if err = xdr.TrackOutputBytesOf[Hash](d); err != nil {
return n, fmt.Errorf("decoding Hash: %w", err)
}
(*s) = append((*s), empty)
nTmp, err = (*s)[i].DecodeFrom(d, maxDepth)
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding Hash: %w", err)
}
}
}
}
return n, nil
}
Expand Down Expand Up @@ -631,14 +642,25 @@ func (s *Hashes3) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
if il, ok := d.InputLen(); ok && uint(il) < uint(l) {
return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, il)
}
(*s) = make([]Hash, l)
{
initialCap := l
if initialCap > xdr.MaxPrealloc {
initialCap = xdr.MaxPrealloc
}
(*s) = make([]Hash, 0, initialCap)
var empty Hash
for i := uint32(0); i < l; i++ {
if err = xdr.TrackOutputBytesOf[Hash](d); err != nil {
return n, fmt.Errorf("decoding Hash: %w", err)
}
(*s) = append((*s), empty)
nTmp, err = (*s)[i].DecodeFrom(d, maxDepth)
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding Hash: %w", err)
}
}
}
}
return n, nil
}
Expand Down Expand Up @@ -1006,6 +1028,9 @@ func (s *MyStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
}
s.Field2 = nil
if b {
if err = xdr.TrackOutputBytesOf[Hash](d); err != nil {
return n, fmt.Errorf("decoding Hash: %w", err)
}
s.Field2 = new(Hash)
nTmp, err = s.Field2.DecodeFrom(d, maxDepth)
n += nTmp
Expand Down Expand Up @@ -1114,14 +1139,25 @@ func (s *LotsOfMyStructs) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error)
if il, ok := d.InputLen(); ok && uint(il) < uint(l) {
return n, fmt.Errorf("decoding MyStruct: length (%d) exceeds remaining input length (%d)", l, il)
}
s.Members = make([]MyStruct, l)
{
initialCap := l
if initialCap > xdr.MaxPrealloc {
initialCap = xdr.MaxPrealloc
}
s.Members = make([]MyStruct, 0, initialCap)
var empty MyStruct
for i := uint32(0); i < l; i++ {
if err = xdr.TrackOutputBytesOf[MyStruct](d); err != nil {
return n, fmt.Errorf("decoding MyStruct: %w", err)
}
s.Members = append(s.Members, empty)
nTmp, err = s.Members[i].DecodeFrom(d, maxDepth)
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding MyStruct: %w", err)
}
}
}
}
return n, nil
}
Expand Down Expand Up @@ -1571,7 +1607,10 @@ switch Color(u.Color) {
// Void
return n, nil
default:
u.Blah2 = new(int32)
if err = xdr.TrackOutputBytesOf[int32](d); err != nil {
return n, fmt.Errorf("decoding int32: %w", err)
}
u.Blah2 = new(int32)
(*u.Blah2), nTmp, err = d.DecodeInt()
n += nTmp
if err != nil {
Expand Down
46 changes: 40 additions & 6 deletions spec/output/generator_spec_go/union.x/MyXDR_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,21 @@ func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
}
switch UnionKey(u.Type) {
case UnionKeyError:
u.Error = new(Error)
if err = xdr.TrackOutputBytesOf[Error](d); err != nil {
return n, fmt.Errorf("decoding Error: %w", err)
}
u.Error = new(Error)
nTmp, err = (*u.Error).DecodeFrom(d, maxDepth)
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding Error: %w", err)
}
return n, nil
case UnionKeyMulti:
u.Things = new([]Multi)
if err = xdr.TrackOutputBytesOf[[]Multi](d); err != nil {
return n, fmt.Errorf("decoding []Multi: %w", err)
}
u.Things = new([]Multi)
var l uint32
l, nTmp, err = d.DecodeUint()
n += nTmp
Expand All @@ -437,14 +443,25 @@ switch UnionKey(u.Type) {
if il, ok := d.InputLen(); ok && uint(il) < uint(l) {
return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, il)
}
(*u.Things) = make([]Multi, l)
{
initialCap := l
if initialCap > xdr.MaxPrealloc {
initialCap = xdr.MaxPrealloc
}
(*u.Things) = make([]Multi, 0, initialCap)
var empty Multi
for i := uint32(0); i < l; i++ {
if err = xdr.TrackOutputBytesOf[Multi](d); err != nil {
return n, fmt.Errorf("decoding Multi: %w", err)
}
(*u.Things) = append((*u.Things), empty)
nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth)
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding Multi: %w", err)
}
}
}
}
return n, nil
}
Expand Down Expand Up @@ -626,15 +643,21 @@ func (u *IntUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
}
switch int32(u.Type) {
case 0:
u.Error = new(Error)
if err = xdr.TrackOutputBytesOf[Error](d); err != nil {
return n, fmt.Errorf("decoding Error: %w", err)
}
u.Error = new(Error)
nTmp, err = (*u.Error).DecodeFrom(d, maxDepth)
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding Error: %w", err)
}
return n, nil
case 1:
u.Things = new([]Multi)
if err = xdr.TrackOutputBytesOf[[]Multi](d); err != nil {
return n, fmt.Errorf("decoding []Multi: %w", err)
}
u.Things = new([]Multi)
var l uint32
l, nTmp, err = d.DecodeUint()
n += nTmp
Expand All @@ -646,14 +669,25 @@ switch int32(u.Type) {
if il, ok := d.InputLen(); ok && uint(il) < uint(l) {
return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, il)
}
(*u.Things) = make([]Multi, l)
{
initialCap := l
if initialCap > xdr.MaxPrealloc {
initialCap = xdr.MaxPrealloc
}
(*u.Things) = make([]Multi, 0, initialCap)
var empty Multi
for i := uint32(0); i < l; i++ {
if err = xdr.TrackOutputBytesOf[Multi](d); err != nil {
return n, fmt.Errorf("decoding Multi: %w", err)
}
(*u.Things) = append((*u.Things), empty)
nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth)
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding Multi: %w", err)
}
}
}
}
return n, nil
}
Expand Down
Loading