diff --git a/config/config.go b/config/config.go index 3a002faecd..fd06e119e3 100644 --- a/config/config.go +++ b/config/config.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "golang.org/x/crypto/acme/autocert" "gopkg.in/yaml.v2" "github.com/prometheus/client_golang/prometheus" @@ -252,6 +253,14 @@ type Config struct { // TLS Config KubernetesEnableTLS bool `yaml:"kubernetes-enable-tls"` + // Letsencrypt + EnableLetsencrypt bool `yaml:"enable-letsencrypt"` + LetsencryptCache string `yaml:"letsencrypt-cache"` + LetsencryptEmail string `yaml:"letsencrypt-email"` + LetsencryptDomains *listFlag `yaml:"letsencrypt-domains"` + LetsencryptDirectoryURL string `yaml:"letsencrypt-directory-url"` + LetsencryptUserAgent string `yaml:"letsencrypt-user-agent"` + // API Monitoring ApiUsageMonitoringEnable bool `yaml:"enable-api-usage-monitoring"` ApiUsageMonitoringRealmKeys string `yaml:"api-usage-monitoring-realm-keys"` @@ -359,6 +368,7 @@ func NewConfig() *Config { cfg.LuaModules = commaListFlag() cfg.LuaSources = commaListFlag() cfg.Oauth2GrantTokeninfoKeys = commaListFlag() + cfg.LetsencryptDomains = commaListFlag() flag := flag.NewFlagSet("", flag.ExitOnError) flag.StringVar(&cfg.ConfigFile, "config-file", "", "if provided the flags will be loaded/overwritten by the values on the file (yaml)") @@ -585,6 +595,14 @@ func NewConfig() *Config { // Exclude insecure cipher suites flag.BoolVar(&cfg.ExcludeInsecureCipherSuites, "exclude-insecure-cipher-suites", false, "excludes insecure cipher suites") + // Letsencrypt + flag.BoolVar(&cfg.EnableLetsencrypt, "enable-letsencrypt", false, "enables letsencrypt autocert handling on the proxy") + flag.StringVar(&cfg.LetsencryptCache, "letsencrypt-cache", "", "Configure the autocert cert cache ") + flag.StringVar(&cfg.LetsencryptEmail, "letsencrypt-email", "", "Sets letsencrypt email address such that you can be reached by letsencrypt if something goes wrong") + flag.Var(cfg.LetsencryptDomains, "letsencrypt-domains", "An allow list of domains for autocert handling") + flag.StringVar(&cfg.LetsencryptDirectoryURL, "letsencrypt-directory-url", "", "Sets directory URL for testing") + flag.StringVar(&cfg.LetsencryptUserAgent, "letsencrypt-user-agent", "", "Sets httpclient useragent that calls letsencrypt that enables letsencrypt to limit you if something goes wrong") + // API Monitoring: flag.BoolVar(&cfg.ApiUsageMonitoringEnable, "enable-api-usage-monitoring", false, "enables the apiUsageMonitoring filter") flag.StringVar(&cfg.ApiUsageMonitoringRealmKeys, "api-usage-monitoring-realm-keys", "", "name of the property in the JWT payload that contains the authority realm") @@ -1138,9 +1156,35 @@ func (c *Config) ToOptions() skipper.Options { }) } + if c.EnableLetsencrypt { + wrappers = append(wrappers, func(handler http.Handler) http.Handler { + return net.NewLetsencrypt( + c.getLetsencryptCache(), + c.LetsencryptEmail, + c.LetsencryptDirectoryURL, + c.LetsencryptUserAgent, + c.LetsencryptDomains.values, + ).Handler(handler) + }) + + } + return options } +func (c *Config) getLetsencryptCache() autocert.Cache { + switch c.LetsencryptCache { + case "directory": + return autocert.DirCache(os.TempDir()) + case "remote": + return &net.RemoteCache{ + Client: &net.RedisRingClient{}, + } + default: + return &net.InmemoryCache{} + } +} + func (c *Config) getMinTLSVersion() uint16 { tlsVersionTable := map[string]uint16{ "1.3": tls.VersionTLS13, diff --git a/config/config_test.go b/config/config_test.go index 48a27745d9..116c0aef18 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -166,6 +166,7 @@ func defaultConfig(with func(*Config)) *Config { ClusterRatelimitMaxGroupShards: 1, ValidateQuery: true, ValidateQueryLog: true, + LetsencryptDomains: commaListFlag(), LuaModules: commaListFlag(), LuaSources: commaListFlag(), OpenPolicyAgentCleanerInterval: openpolicyagent.DefaultCleanIdlePeriod, diff --git a/net/letsencrypt.go b/net/letsencrypt.go new file mode 100644 index 0000000000..973096e320 --- /dev/null +++ b/net/letsencrypt.go @@ -0,0 +1,137 @@ +package net + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "regexp" + "strings" + "sync" + + "golang.org/x/crypto/acme" + "golang.org/x/crypto/acme/autocert" +) + +type InmemoryCache struct { + m sync.Map +} + +func (ic *InmemoryCache) Get(ctx context.Context, key string) ([]byte, error) { + if dat, ok := ic.m.Load(key); !ok { + return nil, fmt.Errorf("missing key %q", key) + } else { + if data, ok := dat.([]byte); !ok { + return nil, fmt.Errorf("failed to convert %q to []byte", dat) + } else { + return data, nil + } + } +} + +func (ic *InmemoryCache) Put(ctx context.Context, key string, data []byte) error { + ic.m.Store(key, data) + return nil +} + +func (ic *InmemoryCache) Delete(ctx context.Context, key string) error { + ic.m.Delete(key) + return nil +} + +type RemoteCache struct { + Client *RedisRingClient +} + +func (rc *RemoteCache) Get(ctx context.Context, key string) ([]byte, error) { + res, err := rc.Client.Get(ctx, key) + if err != nil { + return nil, err + } + return []byte(res), nil +} + +func (rc *RemoteCache) Delete(ctx context.Context, key string) error { + return rc.Client.Del(ctx, key) +} + +func (rc *RemoteCache) Put(ctx context.Context, key string, val []byte) error { + _, err := rc.Client.Set(ctx, key, val, 0) + return err +} + +func (rc *RemoteCache) Close() { + rc.Client.Close() +} + +type Letsencrypt struct { + manager *autocert.Manager +} + +// NewLetsencrypt creates a letsencrypt handler to automatically handle CSR challenges. +// +// The cache argument can be either +// +// - autocert.DirCache for a filesystem cache +// - inmemoryCache for in memory cache +// - remoteCache for redis based production cache to be shared between multiple skipper processes +func NewLetsencrypt(cache autocert.Cache, email, directoryURL, userAgent string, proposedDomains []string) *Letsencrypt { + domains := make([]string, 0, len(proposedDomains)) + for _, s := range proposedDomains { + if validateDomain(s) { + domains = append(domains, s) + } + } + + manager := &autocert.Manager{ + Cache: cache, + Email: email, + HostPolicy: autocert.HostWhitelist(domains...), + Prompt: autocert.AcceptTOS, + Client: &acme.Client{ + DirectoryURL: directoryURL, + UserAgent: userAgent, + HTTPClient: http.DefaultClient, + }, + } + + return &Letsencrypt{ + manager: manager, + } +} + +func (le *Letsencrypt) Handler(fallback http.Handler) http.Handler { + return le.manager.HTTPHandler(fallback) +} + +func (le *Letsencrypt) TLSConfig() *tls.Config { + return le.manager.TLSConfig() +} + +// Listener returns a net.Listener that need to be closed on exit or +// you leak a goroutine +func (le *Letsencrypt) Listener() net.Listener { + return le.manager.Listener() +} + +func (le *Letsencrypt) Client() *acme.Client { + return le.manager.Client +} + +func (le *Letsencrypt) Close() { + le.Listener().Close() +} + +var domainRegex = regexp.MustCompile("^[a-z0-9]+$") + +func validateDomain(s string) bool { + i := 0 + for w := range strings.SplitSeq(s, ".") { + if !domainRegex.MatchString(w) { + return false + } + i++ + } + return i > 1 +} diff --git a/net/letsencrypt_test.go b/net/letsencrypt_test.go new file mode 100644 index 0000000000..610633aeb3 --- /dev/null +++ b/net/letsencrypt_test.go @@ -0,0 +1,116 @@ +package net + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/zalando/skipper/net/redistest" +) + +func TestRemoteCache(t *testing.T) { + t.Logf("create redis..") + redisAddr, done := redistest.NewTestRedis(t) + defer done() + if redisAddr == "" { + t.Fatal("Failed to create redis 1") + } + + redisAddr2, done2 := redistest.NewTestRedis(t) + defer done2() + if redisAddr2 == "" { + t.Fatal("Failed to create redis 2") + } + + rc := RemoteCache{ + Client: NewRedisRingClient(&RedisOptions{ + Addrs: []string{redisAddr, redisAddr2}, + }), + } + defer rc.Close() + + if err := rc.Put(context.Background(), "foo", []byte("bar")); err != nil { + t.Fatalf("Failed to put: %v", err) + } + + if v, err := rc.Get(context.Background(), "foo"); err != nil { + t.Fatalf("Failed to get: %v", err) + } else { + t.Logf("%T %v %s", v, v, v) + if string(v) != "bar" { + t.Fatalf("Failed to get result, got: %q", string(v)) + } + } + + if err := rc.Delete(context.Background(), "foo"); err != nil { + t.Fatalf("Failed to delete: %v", err) + } +} + +func TestInmemoryCache(t *testing.T) { + rc := &InmemoryCache{} + + if _, err := rc.Get(context.Background(), "foo"); err == nil { + t.Fatal(`Failed can not get "foo" on empty cache`) + } + + if err := rc.Put(context.Background(), "foo", []byte("bar")); err != nil { + t.Fatalf("Failed to put: %v", err) + } + + if v, err := rc.Get(context.Background(), "foo"); err != nil { + t.Fatalf("Failed to get: %v", err) + } else { + t.Logf("%T %v %s", v, v, v) + } + + if err := rc.Delete(context.Background(), "foo"); err != nil { + t.Fatalf("Failed to delete: %v", err) + } + + if err := rc.Put(context.Background(), "foo2", []byte("ΓΌ")); err != nil { + t.Fatalf("Failed to put: %v", err) + } + + if v, err := rc.Get(context.Background(), "foo2"); err != nil { + t.Fatalf("Failed to get: %v", err) + } else { + t.Logf("%T %v %s", v, v, v) + } + +} + +func TestLetsencrypt(t *testing.T) { + invalidDomain := "s_.example.org" + if validateDomain(invalidDomain) { + t.Fatalf("Failed to validate invalid domain %q", invalidDomain) + } + validDomain := "example.org" + if !validateDomain(validDomain) { + t.Fatalf("Failed to validate valid domain %q", validDomain) + } + + le := NewLetsencrypt(&InmemoryCache{}, "skipper@example.org", "https://acme-staging-v02.api.letsencrypt.org/directory", "skipper-test TestLetsencrypt", []string{validDomain}) + defer le.Close() + if le.manager.Client != nil { + dir, err := le.manager.Client.Discover(context.TODO()) + if err != nil { + t.Fatalf("Failed to discover: %v", err) + } + t.Logf("order: %s", dir.OrderURL) + + defer func() { + if le.manager.Client.HTTPClient != nil { + le.manager.Client.HTTPClient.CloseIdleConnections() + } + }() + } + + require.NotNil(t, le.Client(), "client should not be nil") + require.NotNil(t, le.TLSConfig(), "TLSConfig should not be nil") + require.NotNil(t, le.Handler(nil), "http.Handler should not be nil") + + li := le.Listener() + defer li.Close() + t.Logf("listener %v", li.Addr()) +} diff --git a/net/redisclient.go b/net/redisclient.go index 5523683fa2..83463e4a0f 100644 --- a/net/redisclient.go +++ b/net/redisclient.go @@ -391,6 +391,11 @@ func (r *RedisRingClient) SetAddrs(ctx context.Context, addrs []string) { r.ring.SetAddrs(createAddressMap(addrs)) } +func (r *RedisRingClient) Del(ctx context.Context, key string) error { + res := r.ring.Del(ctx, key) + return res.Err() +} + func (r *RedisRingClient) Get(ctx context.Context, key string) (string, error) { res := r.ring.Get(ctx, key) return res.Val(), res.Err()