@@ -41,8 +41,9 @@ type SSHConfig struct {
4141 Address string
4242 Port int
4343
44- AuthMethods []ssh.AuthMethod
45- HostKeyCallback ssh.HostKeyCallback
44+ AuthMethods []ssh.AuthMethod
45+ HostKeyVerification settings.HostKeyVerificationType
46+ KnownHostsFiles []string
4647
4748 password []byte
4849
@@ -167,13 +168,9 @@ func NewSSHConfig(ctx context.Context, host string, log logger.Logger, options S
167168 })
168169 auth = append (auth , passwordCallback )
169170
170- hostKeyCallback , err := knownHostsCallback (log , options .HostKeyVerification , options .KnownHostsFiles )
171- if err != nil {
172- return nil , err
173- }
174-
175171 cfg .AuthMethods = auth
176- cfg .HostKeyCallback = hostKeyCallback
172+ cfg .HostKeyVerification = options .HostKeyVerification
173+ cfg .KnownHostsFiles = options .KnownHostsFiles
177174
178175 return cfg , nil
179176}
@@ -229,7 +226,7 @@ func knownHostsCallback(
229226}
230227
231228func NewSSHSystem (cfg * SSHConfig , log logger.Logger ) (* SSHSystem , error ) {
232- client , sftpClient , err := dialClient (cfg )
229+ client , sftpClient , err := dialClient (log , cfg )
233230 if err != nil {
234231 return nil , err
235232 }
@@ -244,11 +241,16 @@ func NewSSHSystem(cfg *SSHConfig, log logger.Logger) (*SSHSystem, error) {
244241 }, nil
245242}
246243
247- func dialClient (cfg * SSHConfig ) (* ssh.Client , * sftp.Client , error ) {
244+ func dialClient (log logger.Logger , cfg * SSHConfig ) (* ssh.Client , * sftp.Client , error ) {
245+ hostKeyCallback , err := knownHostsCallback (log , cfg .HostKeyVerification , cfg .KnownHostsFiles )
246+ if err != nil {
247+ return nil , nil , err
248+ }
249+
248250 client , err := ssh .Dial ("tcp" , net .JoinHostPort (cfg .Address , strconv .Itoa (cfg .Port )), & ssh.ClientConfig {
249251 User : cfg .User ,
250252 Auth : cfg .AuthMethods ,
251- HostKeyCallback : cfg . HostKeyCallback ,
253+ HostKeyCallback : hostKeyCallback ,
252254 Timeout : 30 * time .Second ,
253255 })
254256 if err != nil {
@@ -268,7 +270,7 @@ func (s *SSHSystem) Reconnect() error {
268270 _ = s .sftp .Close ()
269271 _ = s .client .Close ()
270272
271- client , sftpClient , err := dialClient (s .cfg )
273+ client , sftpClient , err := dialClient (s .logger , s . cfg )
272274 if err != nil {
273275 return fmt .Errorf ("failed to reconnect to %s: %w" , s .Address (), err )
274276 }
@@ -284,7 +286,7 @@ func (s *SSHSystem) Reconnect() error {
284286//
285287// Caller must keep the `cfg` field alive until ALL clones are closed.
286288func (s * SSHSystem ) Clone () (* SSHSystem , error ) {
287- client , sftpClient , err := dialClient (s .cfg )
289+ client , sftpClient , err := dialClient (s .logger , s . cfg )
288290 if err != nil {
289291 return nil , fmt .Errorf ("failed to clone connection to %s: %w" , s .Address (), err )
290292 }
@@ -379,7 +381,7 @@ func addKeyToKnownHostsCallback(
379381 log .Infof ("SHA256 fingerprint: %s" , fingerprint )
380382
381383 var confirm bool
382- confirm , err = cmdUtils .ConfirmationInput ("Are you sure you want to continue connecting?" , cmdUtils.ConfirmationPromptOptions {
384+ confirm , err = cmdUtils .ConfirmationInput (context . Background (), "Are you sure you want to continue connecting?" , cmdUtils.ConfirmationPromptOptions {
383385 // Copy the default SSH behavior of retrying for invalid input.
384386 // Disregard user configuration in this case, since this is mimicking
385387 // OpenSSH's behavior.
0 commit comments