diff --git a/cmd/rsocket-cli/runner.go b/cmd/rsocket-cli/runner.go index 92be9b8..749f28e 100644 --- a/cmd/rsocket-cli/runner.go +++ b/cmd/rsocket-cli/runner.go @@ -134,6 +134,8 @@ func (r *Runner) runClientMode(ctx context.Context) (err error) { _ = c.Close() }() + initialRequest := payload.NewString("", "") + for i := 0; i < r.Ops; i++ { if i > 0 { logger.Infof("\n") @@ -153,7 +155,7 @@ func (r *Runner) runClientMode(ctx context.Context) (err error) { } else if r.Stream { err = r.execRequestStream(ctx, c, first) } else if r.Channel { - err = r.execRequestChannel(ctx, c, sendingPayloads) + err = r.execRequestChannel(ctx, c, initialRequest, sendingPayloads) } else if r.MetadataPush { err = r.execMetadataPush(ctx, c, first) } else { @@ -189,7 +191,7 @@ func (r *Runner) runServerMode(ctx context.Context) error { r.showPayload(message) return sendingPayloads })) - options = append(options, rsocket.RequestChannel(func(messages flux.Flux) flux.Flux { + options = append(options, rsocket.RequestChannel(func(initialRequest payload.Payload, messages flux.Flux) flux.Flux { messages.Subscribe(ctx, rx.OnNext(func(input payload.Payload) error { r.showPayload(input) return nil @@ -245,12 +247,12 @@ func (r *Runner) execRequestResponse(ctx context.Context, c rsocket.Client, send return } -func (r *Runner) execRequestChannel(ctx context.Context, c rsocket.Client, send flux.Flux) error { +func (r *Runner) execRequestChannel(ctx context.Context, c rsocket.Client, initialRequest payload.Payload, send flux.Flux) error { var f flux.Flux if r.N < rx.RequestMax { - f = c.RequestChannel(send).Take(r.N) + f = c.RequestChannel(initialRequest, send).Take(r.N) } else { - f = c.RequestChannel(send) + f = c.RequestChannel(initialRequest, send) } return r.printFlux(ctx, f) } diff --git a/examples/echo/echo.go b/examples/echo/echo.go index 6e69410..0583e38 100644 --- a/examples/echo/echo.go +++ b/examples/echo/echo.go @@ -132,7 +132,7 @@ func responder() rsocket.RSocket { emitter.Complete() }) }), - rsocket.RequestChannel(func(payloads flux.Flux) flux.Flux { + rsocket.RequestChannel(func(initialRequest payload.Payload, payloads flux.Flux) flux.Flux { //return payloads.(flux.Flux) payloads. //LimitRate(1). diff --git a/examples/word_counter/main.go b/examples/word_counter/main.go index 575a46d..b4dfc15 100644 --- a/examples/word_counter/main.go +++ b/examples/word_counter/main.go @@ -31,7 +31,7 @@ func main() { func server(readyCh chan struct{}) { // create a handler that will be called when the server receives the RequestChannel frame (FrameTypeRequestChannel - 0x07) - requestChannelHandler := rsocket.RequestChannel(func(requests flux.Flux) flux.Flux { + requestChannelHandler := rsocket.RequestChannel(func(initialRequest payload.Payload, requests flux.Flux) flux.Flux { return flux.Create(func(ctx context.Context, s flux.Sink) { requests.DoOnNext(func(elem payload.Payload) error { // for each payload in a flux stream respond with a word count @@ -70,6 +70,7 @@ func client() { defer client.Close() // strings to count the words + initialRequest := payload.NewString("", "") sentences := []payload.Payload{ payload.NewString("", extension.TextPlain.String()), payload.NewString("qux", extension.TextPlain.String()), @@ -86,7 +87,7 @@ func client() { counter := 0 // register handler for RequestChannel - client.RequestChannel(f).DoOnNext(func(input payload.Payload) error { + client.RequestChannel(initialRequest, f).DoOnNext(func(input payload.Payload) error { // print word count fmt.Println(sentences[counter].DataUTF8(), ":", input.DataUTF8()) counter = counter + 1 diff --git a/internal/socket/abstract_socket.go b/internal/socket/abstract_socket.go index b1bc841..c14d39f 100644 --- a/internal/socket/abstract_socket.go +++ b/internal/socket/abstract_socket.go @@ -22,7 +22,7 @@ type AbstractRSocket struct { MP func(payload.Payload) RR func(payload.Payload) mono.Mono RS func(payload.Payload) flux.Flux - RC func(flux.Flux) flux.Flux + RC func(payload.Payload, flux.Flux) flux.Flux } // MetadataPush starts a request of MetadataPush. @@ -60,9 +60,9 @@ func (a AbstractRSocket) RequestStream(message payload.Payload) flux.Flux { } // RequestChannel starts a request of RequestChannel. -func (a AbstractRSocket) RequestChannel(messages flux.Flux) flux.Flux { +func (a AbstractRSocket) RequestChannel(initialRequest payload.Payload, messages flux.Flux) flux.Flux { if a.RC == nil { return flux.Error(errUnimplementedRequestChannel) } - return a.RC(messages) + return a.RC(initialRequest, messages) } diff --git a/internal/socket/abstract_socket_test.go b/internal/socket/abstract_socket_test.go index 141d0b5..58dabeb 100644 --- a/internal/socket/abstract_socket_test.go +++ b/internal/socket/abstract_socket_test.go @@ -86,12 +86,12 @@ func TestAbstractRSocket_RequestStream(t *testing.T) { func TestAbstractRSocket_RequestChannel(t *testing.T) { s := &socket.AbstractRSocket{ - RC: func(publisher flux.Flux) flux.Flux { + RC: func(initialRequest payload.Payload, publisher flux.Flux) flux.Flux { return flux.Clone(publisher) }, } var res []payload.Payload - _, err := s.RequestChannel(flux.Just(fakeRequest)). + _, err := s.RequestChannel(fakeRequest, flux.Just(fakeRequest)). DoOnNext(func(input payload.Payload) error { res = append(res, input) return nil @@ -101,6 +101,6 @@ func TestAbstractRSocket_RequestChannel(t *testing.T) { assert.Len(t, res, 1) assert.Equal(t, fakeRequest, res[0]) - _, err = emptyAbstractRSocket.RequestChannel(flux.Just(fakeRequest)).BlockFirst(context.Background()) + _, err = emptyAbstractRSocket.RequestChannel(fakeRequest, flux.Just(fakeRequest)).BlockFirst(context.Background()) assert.Error(t, err, "should return an error") } diff --git a/internal/socket/base_socket.go b/internal/socket/base_socket.go index cca1bb9..d242666 100644 --- a/internal/socket/base_socket.go +++ b/internal/socket/base_socket.go @@ -62,11 +62,11 @@ func (p *BaseSocket) RequestStream(message payload.Payload) flux.Flux { } // RequestChannel sends RequestChannel request. -func (p *BaseSocket) RequestChannel(messages flux.Flux) flux.Flux { +func (p *BaseSocket) RequestChannel(initialRequest payload.Payload, messages flux.Flux) flux.Flux { if err := p.reqLease.allow(); err != nil { return flux.Error(err) } - return p.socket.RequestChannel(messages) + return p.socket.RequestChannel(initialRequest, messages) } // OnClose registers handler when socket closed. diff --git a/internal/socket/base_socket_test.go b/internal/socket/base_socket_test.go index 8a784ee..47afa4c 100644 --- a/internal/socket/base_socket_test.go +++ b/internal/socket/base_socket_test.go @@ -50,7 +50,7 @@ func TestBaseSocket(t *testing.T) { s.FireAndForget(fakeRequest) s.RequestResponse(fakeRequest) s.RequestStream(fakeRequest) - s.RequestChannel(flux.Just(fakeRequest)) + s.RequestChannel(fakeRequest, flux.Just(fakeRequest)) }) <-done diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index 049e56e..efaf2cf 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -430,7 +430,7 @@ func (dc *DuplexConnection) killCallback(sid uint32) { } // RequestChannel start a request of RequestChannel. -func (dc *DuplexConnection) RequestChannel(sending flux.Flux) (ret flux.Flux) { +func (dc *DuplexConnection) RequestChannel(request payload.Payload, sending flux.Flux) (ret flux.Flux) { if dc.closed.Load() { ret = flux.Error(errSocketClosed) return @@ -481,9 +481,56 @@ func (dc *DuplexConnection) RequestChannel(sending flux.Flux) (ret flux.Flux) { return } + // First request - send the initial REQUEST_CHANNEL frame with the request payload, + // then subscribe to the sending flux for subsequent payloads. + releasable, isReleasable := request.(common.Releasable) + + if isReleasable { + releasable.IncRef() + } + + data := request.Data() + metadata, _ := request.Metadata() + + size := framing.CalcPayloadFrameSize(data, metadata) + 4 + if !dc.shouldSplit(size) { + toBeSent := framing.NewWriteableRequestChannelFrame(sid, n, data, metadata, core.FlagNext) + + if isReleasable { + toBeSent.HandleDone(func() { + releasable.Release() + }) + } + + if ok := dc.sendFrame(toBeSent); !ok { + dc.killCallback(sid) + return + } + } else { + dc.doSplitSkip(4, data, metadata, func(index int, result fragmentation.SplitResult) { + var toBeSent core.WriteableFrame + if index == 0 { + toBeSent = framing.NewWriteableRequestChannelFrame(sid, n, result.Data, result.Metadata, result.Flag|core.FlagNext) + } else { + toBeSent = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) + } + + // Add release hook at last frame. + if !result.Flag.Check(core.FlagFollow) && isReleasable { + toBeSent.HandleDone(func() { + releasable.Release() + }) + } + + if ok := dc.sendFrame(toBeSent); !ok { + dc.killCallback(sid) + } + }) + } + + // Subscribe to sending flux for subsequent payloads sub := &requestChannelSubscriber{ sid: sid, - n: n, dc: dc, rcv: receiving, } @@ -613,7 +660,7 @@ func (dc *DuplexConnection) respondRequestChannel(req fragmentation.HeaderAndPay } logger.Errorf("handle request-channel failed: %+v\n", err) }() - resp = dc.responder.RequestChannel(receiving) + resp = dc.responder.RequestChannel(req, receiving) if resp == nil { err = framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestChannel) } @@ -643,8 +690,6 @@ func (dc *DuplexConnection) respondRequestChannel(req fragmentation.HeaderAndPay sending.SubscribeWith(dc.ctx, sub) }) - receivingProcessor.Next(req) - <-subscribed return nil diff --git a/internal/socket/subscriber_request_channel.go b/internal/socket/subscriber_request_channel.go index 346cbae..f64b3f8 100644 --- a/internal/socket/subscriber_request_channel.go +++ b/internal/socket/subscriber_request_channel.go @@ -6,7 +6,6 @@ import ( "github.com/jjeffcaii/reactor-go" "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/core/framing" - "github.com/rsocket/rsocket-go/internal/fragmentation" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" @@ -14,35 +13,13 @@ import ( ) type requestChannelSubscriber struct { - sid uint32 - n uint32 - dc *DuplexConnection - requested atomic.Bool - rcv flux.Processor + sid uint32 + dc *DuplexConnection + rcv flux.Processor } func (r *requestChannelSubscriber) OnNext(item payload.Payload) { - if !r.requested.CAS(false, true) { - r.dc.sendPayload(r.sid, item, core.FlagNext) - return - } - d := item.Data() - m, _ := item.Metadata() - size := framing.CalcPayloadFrameSize(d, m) + 4 - if !r.dc.shouldSplit(size) { - metadata, _ := item.Metadata() - r.dc.sendFrame(framing.NewWriteableRequestChannelFrame(r.sid, r.n, item.Data(), metadata, core.FlagNext)) - return - } - r.dc.doSplitSkip(4, d, m, func(index int, result fragmentation.SplitResult) { - var f core.WriteableFrame - if index == 0 { - f = framing.NewWriteableRequestChannelFrame(r.sid, r.n, result.Data, result.Metadata, result.Flag|core.FlagNext) - } else { - f = framing.NewWriteablePayloadFrame(r.sid, result.Data, result.Metadata, result.Flag|core.FlagNext) - } - r.dc.sendFrame(f) - }) + r.dc.sendPayload(r.sid, item, core.FlagNext) } func (r *requestChannelSubscriber) OnError(err error) { diff --git a/internal/socket/types.go b/internal/socket/types.go index c8a36b4..2de1ec6 100644 --- a/internal/socket/types.go +++ b/internal/socket/types.go @@ -29,7 +29,7 @@ type Responder interface { // RequestStream request a completable stream. RequestStream(message payload.Payload) flux.Flux // RequestChannel request a completable stream in both directions. - RequestChannel(messages flux.Flux) flux.Flux + RequestChannel(initialMessage payload.Payload, messages flux.Flux) flux.Flux } // ClientSocket represents a client-side socket. diff --git a/rsocket.go b/rsocket.go index cf43447..0e58358 100644 --- a/rsocket.go +++ b/rsocket.go @@ -56,7 +56,7 @@ type ( // RequestStream request a completable stream. RequestStream(message payload.Payload) flux.Flux // RequestChannel request a completable stream in both directions. - RequestChannel(messages flux.Flux) flux.Flux + RequestChannel(initialMessage payload.Payload, messages flux.Flux) flux.Flux } // CloseableRSocket is RSocket which can be closed and handle close event. @@ -115,7 +115,7 @@ func RequestStream(fn func(request payload.Payload) (responses flux.Flux)) OptAb } // RequestChannel register request handler for RequestChannel. -func RequestChannel(fn func(requests flux.Flux) (responses flux.Flux)) OptAbstractSocket { +func RequestChannel(fn func(initialRequest payload.Payload, requests flux.Flux) (responses flux.Flux)) OptAbstractSocket { return func(opts *socket.AbstractRSocket) { opts.RC = fn } diff --git a/rsocket_example_test.go b/rsocket_example_test.go index f0dec4c..b96a4a9 100644 --- a/rsocket_example_test.go +++ b/rsocket_example_test.go @@ -85,7 +85,7 @@ func ExampleReceive() { s.Complete() }) }), - rsocket.RequestChannel(func(requests flux.Flux) flux.Flux { + rsocket.RequestChannel(func(initialRequest payload.Payload, requests flux.Flux) flux.Flux { return requests }), ), nil @@ -137,13 +137,14 @@ func ExampleConnect() { s.Request(1) })) // Simple RequestChannel. + initialPayload := payload.NewString("This is a RequestChannel initial message.", "") sendFlux := flux.Create(func(ctx context.Context, s flux.Sink) { for i := 0; i < 3; i++ { s.Next(payload.NewString(fmt.Sprintf("This is a RequestChannel message #%d.", i), "")) } s.Complete() }) - cli.RequestChannel(sendFlux). + cli.RequestChannel(initialPayload, sendFlux). DoOnNext(func(elem payload.Payload) error { log.Println("next element in channel:", elem) return nil diff --git a/rsocket_test.go b/rsocket_test.go index ec44af0..9308ec4 100644 --- a/rsocket_test.go +++ b/rsocket_test.go @@ -354,7 +354,7 @@ func testAll(t *testing.T, proto string, clientTp transport.ClientTransporter, s s.Complete() }) }), - RequestChannel(func(inputs flux.Flux) flux.Flux { + RequestChannel(func(initialRequest payload.Payload, inputs flux.Flux) flux.Flux { received := new(int32) inputs. DoOnNext(func(input payload.Payload) error { @@ -471,6 +471,7 @@ func testRequestStreamOneByOne(ctx context.Context, cli Client, t *testing.T) { func testRequestChannel(ctx context.Context, cli Client, t *testing.T) { // RequestChannel + initialPayload := payload.NewString("This is a RequestChannel initial message.", "") send := flux.Create(func(ctx context.Context, s flux.Sink) { for i := 0; i < int(channelElements); i++ { s.Next(payload.NewString(fakeData, fmt.Sprintf("%d", i))) @@ -480,7 +481,7 @@ func testRequestChannel(ctx context.Context, cli Client, t *testing.T) { var seq int - _, err := cli.RequestChannel(send). + _, err := cli.RequestChannel(initialPayload, send). DoOnNext(func(elem payload.Payload) error { //fmt.Println(elem) m, _ := elem.MetadataUTF8() @@ -495,6 +496,7 @@ func testRequestChannel(ctx context.Context, cli Client, t *testing.T) { func testRequestChannelOneByOne(ctx context.Context, cli Client, t *testing.T) { // RequestChannel + initialPayload := payload.NewString("This is a RequestChannel initial message.", "") send := flux.Create(func(ctx context.Context, s flux.Sink) { for i := 0; i < int(channelElements); i++ { s.Next(payload.NewString(fakeData, fmt.Sprintf("%d", i))) @@ -508,7 +510,7 @@ func testRequestChannelOneByOne(ctx context.Context, cli Client, t *testing.T) { var su rx.Subscription - cli.RequestChannel(send). + cli.RequestChannel(initialPayload, send). DoFinally(func(s rx.SignalType) { assert.Equal(t, rx.SignalComplete, s, "bad signal type") close(done) @@ -599,7 +601,7 @@ func (d delayedRSocket) RequestStream(message payload.Payload) flux.Flux { panic("implement me") } -func (d delayedRSocket) RequestChannel(messages flux.Flux) flux.Flux { +func (d delayedRSocket) RequestChannel(initialRequest payload.Payload, messages flux.Flux) flux.Flux { panic("implement me") }