Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions d2renderers/d2svg/d2svg.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ type RenderOpts struct {
OmitVersion *bool
}

// escapeClassNames escapes HTML special characters in class names to produce valid SVG attributes.
func escapeClassNames(classes []string) []string {
escaped := make([]string, len(classes))
for i, class := range classes {
escaped[i] = html.EscapeString(class)
}
return escaped
}

func dimensions(diagram *d2target.Diagram, pad int) (left, top, width, height int) {
tl, br := diagram.BoundingBox()
left = tl.X - pad
Expand Down Expand Up @@ -1010,7 +1019,7 @@ func drawConnection(writer io.Writer, diagramHash string, connection d2target.Co
}

classes := []string{base64.URLEncoding.EncodeToString([]byte(svg.EscapeText(connection.ID)))}
classes = append(classes, connection.Classes...)
classes = append(classes, escapeClassNames(connection.Classes)...)
classStr := fmt.Sprintf(` class="%s"`, strings.Join(classes, " "))

fmt.Fprintf(writer, `<g%s%s>`, classStr, opacityStyle)
Expand Down Expand Up @@ -1628,7 +1637,7 @@ func drawShape(writer, appendixWriter io.Writer, diagramHash string, targetShape
if targetShape.Animated {
classes = append(classes, "animated-shape")
}
classes = append(classes, targetShape.Classes...)
classes = append(classes, escapeClassNames(targetShape.Classes)...)
classStr := fmt.Sprintf(` class="%s"`, strings.Join(classes, " "))
fmt.Fprintf(writer, `<g%s%s>`, classStr, opacityStyle)
tl := geo.NewPoint(float64(targetShape.Pos.X), float64(targetShape.Pos.Y))
Expand Down
53 changes: 53 additions & 0 deletions d2renderers/d2svg/d2svg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,56 @@ func TestSortObjects(t *testing.T) {
}
}
}

func TestEscapeClassNames(t *testing.T) {
tests := []struct {
name string
input []string
expected []string
}{
{
name: "empty slice",
input: []string{},
expected: []string{},
},
{
name: "no special characters",
input: []string{"foo", "bar", "baz"},
expected: []string{"foo", "bar", "baz"},
},
{
name: "with double quotes",
input: []string{"test label: \"Hello World\""},
expected: []string{"test label: &#34;Hello World&#34;"},
},
{
name: "with single quotes",
input: []string{"test's value"},
expected: []string{"test&#39;s value"},
},
{
name: "with angle brackets",
input: []string{"<script>", "foo&bar"},
expected: []string{"&lt;script&gt;", "foo&amp;bar"},
},
{
name: "mixed classes",
input: []string{"normal", "has \"quotes\"", "also-normal"},
expected: []string{"normal", "has &#34;quotes&#34;", "also-normal"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := escapeClassNames(tt.input)
if len(result) != len(tt.expected) {
t.Fatalf("length mismatch: got %d, expected %d", len(result), len(tt.expected))
}
for i := range result {
if result[i] != tt.expected[i] {
t.Errorf("at index %d: got %q, expected %q", i, result[i], tt.expected[i])
}
}
})
}
}