Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 11 additions & 47 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"net/http"
"os"
"path/filepath"

"golang.org/x/oauth2"
)

// FleetCredentials holds Mercedes-Benz Fleet API credentials (OAuth2 + Kafka).
Expand Down Expand Up @@ -36,10 +38,10 @@ type VehicleSpecCredentialStore interface {
Clear() error
}

// Store reads and writes JSON-serializable data.
type Store interface {
Read(target any) error
Write(data any) error
// TokenStore reads and writes OAuth2 tokens.
type TokenStore interface {
Load() (*oauth2.Token, error)
Save(*oauth2.Token) error
Clear() error
}

Expand All @@ -49,7 +51,7 @@ type Option func(*config)
type config struct {
fleetCredentialStore FleetCredentialStore
vehicleSpecCredentialStore VehicleSpecCredentialStore
tokenStore Store
tokenStore TokenStore
httpClient *http.Client
}

Expand All @@ -64,7 +66,7 @@ func WithVehicleSpecCredentialStore(s VehicleSpecCredentialStore) Option {
}

// WithTokenStore sets the token store.
func WithTokenStore(s Store) Option {
func WithTokenStore(s TokenStore) Option {
return func(c *config) { c.tokenStore = s }
}

Expand Down Expand Up @@ -122,45 +124,7 @@ func NewVehicleSpecCredentialFileStore(path string) VehicleSpecCredentialStore {
return &fileStore[VehicleSpecCredentials]{path: path}
}

// FileStore is a file-backed store that uses encoding/json.
type FileStore struct {
path string
}

// NewFileStore creates a new file-backed store at the given path.
func NewFileStore(path string) *FileStore {
return &FileStore{path: path}
}

// Read unmarshals the file contents into target.
func (s *FileStore) Read(target any) error {
data, err := os.ReadFile(s.path)
if err != nil {
return fmt.Errorf("read store: %w", err)
}
if err := json.Unmarshal(data, target); err != nil {
return fmt.Errorf("unmarshal store: %w", err)
}
return nil
}

// Write marshals data and writes it to the file.
func (s *FileStore) Write(data any) error {
bytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return fmt.Errorf("marshal store: %w", err)
}
if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil {
return fmt.Errorf("create store dir: %w", err)
}
return os.WriteFile(s.path, bytes, 0o600)
}

// Clear removes the file.
func (s *FileStore) Clear() error {
err := os.Remove(s.path)
if err != nil && os.IsNotExist(err) {
return nil
}
return err
// NewTokenFileStore creates a file-backed token store.
func NewTokenFileStore(path string) TokenStore {
return &fileStore[oauth2.Token]{path: path}
}
27 changes: 18 additions & 9 deletions cli/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func newLoginFleetCommand(cfg *config) *cobra.Command {
return err
}
if cfg.tokenStore != nil {
if err := cfg.tokenStore.Write(token); err != nil {
if err := cfg.tokenStore.Save(token); err != nil {
return fmt.Errorf("write token: %w", err)
}
}
Expand Down Expand Up @@ -617,9 +617,10 @@ func newConsumeVehicleSignalsCommand(cfg *config) *cobra.Command {
if err != nil {
return err
}
var token oauth2.Token
var token *oauth2.Token
if cfg.tokenStore != nil {
if err := cfg.tokenStore.Read(&token); err != nil {
token, err = cfg.tokenStore.Load()
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("no credentials found, please login using `mbz auth login fleet`")
}
Expand Down Expand Up @@ -662,8 +663,12 @@ func newConsumeVehicleSignalsCommand(cfg *config) *cobra.Command {
kgo.ConsumerGroup(consumerGroup),
kgo.ConsumeTopics(topic),
kgo.SASL(oauth.Oauth(func(_ context.Context) (oauth.Auth, error) {
var accessToken string
if token != nil {
accessToken = token.AccessToken
}
return oauth.Auth{
Token: token.AccessToken,
Token: accessToken,
}, nil
})),
}
Expand Down Expand Up @@ -761,16 +766,17 @@ func newOAuth2Client(cmd *cobra.Command, cfg *config) (*mbz.Client, error) {
if err != nil {
return nil, err
}
var token oauth2.Token
var token *oauth2.Token
if cfg.tokenStore != nil {
if err := cfg.tokenStore.Read(&token); err != nil {
token, err = cfg.tokenStore.Load()
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("no credentials found, please login using `mbz auth login fleet`")
}
return nil, fmt.Errorf("read token: %w", err)
}
}
if token.Expiry.Before(time.Now()) {
if token == nil || token.Expiry.Before(time.Now()) {
return nil, fmt.Errorf("invalid token, please login using `mbz auth login fleet`")
}
region, err := resolveOAuth2Region(creds, token)
Expand All @@ -779,7 +785,7 @@ func newOAuth2Client(cmd *cobra.Command, cfg *config) (*mbz.Client, error) {
}
opts := []mbz.ClientOption{
mbz.WithRegion(region),
mbz.WithOAuth2TokenSource(oauth2.StaticTokenSource(&token)),
mbz.WithOAuth2TokenSource(oauth2.StaticTokenSource(token)),
}
if cfg.httpClient != nil {
opts = append(opts, mbz.WithHTTPClient(cfg.httpClient))
Expand Down Expand Up @@ -818,10 +824,13 @@ func promptSecret(cmd *cobra.Command, prompt string) (string, error) {
return string(input), nil
}

func resolveOAuth2Region(creds *FleetCredentials, token oauth2.Token) (mbz.Region, error) {
func resolveOAuth2Region(creds *FleetCredentials, token *oauth2.Token) (mbz.Region, error) {
if creds.Region != "" {
return mbz.Region(creds.Region), nil
}
if token == nil {
return "", fmt.Errorf("missing region and token")
}
return inferRegionFromAccessToken(token.AccessToken)
}

Expand Down
4 changes: 2 additions & 2 deletions cli/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestResolveOAuth2RegionUsesStoredRegionFirst(t *testing.T) {
}
region, err := resolveOAuth2Region(
creds,
oauth2.Token{AccessToken: testJWT("https://ssoalpha.dvb.corpinter.net/v1")},
&oauth2.Token{AccessToken: testJWT("https://ssoalpha.dvb.corpinter.net/v1")},
)
if err != nil {
t.Fatalf("resolve region: %v", err)
Expand All @@ -43,7 +43,7 @@ func TestResolveOAuth2RegionInfersFromTokenIssuer(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

region, err := resolveOAuth2Region(&FleetCredentials{}, oauth2.Token{
region, err := resolveOAuth2Region(&FleetCredentials{}, &oauth2.Token{
AccessToken: testJWT(tt.iss),
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion cmd/mbz/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func main() {
cmd := cli.NewCommand(
cli.WithFleetCredentialStore(cli.NewFleetCredentialFileStore(fleetCredPath)),
cli.WithVehicleSpecCredentialStore(cli.NewVehicleSpecCredentialFileStore(vspecCredPath)),
cli.WithTokenStore(cli.NewFileStore(tokenPath)),
cli.WithTokenStore(cli.NewTokenFileStore(tokenPath)),
cli.WithHTTPClient(&http.Client{
Transport: &mbz.DebugTransport{Enabled: &debug},
}),
Expand Down
Loading