From f575a68fbd4d18a486ee669221f377b656ab8e85 Mon Sep 17 00:00:00 2001 From: James Balazs Date: Sun, 29 Mar 2026 21:43:25 +1100 Subject: [PATCH 01/11] feat: add error handling for initializers BREAKING CHANGE: adds err to several functions that previously just panicked --- README.md | 21 +++++++--- examples/keypress/main.go | 13 ++++-- examples/lightshow/main.go | 21 +++++++--- examples/reset/main.go | 13 ++++-- examples/volume_wheel/main.go | 17 ++++++-- input/battery_report.go | 8 ++-- input/input_report.go | 8 ++-- input/keypress_report.go | 8 ++-- speededitor.go | 77 +++++++++++++++++++++++------------ 9 files changed, 127 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index 1496ca3..0f7b8d9 100644 --- a/README.md +++ b/README.md @@ -47,8 +47,11 @@ Don't forget to defer a call to `Exit` to avoid memory leaks. Next we can initialize the client: -``` -client := speedEditor.NewClient() +```go +client, err := speedEditor.NewClient() +if err != nil { + log.Fatal(err) +} ``` This connects to the Speed Editor, requests the manufacturer info, and device info such as serial number, and sets up the default event handlers. @@ -118,8 +121,12 @@ keysByName := keys.ByName() leds := []uint32{keysByName[keys.CAM7.Led], keysByName[keys.CAM5.Led], keysByName[keys.CAM3.Led]} jogLeds := []uint8{keysByName[keys.SHTL.JogLed], keysByName[keys.JOG.JogLed], keysByName[keys.SCROLL.JogLed]} -client.SetLeds(leds) -client.SetJogLeds(jogLeds) +if err := client.SetLeds(leds); err != nil { + log.Fatal(err) +} +if err := client.SetJogLeds(jogLeds); err != nil { + log.Fatal(err) +} ``` `JOG`/`SCRL`/`SHTL` are on a different system to the other LEDs, so a different function is required to light them. @@ -137,8 +144,10 @@ Finally, there are a few different jog modes available on the device: You can switch modes via the client: -``` -client.SetJogMode(jogModes.ABSOLUTE) +```go +if err := client.SetJogMode(jogModes.ABSOLUTE); err != nil { + log.Fatal(err) +} ``` You will have to handle lighting the buttons yourself, if you want the modes to work like they do with the editor connected to Davinci. diff --git a/examples/keypress/main.go b/examples/keypress/main.go index e04ab75..ccf762c 100644 --- a/examples/keypress/main.go +++ b/examples/keypress/main.go @@ -16,7 +16,10 @@ func main() { } defer hid.Exit() - client := speedEditor.NewClient() + client, err := speedEditor.NewClient() + if err != nil { + log.Fatal(err) + } deviceInfo := client.GetDeviceInfo() @@ -32,10 +35,14 @@ func main() { func customKeyPressHandler(client speedEditor.SpeedEditorInt, report input.KeyPressReport) { for _, key := range report.Keys { if key.Led != keys.LED_NONE { - client.SetLeds([]uint32{key.Led}) + if err := client.SetLeds([]uint32{key.Led}); err != nil { + log.Printf("error setting LEDs: %v", err) + } } if key.JogLed != keys.LED_NONE { - client.SetJogLeds([]uint8{key.JogLed}) + if err := client.SetJogLeds([]uint8{key.JogLed}); err != nil { + log.Printf("error setting jog LEDs: %v", err) + } } } } diff --git a/examples/lightshow/main.go b/examples/lightshow/main.go index a1c97b3..2f1082d 100644 --- a/examples/lightshow/main.go +++ b/examples/lightshow/main.go @@ -16,7 +16,10 @@ func main() { } defer hid.Exit() - client := speedEditor.NewClient() + client, err := speedEditor.NewClient() + if err != nil { + log.Fatal(err) + } deviceInfo := client.GetDeviceInfo() @@ -38,8 +41,12 @@ func main() { leds = append(leds, key.Led) } } - client.SetLeds(leds) - client.SetJogLeds(jogLeds) + if err := client.SetLeds(leds); err != nil { + log.Printf("error setting LEDs: %v", err) + } + if err := client.SetJogLeds(jogLeds); err != nil { + log.Printf("error setting jog LEDs: %v", err) + } time.Sleep(75 * time.Millisecond) } @@ -54,8 +61,12 @@ func main() { leds = append(leds, key.Led) } } - client.SetLeds(leds) - client.SetJogLeds(jogLeds) + if err := client.SetLeds(leds); err != nil { + log.Printf("error setting LEDs: %v", err) + } + if err := client.SetJogLeds(jogLeds); err != nil { + log.Printf("error setting jog LEDs: %v", err) + } time.Sleep(75 * time.Millisecond) } } diff --git a/examples/reset/main.go b/examples/reset/main.go index b2c520c..7715bf4 100644 --- a/examples/reset/main.go +++ b/examples/reset/main.go @@ -14,7 +14,10 @@ func main() { } defer hid.Exit() - client := speedEditor.NewClient() + client, err := speedEditor.NewClient() + if err != nil { + log.Fatal(err) + } deviceInfo := client.GetDeviceInfo() @@ -22,6 +25,10 @@ func main() { client.Authenticate() - client.SetLeds([]uint32{}) - client.SetJogLeds([]uint8{}) + if err := client.SetLeds([]uint32{}); err != nil { + log.Fatal(err) + } + if err := client.SetJogLeds([]uint8{}); err != nil { + log.Fatal(err) + } } diff --git a/examples/volume_wheel/main.go b/examples/volume_wheel/main.go index 603bf92..8c369ff 100644 --- a/examples/volume_wheel/main.go +++ b/examples/volume_wheel/main.go @@ -23,7 +23,10 @@ func main() { } defer hid.Exit() - client := speedEditor.NewClient() + client, err := speedEditor.NewClient() + if err != nil { + log.Fatal(err) + } deviceInfo := client.GetDeviceInfo() @@ -31,7 +34,9 @@ func main() { client.Authenticate() - client.SetJogMode(jogModes.ID_ABSOLUTE) + if err := client.SetJogMode(jogModes.ID_ABSOLUTE); err != nil { + log.Fatal(err) + } client.SetJogHandler(customJogHandler) client.SetKeyPressHandler(speedEditor.NullKeyPressHandler) @@ -61,6 +66,10 @@ func setLeds(client speedEditor.SpeedEditorInt, percent float64) { } } - client.SetLeds(leds) - client.SetJogLeds(jogLeds) + if err := client.SetLeds(leds); err != nil { + log.Printf("error setting LEDs: %v", err) + } + if err := client.SetJogLeds(jogLeds); err != nil { + log.Printf("error setting jog LEDs: %v", err) + } } diff --git a/input/battery_report.go b/input/battery_report.go index 8d1114b..a59dc30 100644 --- a/input/battery_report.go +++ b/input/battery_report.go @@ -1,19 +1,19 @@ package input import ( - "log" + "fmt" ) -func NewBatteryReport(id byte, payload []byte) BatteryReport { +func NewBatteryReport(id byte, payload []byte) (BatteryReport, error) { if id != ReportBattery { - log.Fatalf("malformed battery stats report id: %v payload: %v", id, payload) + return BatteryReport{}, fmt.Errorf("malformed battery stats report id: %v payload: %v", id, payload) } return BatteryReport{ Id: id, Charging: payload[0] == 1, Battery: float32(payload[1]) / 255, - } + }, nil } type BatteryReport struct { diff --git a/input/input_report.go b/input/input_report.go index df398c8..222b98c 100644 --- a/input/input_report.go +++ b/input/input_report.go @@ -1,6 +1,8 @@ package input import ( + "fmt" + "github.com/JamesBalazs/speed-editor-client/keys" ) @@ -22,19 +24,19 @@ var ( type ReportBytes []byte -func (byt ReportBytes) ToReport() any { +func (byt ReportBytes) ToReport() (any, error) { id := byt[0] payload := byt[1:] switch id { case ReportJog: - return NewJogReport(id, payload) + return NewJogReport(id, payload), nil case ReportKeyPress: return NewKeyPressReport(id, payload) case ReportBattery: return NewBatteryReport(id, payload) default: - return nil + return nil, fmt.Errorf("unknown report id: %v", id) } } diff --git a/input/keypress_report.go b/input/keypress_report.go index 2742a51..b3e0e62 100644 --- a/input/keypress_report.go +++ b/input/keypress_report.go @@ -2,14 +2,14 @@ package input import ( "encoding/binary" - "log" + "fmt" "github.com/JamesBalazs/speed-editor-client/keys" ) -func NewKeyPressReport(id byte, payload []byte) KeyPressReport { +func NewKeyPressReport(id byte, payload []byte) (KeyPressReport, error) { if id != ReportKeyPress { - log.Fatalf("malformed keypress input report id: %v payload: %v", id, payload) + return KeyPressReport{}, fmt.Errorf("malformed keypress input report id: %v payload: %v", id, payload) } keys := []keys.Key{} @@ -24,7 +24,7 @@ func NewKeyPressReport(id byte, payload []byte) KeyPressReport { return KeyPressReport{ Id: id, Keys: keys, - } + }, nil } type KeyPressReport struct { diff --git a/speededitor.go b/speededitor.go index cafb2bf..c8df6e0 100644 --- a/speededitor.go +++ b/speededitor.go @@ -3,7 +3,6 @@ package speedEditor import ( "encoding/binary" "fmt" - "log" "time" "github.com/JamesBalazs/speed-editor-client/input" @@ -26,19 +25,22 @@ const ( // creating the Speed Editor client, with `hid.Init()`. // // Ensure to use `defer hid.Exit()` to avoid memory leaks. -func NewClient() SpeedEditorInt { +func NewClient() (SpeedEditorInt, error) { device, err := hid.OpenFirst(VID, PID) if err != nil { - log.Fatal(err) + return nil, fmt.Errorf("failed to create client: %w", err) } speedEditor := &SpeedEditor{ device: device, AuthHandler: AuthHandler{device}, } - speedEditor.initialize() - return speedEditor + if err = speedEditor.initialize(); err != nil { + return nil, fmt.Errorf("failed to initialize: %w", err) + } + + return speedEditor, nil } type SpeedEditorInt interface { @@ -56,7 +58,7 @@ type SpeedEditorInt interface { // the data length. // // The first byte indicates which report type was received. - Read() ([]byte, int) + Read() ([]byte, int, error) // Poll starts a Read loop, parses each input report and calls Handle on each // via the Report interface. @@ -67,14 +69,14 @@ type SpeedEditorInt interface { // // SetLeds does not keep any state, so it will reset any previously enabled // LEDs if they aren't included in the next call. - SetLeds(leds []uint32) + SetLeds(leds []uint32) error // SetJogMode switches between the 4 jog modes: // RELATIVE - Relative position // ABSOLUTE - Absolute position from -4096 to 4096 // RELATIVE2 - Relative position, I think this is used to enable a faster scroll mode when the SCRL button is pressed twice in Resolve: https://www.reddit.com/r/blackmagicdesign/comments/1dv56d4/speed_editor_firmware_update_dial_speed_change/ // ABSOLUTE_0 - Absolute position from -4096 to 4096 with deadzone around 0 - SetJogMode(mode uint8) + SetJogMode(mode uint8) error // SetJogLeds accepts the bitmask for a list of LEDs, and binary ORs the bitmask // to enable all LEDs in the mask. Jog LEDs are on a separate system, and overlap @@ -82,7 +84,7 @@ type SpeedEditorInt interface { // // SetJogLeds does not keep any state, so it will reset any previously enabled // LEDs if they aren't included in the next call. - SetJogLeds(leds []uint8) + SetJogLeds(leds []uint8) error // SetJogHandler allows replacing the handler function that will be called on Poll() // when a JogReport is received. @@ -109,10 +111,10 @@ type SpeedEditor struct { // initialize grabs the device's serial number, manufacturer string etc via HID. // The handshake is not required before this step can take place. -func (se *SpeedEditor) initialize() { +func (se *SpeedEditor) initialize() error { deviceInfo, err := se.device.GetDeviceInfo() if err != nil { - log.Fatal(err) + return fmt.Errorf("failed to get device info: %w", err) } se.deviceInfo = *deviceInfo @@ -120,6 +122,8 @@ func (se *SpeedEditor) initialize() { se.SetJogHandler(defaultJogHandler) se.SetBatteryHandler(defaultBatteryHandler) se.SetKeyPressHandler(defaultKeyPressHandler) + + return nil } func (se SpeedEditor) Authenticate() { @@ -127,12 +131,8 @@ func (se SpeedEditor) Authenticate() { // Do not read or update this outside of the goroutine to avoid a data race. reAuthSeconds := se.AuthHandler.Authenticate() - fmt.Printf("Initial handshake\n") - go func() { for { - fmt.Printf("Sleeping %s\n", reAuthSeconds) - time.Sleep(reAuthSeconds) reAuthSeconds = se.AuthHandler.Authenticate() @@ -144,21 +144,29 @@ func (se SpeedEditor) GetDeviceInfo() hid.DeviceInfo { return se.deviceInfo } -func (se SpeedEditor) Read() ([]byte, int) { +func (se SpeedEditor) Read() ([]byte, int, error) { data := make([]byte, 9) - len, err := se.device.Read(data) + n, err := se.device.Read(data) if err != nil { - log.Fatal(err) + return nil, 0, fmt.Errorf("failed to read from device: %w", err) } - return data, len + return data, n, nil } func (se SpeedEditor) Poll() { for { - data, _ := se.Read() - report := input.ReportBytes(data).ToReport() + data, _, err := se.Read() + if err != nil { + fmt.Printf("error reading from device: %v\n", err) + continue + } + report, parseErr := input.ReportBytes(data).ToReport() + if parseErr != nil { + fmt.Printf("error parsing report: %v\n", parseErr) + continue + } se.HandleReport(report) } } @@ -174,7 +182,7 @@ func (se SpeedEditor) HandleReport(genericReport any) { } // TODO handle unknown reports (log error and continue) } -func (se SpeedEditor) SetLeds(leds []uint32) { +func (se SpeedEditor) SetLeds(leds []uint32) error { payload := make([]byte, 5) payload[0] = LedReportId @@ -184,20 +192,30 @@ func (se SpeedEditor) SetLeds(leds []uint32) { } binary.LittleEndian.PutUint32(payload[1:], bitField) - se.device.Write(payload) + _, err := se.device.Write(payload) + if err != nil { + return fmt.Errorf("failed to set LEDs: %w", err) + } + + return nil } -func (se SpeedEditor) SetJogMode(mode uint8) { +func (se SpeedEditor) SetJogMode(mode uint8) error { payload := make([]byte, 7) payload[0] = JogModeReportId payload[1] = mode // bytes 3-6 are zero payload[6] = 255 // byte 7 has unknown purpose - se.device.Write(payload) + _, err := se.device.Write(payload) + if err != nil { + return fmt.Errorf("failed to set jog mode: %w", err) + } + + return nil } -func (se SpeedEditor) SetJogLeds(leds []uint8) { +func (se SpeedEditor) SetJogLeds(leds []uint8) error { payload := make([]byte, 2) payload[0] = JogLedReportId @@ -207,7 +225,12 @@ func (se SpeedEditor) SetJogLeds(leds []uint8) { } payload[1] = bitField - se.device.Write(payload) + _, err := se.device.Write(payload) + if err != nil { + return fmt.Errorf("failed to set jog LEDs: %w", err) + } + + return nil } func (se *SpeedEditor) SetJogHandler(handler func(SpeedEditorInt, input.JogReport)) { From df6bea73e248578a44ff86c85e58e81d48165501 Mon Sep 17 00:00:00 2001 From: James Balazs Date: Sun, 29 Mar 2026 21:49:38 +1100 Subject: [PATCH 02/11] docs: add language to all code blocks for syntax highlighting --- README.md | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 0f7b8d9..4713224 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ View docs on pkg.go.dev: To import as a dependency: -``` +```go go get github.com/JamesBalazs/speed-editor-client ``` @@ -36,7 +36,7 @@ The project depends on the [go-hid](https://github.com/sstallion/go-hid) library Before creating a Speed Editor client, we need to initialize the HID library: -``` +```go if err := hid.Init(); err != nil { log.Fatal(err) } @@ -58,7 +58,7 @@ This connects to the Speed Editor, requests the manufacturer info, and device in Device info is cached on initialize, since it will never change once the device is connected:= -``` +```go deviceInfo := client.GetDeviceInfo() fmt.Printf("Manufacturer: %s\nProduct: %s\nSerial: %s\n", deviceInfo.MfrStr, deviceInfo.ProductStr, deviceInfo.SerialNbr) @@ -75,13 +75,13 @@ I re-implemented his authentication algorithm in Go, and exported the underlying When using the client, you just need to call `Authenticate` before sending / receiving any messages, and the handshake will be handled for you: -``` +```go client.Authenticate() ``` Finally, to receive messages from the Speed Editor, you can call `Poll`. This will start a loop which does a blocking read, waiting for either a keypress, battery report, or jog wheel movement from the device: -``` +```go client.Poll() ``` @@ -89,7 +89,7 @@ When any of the aforementioned events happen, the corresponding Handler function The event handlers can be overridden by the user to implement custom functionality: -``` +```go func customJogHandler(client speedEditor.SpeedEditorInt, report input.JogReport) { fmt.Printf("Jog wheel position: %d\n", report.Value) } @@ -115,7 +115,7 @@ This helps light LEDs based on their position such as in the lightshow and volum You can light any combination of LEDs on the board: -``` +```go keysByName := keys.ByName() leds := []uint32{keysByName[keys.CAM7.Led], keysByName[keys.CAM5.Led], keysByName[keys.CAM3.Led]} @@ -154,13 +154,13 @@ You will have to handle lighting the buttons yourself, if you want the modes to You can also get a list of keys and their attributes, with deterministic ordering: -``` +```go keys.Get() ``` And maps / indexes of keys by name, ID, LED ID etc so you can easily retrieve key details given only a single attribute, in constant time: -``` +```go keys.ById() keys.ByName() keys.ByLedId() @@ -173,7 +173,7 @@ keys.ByRow() The same goes for jog modes: -``` +```go jogModes.Get() jogModes.ById() @@ -189,44 +189,44 @@ This is done since manipulating the maps could mess up the underlying key data i My setup is weird (WSL remote via Zed) so some extra steps are required to pass the Speed Editor through to WSL Installing [usbipd](https://github.com/dorssel/usbipd-win): -``` +```bash winget install usbipd ``` Listing devices: -``` +```bash usbipd list ``` Binding the Speed Editor (persists reboot, your BUSID will be different to mine): -``` +```bash sudo usbipd bind --busid=4-9 ``` Attaching to WSL (does not persist reboot): -``` +```bash sudo usbipd attach --wsl --busid=4-9 ``` To confirm w/ [lshid](https://github.com/FFY00/lshid) within WSL: -``` +```bash $HOME/go/bin/lshid ``` Should output something like `/dev/hidraw0: ID 1edb:da0e Blackmagic Design DaVinci Resolve Speed Editor` ### Deps -``` +```bash sudo dnf install systemd-devel ``` To get `libudev.h` on Fedora (required for lshid) I then had permission issues reading from `/dev/hidraw0` so had to create a [udev](https://wiki.archlinux.org/title/Udev) rule: -``` +```bash KERNEL=="hidraw*", SUBSYSTEM=="hidraw", MODE="0660", GROUP="plugdev" ``` in `/etc/udev/rules.d/99-hidraw-permissions.rules`, then: -``` +```bash sudo groupadd plugdev sudo usermod -a -G plugdev james sudo udevadm control --reload @@ -238,12 +238,12 @@ After this `stat /dev/hidraw0` should list the new plugdev group. ### Cross platform builds for Windows mingw-w64 is required to compile the HID library on Linux for Windows with CGO. Installation: -``` +```bash sudo dnf install mingw64-gcc ``` To build the examples: -``` +```bash GOOS=windows GOARCH=amd64 CGO_ENABLED=1 CXX=x86_64-w64-mingw32-g++ CC=x86_64-w64-mingw32-gcc go build main.go ``` From 288b65441ff8d269bc47e340fdb431fb964d4721 Mon Sep 17 00:00:00 2001 From: James Balazs Date: Sun, 29 Mar 2026 22:28:25 +1100 Subject: [PATCH 03/11] test: add unit tests for client --- auth.go | 3 +- go.mod | 13 +- go.sum | 12 + speededitor.go | 47 ++- speededitor_test.go | 777 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 845 insertions(+), 7 deletions(-) create mode 100644 speededitor_test.go diff --git a/auth.go b/auth.go index b21226c..3882e1c 100644 --- a/auth.go +++ b/auth.go @@ -11,7 +11,6 @@ import ( "time" "github.com/JamesBalazs/speed-editor-client/auth" - "github.com/sstallion/go-hid" ) var ( @@ -38,7 +37,7 @@ type AuthHandlerInt interface { } type AuthHandler struct { - device *hid.Device + device deviceInterface } // Authenticate handles the entire handshake between the host and the Speed Editor. diff --git a/go.mod b/go.mod index 38ddcd4..8d8a8cd 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,15 @@ module github.com/JamesBalazs/speed-editor-client go 1.26.1 -require github.com/sstallion/go-hid v0.15.0 +require ( + github.com/sstallion/go-hid v0.15.0 + github.com/stretchr/testify v1.11.1 +) -require golang.org/x/sys v0.8.0 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect + golang.org/x/sys v0.8.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 5995c77..5be9a2a 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,16 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sstallion/go-hid v0.15.0 h1:WERW/VW3Us6N73V2qa7HjdqWQvwHd0CoRDOP/N707/w= github.com/sstallion/go-hid v0.15.0/go.mod h1:fPKp4rqx0xuoTV94gwKojsPG++KNKhxuU88goGuGM7I= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/speededitor.go b/speededitor.go index c8df6e0..81d10a5 100644 --- a/speededitor.go +++ b/speededitor.go @@ -18,6 +18,45 @@ const ( JogLedReportId = 4 ) +// deviceInterface defines the HID device operations for testability +type deviceInterface interface { + Close() error + Read(buf []byte) (int, error) + Write(buf []byte) (int, error) + GetDeviceInfo() (*hid.DeviceInfo, error) + GetFeatureReport(buf []byte) (int, error) + SendFeatureReport(buf []byte) (int, error) +} + +// hidDeviceWrapper wraps the hid.Device to implement deviceInterface +type hidDeviceWrapper struct { + device *hid.Device +} + +func (w *hidDeviceWrapper) Close() error { + return w.device.Close() +} + +func (w *hidDeviceWrapper) Read(buf []byte) (int, error) { + return w.device.Read(buf) +} + +func (w *hidDeviceWrapper) Write(buf []byte) (int, error) { + return w.device.Write(buf) +} + +func (w *hidDeviceWrapper) GetDeviceInfo() (*hid.DeviceInfo, error) { + return w.device.GetDeviceInfo() +} + +func (w *hidDeviceWrapper) GetFeatureReport(buf []byte) (int, error) { + return w.device.GetFeatureReport(buf) +} + +func (w *hidDeviceWrapper) SendFeatureReport(buf []byte) (int, error) { + return w.device.SendFeatureReport(buf) +} + // NewClient connects to a Speed Editor via the HID library // and returns a SpeedEditorInt to interact with the device. // @@ -31,9 +70,11 @@ func NewClient() (SpeedEditorInt, error) { return nil, fmt.Errorf("failed to create client: %w", err) } + wrapper := &hidDeviceWrapper{device: device} + speedEditor := &SpeedEditor{ - device: device, - AuthHandler: AuthHandler{device}, + device: wrapper, + AuthHandler: AuthHandler{device: wrapper}, } if err = speedEditor.initialize(); err != nil { @@ -98,7 +139,7 @@ type SpeedEditorInt interface { } type SpeedEditor struct { - device *hid.Device + device deviceInterface deviceInfo hid.DeviceInfo activeLeds []uint32 diff --git a/speededitor_test.go b/speededitor_test.go new file mode 100644 index 0000000..219aa8d --- /dev/null +++ b/speededitor_test.go @@ -0,0 +1,777 @@ +package speedEditor + +import ( + "encoding/binary" + "errors" + "testing" + "time" + + "github.com/JamesBalazs/speed-editor-client/input" + "github.com/JamesBalazs/speed-editor-client/keys" + "github.com/sstallion/go-hid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// MockHIDDevice is a mock implementation of deviceInterface for testing +type MockHIDDevice struct { + mock.Mock +} + +func (m *MockHIDDevice) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockHIDDevice) Read(buf []byte) (int, error) { + args := m.Called(buf) + return args.Int(0), args.Error(1) +} + +func (m *MockHIDDevice) Write(buf []byte) (int, error) { + args := m.Called(buf) + return args.Int(0), args.Error(1) +} + +func (m *MockHIDDevice) GetDeviceInfo() (*hid.DeviceInfo, error) { + args := m.Called() + return args.Get(0).(*hid.DeviceInfo), args.Error(1) +} + +func (m *MockHIDDevice) GetFeatureReport(buf []byte) (int, error) { + args := m.Called(buf) + return args.Int(0), args.Error(1) +} + +func (m *MockHIDDevice) SendFeatureReport(buf []byte) (int, error) { + args := m.Called(buf) + return args.Int(0), args.Error(1) +} + +// setupFixture creates a SpeedEditor instance with a mocked HID device +func setupFixture(t *testing.T) (*SpeedEditor, *MockHIDDevice) { + mockDevice := new(MockHIDDevice) + deviceInfo := &hid.DeviceInfo{ + MfrStr: "Test Manufacturer", + ProductStr: "Test Product", + SerialNbr: "TEST123", + } + + se := &SpeedEditor{ + device: mockDevice, + deviceInfo: *deviceInfo, + } + + return se, mockDevice +} + +// setupFixtureWithDeviceInfo creates a SpeedEditor with custom device info +func setupFixtureWithDeviceInfo(t *testing.T, deviceInfo *hid.DeviceInfo) (*SpeedEditor, *MockHIDDevice) { + mockDevice := new(MockHIDDevice) + + se := &SpeedEditor{ + device: mockDevice, + deviceInfo: *deviceInfo, + } + + return se, mockDevice +} + +// TestNewClient tests the NewClient function +func TestNewClient(t *testing.T) { + t.Run("connects to device or returns error", func(t *testing.T) { + // This test will succeed if a physical device is connected, + // or fail with an error if no device is present + client, err := NewClient() + + if err != nil { + // No device connected - verify error message + assert.Contains(t, err.Error(), "failed to create client") + assert.Nil(t, client) + } else { + // Device connected - verify client is valid + assert.NotNil(t, client) + deviceInfo := client.GetDeviceInfo() + assert.NotEmpty(t, deviceInfo.SerialNbr) + } + }) +} + +// TestInitialize tests the initialize method +func TestInitialize(t *testing.T) { + t.Run("success", func(t *testing.T) { + mockDevice := new(MockHIDDevice) + deviceInfo := &hid.DeviceInfo{ + MfrStr: "Test Manufacturer", + ProductStr: "Test Product", + SerialNbr: "TEST123", + } + mockDevice.On("GetDeviceInfo").Return(deviceInfo, nil).Once() + + se := &SpeedEditor{ + device: mockDevice, + AuthHandler: AuthHandler{}, + } + + err := se.initialize() + + require.NoError(t, err) + assert.NotNil(t, se.jogHandler) + assert.NotNil(t, se.batteryHandler) + assert.NotNil(t, se.keyPressHandler) + mockDevice.AssertExpectations(t) + }) + + t.Run("error getting device info", func(t *testing.T) { + mockDevice := new(MockHIDDevice) + mockDevice.On("GetDeviceInfo").Return((*hid.DeviceInfo)(nil), errors.New("device info error")).Once() + + se := &SpeedEditor{ + device: mockDevice, + AuthHandler: AuthHandler{}, + } + + err := se.initialize() + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get device info") + assert.Contains(t, err.Error(), "device info error") + mockDevice.AssertExpectations(t) + }) +} + +// TestGetDeviceInfo tests the GetDeviceInfo method +func TestGetDeviceInfo(t *testing.T) { + t.Run("success returns cached device info", func(t *testing.T) { + deviceInfo := &hid.DeviceInfo{ + MfrStr: "Test Manufacturer", + ProductStr: "Test Product", + SerialNbr: "TEST123", + } + + se := &SpeedEditor{ + deviceInfo: *deviceInfo, + } + + result := se.GetDeviceInfo() + assert.Equal(t, *deviceInfo, result) + }) +} + +// TestRead tests the Read method +func TestRead(t *testing.T) { + t.Run("success", func(t *testing.T) { + se, mockDevice := setupFixture(t) + expectedData := []byte{0x04, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00} + + mockDevice.On("Read", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + copy(buf, expectedData) + }).Return(len(expectedData), nil).Once() + + data, n, err := se.Read() + + require.NoError(t, err) + assert.Equal(t, len(expectedData), n) + assert.Equal(t, expectedData, data) + mockDevice.AssertExpectations(t) + }) + + t.Run("error reading from device", func(t *testing.T) { + se, mockDevice := setupFixture(t) + mockDevice.On("Read", mock.AnythingOfType("[]uint8")).Return(0, errors.New("read error")).Once() + + data, n, err := se.Read() + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to read from device") + assert.Contains(t, err.Error(), "read error") + assert.Nil(t, data) + assert.Equal(t, 0, n) + mockDevice.AssertExpectations(t) + }) +} + +// TestSetLeds tests the SetLeds method +func TestSetLeds(t *testing.T) { + t.Run("success single led", func(t *testing.T) { + se, mockDevice := setupFixture(t) + + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + assert.Equal(t, byte(LedReportId), buf[0]) + expectedBitmask := uint32(0x00000001) + actualBitmask := binary.LittleEndian.Uint32(buf[1:5]) + assert.Equal(t, expectedBitmask, actualBitmask) + }).Return(5, nil).Once() + + leds := []uint32{0x01} + err := se.SetLeds(leds) + + require.NoError(t, err) + mockDevice.AssertExpectations(t) + }) + + t.Run("success multiple leds", func(t *testing.T) { + se, mockDevice := setupFixture(t) + + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + assert.Equal(t, byte(LedReportId), buf[0]) + // Verify OR operation: 0x01 | 0x02 | 0x04 = 0x07 + expectedBitmask := uint32(0x07) + actualBitmask := binary.LittleEndian.Uint32(buf[1:5]) + assert.Equal(t, expectedBitmask, actualBitmask) + }).Return(5, nil).Once() + + leds := []uint32{0x01, 0x02, 0x04} + err := se.SetLeds(leds) + + require.NoError(t, err) + mockDevice.AssertExpectations(t) + }) + + t.Run("success empty leds", func(t *testing.T) { + se, mockDevice := setupFixture(t) + + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + assert.Equal(t, byte(LedReportId), buf[0]) + expectedBitmask := uint32(0x00000000) + actualBitmask := binary.LittleEndian.Uint32(buf[1:5]) + assert.Equal(t, expectedBitmask, actualBitmask) + }).Return(5, nil).Once() + + leds := []uint32{} + err := se.SetLeds(leds) + + require.NoError(t, err) + mockDevice.AssertExpectations(t) + }) + + t.Run("error writing to device", func(t *testing.T) { + se, mockDevice := setupFixture(t) + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Return(0, errors.New("write error")).Once() + + leds := []uint32{0x01} + err := se.SetLeds(leds) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to set LEDs") + assert.Contains(t, err.Error(), "write error") + mockDevice.AssertExpectations(t) + }) +} + +// TestSetJogMode tests the SetJogMode method +func TestSetJogMode(t *testing.T) { + t.Run("success absolute mode", func(t *testing.T) { + se, mockDevice := setupFixture(t) + + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + assert.Equal(t, byte(JogModeReportId), buf[0]) + assert.Equal(t, byte(1), buf[1]) // ABSOLUTE mode + assert.Equal(t, byte(0), buf[2]) + assert.Equal(t, byte(255), buf[6]) + }).Return(7, nil).Once() + + err := se.SetJogMode(1) + + require.NoError(t, err) + mockDevice.AssertExpectations(t) + }) + + t.Run("success relative mode", func(t *testing.T) { + se, mockDevice := setupFixture(t) + + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + assert.Equal(t, byte(JogModeReportId), buf[0]) + assert.Equal(t, byte(0), buf[1]) // RELATIVE mode + }).Return(7, nil).Once() + + err := se.SetJogMode(0) + + require.NoError(t, err) + mockDevice.AssertExpectations(t) + }) + + t.Run("error writing to device", func(t *testing.T) { + se, mockDevice := setupFixture(t) + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Return(0, errors.New("write error")).Once() + + err := se.SetJogMode(1) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to set jog mode") + assert.Contains(t, err.Error(), "write error") + mockDevice.AssertExpectations(t) + }) +} + +// TestSetJogLeds tests the SetJogLeds method +func TestSetJogLeds(t *testing.T) { + t.Run("success single jog led", func(t *testing.T) { + se, mockDevice := setupFixture(t) + + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + assert.Equal(t, byte(JogLedReportId), buf[0]) + assert.Equal(t, byte(0x01), buf[1]) + }).Return(2, nil).Once() + + leds := []uint8{0x01} + err := se.SetJogLeds(leds) + + require.NoError(t, err) + mockDevice.AssertExpectations(t) + }) + + t.Run("success multiple jog leds", func(t *testing.T) { + se, mockDevice := setupFixture(t) + + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + assert.Equal(t, byte(JogLedReportId), buf[0]) + // Verify OR operation: 0x01 | 0x02 | 0x04 = 0x07 + expectedBitmask := uint8(0x07) + actualBitmask := buf[1] + assert.Equal(t, expectedBitmask, actualBitmask) + }).Return(2, nil).Once() + + leds := []uint8{0x01, 0x02, 0x04} + err := se.SetJogLeds(leds) + + require.NoError(t, err) + mockDevice.AssertExpectations(t) + }) + + t.Run("success empty jog leds", func(t *testing.T) { + se, mockDevice := setupFixture(t) + + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + assert.Equal(t, byte(JogLedReportId), buf[0]) + assert.Equal(t, byte(0x00), buf[1]) + }).Return(2, nil).Once() + + leds := []uint8{} + err := se.SetJogLeds(leds) + + require.NoError(t, err) + mockDevice.AssertExpectations(t) + }) + + t.Run("error writing to device", func(t *testing.T) { + se, mockDevice := setupFixture(t) + mockDevice.On("Write", mock.AnythingOfType("[]uint8")).Return(0, errors.New("write error")).Once() + + leds := []uint8{0x01} + err := se.SetJogLeds(leds) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to set jog LEDs") + assert.Contains(t, err.Error(), "write error") + mockDevice.AssertExpectations(t) + }) +} + +// TestSetJogHandler tests the SetJogHandler method +func TestSetJogHandler(t *testing.T) { + t.Run("success", func(t *testing.T) { + se := &SpeedEditor{} + called := false + handler := func(client SpeedEditorInt, report input.JogReport) { + called = true + } + + se.SetJogHandler(handler) + assert.NotNil(t, se.jogHandler) + + se.jogHandler(se, input.JogReport{}) + assert.True(t, called) + }) + + t.Run("success replace handler", func(t *testing.T) { + se := &SpeedEditor{} + firstCalled := false + secondCalled := false + + firstHandler := func(client SpeedEditorInt, report input.JogReport) { + firstCalled = true + } + secondHandler := func(client SpeedEditorInt, report input.JogReport) { + secondCalled = true + } + + se.SetJogHandler(firstHandler) + se.SetJogHandler(secondHandler) + + se.jogHandler(se, input.JogReport{}) + assert.False(t, firstCalled) + assert.True(t, secondCalled) + }) +} + +// TestSetBatteryHandler tests the SetBatteryHandler method +func TestSetBatteryHandler(t *testing.T) { + t.Run("success", func(t *testing.T) { + se := &SpeedEditor{} + called := false + handler := func(client SpeedEditorInt, report input.BatteryReport) { + called = true + } + + se.SetBatteryHandler(handler) + assert.NotNil(t, se.batteryHandler) + + se.batteryHandler(se, input.BatteryReport{}) + assert.True(t, called) + }) + + t.Run("success replace handler", func(t *testing.T) { + se := &SpeedEditor{} + firstCalled := false + secondCalled := false + + firstHandler := func(client SpeedEditorInt, report input.BatteryReport) { + firstCalled = true + } + secondHandler := func(client SpeedEditorInt, report input.BatteryReport) { + secondCalled = true + } + + se.SetBatteryHandler(firstHandler) + se.SetBatteryHandler(secondHandler) + + se.batteryHandler(se, input.BatteryReport{}) + assert.False(t, firstCalled) + assert.True(t, secondCalled) + }) +} + +// TestSetKeyPressHandler tests the SetKeyPressHandler method +func TestSetKeyPressHandler(t *testing.T) { + t.Run("success", func(t *testing.T) { + se := &SpeedEditor{} + called := false + handler := func(client SpeedEditorInt, report input.KeyPressReport) { + called = true + } + + se.SetKeyPressHandler(handler) + assert.NotNil(t, se.keyPressHandler) + + se.keyPressHandler(se, input.KeyPressReport{}) + assert.True(t, called) + }) + + t.Run("success replace handler", func(t *testing.T) { + se := &SpeedEditor{} + firstCalled := false + secondCalled := false + + firstHandler := func(client SpeedEditorInt, report input.KeyPressReport) { + firstCalled = true + } + secondHandler := func(client SpeedEditorInt, report input.KeyPressReport) { + secondCalled = true + } + + se.SetKeyPressHandler(firstHandler) + se.SetKeyPressHandler(secondHandler) + + se.keyPressHandler(se, input.KeyPressReport{}) + assert.False(t, firstCalled) + assert.True(t, secondCalled) + }) +} + +// TestHandleReport tests the HandleReport method +func TestHandleReport(t *testing.T) { + t.Run("jog report", func(t *testing.T) { + se := &SpeedEditor{} + jogCalled := false + se.SetJogHandler(func(client SpeedEditorInt, report input.JogReport) { + jogCalled = true + assert.Equal(t, int32(100), report.Value) + }) + + report := input.JogReport{Value: 100} + se.HandleReport(report) + assert.True(t, jogCalled) + }) + + t.Run("battery report", func(t *testing.T) { + se := &SpeedEditor{} + batteryCalled := false + se.SetBatteryHandler(func(client SpeedEditorInt, report input.BatteryReport) { + batteryCalled = true + assert.Equal(t, float32(0.5), report.Battery) + }) + + report := input.BatteryReport{Battery: 0.5} + se.HandleReport(report) + assert.True(t, batteryCalled) + }) + + t.Run("keypress report", func(t *testing.T) { + se := &SpeedEditor{} + keyPressCalled := false + se.SetKeyPressHandler(func(client SpeedEditorInt, report input.KeyPressReport) { + keyPressCalled = true + assert.Len(t, report.Keys, 1) + }) + + keysByName := keys.ByName() + report := input.KeyPressReport{Keys: []keys.Key{keysByName[keys.CAM1]}} + se.HandleReport(report) + assert.True(t, keyPressCalled) + }) +} + +// TestHandleJog tests the HandleJog method +func TestHandleJog(t *testing.T) { + t.Run("success", func(t *testing.T) { + se := &SpeedEditor{} + handlerCalled := false + var receivedReport input.JogReport + + se.SetJogHandler(func(client SpeedEditorInt, report input.JogReport) { + handlerCalled = true + receivedReport = report + }) + + expectedReport := input.JogReport{ + Id: 3, + Value: 500, + Unknown: 0, + } + + se.HandleJog(expectedReport) + assert.True(t, handlerCalled) + assert.Equal(t, expectedReport, receivedReport) + }) +} + +// TestHandleBattery tests the HandleBattery method +func TestHandleBattery(t *testing.T) { + t.Run("success charging", func(t *testing.T) { + se := &SpeedEditor{} + handlerCalled := false + var receivedReport input.BatteryReport + + se.SetBatteryHandler(func(client SpeedEditorInt, report input.BatteryReport) { + handlerCalled = true + receivedReport = report + }) + + expectedReport := input.BatteryReport{ + Id: 7, + Charging: true, + Battery: 0.75, + } + + se.HandleBattery(expectedReport) + assert.True(t, handlerCalled) + assert.Equal(t, expectedReport, receivedReport) + }) + + t.Run("success not charging", func(t *testing.T) { + se := &SpeedEditor{} + handlerCalled := false + var receivedReport input.BatteryReport + + se.SetBatteryHandler(func(client SpeedEditorInt, report input.BatteryReport) { + handlerCalled = true + receivedReport = report + }) + + expectedReport := input.BatteryReport{ + Id: 7, + Charging: false, + Battery: 0.25, + } + + se.HandleBattery(expectedReport) + assert.True(t, handlerCalled) + assert.Equal(t, expectedReport, receivedReport) + }) +} + +// TestHandleKeyPress tests the HandleKeyPress method +func TestHandleKeyPress(t *testing.T) { + t.Run("single key", func(t *testing.T) { + se := &SpeedEditor{} + handlerCalled := false + var receivedReport input.KeyPressReport + + se.SetKeyPressHandler(func(client SpeedEditorInt, report input.KeyPressReport) { + handlerCalled = true + receivedReport = report + }) + + keysByName := keys.ByName() + expectedReport := input.KeyPressReport{ + Id: 4, + Keys: []keys.Key{keysByName[keys.CAM1]}, + } + + se.HandleKeyPress(expectedReport) + assert.True(t, handlerCalled) + assert.Equal(t, expectedReport, receivedReport) + }) + + t.Run("multiple keys", func(t *testing.T) { + se := &SpeedEditor{} + handlerCalled := false + var receivedReport input.KeyPressReport + + se.SetKeyPressHandler(func(client SpeedEditorInt, report input.KeyPressReport) { + handlerCalled = true + receivedReport = report + }) + + keysByName := keys.ByName() + expectedReport := input.KeyPressReport{ + Id: 4, + Keys: []keys.Key{keysByName[keys.CAM1], keysByName[keys.CAM2], keysByName[keys.CAM3]}, + } + + se.HandleKeyPress(expectedReport) + assert.True(t, handlerCalled) + assert.Equal(t, expectedReport, receivedReport) + }) + + t.Run("no keys", func(t *testing.T) { + se := &SpeedEditor{} + handlerCalled := false + var receivedReport input.KeyPressReport + + se.SetKeyPressHandler(func(client SpeedEditorInt, report input.KeyPressReport) { + handlerCalled = true + receivedReport = report + }) + + expectedReport := input.KeyPressReport{ + Id: 4, + Keys: []keys.Key{}, + } + + se.HandleKeyPress(expectedReport) + assert.True(t, handlerCalled) + assert.Equal(t, expectedReport, receivedReport) + }) +} + +// TestAuthenticate tests the Authenticate method +func TestAuthenticate(t *testing.T) { + t.Run("success starts goroutine", func(t *testing.T) { + mockDevice := new(MockHIDDevice) + + // Setup mock expectations for the auth handshake + // The auth flow has multiple GetFeatureReport calls with different expected headers + callCount := 0 + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(10, nil) + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + callCount++ + switch callCount { + case 1: + // Keyboard challenge response + buf[0] = 0x06 + buf[1] = 0x00 + binary.LittleEndian.PutUint64(buf[2:], uint64(12345)) + case 2: + // Host challenge response + buf[0] = 0x06 + buf[1] = 0x02 + case 3: + // Auth challenge result + buf[0] = 0x06 + buf[1] = 0x04 + binary.LittleEndian.PutUint16(buf[2:], uint16(65535)) + } + }).Return(10, nil) + + se := &SpeedEditor{ + device: mockDevice, + AuthHandler: AuthHandler{device: mockDevice}, + } + + assert.NotPanics(t, func() { + se.Authenticate() + }) + + time.Sleep(100 * time.Millisecond) + }) +} + +// TestDefaultHandlers tests the default handler functions +func TestDefaultHandlers(t *testing.T) { + t.Run("default jog handler", func(t *testing.T) { + se := &SpeedEditor{} + report := input.JogReport{Value: 100} + + assert.NotPanics(t, func() { + defaultJogHandler(se, report) + }) + }) + + t.Run("default battery handler", func(t *testing.T) { + se := &SpeedEditor{} + report := input.BatteryReport{ + Charging: true, + Battery: 0.5, + } + + assert.NotPanics(t, func() { + defaultBatteryHandler(se, report) + }) + }) + + t.Run("default keypress handler", func(t *testing.T) { + se := &SpeedEditor{} + keysByName := keys.ByName() + report := input.KeyPressReport{ + Keys: []keys.Key{keysByName[keys.CAM1]}, + } + + assert.NotPanics(t, func() { + defaultKeyPressHandler(se, report) + }) + }) +} + +// TestNullHandlers tests the null handler functions +func TestNullHandlers(t *testing.T) { + t.Run("null jog handler", func(t *testing.T) { + se := &SpeedEditor{} + report := input.JogReport{Value: 100} + + assert.NotPanics(t, func() { + NullJogHandler(se, report) + }) + }) + + t.Run("null battery handler", func(t *testing.T) { + se := &SpeedEditor{} + report := input.BatteryReport{Battery: 0.5} + + assert.NotPanics(t, func() { + NullBatteryHandler(se, report) + }) + }) + + t.Run("null keypress handler", func(t *testing.T) { + se := &SpeedEditor{} + keysByName := keys.ByName() + report := input.KeyPressReport{Keys: []keys.Key{keysByName[keys.CAM1]}} + + assert.NotPanics(t, func() { + NullKeyPressHandler(se, report) + }) + }) +} From 2648534a147df6ee94270670e6a62919c5378f33 Mon Sep 17 00:00:00 2001 From: James Balazs Date: Sun, 29 Mar 2026 22:35:25 +1100 Subject: [PATCH 04/11] test: add unit tests for auth handler --- auth_test.go | 634 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 634 insertions(+) create mode 100644 auth_test.go diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..ca718ac --- /dev/null +++ b/auth_test.go @@ -0,0 +1,634 @@ +package speedEditor + +import ( + "encoding/binary" + "errors" + "testing" + "time" + + "github.com/JamesBalazs/speed-editor-client/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// setupAuthFixture creates an AuthHandler with a mocked HID device +func setupAuthFixture(t *testing.T) (*AuthHandler, *MockHIDDevice) { + mockDevice := new(MockHIDDevice) + + ah := AuthHandler{ + device: mockDevice, + } + + return &ah, mockDevice +} + +// TestResetAuthState tests the ResetAuthState method +func TestResetAuthState(t *testing.T) { + t.Run("success", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + + assert.NotPanics(t, func() { + ah.ResetAuthState() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("error sending feature report panics", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(0, errors.New("send error")).Once() + + assert.Panics(t, func() { + ah.ResetAuthState() + }) + + mockDevice.AssertExpectations(t) + }) +} + +// TestGetKeyboardChallenge tests the GetKeyboardChallenge method +func TestGetKeyboardChallenge(t *testing.T) { + t.Run("success", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + expectedChallenge := uint64(1234567890) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x00 + binary.LittleEndian.PutUint64(buf[2:], expectedChallenge) + }).Return(10, nil).Once() + + var challenge uint64 + assert.NotPanics(t, func() { + challenge = ah.GetKeyboardChallenge() + }) + + assert.Equal(t, expectedChallenge, challenge) + mockDevice.AssertExpectations(t) + }) + + t.Run("error getting feature report panics", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() + + assert.Panics(t, func() { + ah.GetKeyboardChallenge() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("unexpected header panics", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + // Wrong header + buf[0] = 0x06 + buf[1] = 0x99 + }).Return(10, nil).Once() + + assert.Panics(t, func() { + ah.GetKeyboardChallenge() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("unexpected header with different wrong values panics", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + // Wrong header - different values + buf[0] = 0x00 + buf[1] = 0x00 + }).Return(10, nil).Once() + + assert.PanicsWithValue(t, "Unexpected auth response header: [0 0 0 0 0 0 0 0 0 0]", func() { + ah.GetKeyboardChallenge() + }) + + mockDevice.AssertExpectations(t) + }) +} + +// TestSendHostChallenge tests the SendHostChallenge method +func TestSendHostChallenge(t *testing.T) { + t.Run("success", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(10, nil).Once() + + assert.NotPanics(t, func() { + ah.SendHostChallenge() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("error sending feature report panics", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(0, errors.New("send error")).Once() + + assert.Panics(t, func() { + ah.SendHostChallenge() + }) + + mockDevice.AssertExpectations(t) + }) +} + +// TestGetHostChallengeResponse tests the GetHostChallengeResponse method +func TestGetHostChallengeResponse(t *testing.T) { + t.Run("success", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x02 + buf[2] = 0x01 + buf[3] = 0x02 + }).Return(10, nil).Once() + + var response []byte + assert.NotPanics(t, func() { + response = ah.GetHostChallengeResponse() + }) + + assert.Len(t, response, 10) + assert.Equal(t, byte(0x06), response[0]) + assert.Equal(t, byte(0x02), response[1]) + mockDevice.AssertExpectations(t) + }) + + t.Run("error getting feature report panics", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() + + assert.Panics(t, func() { + ah.GetHostChallengeResponse() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("unexpected header panics", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + // Wrong header + buf[0] = 0x06 + buf[1] = 0x99 + }).Return(10, nil).Once() + + assert.Panics(t, func() { + ah.GetHostChallengeResponse() + }) + + mockDevice.AssertExpectations(t) + }) +} + +// TestSendAuthChallengeResponse tests the SendAuthChallengeResponse method +func TestSendAuthChallengeResponse(t *testing.T) { + t.Run("success", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + expectedResponse := uint64(9876543210) + + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + // Verify the response bytes are correct + assert.Equal(t, byte(0x06), buf[0]) + assert.Equal(t, byte(0x03), buf[1]) + actualResponse := binary.LittleEndian.Uint64(buf[2:]) + assert.Equal(t, expectedResponse, actualResponse) + }).Return(10, nil).Once() + + assert.NotPanics(t, func() { + ah.SendAuthChallengeResponse(expectedResponse) + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("success with zero response", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(10, nil).Once() + + assert.NotPanics(t, func() { + ah.SendAuthChallengeResponse(0) + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("error sending feature report panics", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("send error")).Once() + + assert.Panics(t, func() { + ah.SendAuthChallengeResponse(12345) + }) + + mockDevice.AssertExpectations(t) + }) +} + +// TestGetAuthChallengeResult tests the GetAuthChallengeResult method +func TestGetAuthChallengeResult(t *testing.T) { + t.Run("success", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + expectedResult := uint16(65535) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x04 + binary.LittleEndian.PutUint16(buf[2:4], expectedResult) + }).Return(10, nil).Once() + + var result uint16 + assert.NotPanics(t, func() { + result = ah.GetAuthChallengeResult() + }) + + assert.Equal(t, expectedResult, result) + mockDevice.AssertExpectations(t) + }) + + t.Run("success with different timeout value", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + expectedResult := uint16(3600) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x04 + binary.LittleEndian.PutUint16(buf[2:4], expectedResult) + }).Return(10, nil).Once() + + var result uint16 + assert.NotPanics(t, func() { + result = ah.GetAuthChallengeResult() + }) + + assert.Equal(t, expectedResult, result) + mockDevice.AssertExpectations(t) + }) + + t.Run("error getting feature report panics", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() + + assert.Panics(t, func() { + ah.GetAuthChallengeResult() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("unexpected header panics", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + // Wrong header + buf[0] = 0x06 + buf[1] = 0x99 + }).Return(10, nil).Once() + + assert.Panics(t, func() { + ah.GetAuthChallengeResult() + }) + + mockDevice.AssertExpectations(t) + }) +} + +// TestAuthHandlerAuthenticate tests the full Authenticate method +func TestAuthHandlerAuthenticate(t *testing.T) { + t.Run("success full flow", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + keyboardChallenge := uint64(1234567890) + expectedReauthTimeout := uint16(65535) + + // Step 1: ResetAuthState - SendFeatureReport with default state + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + + // Step 2: GetKeyboardChallenge - GetFeatureReport with keyboard challenge header + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x00 + binary.LittleEndian.PutUint64(buf[2:], keyboardChallenge) + }).Return(10, nil).Once() + + // Step 3: SendHostChallenge - SendFeatureReport with host challenge + mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(10, nil).Once() + + // Step 4: GetHostChallengeResponse - GetFeatureReport with host challenge response header + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x02 + }).Return(10, nil).Once() + + // Step 5: SendAuthChallengeResponse - SendFeatureReport with calculated response + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + // Verify header + assert.Equal(t, byte(0x06), buf[0]) + assert.Equal(t, byte(0x03), buf[1]) + }).Return(10, nil).Once() + + // Step 6: GetAuthChallengeResult - GetFeatureReport with auth response header + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x04 + binary.LittleEndian.PutUint16(buf[2:4], expectedReauthTimeout) + }).Return(10, nil).Once() + + var reauthDuration time.Duration + assert.NotPanics(t, func() { + reauthDuration = ah.Authenticate() + }) + + // Verify the reauth duration is calculated correctly (timeout - 10 seconds) + expectedDuration := time.Duration(expectedReauthTimeout-10) * time.Second + assert.Equal(t, expectedDuration, reauthDuration) + mockDevice.AssertExpectations(t) + }) + + t.Run("success with different challenge values", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + keyboardChallenge := uint64(9999999999) + expectedReauthTimeout := uint16(3600) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x00 + binary.LittleEndian.PutUint64(buf[2:], keyboardChallenge) + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x02 + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x04 + binary.LittleEndian.PutUint16(buf[2:4], expectedReauthTimeout) + }).Return(10, nil).Once() + + var reauthDuration time.Duration + assert.NotPanics(t, func() { + reauthDuration = ah.Authenticate() + }) + + expectedDuration := time.Duration(expectedReauthTimeout-10) * time.Second + assert.Equal(t, expectedDuration, reauthDuration) + mockDevice.AssertExpectations(t) + }) + + t.Run("panic on ResetAuthState error", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(0, errors.New("send error")).Once() + + assert.Panics(t, func() { + ah.Authenticate() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("panic on GetKeyboardChallenge error", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() + + assert.Panics(t, func() { + ah.Authenticate() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("panic on GetKeyboardChallenge unexpected header", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x99 // Wrong header + }).Return(10, nil).Once() + + assert.Panics(t, func() { + ah.Authenticate() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("panic on SendHostChallenge error", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x00 + binary.LittleEndian.PutUint64(buf[2:], uint64(12345)) + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(0, errors.New("send error")).Once() + + assert.Panics(t, func() { + ah.Authenticate() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("panic on GetHostChallengeResponse error", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x00 + binary.LittleEndian.PutUint64(buf[2:], uint64(12345)) + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() + + assert.Panics(t, func() { + ah.Authenticate() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("panic on GetHostChallengeResponse unexpected header", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x00 + binary.LittleEndian.PutUint64(buf[2:], uint64(12345)) + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x99 // Wrong header + }).Return(10, nil).Once() + + assert.Panics(t, func() { + ah.Authenticate() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("panic on SendAuthChallengeResponse error", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x00 + binary.LittleEndian.PutUint64(buf[2:], uint64(12345)) + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x02 + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("send error")).Once() + + assert.Panics(t, func() { + ah.Authenticate() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("panic on GetAuthChallengeResult error", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x00 + binary.LittleEndian.PutUint64(buf[2:], uint64(12345)) + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x02 + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() + + assert.Panics(t, func() { + ah.Authenticate() + }) + + mockDevice.AssertExpectations(t) + }) + + t.Run("panic on GetAuthChallengeResult unexpected header", func(t *testing.T) { + ah, mockDevice := setupAuthFixture(t) + + mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x00 + binary.LittleEndian.PutUint64(buf[2:], uint64(12345)) + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x02 + }).Return(10, nil).Once() + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x99 // Wrong header + }).Return(10, nil).Once() + + assert.Panics(t, func() { + ah.Authenticate() + }) + + mockDevice.AssertExpectations(t) + }) +} + +// TestAuthChallengeCalculation verifies the challenge response calculation is correct +func TestAuthChallengeCalculation(t *testing.T) { + t.Run("calculate challenge response", func(t *testing.T) { + // Test that the auth.CalculateChallengeResponse function works + challenge := uint64(1234567890) + response := auth.CalculateChallengeResponse(challenge) + + // The response should be different from the challenge + assert.NotEqual(t, challenge, response) + // The response should be non-zero + assert.NotZero(t, response) + }) + + t.Run("calculate challenge response with zero", func(t *testing.T) { + challenge := uint64(0) + response := auth.CalculateChallengeResponse(challenge) + + // Response should be non-zero even with zero challenge + assert.NotZero(t, response) + }) + + t.Run("calculate challenge response with max value", func(t *testing.T) { + challenge := uint64(18446744073709551615) // Max uint64 + response := auth.CalculateChallengeResponse(challenge) + + // Response should be non-zero + assert.NotZero(t, response) + }) +} + +// TestAuthHandlerInterface verifies AuthHandler implements AuthHandlerInt +func TestAuthHandlerInterface(t *testing.T) { + t.Run("AuthHandler implements AuthHandlerInt", func(t *testing.T) { + var _ AuthHandlerInt = (*AuthHandler)(nil) + }) +} From 2b53fc7fde1b4addbd7e3edf2bf8d6f542af2ffc Mon Sep 17 00:00:00 2001 From: James Balazs Date: Sun, 29 Mar 2026 22:50:28 +1100 Subject: [PATCH 05/11] feat: add error handling for AuthHandler BREAKING CHANGE: AuthHandler functions now return errors instead of panicking --- .qwen/settings.json | 7 + README.md | 4 +- auth.go | 85 +++++++----- auth_test.go | 254 +++++++++++++++++++--------------- examples/keypress/main.go | 4 +- examples/lightshow/main.go | 4 +- examples/reset/main.go | 4 +- examples/volume_wheel/main.go | 4 +- speededitor.go | 16 ++- speededitor_test.go | 5 +- 10 files changed, 230 insertions(+), 157 deletions(-) create mode 100644 .qwen/settings.json diff --git a/.qwen/settings.json b/.qwen/settings.json new file mode 100644 index 0000000..502a9c0 --- /dev/null +++ b/.qwen/settings.json @@ -0,0 +1,7 @@ +{ + "permissions": { + "allow": [ + "Bash(go *)" + ] + } +} \ No newline at end of file diff --git a/README.md b/README.md index 4713224..6b0e438 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,9 @@ I re-implemented his authentication algorithm in Go, and exported the underlying When using the client, you just need to call `Authenticate` before sending / receiving any messages, and the handshake will be handled for you: ```go -client.Authenticate() +if err := client.Authenticate(); err != nil { + log.Fatal(err) +} ``` Finally, to receive messages from the Speed Editor, you can call `Poll`. This will start a loop which does a blocking read, waiting for either a keypress, battery report, or jog wheel movement from the device: diff --git a/auth.go b/auth.go index 3882e1c..97e28d8 100644 --- a/auth.go +++ b/auth.go @@ -27,13 +27,13 @@ var ( ) type AuthHandlerInt interface { - Authenticate() time.Duration - ResetAuthState() - GetKeyboardChallenge() uint64 - SendHostChallenge() - GetHostChallengeResponse() []byte - SendAuthChallengeResponse(response uint64) - GetAuthChallengeResult() uint16 + Authenticate() (time.Duration, error) + ResetAuthState() error + GetKeyboardChallenge() (uint64, error) + SendHostChallenge() error + GetHostChallengeResponse() ([]byte, error) + SendAuthChallengeResponse(response uint64) error + GetAuthChallengeResult() (uint16, error) } type AuthHandler struct { @@ -43,92 +43,111 @@ type AuthHandler struct { // Authenticate handles the entire handshake between the host and the Speed Editor. // // It returns the duration before the Speed Editor expects a reauth. -func (ah AuthHandler) Authenticate() time.Duration { - ah.ResetAuthState() - challenge := ah.GetKeyboardChallenge() +func (ah AuthHandler) Authenticate() (time.Duration, error) { + if err := ah.ResetAuthState(); err != nil { + return 0, fmt.Errorf("failed to reset auth state: %w", err) + } + + challenge, err := ah.GetKeyboardChallenge() + if err != nil { + return 0, fmt.Errorf("failed to get keyboard challenge: %w", err) + } + + if err := ah.SendHostChallenge(); err != nil { + return 0, fmt.Errorf("failed to send host challenge: %w", err) + } - ah.SendHostChallenge() - _ = ah.GetHostChallengeResponse() // We don't care about the response, since we don't care if it's a real Speed Editor + // We don't care about the response or error, since we don't care if it's a real Speed Editor + _, _ = ah.GetHostChallengeResponse() response := auth.CalculateChallengeResponse(challenge) - ah.SendAuthChallengeResponse(response) + if err := ah.SendAuthChallengeResponse(response); err != nil { + return 0, fmt.Errorf("failed to send auth challenge response: %w", err) + } + + reauthTimeout, err := ah.GetAuthChallengeResult() + if err != nil { + return 0, fmt.Errorf("failed to get auth challenge result: %w", err) + } - reauthTimeout = ah.GetAuthChallengeResult() - return time.Duration(reauthTimeout-10) * time.Second + return time.Duration(reauthTimeout-10) * time.Second, nil } -func (ah AuthHandler) ResetAuthState() { +func (ah AuthHandler) ResetAuthState() error { _, err := ah.device.SendFeatureReport(featureReportDefaultState) if err != nil { - panic(err.Error()) + return fmt.Errorf("failed to send feature report: %w", err) } + return nil } -func (ah AuthHandler) GetKeyboardChallenge() uint64 { +func (ah AuthHandler) GetKeyboardChallenge() (uint64, error) { // get keyboard challenge, store in a new copy of the byte array data := make([]byte, len(featureReportDefaultState)) copy(data, featureReportDefaultState) _, err := ah.device.GetFeatureReport(data) if err != nil { - panic(err.Error()) + return 0, fmt.Errorf("failed to get feature report: %w", err) } if !bytes.Equal(data[0:2], expectedKeyboardChallengeResponseHeader) { - panic(fmt.Sprintf("Unexpected auth response header: %v", data)) + return 0, fmt.Errorf("unexpected keyboard challenge response header: %v", data) } - return binary.LittleEndian.Uint64(data[2:]) + return binary.LittleEndian.Uint64(data[2:]), nil } // sendHostChallenge requests a challenge response from the device. // Presumably this step exists to confirm it's a real Speed Editor. -func (ah AuthHandler) SendHostChallenge() { +func (ah AuthHandler) SendHostChallenge() error { _, err := ah.device.SendFeatureReport(featureReportHostChallenge) if err != nil { - panic(err.Error()) + return fmt.Errorf("failed to send feature report: %w", err) } + return nil } -func (ah AuthHandler) GetHostChallengeResponse() []byte { +func (ah AuthHandler) GetHostChallengeResponse() ([]byte, error) { data := make([]byte, len(featureReportDefaultState)) copy(data, featureReportDefaultState) _, err := ah.device.GetFeatureReport(data) if err != nil { - panic(err.Error()) + return nil, fmt.Errorf("failed to get feature report: %w", err) } if !bytes.Equal(data[0:2], expectedHostChallengeResponseHeader) { - panic(fmt.Sprintf("Unexpected auth response header: %v", data)) + return nil, fmt.Errorf("unexpected host challenge response header: %v", data) } - return data + return data, nil } -func (ah AuthHandler) SendAuthChallengeResponse(response uint64) { +func (ah AuthHandler) SendAuthChallengeResponse(response uint64) error { responseBytes := make([]byte, len(authChallengeHeader)) copy(responseBytes, authChallengeHeader) responseBytes = binary.LittleEndian.AppendUint64(authChallengeHeader, response) _, err := ah.device.SendFeatureReport(responseBytes) if err != nil { - panic(err.Error()) + return fmt.Errorf("failed to send feature report: %w", err) } + return nil } -func (ah AuthHandler) GetAuthChallengeResult() uint16 { +func (ah AuthHandler) GetAuthChallengeResult() (uint16, error) { data := make([]byte, len(featureReportDefaultState)) copy(data, featureReportDefaultState) _, err := ah.device.GetFeatureReport(data) if err != nil { - panic(err.Error()) + return 0, fmt.Errorf("failed to get feature report: %w", err) } if !bytes.Equal(data[0:2], expectedAuthResponseHeader) { - panic(fmt.Sprintf("Unexpected auth response header: %v", data)) + return 0, fmt.Errorf("unexpected auth response header: %v", data) } - return binary.LittleEndian.Uint16(data[2:4]) + return binary.LittleEndian.Uint16(data[2:4]), nil } diff --git a/auth_test.go b/auth_test.go index ca718ac..c588b4b 100644 --- a/auth_test.go +++ b/auth_test.go @@ -9,6 +9,7 @@ import ( "github.com/JamesBalazs/speed-editor-client/auth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) // setupAuthFixture creates an AuthHandler with a mocked HID device @@ -29,22 +30,22 @@ func TestResetAuthState(t *testing.T) { mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() - assert.NotPanics(t, func() { - ah.ResetAuthState() - }) + err := ah.ResetAuthState() + require.NoError(t, err) mockDevice.AssertExpectations(t) }) - t.Run("error sending feature report panics", func(t *testing.T) { + t.Run("error sending feature report", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(0, errors.New("send error")).Once() - assert.Panics(t, func() { - ah.ResetAuthState() - }) + err := ah.ResetAuthState() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to send feature report") + assert.Contains(t, err.Error(), "send error") mockDevice.AssertExpectations(t) }) } @@ -62,28 +63,28 @@ func TestGetKeyboardChallenge(t *testing.T) { binary.LittleEndian.PutUint64(buf[2:], expectedChallenge) }).Return(10, nil).Once() - var challenge uint64 - assert.NotPanics(t, func() { - challenge = ah.GetKeyboardChallenge() - }) + challenge, err := ah.GetKeyboardChallenge() + require.NoError(t, err) assert.Equal(t, expectedChallenge, challenge) mockDevice.AssertExpectations(t) }) - t.Run("error getting feature report panics", func(t *testing.T) { + t.Run("error getting feature report", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() - assert.Panics(t, func() { - ah.GetKeyboardChallenge() - }) + challenge, err := ah.GetKeyboardChallenge() + require.Error(t, err) + assert.Equal(t, uint64(0), challenge) + assert.Contains(t, err.Error(), "failed to get feature report") + assert.Contains(t, err.Error(), "get error") mockDevice.AssertExpectations(t) }) - t.Run("unexpected header panics", func(t *testing.T) { + t.Run("unexpected header returns error", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { @@ -93,14 +94,15 @@ func TestGetKeyboardChallenge(t *testing.T) { buf[1] = 0x99 }).Return(10, nil).Once() - assert.Panics(t, func() { - ah.GetKeyboardChallenge() - }) + challenge, err := ah.GetKeyboardChallenge() + require.Error(t, err) + assert.Equal(t, uint64(0), challenge) + assert.Contains(t, err.Error(), "unexpected keyboard challenge response header") mockDevice.AssertExpectations(t) }) - t.Run("unexpected header with different wrong values panics", func(t *testing.T) { + t.Run("unexpected header with different wrong values returns error", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { @@ -110,10 +112,11 @@ func TestGetKeyboardChallenge(t *testing.T) { buf[1] = 0x00 }).Return(10, nil).Once() - assert.PanicsWithValue(t, "Unexpected auth response header: [0 0 0 0 0 0 0 0 0 0]", func() { - ah.GetKeyboardChallenge() - }) + challenge, err := ah.GetKeyboardChallenge() + require.Error(t, err) + assert.Equal(t, uint64(0), challenge) + assert.Contains(t, err.Error(), "unexpected keyboard challenge response header") mockDevice.AssertExpectations(t) }) } @@ -125,22 +128,22 @@ func TestSendHostChallenge(t *testing.T) { mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(10, nil).Once() - assert.NotPanics(t, func() { - ah.SendHostChallenge() - }) + err := ah.SendHostChallenge() + require.NoError(t, err) mockDevice.AssertExpectations(t) }) - t.Run("error sending feature report panics", func(t *testing.T) { + t.Run("error sending feature report", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(0, errors.New("send error")).Once() - assert.Panics(t, func() { - ah.SendHostChallenge() - }) + err := ah.SendHostChallenge() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to send feature report") + assert.Contains(t, err.Error(), "send error") mockDevice.AssertExpectations(t) }) } @@ -158,30 +161,30 @@ func TestGetHostChallengeResponse(t *testing.T) { buf[3] = 0x02 }).Return(10, nil).Once() - var response []byte - assert.NotPanics(t, func() { - response = ah.GetHostChallengeResponse() - }) + response, err := ah.GetHostChallengeResponse() + require.NoError(t, err) assert.Len(t, response, 10) assert.Equal(t, byte(0x06), response[0]) assert.Equal(t, byte(0x02), response[1]) mockDevice.AssertExpectations(t) }) - t.Run("error getting feature report panics", func(t *testing.T) { + t.Run("error getting feature report", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() - assert.Panics(t, func() { - ah.GetHostChallengeResponse() - }) + response, err := ah.GetHostChallengeResponse() + require.Error(t, err) + assert.Nil(t, response) + assert.Contains(t, err.Error(), "failed to get feature report") + assert.Contains(t, err.Error(), "get error") mockDevice.AssertExpectations(t) }) - t.Run("unexpected header panics", func(t *testing.T) { + t.Run("unexpected header returns error", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { @@ -191,10 +194,11 @@ func TestGetHostChallengeResponse(t *testing.T) { buf[1] = 0x99 }).Return(10, nil).Once() - assert.Panics(t, func() { - ah.GetHostChallengeResponse() - }) + response, err := ah.GetHostChallengeResponse() + require.Error(t, err) + assert.Nil(t, response) + assert.Contains(t, err.Error(), "unexpected host challenge response header") mockDevice.AssertExpectations(t) }) } @@ -214,10 +218,9 @@ func TestSendAuthChallengeResponse(t *testing.T) { assert.Equal(t, expectedResponse, actualResponse) }).Return(10, nil).Once() - assert.NotPanics(t, func() { - ah.SendAuthChallengeResponse(expectedResponse) - }) + err := ah.SendAuthChallengeResponse(expectedResponse) + require.NoError(t, err) mockDevice.AssertExpectations(t) }) @@ -226,22 +229,22 @@ func TestSendAuthChallengeResponse(t *testing.T) { mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(10, nil).Once() - assert.NotPanics(t, func() { - ah.SendAuthChallengeResponse(0) - }) + err := ah.SendAuthChallengeResponse(0) + require.NoError(t, err) mockDevice.AssertExpectations(t) }) - t.Run("error sending feature report panics", func(t *testing.T) { + t.Run("error sending feature report", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("send error")).Once() - assert.Panics(t, func() { - ah.SendAuthChallengeResponse(12345) - }) + err := ah.SendAuthChallengeResponse(12345) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to send feature report") + assert.Contains(t, err.Error(), "send error") mockDevice.AssertExpectations(t) }) } @@ -259,11 +262,9 @@ func TestGetAuthChallengeResult(t *testing.T) { binary.LittleEndian.PutUint16(buf[2:4], expectedResult) }).Return(10, nil).Once() - var result uint16 - assert.NotPanics(t, func() { - result = ah.GetAuthChallengeResult() - }) + result, err := ah.GetAuthChallengeResult() + require.NoError(t, err) assert.Equal(t, expectedResult, result) mockDevice.AssertExpectations(t) }) @@ -279,28 +280,28 @@ func TestGetAuthChallengeResult(t *testing.T) { binary.LittleEndian.PutUint16(buf[2:4], expectedResult) }).Return(10, nil).Once() - var result uint16 - assert.NotPanics(t, func() { - result = ah.GetAuthChallengeResult() - }) + result, err := ah.GetAuthChallengeResult() + require.NoError(t, err) assert.Equal(t, expectedResult, result) mockDevice.AssertExpectations(t) }) - t.Run("error getting feature report panics", func(t *testing.T) { + t.Run("error getting feature report", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() - assert.Panics(t, func() { - ah.GetAuthChallengeResult() - }) + result, err := ah.GetAuthChallengeResult() + require.Error(t, err) + assert.Equal(t, uint16(0), result) + assert.Contains(t, err.Error(), "failed to get feature report") + assert.Contains(t, err.Error(), "get error") mockDevice.AssertExpectations(t) }) - t.Run("unexpected header panics", func(t *testing.T) { + t.Run("unexpected header returns error", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { @@ -310,10 +311,11 @@ func TestGetAuthChallengeResult(t *testing.T) { buf[1] = 0x99 }).Return(10, nil).Once() - assert.Panics(t, func() { - ah.GetAuthChallengeResult() - }) + result, err := ah.GetAuthChallengeResult() + require.Error(t, err) + assert.Equal(t, uint16(0), result) + assert.Contains(t, err.Error(), "unexpected auth response header") mockDevice.AssertExpectations(t) }) } @@ -362,11 +364,9 @@ func TestAuthHandlerAuthenticate(t *testing.T) { binary.LittleEndian.PutUint16(buf[2:4], expectedReauthTimeout) }).Return(10, nil).Once() - var reauthDuration time.Duration - assert.NotPanics(t, func() { - reauthDuration = ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + require.NoError(t, err) // Verify the reauth duration is calculated correctly (timeout - 10 seconds) expectedDuration := time.Duration(expectedReauthTimeout-10) * time.Second assert.Equal(t, expectedDuration, reauthDuration) @@ -399,42 +399,44 @@ func TestAuthHandlerAuthenticate(t *testing.T) { binary.LittleEndian.PutUint16(buf[2:4], expectedReauthTimeout) }).Return(10, nil).Once() - var reauthDuration time.Duration - assert.NotPanics(t, func() { - reauthDuration = ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + require.NoError(t, err) expectedDuration := time.Duration(expectedReauthTimeout-10) * time.Second assert.Equal(t, expectedDuration, reauthDuration) mockDevice.AssertExpectations(t) }) - t.Run("panic on ResetAuthState error", func(t *testing.T) { + t.Run("error on ResetAuthState", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(0, errors.New("send error")).Once() - assert.Panics(t, func() { - ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + require.Error(t, err) + assert.Equal(t, time.Duration(0), reauthDuration) + assert.Contains(t, err.Error(), "failed to reset auth state") + assert.Contains(t, err.Error(), "send error") mockDevice.AssertExpectations(t) }) - t.Run("panic on GetKeyboardChallenge error", func(t *testing.T) { + t.Run("error on GetKeyboardChallenge", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() - assert.Panics(t, func() { - ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + require.Error(t, err) + assert.Equal(t, time.Duration(0), reauthDuration) + assert.Contains(t, err.Error(), "failed to get keyboard challenge") + assert.Contains(t, err.Error(), "get error") mockDevice.AssertExpectations(t) }) - t.Run("panic on GetKeyboardChallenge unexpected header", func(t *testing.T) { + t.Run("error on GetKeyboardChallenge unexpected header", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() @@ -444,14 +446,16 @@ func TestAuthHandlerAuthenticate(t *testing.T) { buf[1] = 0x99 // Wrong header }).Return(10, nil).Once() - assert.Panics(t, func() { - ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + require.Error(t, err) + assert.Equal(t, time.Duration(0), reauthDuration) + assert.Contains(t, err.Error(), "failed to get keyboard challenge") + assert.Contains(t, err.Error(), "unexpected keyboard challenge response header") mockDevice.AssertExpectations(t) }) - t.Run("panic on SendHostChallenge error", func(t *testing.T) { + t.Run("error on SendHostChallenge", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() @@ -463,14 +467,16 @@ func TestAuthHandlerAuthenticate(t *testing.T) { }).Return(10, nil).Once() mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(0, errors.New("send error")).Once() - assert.Panics(t, func() { - ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + require.Error(t, err) + assert.Equal(t, time.Duration(0), reauthDuration) + assert.Contains(t, err.Error(), "failed to send host challenge") + assert.Contains(t, err.Error(), "send error") mockDevice.AssertExpectations(t) }) - t.Run("panic on GetHostChallengeResponse error", func(t *testing.T) { + t.Run("error on GetHostChallengeResponse", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() @@ -482,15 +488,24 @@ func TestAuthHandlerAuthenticate(t *testing.T) { }).Return(10, nil).Once() mockDevice.On("SendFeatureReport", featureReportHostChallenge).Return(10, nil).Once() mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() + // GetHostChallengeResponse error is ignored, so Authenticate continues to SendAuthChallengeResponse + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x04 + binary.LittleEndian.PutUint16(buf[2:4], uint16(65535)) + }).Return(10, nil).Once() - assert.Panics(t, func() { - ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + // GetHostChallengeResponse error is ignored in Authenticate + require.NoError(t, err) + assert.NotEqual(t, time.Duration(0), reauthDuration) mockDevice.AssertExpectations(t) }) - t.Run("panic on GetHostChallengeResponse unexpected header", func(t *testing.T) { + t.Run("error on GetHostChallengeResponse unexpected header", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() @@ -506,15 +521,24 @@ func TestAuthHandlerAuthenticate(t *testing.T) { buf[0] = 0x06 buf[1] = 0x99 // Wrong header }).Return(10, nil).Once() + // GetHostChallengeResponse error is ignored, so Authenticate continues to SendAuthChallengeResponse + mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(10, nil).Once() + mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Run(func(args mock.Arguments) { + buf := args.Get(0).([]byte) + buf[0] = 0x06 + buf[1] = 0x04 + binary.LittleEndian.PutUint16(buf[2:4], uint16(65535)) + }).Return(10, nil).Once() - assert.Panics(t, func() { - ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + // GetHostChallengeResponse error is ignored in Authenticate + require.NoError(t, err) + assert.NotEqual(t, time.Duration(0), reauthDuration) mockDevice.AssertExpectations(t) }) - t.Run("panic on SendAuthChallengeResponse error", func(t *testing.T) { + t.Run("error on SendAuthChallengeResponse", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() @@ -532,14 +556,16 @@ func TestAuthHandlerAuthenticate(t *testing.T) { }).Return(10, nil).Once() mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("send error")).Once() - assert.Panics(t, func() { - ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + require.Error(t, err) + assert.Equal(t, time.Duration(0), reauthDuration) + assert.Contains(t, err.Error(), "failed to send auth challenge response") + assert.Contains(t, err.Error(), "send error") mockDevice.AssertExpectations(t) }) - t.Run("panic on GetAuthChallengeResult error", func(t *testing.T) { + t.Run("error on GetAuthChallengeResult", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() @@ -558,14 +584,16 @@ func TestAuthHandlerAuthenticate(t *testing.T) { mockDevice.On("SendFeatureReport", mock.AnythingOfType("[]uint8")).Return(10, nil).Once() mockDevice.On("GetFeatureReport", mock.AnythingOfType("[]uint8")).Return(0, errors.New("get error")).Once() - assert.Panics(t, func() { - ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + require.Error(t, err) + assert.Equal(t, time.Duration(0), reauthDuration) + assert.Contains(t, err.Error(), "failed to get auth challenge result") + assert.Contains(t, err.Error(), "get error") mockDevice.AssertExpectations(t) }) - t.Run("panic on GetAuthChallengeResult unexpected header", func(t *testing.T) { + t.Run("error on GetAuthChallengeResult unexpected header", func(t *testing.T) { ah, mockDevice := setupAuthFixture(t) mockDevice.On("SendFeatureReport", featureReportDefaultState).Return(10, nil).Once() @@ -588,10 +616,12 @@ func TestAuthHandlerAuthenticate(t *testing.T) { buf[1] = 0x99 // Wrong header }).Return(10, nil).Once() - assert.Panics(t, func() { - ah.Authenticate() - }) + reauthDuration, err := ah.Authenticate() + require.Error(t, err) + assert.Equal(t, time.Duration(0), reauthDuration) + assert.Contains(t, err.Error(), "failed to get auth challenge result") + assert.Contains(t, err.Error(), "unexpected auth response header") mockDevice.AssertExpectations(t) }) } diff --git a/examples/keypress/main.go b/examples/keypress/main.go index ccf762c..fe88c28 100644 --- a/examples/keypress/main.go +++ b/examples/keypress/main.go @@ -25,7 +25,9 @@ func main() { fmt.Printf("Manufacturer: %s\nProduct: %s\nSerial: %s\n", deviceInfo.MfrStr, deviceInfo.ProductStr, deviceInfo.SerialNbr) - client.Authenticate() + if err := client.Authenticate(); err != nil { + log.Fatal(err) + } client.SetKeyPressHandler(customKeyPressHandler) diff --git a/examples/lightshow/main.go b/examples/lightshow/main.go index 2f1082d..a8db09d 100644 --- a/examples/lightshow/main.go +++ b/examples/lightshow/main.go @@ -25,7 +25,9 @@ func main() { fmt.Printf("Manufacturer: %s\nProduct: %s\nSerial: %s\n", deviceInfo.MfrStr, deviceInfo.ProductStr, deviceInfo.SerialNbr) - client.Authenticate() + if err := client.Authenticate(); err != nil { + log.Fatal(err) + } keysByCol := keys.ByCol() keysByRow := keys.ByRow() diff --git a/examples/reset/main.go b/examples/reset/main.go index 7715bf4..039022d 100644 --- a/examples/reset/main.go +++ b/examples/reset/main.go @@ -23,7 +23,9 @@ func main() { fmt.Printf("Manufacturer: %s\nProduct: %s\nSerial: %s\n", deviceInfo.MfrStr, deviceInfo.ProductStr, deviceInfo.SerialNbr) - client.Authenticate() + if err := client.Authenticate(); err != nil { + log.Fatal(err) + } if err := client.SetLeds([]uint32{}); err != nil { log.Fatal(err) diff --git a/examples/volume_wheel/main.go b/examples/volume_wheel/main.go index 8c369ff..b5a82e0 100644 --- a/examples/volume_wheel/main.go +++ b/examples/volume_wheel/main.go @@ -32,7 +32,9 @@ func main() { fmt.Printf("Manufacturer: %s\nProduct: %s\nSerial: %s\n", deviceInfo.MfrStr, deviceInfo.ProductStr, deviceInfo.SerialNbr) - client.Authenticate() + if err := client.Authenticate(); err != nil { + log.Fatal(err) + } if err := client.SetJogMode(jogModes.ID_ABSOLUTE); err != nil { log.Fatal(err) diff --git a/speededitor.go b/speededitor.go index 81d10a5..9514f47 100644 --- a/speededitor.go +++ b/speededitor.go @@ -87,7 +87,7 @@ func NewClient() (SpeedEditorInt, error) { type SpeedEditorInt interface { // Authenticate does the initial handshake with the Speed Editor, // and re-auths periodically in the background when requested by the device. - Authenticate() + Authenticate() error // GetDeviceInfo returns the serial number, manufacturer string etc published // by the device via HID. This info is cached on init, so we don't have to @@ -167,18 +167,26 @@ func (se *SpeedEditor) initialize() error { return nil } -func (se SpeedEditor) Authenticate() { +func (se SpeedEditor) Authenticate() error { // Getting the initial reAuthSeconds synchronously. // Do not read or update this outside of the goroutine to avoid a data race. - reAuthSeconds := se.AuthHandler.Authenticate() + reAuthSeconds, err := se.AuthHandler.Authenticate() + if err != nil { + return fmt.Errorf("failed to authenticate: %w", err) + } go func() { for { time.Sleep(reAuthSeconds) - reAuthSeconds = se.AuthHandler.Authenticate() + reAuthSeconds, err = se.AuthHandler.Authenticate() + if err != nil { + fmt.Printf("failed to re-authenticate: %v\n", err) + } } }() + + return nil } func (se SpeedEditor) GetDeviceInfo() hid.DeviceInfo { diff --git a/speededitor_test.go b/speededitor_test.go index 219aa8d..118fd79 100644 --- a/speededitor_test.go +++ b/speededitor_test.go @@ -701,9 +701,8 @@ func TestAuthenticate(t *testing.T) { AuthHandler: AuthHandler{device: mockDevice}, } - assert.NotPanics(t, func() { - se.Authenticate() - }) + err := se.Authenticate() + require.NoError(t, err) time.Sleep(100 * time.Millisecond) }) From 8335d42788279d42793fac97f7e21591d612cab0 Mon Sep 17 00:00:00 2001 From: James Balazs Date: Sun, 29 Mar 2026 22:51:57 +1100 Subject: [PATCH 06/11] style: go fmt --- speededitor_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/speededitor_test.go b/speededitor_test.go index 118fd79..fae1345 100644 --- a/speededitor_test.go +++ b/speededitor_test.go @@ -84,7 +84,7 @@ func TestNewClient(t *testing.T) { // This test will succeed if a physical device is connected, // or fail with an error if no device is present client, err := NewClient() - + if err != nil { // No device connected - verify error message assert.Contains(t, err.Error(), "failed to create client") @@ -145,9 +145,9 @@ func TestInitialize(t *testing.T) { func TestGetDeviceInfo(t *testing.T) { t.Run("success returns cached device info", func(t *testing.T) { deviceInfo := &hid.DeviceInfo{ - MfrStr: "Test Manufacturer", + MfrStr: "Test Manufacturer", ProductStr: "Test Product", - SerialNbr: "TEST123", + SerialNbr: "TEST123", } se := &SpeedEditor{ From 9f5696831e298f93e6b8a306509e3824b87b6452 Mon Sep 17 00:00:00 2001 From: James Balazs Date: Sun, 29 Mar 2026 22:59:36 +1100 Subject: [PATCH 07/11] ci: add go test action --- .github/workflows/semantic-release.yaml | 1 + .github/workflows/test.yaml | 58 +++++++++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 .github/workflows/test.yaml diff --git a/.github/workflows/semantic-release.yaml b/.github/workflows/semantic-release.yaml index 6fd5cdb..8e3305f 100644 --- a/.github/workflows/semantic-release.yaml +++ b/.github/workflows/semantic-release.yaml @@ -1,4 +1,5 @@ name: Release + on: push: branches: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..2ba668d --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,58 @@ +name: Go CI + +on: + push: + branches: + - "**" + pull_request: + branches: + - "**" + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.26" # Specify your desired Go version + + - name: Update package source + run: sudo apt-get update + + - name: Install systemd-devel + run: sudo apt-get install libudev-dev + + - name: Download dependencies + run: go mod download + + - name: Build + run: go build -v ./... + + - name: Test with the Go CLI + run: go test -v ./... + + - name: Generate coverage report + run: go test -coverprofile=coverage.out ./... + + - name: Check coverage + id: coverage + run: | + coverage=$(go tool cover -func=coverage.out | grep total | awk '{print substr($3, 1, length($3)-1)}') + echo "total_coverage=$coverage" >> $GITHUB_OUTPUT + echo "Coverage: $coverage%" + + - name: Fail if coverage is below threshold + run: | + total_coverage="${{ steps.coverage.outputs.total_coverage }}" + if (( $(echo "$total_coverage < 50" | bc -l) )); then + echo "Coverage ($total_coverage%) is below the threshold (50%)" + exit 1 + fi From 3d9181b23f054882c9b5c42c54cb1f04420f834c Mon Sep 17 00:00:00 2001 From: James Balazs Date: Sun, 29 Mar 2026 23:35:35 +1100 Subject: [PATCH 08/11] fix: wrong copy order --- jog_modes/modes.go | 2 +- keys/keys.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jog_modes/modes.go b/jog_modes/modes.go index 56f3d63..7887ba6 100644 --- a/jog_modes/modes.go +++ b/jog_modes/modes.go @@ -30,7 +30,7 @@ var modes = []Mode{ // Get returns a new slice of all jog Modes each time it is called. func Get() []Mode { modesCopy := make([]Mode, len(modes)) - copy(modes, modesCopy) + copy(modesCopy, modes) return modesCopy } diff --git a/keys/keys.go b/keys/keys.go index e78e376..e9f0359 100644 --- a/keys/keys.go +++ b/keys/keys.go @@ -84,7 +84,7 @@ type Key struct { // Get returns a new slice of all Keys each time it is called. func Get() []Key { keysCopy := make([]Key, len(keys)) - copy(keys, keysCopy) + copy(keysCopy, keys) return keysCopy } From 6be6d107bde5a07e7191843a1d7406be058968d6 Mon Sep 17 00:00:00 2001 From: James Balazs Date: Sun, 29 Mar 2026 23:36:18 +1100 Subject: [PATCH 09/11] test: add getter tests --- .qwen/settings.json | 6 +- .qwen/settings.json.orig | 7 + jog_modes/modes_test.go | 278 ++++++++++++++++++++ keys/keys_test.go | 549 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 838 insertions(+), 2 deletions(-) create mode 100644 .qwen/settings.json.orig create mode 100644 jog_modes/modes_test.go create mode 100644 keys/keys_test.go diff --git a/.qwen/settings.json b/.qwen/settings.json index 502a9c0..404fae1 100644 --- a/.qwen/settings.json +++ b/.qwen/settings.json @@ -1,7 +1,9 @@ { "permissions": { "allow": [ - "Bash(go *)" + "Bash(go *)", + "Bash(sed *)" ] - } + }, + "$version": 3 } \ No newline at end of file diff --git a/.qwen/settings.json.orig b/.qwen/settings.json.orig new file mode 100644 index 0000000..502a9c0 --- /dev/null +++ b/.qwen/settings.json.orig @@ -0,0 +1,7 @@ +{ + "permissions": { + "allow": [ + "Bash(go *)" + ] + } +} \ No newline at end of file diff --git a/jog_modes/modes_test.go b/jog_modes/modes_test.go new file mode 100644 index 0000000..c7fbf65 --- /dev/null +++ b/jog_modes/modes_test.go @@ -0,0 +1,278 @@ +package jogModes + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestGet tests the Get function +func TestGet(t *testing.T) { + t.Run("returns all modes", func(t *testing.T) { + modes := Get() + + // Should return 4 modes + assert.Len(t, modes, 4) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + modes1 := Get() + modes2 := Get() + + // Should have equal content but be different underlying arrays + assert.Equal(t, modes1, modes2) + + // Modifying one should not affect the other + if len(modes1) > 0 { + originalName := modes1[0].Name + modes1[0].Name = "MODIFIED" + assert.NotEqual(t, modes1[0].Name, modes2[0].Name) + modes1[0].Name = originalName + } + }) + + t.Run("contains expected modes", func(t *testing.T) { + modes := Get() + + found := false + for _, mode := range modes { + if mode.Name == ABSOLUTE { + found = true + assert.Equal(t, ID_ABSOLUTE, mode.Id) + } + } + assert.True(t, found, "ABSOLUTE mode should be present") + }) + + t.Run("contains all mode types", func(t *testing.T) { + modes := Get() + + modeNames := make(map[string]bool) + for _, mode := range modes { + modeNames[mode.Name] = true + } + + assert.True(t, modeNames[RELATIVE], "RELATIVE mode should exist") + assert.True(t, modeNames[ABSOLUTE], "ABSOLUTE mode should exist") + assert.True(t, modeNames[RELATIVE_2], "RELATIVE_2 mode should exist") + assert.True(t, modeNames[ABSOLUTE_DEADZONE], "ABSOLUTE_DEADZONE mode should exist") + }) +} + +// TestByName tests the ByName function +func TestByName(t *testing.T) { + t.Run("returns map with all modes", func(t *testing.T) { + modesByName := ByName() + + assert.Len(t, modesByName, 4) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + map1 := ByName() + map2 := ByName() + + // Should have equal content but be different maps + assert.Equal(t, map1, map2) + }) + + t.Run("lookup relative mode", func(t *testing.T) { + modesByName := ByName() + + mode, exists := modesByName[RELATIVE] + assert.True(t, exists) + assert.Equal(t, ID_RELATIVE, mode.Id) + assert.Equal(t, RELATIVE, mode.Name) + }) + + t.Run("lookup absolute mode", func(t *testing.T) { + modesByName := ByName() + + mode, exists := modesByName[ABSOLUTE] + assert.True(t, exists) + assert.Equal(t, ID_ABSOLUTE, mode.Id) + assert.Equal(t, ABSOLUTE, mode.Name) + }) + + t.Run("lookup relative_2 mode", func(t *testing.T) { + modesByName := ByName() + + mode, exists := modesByName[RELATIVE_2] + assert.True(t, exists) + assert.Equal(t, ID_RELATIVE_2, mode.Id) + assert.Equal(t, RELATIVE_2, mode.Name) + }) + + t.Run("lookup absolute_deadzone mode", func(t *testing.T) { + modesByName := ByName() + + mode, exists := modesByName[ABSOLUTE_DEADZONE] + assert.True(t, exists) + assert.Equal(t, ID_ABSOLUTE_DEADZONE, mode.Id) + assert.Equal(t, ABSOLUTE_DEADZONE, mode.Name) + }) + + t.Run("non-existent mode returns zero value", func(t *testing.T) { + modesByName := ByName() + + mode, exists := modesByName["NON_EXISTENT"] + assert.False(t, exists) + assert.Equal(t, Mode{}, mode) + }) +} + +// TestById tests the ById function +func TestById(t *testing.T) { + t.Run("returns map with all modes", func(t *testing.T) { + modesById := ById() + + assert.Len(t, modesById, 4) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + map1 := ById() + map2 := ById() + + // Should have equal content but be different maps + assert.Equal(t, map1, map2) + }) + + t.Run("lookup relative mode by id", func(t *testing.T) { + modesById := ById() + + mode, exists := modesById[ID_RELATIVE] + assert.True(t, exists) + assert.Equal(t, RELATIVE, mode.Name) + }) + + t.Run("lookup absolute mode by id", func(t *testing.T) { + modesById := ById() + + mode, exists := modesById[ID_ABSOLUTE] + assert.True(t, exists) + assert.Equal(t, ABSOLUTE, mode.Name) + }) + + t.Run("lookup relative_2 mode by id", func(t *testing.T) { + modesById := ById() + + mode, exists := modesById[ID_RELATIVE_2] + assert.True(t, exists) + assert.Equal(t, RELATIVE_2, mode.Name) + }) + + t.Run("lookup absolute_deadzone mode by id", func(t *testing.T) { + modesById := ById() + + mode, exists := modesById[ID_ABSOLUTE_DEADZONE] + assert.True(t, exists) + assert.Equal(t, ABSOLUTE_DEADZONE, mode.Name) + }) + + t.Run("non-existent id returns zero value", func(t *testing.T) { + modesById := ById() + + mode, exists := modesById[999] + assert.False(t, exists) + assert.Equal(t, Mode{}, mode) + }) +} + +// TestModeConstants tests the mode constants +func TestModeConstants(t *testing.T) { + t.Run("relative constant", func(t *testing.T) { + assert.Equal(t, "RELATIVE", RELATIVE) + }) + + t.Run("absolute constant", func(t *testing.T) { + assert.Equal(t, "ABSOLUTE", ABSOLUTE) + }) + + t.Run("relative_2 constant", func(t *testing.T) { + assert.Equal(t, "RELATIVE_2", RELATIVE_2) + }) + + t.Run("absolute_deadzone constant", func(t *testing.T) { + assert.Equal(t, "ABSOLUTE_DEADZONE", ABSOLUTE_DEADZONE) + }) +} + +// TestIdConstants tests the ID constants +func TestIdConstants(t *testing.T) { + t.Run("id relative is zero", func(t *testing.T) { + assert.Equal(t, 0, ID_RELATIVE) + }) + + t.Run("id absolute is one", func(t *testing.T) { + assert.Equal(t, 1, ID_ABSOLUTE) + }) + + t.Run("id relative_2 is two", func(t *testing.T) { + assert.Equal(t, 2, ID_RELATIVE_2) + }) + + t.Run("id absolute_deadzone is three", func(t *testing.T) { + assert.Equal(t, 3, ID_ABSOLUTE_DEADZONE) + }) + + t.Run("ids are sequential", func(t *testing.T) { + assert.Equal(t, ID_RELATIVE+1, ID_ABSOLUTE) + assert.Equal(t, ID_ABSOLUTE+1, ID_RELATIVE_2) + assert.Equal(t, ID_RELATIVE_2+1, ID_ABSOLUTE_DEADZONE) + }) +} + +// TestAbsoluteConstants tests the absolute value constants +func TestAbsoluteConstants(t *testing.T) { + t.Run("absolute max", func(t *testing.T) { + assert.Equal(t, 4096, ABSOLUTE_MAX) + }) + + t.Run("absolute min", func(t *testing.T) { + assert.Equal(t, -4096, ABSOLUTE_MIN) + }) + + t.Run("max and min are symmetric", func(t *testing.T) { + assert.Equal(t, ABSOLUTE_MAX, -ABSOLUTE_MIN) + }) +} + +// TestModeStruct tests the Mode struct +func TestModeStruct(t *testing.T) { + t.Run("can create mode with all fields", func(t *testing.T) { + mode := Mode{ + Id: 5, + Name: "CUSTOM_MODE", + } + + assert.Equal(t, int(5), mode.Id) + assert.Equal(t, "CUSTOM_MODE", mode.Name) + }) + + t.Run("zero value mode", func(t *testing.T) { + var mode Mode + + assert.Equal(t, int(0), mode.Id) + assert.Equal(t, "", mode.Name) + }) +} + +// TestModesSlice tests the internal modes slice +func TestModesSlice(t *testing.T) { + t.Run("modes are in correct order", func(t *testing.T) { + modes := Get() + + // Verify the order matches the iota definition + assert.Equal(t, ID_RELATIVE, modes[0].Id) + assert.Equal(t, ID_ABSOLUTE, modes[1].Id) + assert.Equal(t, ID_RELATIVE_2, modes[2].Id) + assert.Equal(t, ID_ABSOLUTE_DEADZONE, modes[3].Id) + }) + + t.Run("all modes have non-empty names", func(t *testing.T) { + modes := Get() + + for _, mode := range modes { + assert.NotEmpty(t, mode.Name, "Mode %d should have a name", mode.Id) + } + }) +} diff --git a/keys/keys_test.go b/keys/keys_test.go new file mode 100644 index 0000000..99ea748 --- /dev/null +++ b/keys/keys_test.go @@ -0,0 +1,549 @@ +package keys + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestGet tests the Get function +func TestGet(t *testing.T) { + t.Run("returns all keys", func(t *testing.T) { + keys := Get() + + // Should return 43 keys + assert.Len(t, keys, 43) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + keys1 := Get() + keys2 := Get() + + // Should have equal content but be different underlying arrays + assert.Equal(t, keys1, keys2) + + // Modifying one should not affect the other + if len(keys1) > 0 { + originalName := keys1[0].Name + keys1[0].Name = "MODIFIED" + assert.NotEqual(t, keys1[0].Name, keys2[0].Name) + keys1[0].Name = originalName + } + }) + + t.Run("contains expected keys", func(t *testing.T) { + keys := Get() + + // Check for some expected keys + found := false + for _, key := range keys { + if key.Name == CAM1 { + found = true + assert.Equal(t, "CAM1", key.Name) + assert.NotZero(t, key.Id) + assert.NotZero(t, key.Led) + assert.Equal(t, 4, key.Row) + assert.Equal(t, float32(3), key.Col) + } + } + assert.True(t, found, "CAM1 key should be present") + }) + + t.Run("contains jog keys", func(t *testing.T) { + keys := Get() + + jogKeys := []string{JOG, SHTL, SCRL} + for _, expectedName := range jogKeys { + found := false + for _, key := range keys { + if key.Name == expectedName { + found = true + assert.NotZero(t, key.JogLed) + } + } + assert.True(t, found, "%s key should be present", expectedName) + } + }) + + t.Run("contains cam keys", func(t *testing.T) { + keys := Get() + + camKeys := []string{CAM1, CAM2, CAM3, CAM4, CAM5, CAM6, CAM7, CAM8, CAM9} + for _, expectedName := range camKeys { + found := false + for _, key := range keys { + if key.Name == expectedName { + found = true + assert.NotZero(t, key.Led) + } + } + assert.True(t, found, "%s key should be present", expectedName) + } + }) +} + +// TestByName tests the ByName function +func TestByName(t *testing.T) { + t.Run("returns map with all keys", func(t *testing.T) { + keysByName := ByName() + + assert.Len(t, keysByName, 43) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + map1 := ByName() + map2 := ByName() + + // Should have equal content but be different maps + assert.Equal(t, map1, map2) + }) + + t.Run("lookup by name", func(t *testing.T) { + keysByName := ByName() + + key, exists := keysByName[CAM1] + assert.True(t, exists) + assert.Equal(t, "CAM1", key.Name) + assert.NotZero(t, key.Led) + + key, exists = keysByName[TRANS] + assert.True(t, exists) + assert.Equal(t, "TRANS", key.Name) + assert.NotZero(t, key.Led) + }) + + t.Run("lookup jog keys", func(t *testing.T) { + keysByName := ByName() + + key, exists := keysByName[JOG] + assert.True(t, exists) + assert.Equal(t, "JOG", key.Name) + assert.NotZero(t, key.JogLed) + + key, exists = keysByName[SHTL] + assert.True(t, exists) + assert.Equal(t, "SHTL", key.Name) + assert.NotZero(t, key.JogLed) + + key, exists = keysByName[SCRL] + assert.True(t, exists) + assert.Equal(t, "SCRL", key.Name) + assert.NotZero(t, key.JogLed) + }) + + t.Run("lookup wide keys", func(t *testing.T) { + keysByName := ByName() + + key, exists := keysByName[IN] + assert.True(t, exists) + assert.Equal(t, float32(1.5), key.Width) + + key, exists = keysByName[OUT] + assert.True(t, exists) + assert.Equal(t, float32(1.5), key.Width) + + key, exists = keysByName[SOURCE] + assert.True(t, exists) + assert.Equal(t, float32(1.5), key.Width) + + key, exists = keysByName[TIMELINE] + assert.True(t, exists) + assert.Equal(t, float32(1.5), key.Width) + + key, exists = keysByName[STOP_PLAY] + assert.True(t, exists) + assert.Equal(t, float32(4), key.Width) + }) + + t.Run("non-existent key returns zero value", func(t *testing.T) { + keysByName := ByName() + + key, exists := keysByName["NON_EXISTENT"] + assert.False(t, exists) + assert.Equal(t, Key{}, key) + }) +} + +// TestById tests the ById function +func TestById(t *testing.T) { + t.Run("returns map with all keys", func(t *testing.T) { + keysById := ById() + + assert.Len(t, keysById, 43) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + map1 := ById() + map2 := ById() + + // Should have equal content but be different maps + assert.Equal(t, map1, map2) + }) + + t.Run("lookup by id", func(t *testing.T) { + keysById := ById() + + // Find a key by its ID + var cam1ID uint16 + for _, key := range Get() { + if key.Name == CAM1 { + cam1ID = key.Id + break + } + } + assert.NotZero(t, cam1ID) + + key, exists := keysById[cam1ID] + assert.True(t, exists) + assert.Equal(t, CAM1, key.Name) + }) + + t.Run("non-existent id returns zero value", func(t *testing.T) { + keysById := ById() + + key, exists := keysById[uint16(9999)] + assert.False(t, exists) + assert.Equal(t, Key{}, key) + }) +} + +// TestByLedId tests the ByLedId function +func TestByLedId(t *testing.T) { + t.Run("returns map with keys that have leds", func(t *testing.T) { + keysByLedId := ByLedId() + + // Only keys with non-zero LEDs will be in the map + // Many keys share LED_NONE (0), so they'll be overwritten + assert.Greater(t, len(keysByLedId), 1) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + map1 := ByLedId() + map2 := ByLedId() + + // Should have equal content but be different maps + assert.Equal(t, map1, map2) + }) + + t.Run("lookup by led id", func(t *testing.T) { + keysByLedId := ByLedId() + + // Find a key's LED ID + var cam1Led uint32 + for _, key := range Get() { + if key.Name == CAM1 { + cam1Led = key.Led + break + } + } + assert.NotZero(t, cam1Led) + + key, exists := keysByLedId[cam1Led] + assert.True(t, exists) + assert.Equal(t, CAM1, key.Name) + }) + + t.Run("non-existent led id returns zero value", func(t *testing.T) { + keysByLedId := ByLedId() + + key, exists := keysByLedId[uint32(9999)] + assert.False(t, exists) + assert.Equal(t, Key{}, key) + }) +} + +// TestByJogLedId tests the ByJogLedId function +func TestByJogLedId(t *testing.T) { + t.Run("returns map with jog leds", func(t *testing.T) { + keysByJogLedId := ByJogLedId() + + // Should have jog leds: LED_JOG, LED_SHTL, LED_SCRL, and LED_NONE + assert.GreaterOrEqual(t, len(keysByJogLedId), 3) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + map1 := ByJogLedId() + map2 := ByJogLedId() + + // Should have equal content but be different maps + assert.Equal(t, map1, map2) + }) + + t.Run("lookup by jog led id", func(t *testing.T) { + keysByJogLedId := ByJogLedId() + + // Find jog key's LED ID + var jogLed uint8 + for _, key := range Get() { + if key.Name == JOG { + jogLed = key.JogLed + break + } + } + assert.NotZero(t, jogLed) + + key, exists := keysByJogLedId[jogLed] + assert.True(t, exists) + assert.Equal(t, JOG, key.Name) + }) + + t.Run("non-existent jog led id returns zero value", func(t *testing.T) { + keysByJogLedId := ByJogLedId() + + key, exists := keysByJogLedId[uint8(99)] + assert.False(t, exists) + assert.Equal(t, Key{}, key) + }) +} + +// TestByText tests the ByText function +func TestByText(t *testing.T) { + t.Run("returns map with all keys", func(t *testing.T) { + keysByText := ByText() + + assert.Len(t, keysByText, 43) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + map1 := ByText() + map2 := ByText() + + // Should have equal content but be different maps + assert.Equal(t, map1, map2) + }) + + t.Run("lookup by text", func(t *testing.T) { + keysByText := ByText() + + key, exists := keysByText[TEXT_CAM1] + assert.True(t, exists) + assert.Equal(t, CAM1, key.Name) + + key, exists = keysByText[TEXT_TRANS] + assert.True(t, exists) + assert.Equal(t, TRANS, key.Name) + }) + + t.Run("non-existent text returns zero value", func(t *testing.T) { + keysByText := ByText() + + key, exists := keysByText["NON_EXISTENT_TEXT"] + assert.False(t, exists) + assert.Equal(t, Key{}, key) + }) +} + +// TestBySubText tests the BySubText function +func TestBySubText(t *testing.T) { + t.Run("returns map with keys that have subtext", func(t *testing.T) { + keysBySubText := BySubText() + + // Only keys with non-empty SubText will be in the map + // Many keys share empty SubText, so they'll be overwritten + assert.Greater(t, len(keysBySubText), 1) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + map1 := BySubText() + map2 := BySubText() + + // Should have equal content but be different maps + assert.Equal(t, map1, map2) + }) + + t.Run("lookup by subtext", func(t *testing.T) { + keysBySubText := BySubText() + + // Find a key with unique subtext and verify it maps correctly + // Note: Multiple keys may share the same subtext (e.g., "CLIP"), + // so we just verify the map returns a valid key for existing subtexts + for subtext, key := range keysBySubText { + if subtext != "" { + assert.NotEmpty(t, key.Name) + assert.Equal(t, subtext, key.SubText) + return + } + } + }) + + t.Run("non-existent subtext returns zero value", func(t *testing.T) { + keysBySubText := BySubText() + + key, exists := keysBySubText["NON_EXISTENT_SUBTEXT"] + assert.False(t, exists) + assert.Equal(t, Key{}, key) + }) +} + +// TestByCol tests the ByCol function +func TestByCol(t *testing.T) { + t.Run("returns nested map", func(t *testing.T) { + keysByCol := ByCol() + + // Should have multiple columns + assert.Greater(t, len(keysByCol), 0) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + map1 := ByCol() + map2 := ByCol() + + // Should have equal content but be different maps + assert.Equal(t, map1, map2) + }) + + t.Run("lookup by column and row", func(t *testing.T) { + keysByCol := ByCol() + + // CAM1 is at col 3, row 4 + col, exists := keysByCol[float32(3)] + assert.True(t, exists) + + key, exists := col[4] + assert.True(t, exists) + assert.Equal(t, CAM1, key.Name) + }) + + t.Run("lookup wide key", func(t *testing.T) { + keysByCol := ByCol() + + // IN is at col 0, row 2 + col, exists := keysByCol[float32(0)] + assert.True(t, exists) + + key, exists := col[2] + assert.True(t, exists) + assert.Equal(t, IN, key.Name) + }) + + t.Run("non-existent column returns empty map", func(t *testing.T) { + keysByCol := ByCol() + + col, exists := keysByCol[float32(999)] + assert.False(t, exists) + assert.Nil(t, col) + }) + + t.Run("non-existent row in existing column", func(t *testing.T) { + keysByCol := ByCol() + + col, exists := keysByCol[float32(0)] + assert.True(t, exists) + + key, exists := col[999] + assert.False(t, exists) + assert.Equal(t, Key{}, key) + }) +} + +// TestByRow tests the ByRow function +func TestByRow(t *testing.T) { + t.Run("returns nested map", func(t *testing.T) { + keysByRow := ByRow() + + // Should have 6 rows (0-5) + assert.Len(t, keysByRow, 6) + }) + + t.Run("returns copy not reference", func(t *testing.T) { + map1 := ByRow() + map2 := ByRow() + + // Should have equal content but be different maps + assert.Equal(t, map1, map2) + }) + + t.Run("lookup by row and column", func(t *testing.T) { + keysByRow := ByRow() + + // CAM1 is at row 4, col 3 + row, exists := keysByRow[4] + assert.True(t, exists) + + key, exists := row[float32(3)] + assert.True(t, exists) + assert.Equal(t, CAM1, key.Name) + }) + + t.Run("lookup wide key by row", func(t *testing.T) { + keysByRow := ByRow() + + // IN is at row 2, col 0 + row, exists := keysByRow[2] + assert.True(t, exists) + + key, exists := row[float32(0)] + assert.True(t, exists) + assert.Equal(t, IN, key.Name) + }) + + t.Run("non-existent row returns empty map", func(t *testing.T) { + keysByRow := ByRow() + + row, exists := keysByRow[999] + assert.False(t, exists) + assert.Nil(t, row) + }) + + t.Run("non-existent column in existing row", func(t *testing.T) { + keysByRow := ByRow() + + row, exists := keysByRow[0] + assert.True(t, exists) + + key, exists := row[float32(999)] + assert.False(t, exists) + assert.Equal(t, Key{}, key) + }) + + t.Run("all rows have expected keys", func(t *testing.T) { + keysByRow := ByRow() + + // Row 0 should have multiple keys + assert.Greater(t, len(keysByRow[0]), 0) + // Row 5 should have keys including CUT, DIS, SMTH_CUT, STOP_PLAY + assert.Greater(t, len(keysByRow[5]), 0) + }) +} + +// TestNullKey tests the NullKey constant +func TestNullKey(t *testing.T) { + t.Run("has expected values", func(t *testing.T) { + assert.Equal(t, NONE, NullKey.Name) + assert.Equal(t, uint16(0), NullKey.Id) + assert.Equal(t, uint32(0), NullKey.Led) + assert.Equal(t, uint8(0), NullKey.JogLed) + assert.Equal(t, TEXT_NONE, NullKey.Text) + assert.Equal(t, SUBTEXT_NONE, NullKey.SubText) + assert.Equal(t, -1, NullKey.Row) + assert.Equal(t, float32(-1), NullKey.Col) + assert.Equal(t, float32(1), NullKey.Width) + }) +} + +// TestKeyStruct tests the Key struct +func TestKeyStruct(t *testing.T) { + t.Run("can create key with all fields", func(t *testing.T) { + key := Key{ + Name: "TEST", + Id: 123, + Led: 456, + JogLed: 78, + Text: "Test Text", + SubText: "Test Sub", + Row: 5, + Col: 3.5, + Width: 2.0, + } + + assert.Equal(t, "TEST", key.Name) + assert.Equal(t, uint16(123), key.Id) + assert.Equal(t, uint32(456), key.Led) + assert.Equal(t, uint8(78), key.JogLed) + assert.Equal(t, "Test Text", key.Text) + assert.Equal(t, "Test Sub", key.SubText) + assert.Equal(t, 5, key.Row) + assert.Equal(t, float32(3.5), key.Col) + assert.Equal(t, float32(2.0), key.Width) + }) +} From 90635e0fa7d77c1693ae10f2fdd8491715afe635 Mon Sep 17 00:00:00 2001 From: James Balazs Date: Sun, 29 Mar 2026 23:39:37 +1100 Subject: [PATCH 10/11] ci: exclude examples from code coverage --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2ba668d..b44a3b5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -40,7 +40,7 @@ jobs: run: go test -v ./... - name: Generate coverage report - run: go test -coverprofile=coverage.out ./... + run: go test -coverprofile=coverage.out $(go list ./... | grep -v /examples/) - name: Check coverage id: coverage From 9cb69b47a73ccc30f03f200227436b640c02d872 Mon Sep 17 00:00:00 2001 From: James Balazs Date: Mon, 30 Mar 2026 00:01:41 +1100 Subject: [PATCH 11/11] test: more coverage --- .github/workflows/test.yaml | 4 +- .gitignore | 1 + auth/auth_test.go | 119 ++++++++++ input/input_test.go | 442 ++++++++++++++++++++++++++++++++++++ jog_modes/modes_test.go | 41 ---- keys/keys_test.go | 27 --- 6 files changed, 564 insertions(+), 70 deletions(-) create mode 100644 auth/auth_test.go create mode 100644 input/input_test.go diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b44a3b5..685e546 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -52,7 +52,7 @@ jobs: - name: Fail if coverage is below threshold run: | total_coverage="${{ steps.coverage.outputs.total_coverage }}" - if (( $(echo "$total_coverage < 50" | bc -l) )); then - echo "Coverage ($total_coverage%) is below the threshold (50%)" + if (( $(echo "$total_coverage < 80" | bc -l) )); then + echo "Coverage ($total_coverage%) is below the threshold (80%)" exit 1 fi diff --git a/.gitignore b/.gitignore index 9b57a6b..6bec34b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __debug_bin* *.exe node_modules +coverage.out diff --git a/auth/auth_test.go b/auth/auth_test.go new file mode 100644 index 0000000..d3897b8 --- /dev/null +++ b/auth/auth_test.go @@ -0,0 +1,119 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestRol8 tests the rol8 function +func TestRol8(t *testing.T) { + t.Run("rotate left by 8 bits", func(t *testing.T) { + // 0x01 rotated left by 8 bits should move the high byte to low byte + result := rol8(0x01) + assert.Equal(t, uint64(0x0100000000000000), result) + }) + + t.Run("rotate 0x1234567890ABCDEF left by 8 bits", func(t *testing.T) { + // Left rotate by 8: moves highest byte to lowest position + result := rol8(0x1234567890ABCDEF) + assert.Equal(t, uint64(0xEF1234567890ABCD), result) + }) + + t.Run("rotate 0xFF left by 8 bits", func(t *testing.T) { + result := rol8(0xFF) + assert.Equal(t, uint64(0xFF00000000000000), result) + }) + + t.Run("rotate 0x00 returns 0x00", func(t *testing.T) { + result := rol8(0x00) + assert.Equal(t, uint64(0x00), result) + }) + + t.Run("rotate max uint64", func(t *testing.T) { + result := rol8(0xFFFFFFFFFFFFFFFF) + assert.Equal(t, uint64(0xFFFFFFFFFFFFFFFF), result) + }) +} + +// TestRol8n tests the rol8n function +func TestRol8n(t *testing.T) { + t.Run("rotate by 0 returns original value", func(t *testing.T) { + value := uint64(0x1234567890ABCDEF) + result := rol8n(value, 0) + assert.Equal(t, value, result) + }) + + t.Run("rotate by 1 applies single rotation", func(t *testing.T) { + value := uint64(0x1234567890ABCDEF) + result := rol8n(value, 1) + expected := rol8(value) + assert.Equal(t, expected, result) + }) + + t.Run("rotate by 3 applies three rotations", func(t *testing.T) { + value := uint64(0x1234567890ABCDEF) + result := rol8n(value, 3) + expected := rol8(rol8(rol8(value))) + assert.Equal(t, expected, result) + }) + + t.Run("rotate by 8 (full cycle) returns original", func(t *testing.T) { + // 8 rotations of 8 bits = 64 bits = full cycle + value := uint64(0x1234567890ABCDEF) + result := rol8n(value, 8) + assert.Equal(t, value, result) + }) + + t.Run("rotate 0x01 by 4", func(t *testing.T) { + result := rol8n(0x01, 4) + assert.NotZero(t, result) + assert.NotEqual(t, uint64(0x01), result) + }) +} + +// TestCalculateChallengeResponse tests the challenge response calculation +func TestCalculateChallengeResponse(t *testing.T) { + t.Run("known challenge-response pairs", func(t *testing.T) { + testCases := []struct { + challenge uint64 + response uint64 + }{ + {0x0000000000000000, 0x3ae1206f97c10bc8}, + {0x0000000000000001, 0x2b9ab32bebf244c6}, + {0x0000000000000002, 0x20a4fab8df9adf0a}, + {0x0000000000000003, 0x6df72d1b40aef698}, + {0x0000000000000004, 0x72226f051e66ab94}, + {0x0000000000000005, 0x3831a3c6032d6a42}, + {0x0000000000000006, 0xfd7ff81881352889}, + {0x0000000000000007, 0x751bf623f42e0ade}, + {0x0000000000000008, 0x3ae1206f97c10bc0}, + {0x0000000000000009, 0x2392b32bebf244c6}, + {0x000000000000000A, 0x20acfab8df9adf0a}, + {0x000000000000000B, 0x6df7251340aef698}, + {0x000000000000000C, 0x72226f0d1666ab94}, + {0x000000000000000D, 0x3831a3c60b256242}, + {0x000000000000000E, 0xfd7ff818813d2889}, + {0x000000000000000F, 0x751bf623f42e02de}, + {0xFFFFFFFFFFFFFFFF, 0x61a3f6474ff236c6}, + } + + for _, tc := range testCases { + t.Run(string(rune(tc.challenge)), func(t *testing.T) { + result := CalculateChallengeResponse(tc.challenge) + assert.Equal(t, tc.response, result, "Challenge 0x%016x should produce response 0x%016x", tc.challenge, tc.response) + }) + } + }) + + t.Run("response is deterministic", func(t *testing.T) { + challenges := []uint64{0, 1, 42, 100, 1000, 10000, 100000} + for _, challenge := range challenges { + t.Run(string(rune(challenge)), func(t *testing.T) { + result1 := CalculateChallengeResponse(challenge) + result2 := CalculateChallengeResponse(challenge) + assert.Equal(t, result1, result2) + }) + } + }) +} diff --git a/input/input_test.go b/input/input_test.go new file mode 100644 index 0000000..eaa142b --- /dev/null +++ b/input/input_test.go @@ -0,0 +1,442 @@ +package input + +import ( + "encoding/binary" + "testing" + + "github.com/JamesBalazs/speed-editor-client/keys" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestNewBatteryReport tests the NewBatteryReport function +func TestNewBatteryReport(t *testing.T) { + t.Run("success charging", func(t *testing.T) { + id := ReportBattery + payload := []byte{0x01, 0x80} // Charging, ~50% battery + + report, err := NewBatteryReport(byte(id), payload) + + require.NoError(t, err) + assert.Equal(t, byte(id), report.Id) + assert.True(t, report.Charging) + assert.InDelta(t, 0.5, report.Battery, 0.01) + }) + + t.Run("success not charging", func(t *testing.T) { + id := ReportBattery + payload := []byte{0x00, 0xFF} // Not charging, 100% battery + + report, err := NewBatteryReport(byte(id), payload) + + require.NoError(t, err) + assert.Equal(t, byte(id), report.Id) + assert.False(t, report.Charging) + assert.InDelta(t, 1.0, report.Battery, 0.01) + }) + + t.Run("success zero battery", func(t *testing.T) { + id := ReportBattery + payload := []byte{0x01, 0x00} // Charging, 0% battery + + report, err := NewBatteryReport(byte(id), payload) + + require.NoError(t, err) + assert.Equal(t, byte(id), report.Id) + assert.True(t, report.Charging) + assert.Equal(t, float32(0), report.Battery) + }) + + t.Run("success partial battery", func(t *testing.T) { + id := ReportBattery + payload := []byte{0x00, 0x40} // Not charging, ~25% battery + + report, err := NewBatteryReport(byte(id), payload) + + require.NoError(t, err) + assert.Equal(t, byte(id), report.Id) + assert.False(t, report.Charging) + assert.InDelta(t, 0.25, report.Battery, 0.01) + }) + + t.Run("error wrong report id", func(t *testing.T) { + id := byte(99) // Wrong ID + payload := []byte{0x01, 0x80} + + report, err := NewBatteryReport(byte(id), payload) + + require.Error(t, err) + assert.Contains(t, err.Error(), "malformed battery stats report id") + assert.Contains(t, err.Error(), "99") + assert.Equal(t, BatteryReport{}, report) + }) + + t.Run("error wrong report id with different value", func(t *testing.T) { + id := ReportKeyPress // Wrong ID (keypress instead of battery) + payload := []byte{0x01, 0x80} + + report, err := NewBatteryReport(byte(id), payload) + + require.Error(t, err) + assert.Contains(t, err.Error(), "malformed battery stats report id") + assert.Equal(t, BatteryReport{}, report) + }) +} + +// TestNewKeyPressReport tests the NewKeyPressReport function +func TestNewKeyPressReport(t *testing.T) { + t.Run("success single key", func(t *testing.T) { + id := ReportKeyPress + // Get a valid key ID from the keys package + keysByName := keys.ByName() + cam1Key := keysByName[keys.CAM1] + payload := make([]byte, 2) + binary.LittleEndian.PutUint16(payload, cam1Key.Id) + + report, err := NewKeyPressReport(byte(id), payload) + + require.NoError(t, err) + assert.Equal(t, byte(id), report.Id) + require.Len(t, report.Keys, 1) + assert.Equal(t, keys.CAM1, report.Keys[0].Name) + }) + + t.Run("success multiple keys", func(t *testing.T) { + id := ReportKeyPress + keysByName := keys.ByName() + + // Create payload with multiple keys + payload := make([]byte, 6) // 3 keys * 2 bytes each + binary.LittleEndian.PutUint16(payload[0:2], keysByName[keys.CAM1].Id) + binary.LittleEndian.PutUint16(payload[2:4], keysByName[keys.CAM2].Id) + binary.LittleEndian.PutUint16(payload[4:6], keysByName[keys.CAM3].Id) + + report, err := NewKeyPressReport(byte(id), payload) + + require.NoError(t, err) + assert.Equal(t, byte(id), report.Id) + require.Len(t, report.Keys, 3) + assert.Equal(t, keys.CAM1, report.Keys[0].Name) + assert.Equal(t, keys.CAM2, report.Keys[1].Name) + assert.Equal(t, keys.CAM3, report.Keys[2].Name) + }) + + t.Run("success no keys", func(t *testing.T) { + id := ReportKeyPress + payload := []byte{} // Empty payload + + report, err := NewKeyPressReport(byte(id), payload) + + require.NoError(t, err) + assert.Equal(t, byte(id), report.Id) + assert.Empty(t, report.Keys) + }) + + t.Run("success with jog keys", func(t *testing.T) { + id := ReportKeyPress + keysByName := keys.ByName() + + // Test with jog mode keys + payload := make([]byte, 6) + binary.LittleEndian.PutUint16(payload[0:2], keysByName[keys.JOG].Id) + binary.LittleEndian.PutUint16(payload[2:4], keysByName[keys.SHTL].Id) + binary.LittleEndian.PutUint16(payload[4:6], keysByName[keys.SCRL].Id) + + report, err := NewKeyPressReport(byte(id), payload) + + require.NoError(t, err) + require.Len(t, report.Keys, 3) + assert.Contains(t, report.Keys, keysByName[keys.JOG]) + assert.Contains(t, report.Keys, keysByName[keys.SHTL]) + assert.Contains(t, report.Keys, keysByName[keys.SCRL]) + }) + + t.Run("success ignores unknown key ids", func(t *testing.T) { + id := ReportKeyPress + keysByName := keys.ByName() + + // Mix of valid and invalid key IDs + payload := make([]byte, 4) + binary.LittleEndian.PutUint16(payload[0:2], keysByName[keys.CAM1].Id) + binary.LittleEndian.PutUint16(payload[2:4], uint16(9999)) // Invalid ID + + report, err := NewKeyPressReport(byte(id), payload) + + require.NoError(t, err) + require.Len(t, report.Keys, 1) // Only the valid key should be included + assert.Equal(t, keys.CAM1, report.Keys[0].Name) + }) + + t.Run("success handles even length payload with trailing zeros", func(t *testing.T) { + id := ReportKeyPress + keysByName := keys.ByName() + + // Even length payload with a zero key ID at the end + payload := make([]byte, 4) + binary.LittleEndian.PutUint16(payload[0:2], keysByName[keys.CAM1].Id) + binary.LittleEndian.PutUint16(payload[2:4], uint16(0)) // Zero/invalid ID + + report, err := NewKeyPressReport(byte(id), payload) + + require.NoError(t, err) + require.Len(t, report.Keys, 1) // Only the valid key should be included + assert.Equal(t, keys.CAM1, report.Keys[0].Name) + }) + + t.Run("error wrong report id", func(t *testing.T) { + id := byte(99) // Wrong ID + payload := []byte{0x01, 0x00} + + report, err := NewKeyPressReport(byte(id), payload) + + require.Error(t, err) + assert.Contains(t, err.Error(), "malformed keypress input report id") + assert.Contains(t, err.Error(), "99") + assert.Equal(t, KeyPressReport{}, report) + }) + + t.Run("error wrong report id battery", func(t *testing.T) { + id := ReportBattery // Wrong ID (battery instead of keypress) + payload := []byte{0x01, 0x00} + + report, err := NewKeyPressReport(byte(id), payload) + + require.Error(t, err) + assert.Contains(t, err.Error(), "malformed keypress input report id") + assert.Equal(t, KeyPressReport{}, report) + }) +} + +// TestNewJogReport tests the NewJogReport function +func TestNewJogReport(t *testing.T) { + t.Run("success relative mode", func(t *testing.T) { + id := ReportJog + payload := make([]byte, 6) + payload[0] = 0x00 // RELATIVE mode + binary.LittleEndian.PutUint32(payload[1:5], uint32(100)) + payload[5] = 0x00 + + report := NewJogReport(byte(id), payload) + + assert.Equal(t, byte(id), report.Id) + assert.Equal(t, "RELATIVE", report.Mode.Name) + assert.Equal(t, int32(100), report.Value) + assert.Equal(t, uint8(0), report.Unknown) + }) + + t.Run("success absolute mode", func(t *testing.T) { + id := ReportJog + payload := make([]byte, 6) + payload[0] = 0x01 // ABSOLUTE mode + binary.LittleEndian.PutUint32(payload[1:5], uint32(2048)) + payload[5] = 0xFF + + report := NewJogReport(byte(id), payload) + + assert.Equal(t, byte(id), report.Id) + assert.Equal(t, "ABSOLUTE", report.Mode.Name) + assert.Equal(t, int32(2048), report.Value) + assert.Equal(t, uint8(0xFF), report.Unknown) + }) + + t.Run("success relative_2 mode", func(t *testing.T) { + id := ReportJog + payload := make([]byte, 6) + payload[0] = 0x02 // RELATIVE_2 mode + binary.LittleEndian.PutUint32(payload[1:5], uint32(500)) + + report := NewJogReport(byte(id), payload) + + assert.Equal(t, byte(id), report.Id) + assert.Equal(t, "RELATIVE_2", report.Mode.Name) + assert.Equal(t, int32(500), report.Value) + }) + + t.Run("success absolute_deadzone mode", func(t *testing.T) { + id := ReportJog + payload := make([]byte, 6) + payload[0] = 0x03 // ABSOLUTE_DEADZONE mode + binary.LittleEndian.PutUint32(payload[1:5], uint32(4096)) + + report := NewJogReport(byte(id), payload) + + assert.Equal(t, byte(id), report.Id) + assert.Equal(t, "ABSOLUTE_DEADZONE", report.Mode.Name) + assert.Equal(t, int32(4096), report.Value) + }) + + t.Run("success negative value", func(t *testing.T) { + id := ReportJog + payload := make([]byte, 6) + payload[0] = 0x01 // ABSOLUTE mode + binary.LittleEndian.PutUint32(payload[1:5], uint32(0xFFFFF000)) // Negative value + + report := NewJogReport(byte(id), payload) + + assert.Equal(t, byte(id), report.Id) + assert.Less(t, report.Value, int32(0)) + }) + + t.Run("success zero value", func(t *testing.T) { + id := ReportJog + payload := make([]byte, 6) + payload[0] = 0x00 + binary.LittleEndian.PutUint32(payload[1:5], uint32(0)) + + report := NewJogReport(byte(id), payload) + + assert.Equal(t, byte(id), report.Id) + assert.Equal(t, int32(0), report.Value) + }) + + t.Run("success max value", func(t *testing.T) { + id := ReportJog + payload := make([]byte, 6) + payload[0] = 0x00 + binary.LittleEndian.PutUint32(payload[1:5], uint32(0xFFFFFFFF)) + + report := NewJogReport(byte(id), payload) + + assert.Equal(t, byte(id), report.Id) + assert.Equal(t, int32(-1), report.Value) + }) + + t.Run("unknown mode id returns empty mode", func(t *testing.T) { + id := ReportJog + payload := make([]byte, 6) + payload[0] = 0x99 // Unknown mode ID + binary.LittleEndian.PutUint32(payload[1:5], uint32(100)) + + report := NewJogReport(byte(id), payload) + + assert.Equal(t, byte(id), report.Id) + assert.Equal(t, "", report.Mode.Name) + assert.Equal(t, 0, report.Mode.Id) + assert.Equal(t, int32(100), report.Value) + }) + + t.Run("prints error for wrong report id but still returns report", func(t *testing.T) { + id := byte(99) // Wrong ID + payload := make([]byte, 6) + payload[0] = 0x00 + binary.LittleEndian.PutUint32(payload[1:5], uint32(100)) + + // Note: NewJogReport doesn't return an error, it just prints + report := NewJogReport(byte(id), payload) + + assert.Equal(t, byte(id), report.Id) + assert.Equal(t, int32(100), report.Value) + }) + + t.Run("prints error for battery report id", func(t *testing.T) { + id := ReportBattery // Wrong ID + payload := make([]byte, 6) + payload[0] = 0x00 + binary.LittleEndian.PutUint32(payload[1:5], uint32(100)) + + report := NewJogReport(byte(id), payload) + + assert.Equal(t, byte(id), report.Id) + assert.Equal(t, int32(100), report.Value) + }) +} + +// TestToReport tests the ReportBytes.ToReport method +func TestToReport(t *testing.T) { + t.Run("jog report", func(t *testing.T) { + data := make([]byte, 7) + data[0] = ReportJog + data[1] = 0x00 // RELATIVE mode + binary.LittleEndian.PutUint32(data[2:6], uint32(100)) + + report, err := ReportBytes(data).ToReport() + + require.NoError(t, err) + require.NotNil(t, report) + jogReport, ok := report.(JogReport) + require.True(t, ok) + assert.Equal(t, byte(ReportJog), jogReport.Id) + assert.Equal(t, int32(100), jogReport.Value) + }) + + t.Run("keypress report", func(t *testing.T) { + keysByName := keys.ByName() + data := make([]byte, 3) + data[0] = ReportKeyPress + binary.LittleEndian.PutUint16(data[1:3], keysByName[keys.CAM1].Id) + + report, err := ReportBytes(data).ToReport() + + require.NoError(t, err) + require.NotNil(t, report) + keyReport, ok := report.(KeyPressReport) + require.True(t, ok) + assert.Equal(t, byte(ReportKeyPress), keyReport.Id) + require.Len(t, keyReport.Keys, 1) + assert.Equal(t, keys.CAM1, keyReport.Keys[0].Name) + }) + + t.Run("battery report", func(t *testing.T) { + data := []byte{ReportBattery, 0x01, 0x80} + + report, err := ReportBytes(data).ToReport() + + require.NoError(t, err) + require.NotNil(t, report) + batteryReport, ok := report.(BatteryReport) + require.True(t, ok) + assert.Equal(t, byte(ReportBattery), batteryReport.Id) + assert.True(t, batteryReport.Charging) + assert.InDelta(t, 0.5, batteryReport.Battery, 0.01) + }) + + t.Run("unknown report id returns error", func(t *testing.T) { + data := []byte{99, 0x00, 0x00} + + report, err := ReportBytes(data).ToReport() + + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown report id") + assert.Contains(t, err.Error(), "99") + assert.Nil(t, report) + }) + + t.Run("empty data panics", func(t *testing.T) { + data := []byte{} + + // ToReport doesn't handle empty data gracefully - it panics + assert.Panics(t, func() { + _, _ = ReportBytes(data).ToReport() + }) + }) + + t.Run("single byte panics", func(t *testing.T) { + data := []byte{ReportJog} + + // ToReport with single byte causes panic in NewJogReport due to slice access + assert.Panics(t, func() { + _, _ = ReportBytes(data).ToReport() + }) + }) +} + +// TestReportConstants tests the report ID constants +func TestReportConstants(t *testing.T) { + t.Run("report jog constant", func(t *testing.T) { + assert.Equal(t, byte(3), byte(ReportJog)) + }) + + t.Run("report keypress constant", func(t *testing.T) { + assert.Equal(t, byte(4), byte(ReportKeyPress)) + }) + + t.Run("report battery constant", func(t *testing.T) { + assert.Equal(t, byte(7), byte(ReportBattery)) + }) + + t.Run("report constants are unique", func(t *testing.T) { + assert.NotEqual(t, ReportJog, ReportKeyPress) + assert.NotEqual(t, ReportJog, ReportBattery) + assert.NotEqual(t, ReportKeyPress, ReportBattery) + }) +} \ No newline at end of file diff --git a/jog_modes/modes_test.go b/jog_modes/modes_test.go index c7fbf65..bf9e077 100644 --- a/jog_modes/modes_test.go +++ b/jog_modes/modes_test.go @@ -235,44 +235,3 @@ func TestAbsoluteConstants(t *testing.T) { assert.Equal(t, ABSOLUTE_MAX, -ABSOLUTE_MIN) }) } - -// TestModeStruct tests the Mode struct -func TestModeStruct(t *testing.T) { - t.Run("can create mode with all fields", func(t *testing.T) { - mode := Mode{ - Id: 5, - Name: "CUSTOM_MODE", - } - - assert.Equal(t, int(5), mode.Id) - assert.Equal(t, "CUSTOM_MODE", mode.Name) - }) - - t.Run("zero value mode", func(t *testing.T) { - var mode Mode - - assert.Equal(t, int(0), mode.Id) - assert.Equal(t, "", mode.Name) - }) -} - -// TestModesSlice tests the internal modes slice -func TestModesSlice(t *testing.T) { - t.Run("modes are in correct order", func(t *testing.T) { - modes := Get() - - // Verify the order matches the iota definition - assert.Equal(t, ID_RELATIVE, modes[0].Id) - assert.Equal(t, ID_ABSOLUTE, modes[1].Id) - assert.Equal(t, ID_RELATIVE_2, modes[2].Id) - assert.Equal(t, ID_ABSOLUTE_DEADZONE, modes[3].Id) - }) - - t.Run("all modes have non-empty names", func(t *testing.T) { - modes := Get() - - for _, mode := range modes { - assert.NotEmpty(t, mode.Name, "Mode %d should have a name", mode.Id) - } - }) -} diff --git a/keys/keys_test.go b/keys/keys_test.go index 99ea748..220a778 100644 --- a/keys/keys_test.go +++ b/keys/keys_test.go @@ -520,30 +520,3 @@ func TestNullKey(t *testing.T) { assert.Equal(t, float32(1), NullKey.Width) }) } - -// TestKeyStruct tests the Key struct -func TestKeyStruct(t *testing.T) { - t.Run("can create key with all fields", func(t *testing.T) { - key := Key{ - Name: "TEST", - Id: 123, - Led: 456, - JogLed: 78, - Text: "Test Text", - SubText: "Test Sub", - Row: 5, - Col: 3.5, - Width: 2.0, - } - - assert.Equal(t, "TEST", key.Name) - assert.Equal(t, uint16(123), key.Id) - assert.Equal(t, uint32(456), key.Led) - assert.Equal(t, uint8(78), key.JogLed) - assert.Equal(t, "Test Text", key.Text) - assert.Equal(t, "Test Sub", key.SubText) - assert.Equal(t, 5, key.Row) - assert.Equal(t, float32(3.5), key.Col) - assert.Equal(t, float32(2.0), key.Width) - }) -}