Skip to content

Commit 66c2f63

Browse files
committed
refactor(ssh): construct host key callback only when dialing
knownhosts.New() caches the known hosts database, which means that it will not be picked up across reconnects. As such, this commit delays the instantiation of the host key callback so that this caching does not happen and the user isn't reprompted to add a key they've already added during the ACK process before magic rollback.
1 parent 1056a66 commit 66c2f63

6 files changed

Lines changed: 43 additions & 20 deletions

File tree

cmd/apply/apply.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ func applyMain(cmd *cobra.Command, opts *cmdOpts.ApplyOpts) error {
646646
log.Printf("\n")
647647

648648
var confirm bool
649-
confirm, err = cmdUtils.ConfirmationInput("Activate this configuration?", cmdUtils.ConfirmationPromptOptions{
649+
confirm, err = cmdUtils.ConfirmationInput(stopCtx, "Activate this configuration?", cmdUtils.ConfirmationPromptOptions{
650650
InvalidBehavior: cfg.Confirmation.Invalid,
651651
EmptyBehavior: cfg.Confirmation.Empty,
652652
})

cmd/generation/delete/delete.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package delete
22

33
import (
4+
"context"
45
"fmt"
56
"os"
67
"path/filepath"
@@ -160,7 +161,7 @@ func generationDeleteMain(cmd *cobra.Command, genOpts *cmdOpts.GenerationOpts, o
160161

161162
if !opts.AlwaysConfirm && !cfg.Confirmation.Always {
162163
var confirm bool
163-
confirm, err = cmdUtils.ConfirmationInput("Proceed?", cmdUtils.ConfirmationPromptOptions{
164+
confirm, err = cmdUtils.ConfirmationInput(context.Background(), "Proceed?", cmdUtils.ConfirmationPromptOptions{
164165
InvalidBehavior: cfg.Confirmation.Invalid,
165166
EmptyBehavior: cfg.Confirmation.Empty,
166167
})

cmd/generation/rollback/rollback.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package rollback
22

33
import (
4+
"context"
45
"fmt"
56
"os"
67
"path/filepath"
@@ -98,7 +99,7 @@ func generationRollbackMain(cmd *cobra.Command, genOpts *cmdOpts.GenerationOpts,
9899
log.Printf("\n")
99100

100101
var confirm bool
101-
confirm, err = cmdUtils.ConfirmationInput("Activate the previous generation?", cmdUtils.ConfirmationPromptOptions{
102+
confirm, err = cmdUtils.ConfirmationInput(context.Background(), "Activate the previous generation?", cmdUtils.ConfirmationPromptOptions{
102103
InvalidBehavior: cfg.Confirmation.Invalid,
103104
EmptyBehavior: cfg.Confirmation.Empty,
104105
})

cmd/generation/switch/switch.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package switch_cmd
22

33
import (
4+
"context"
45
"fmt"
56
"os"
67
"path/filepath"
@@ -150,7 +151,7 @@ func generationSwitchMain(cmd *cobra.Command, genOpts *cmdOpts.GenerationOpts, o
150151
log.Printf("\n")
151152

152153
var confirm bool
153-
confirm, err = cmdUtils.ConfirmationInput("Activate this generation?", cmdUtils.ConfirmationPromptOptions{
154+
confirm, err = cmdUtils.ConfirmationInput(context.Background(), "Activate this generation?", cmdUtils.ConfirmationPromptOptions{
154155
InvalidBehavior: cfg.Confirmation.Invalid,
155156
EmptyBehavior: cfg.Confirmation.Empty,
156157
})

internal/cmd/utils/confirmation.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cmdUtils
22

33
import (
44
"bufio"
5+
"context"
56
"fmt"
67
"os"
78
"strings"
@@ -15,9 +16,26 @@ type ConfirmationPromptOptions struct {
1516
EmptyBehavior settings.ConfirmationPromptBehavior
1617
}
1718

18-
func ConfirmationInput(msg string, opts ConfirmationPromptOptions) (bool, error) {
19+
// This operation can be cancelled using the provided context,
20+
// but any errors returned from here MAY have the potential
21+
// to keep consuming stdin until another character is typed
22+
// if a duplicate instance of stdin cannot be opened, so
23+
// any errors here will result in potentiall undefined behavior
24+
// for stdin input.
25+
func ConfirmationInput(ctx context.Context, msg string, opts ConfirmationPromptOptions) (bool, error) {
26+
var stdin *os.File
27+
if dupStdin, err := os.OpenFile("/dev/stdin", os.O_RDONLY, 0); err == nil {
28+
defer dupStdin.Close()
29+
stdin = dupStdin
30+
} else {
31+
// NOTE: falling back to stdin will make context
32+
// cancellation behavior a bit unclear, as mentioned
33+
// in the doc comment.
34+
stdin = os.Stdin
35+
}
36+
1937
var input string
20-
scanner := bufio.NewScanner(os.Stdin)
38+
scanner := bufio.NewScanner(stdin)
2139

2240
for {
2341
fmt.Fprintf(os.Stderr, "%s\n[y/n]: ", color.GreenString("|> %s", msg))

internal/system/ssh.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ type SSHConfig struct {
4141
Address string
4242
Port int
4343

44-
AuthMethods []ssh.AuthMethod
45-
HostKeyCallback ssh.HostKeyCallback
44+
AuthMethods []ssh.AuthMethod
45+
HostKeyVerification settings.HostKeyVerificationType
46+
KnownHostsFiles []string
4647

4748
password []byte
4849

@@ -167,13 +168,9 @@ func NewSSHConfig(ctx context.Context, host string, log logger.Logger, options S
167168
})
168169
auth = append(auth, passwordCallback)
169170

170-
hostKeyCallback, err := knownHostsCallback(log, options.HostKeyVerification, options.KnownHostsFiles)
171-
if err != nil {
172-
return nil, err
173-
}
174-
175171
cfg.AuthMethods = auth
176-
cfg.HostKeyCallback = hostKeyCallback
172+
cfg.HostKeyVerification = options.HostKeyVerification
173+
cfg.KnownHostsFiles = options.KnownHostsFiles
177174

178175
return cfg, nil
179176
}
@@ -229,7 +226,7 @@ func knownHostsCallback(
229226
}
230227

231228
func NewSSHSystem(cfg *SSHConfig, log logger.Logger) (*SSHSystem, error) {
232-
client, sftpClient, err := dialClient(cfg)
229+
client, sftpClient, err := dialClient(log, cfg)
233230
if err != nil {
234231
return nil, err
235232
}
@@ -244,11 +241,16 @@ func NewSSHSystem(cfg *SSHConfig, log logger.Logger) (*SSHSystem, error) {
244241
}, nil
245242
}
246243

247-
func dialClient(cfg *SSHConfig) (*ssh.Client, *sftp.Client, error) {
244+
func dialClient(log logger.Logger, cfg *SSHConfig) (*ssh.Client, *sftp.Client, error) {
245+
hostKeyCallback, err := knownHostsCallback(log, cfg.HostKeyVerification, cfg.KnownHostsFiles)
246+
if err != nil {
247+
return nil, nil, err
248+
}
249+
248250
client, err := ssh.Dial("tcp", net.JoinHostPort(cfg.Address, strconv.Itoa(cfg.Port)), &ssh.ClientConfig{
249251
User: cfg.User,
250252
Auth: cfg.AuthMethods,
251-
HostKeyCallback: cfg.HostKeyCallback,
253+
HostKeyCallback: hostKeyCallback,
252254
Timeout: 30 * time.Second,
253255
})
254256
if err != nil {
@@ -268,7 +270,7 @@ func (s *SSHSystem) Reconnect() error {
268270
_ = s.sftp.Close()
269271
_ = s.client.Close()
270272

271-
client, sftpClient, err := dialClient(s.cfg)
273+
client, sftpClient, err := dialClient(s.logger, s.cfg)
272274
if err != nil {
273275
return fmt.Errorf("failed to reconnect to %s: %w", s.Address(), err)
274276
}
@@ -284,7 +286,7 @@ func (s *SSHSystem) Reconnect() error {
284286
//
285287
// Caller must keep the `cfg` field alive until ALL clones are closed.
286288
func (s *SSHSystem) Clone() (*SSHSystem, error) {
287-
client, sftpClient, err := dialClient(s.cfg)
289+
client, sftpClient, err := dialClient(s.logger, s.cfg)
288290
if err != nil {
289291
return nil, fmt.Errorf("failed to clone connection to %s: %w", s.Address(), err)
290292
}
@@ -379,7 +381,7 @@ func addKeyToKnownHostsCallback(
379381
log.Infof("SHA256 fingerprint: %s", fingerprint)
380382

381383
var confirm bool
382-
confirm, err = cmdUtils.ConfirmationInput("Are you sure you want to continue connecting?", cmdUtils.ConfirmationPromptOptions{
384+
confirm, err = cmdUtils.ConfirmationInput(context.Background(), "Are you sure you want to continue connecting?", cmdUtils.ConfirmationPromptOptions{
383385
// Copy the default SSH behavior of retrying for invalid input.
384386
// Disregard user configuration in this case, since this is mimicking
385387
// OpenSSH's behavior.

0 commit comments

Comments
 (0)