Skip to content

Commit 85a1a38

Browse files
committed
Update RequestChannel API and first frame handling
1 parent 2347d76 commit 85a1a38

10 files changed

Lines changed: 75 additions & 50 deletions

internal/socket/abstract_socket.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ type AbstractRSocket struct {
2222
MP func(payload.Payload)
2323
RR func(payload.Payload) mono.Mono
2424
RS func(payload.Payload) flux.Flux
25-
RC func(flux.Flux) flux.Flux
25+
RC func(payload.Payload, flux.Flux) flux.Flux
2626
}
2727

2828
// MetadataPush starts a request of MetadataPush.
@@ -60,9 +60,9 @@ func (a AbstractRSocket) RequestStream(message payload.Payload) flux.Flux {
6060
}
6161

6262
// RequestChannel starts a request of RequestChannel.
63-
func (a AbstractRSocket) RequestChannel(messages flux.Flux) flux.Flux {
63+
func (a AbstractRSocket) RequestChannel(initialRequest payload.Payload, messages flux.Flux) flux.Flux {
6464
if a.RC == nil {
6565
return flux.Error(errUnimplementedRequestChannel)
6666
}
67-
return a.RC(messages)
67+
return a.RC(initialRequest, messages)
6868
}

internal/socket/abstract_socket_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ func TestAbstractRSocket_RequestStream(t *testing.T) {
8686

8787
func TestAbstractRSocket_RequestChannel(t *testing.T) {
8888
s := &socket.AbstractRSocket{
89-
RC: func(publisher flux.Flux) flux.Flux {
89+
RC: func(initialRequest payload.Payload, publisher flux.Flux) flux.Flux {
9090
return flux.Clone(publisher)
9191
},
9292
}
9393
var res []payload.Payload
94-
_, err := s.RequestChannel(flux.Just(fakeRequest)).
94+
_, err := s.RequestChannel(fakeRequest, flux.Just(fakeRequest)).
9595
DoOnNext(func(input payload.Payload) error {
9696
res = append(res, input)
9797
return nil
@@ -101,6 +101,6 @@ func TestAbstractRSocket_RequestChannel(t *testing.T) {
101101
assert.Len(t, res, 1)
102102
assert.Equal(t, fakeRequest, res[0])
103103

104-
_, err = emptyAbstractRSocket.RequestChannel(flux.Just(fakeRequest)).BlockFirst(context.Background())
104+
_, err = emptyAbstractRSocket.RequestChannel(fakeRequest, flux.Just(fakeRequest)).BlockFirst(context.Background())
105105
assert.Error(t, err, "should return an error")
106106
}

internal/socket/base_socket.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ func (p *BaseSocket) RequestStream(message payload.Payload) flux.Flux {
6262
}
6363

6464
// RequestChannel sends RequestChannel request.
65-
func (p *BaseSocket) RequestChannel(messages flux.Flux) flux.Flux {
65+
func (p *BaseSocket) RequestChannel(initialRequest payload.Payload, messages flux.Flux) flux.Flux {
6666
if err := p.reqLease.allow(); err != nil {
6767
return flux.Error(err)
6868
}
69-
return p.socket.RequestChannel(messages)
69+
return p.socket.RequestChannel(initialRequest, messages)
7070
}
7171

7272
// OnClose registers handler when socket closed.

internal/socket/base_socket_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func TestBaseSocket(t *testing.T) {
5050
s.FireAndForget(fakeRequest)
5151
s.RequestResponse(fakeRequest)
5252
s.RequestStream(fakeRequest)
53-
s.RequestChannel(flux.Just(fakeRequest))
53+
s.RequestChannel(fakeRequest, flux.Just(fakeRequest))
5454
})
5555

5656
<-done

internal/socket/duplex.go

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ func (dc *DuplexConnection) killCallback(sid uint32) {
430430
}
431431

432432
// RequestChannel start a request of RequestChannel.
433-
func (dc *DuplexConnection) RequestChannel(sending flux.Flux) (ret flux.Flux) {
433+
func (dc *DuplexConnection) RequestChannel(request payload.Payload, sending flux.Flux) (ret flux.Flux) {
434434
if dc.closed.Load() {
435435
ret = flux.Error(errSocketClosed)
436436
return
@@ -481,9 +481,56 @@ func (dc *DuplexConnection) RequestChannel(sending flux.Flux) (ret flux.Flux) {
481481
return
482482
}
483483

484+
// First request - send the initial REQUEST_CHANNEL frame with the request payload,
485+
// then subscribe to the sending flux for subsequent payloads.
486+
releasable, isReleasable := request.(common.Releasable)
487+
488+
if isReleasable {
489+
releasable.IncRef()
490+
}
491+
492+
data := request.Data()
493+
metadata, _ := request.Metadata()
494+
495+
size := framing.CalcPayloadFrameSize(data, metadata) + 4
496+
if !dc.shouldSplit(size) {
497+
toBeSent := framing.NewWriteableRequestChannelFrame(sid, n, data, metadata, core.FlagNext)
498+
499+
if isReleasable {
500+
toBeSent.HandleDone(func() {
501+
releasable.Release()
502+
})
503+
}
504+
505+
if ok := dc.sendFrame(toBeSent); !ok {
506+
dc.killCallback(sid)
507+
return
508+
}
509+
} else {
510+
dc.doSplitSkip(4, data, metadata, func(index int, result fragmentation.SplitResult) {
511+
var toBeSent core.WriteableFrame
512+
if index == 0 {
513+
toBeSent = framing.NewWriteableRequestChannelFrame(sid, n, result.Data, result.Metadata, result.Flag|core.FlagNext)
514+
} else {
515+
toBeSent = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext)
516+
}
517+
518+
// Add release hook at last frame.
519+
if !result.Flag.Check(core.FlagFollow) && isReleasable {
520+
toBeSent.HandleDone(func() {
521+
releasable.Release()
522+
})
523+
}
524+
525+
if ok := dc.sendFrame(toBeSent); !ok {
526+
dc.killCallback(sid)
527+
}
528+
})
529+
}
530+
531+
// Subscribe to sending flux for subsequent payloads
484532
sub := &requestChannelSubscriber{
485533
sid: sid,
486-
n: n,
487534
dc: dc,
488535
rcv: receiving,
489536
}
@@ -613,7 +660,7 @@ func (dc *DuplexConnection) respondRequestChannel(req fragmentation.HeaderAndPay
613660
}
614661
logger.Errorf("handle request-channel failed: %+v\n", err)
615662
}()
616-
resp = dc.responder.RequestChannel(receiving)
663+
resp = dc.responder.RequestChannel(req, receiving)
617664
if resp == nil {
618665
err = framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestChannel)
619666
}
@@ -643,8 +690,6 @@ func (dc *DuplexConnection) respondRequestChannel(req fragmentation.HeaderAndPay
643690
sending.SubscribeWith(dc.ctx, sub)
644691
})
645692

646-
receivingProcessor.Next(req)
647-
648693
<-subscribed
649694

650695
return nil

internal/socket/subscriber_request_channel.go

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,20 @@ import (
66
"github.com/jjeffcaii/reactor-go"
77
"github.com/rsocket/rsocket-go/core"
88
"github.com/rsocket/rsocket-go/core/framing"
9-
"github.com/rsocket/rsocket-go/internal/fragmentation"
109
"github.com/rsocket/rsocket-go/payload"
1110
"github.com/rsocket/rsocket-go/rx"
1211
"github.com/rsocket/rsocket-go/rx/flux"
1312
"go.uber.org/atomic"
1413
)
1514

1615
type requestChannelSubscriber struct {
17-
sid uint32
18-
n uint32
19-
dc *DuplexConnection
20-
requested atomic.Bool
21-
rcv flux.Processor
16+
sid uint32
17+
dc *DuplexConnection
18+
rcv flux.Processor
2219
}
2320

2421
func (r *requestChannelSubscriber) OnNext(item payload.Payload) {
25-
if !r.requested.CAS(false, true) {
26-
r.dc.sendPayload(r.sid, item, core.FlagNext)
27-
return
28-
}
29-
d := item.Data()
30-
m, _ := item.Metadata()
31-
size := framing.CalcPayloadFrameSize(d, m) + 4
32-
if !r.dc.shouldSplit(size) {
33-
metadata, _ := item.Metadata()
34-
r.dc.sendFrame(framing.NewWriteableRequestChannelFrame(r.sid, r.n, item.Data(), metadata, core.FlagNext))
35-
return
36-
}
37-
r.dc.doSplitSkip(4, d, m, func(index int, result fragmentation.SplitResult) {
38-
var f core.WriteableFrame
39-
if index == 0 {
40-
f = framing.NewWriteableRequestChannelFrame(r.sid, r.n, result.Data, result.Metadata, result.Flag|core.FlagNext)
41-
} else {
42-
f = framing.NewWriteablePayloadFrame(r.sid, result.Data, result.Metadata, result.Flag|core.FlagNext)
43-
}
44-
r.dc.sendFrame(f)
45-
})
22+
r.dc.sendPayload(r.sid, item, core.FlagNext)
4623
}
4724

4825
func (r *requestChannelSubscriber) OnError(err error) {

internal/socket/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ type Responder interface {
2929
// RequestStream request a completable stream.
3030
RequestStream(message payload.Payload) flux.Flux
3131
// RequestChannel request a completable stream in both directions.
32-
RequestChannel(messages flux.Flux) flux.Flux
32+
RequestChannel(initialMessage payload.Payload, messages flux.Flux) flux.Flux
3333
}
3434

3535
// ClientSocket represents a client-side socket.

rsocket.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ type (
5656
// RequestStream request a completable stream.
5757
RequestStream(message payload.Payload) flux.Flux
5858
// RequestChannel request a completable stream in both directions.
59-
RequestChannel(messages flux.Flux) flux.Flux
59+
RequestChannel(initialMessage payload.Payload, messages flux.Flux) flux.Flux
6060
}
6161

6262
// CloseableRSocket is RSocket which can be closed and handle close event.
@@ -115,7 +115,7 @@ func RequestStream(fn func(request payload.Payload) (responses flux.Flux)) OptAb
115115
}
116116

117117
// RequestChannel register request handler for RequestChannel.
118-
func RequestChannel(fn func(requests flux.Flux) (responses flux.Flux)) OptAbstractSocket {
118+
func RequestChannel(fn func(initialRequest payload.Payload, requests flux.Flux) (responses flux.Flux)) OptAbstractSocket {
119119
return func(opts *socket.AbstractRSocket) {
120120
opts.RC = fn
121121
}

rsocket_example_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func ExampleReceive() {
8585
s.Complete()
8686
})
8787
}),
88-
rsocket.RequestChannel(func(requests flux.Flux) flux.Flux {
88+
rsocket.RequestChannel(func(initialRequest payload.Payload, requests flux.Flux) flux.Flux {
8989
return requests
9090
}),
9191
), nil
@@ -137,13 +137,14 @@ func ExampleConnect() {
137137
s.Request(1)
138138
}))
139139
// Simple RequestChannel.
140+
initialPayload := payload.NewString("This is a RequestChannel initial message.", "")
140141
sendFlux := flux.Create(func(ctx context.Context, s flux.Sink) {
141142
for i := 0; i < 3; i++ {
142143
s.Next(payload.NewString(fmt.Sprintf("This is a RequestChannel message #%d.", i), ""))
143144
}
144145
s.Complete()
145146
})
146-
cli.RequestChannel(sendFlux).
147+
cli.RequestChannel(initialPayload, sendFlux).
147148
DoOnNext(func(elem payload.Payload) error {
148149
log.Println("next element in channel:", elem)
149150
return nil

rsocket_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ func testAll(t *testing.T, proto string, clientTp transport.ClientTransporter, s
354354
s.Complete()
355355
})
356356
}),
357-
RequestChannel(func(inputs flux.Flux) flux.Flux {
357+
RequestChannel(func(initialRequest payload.Payload, inputs flux.Flux) flux.Flux {
358358
received := new(int32)
359359
inputs.
360360
DoOnNext(func(input payload.Payload) error {
@@ -471,6 +471,7 @@ func testRequestStreamOneByOne(ctx context.Context, cli Client, t *testing.T) {
471471

472472
func testRequestChannel(ctx context.Context, cli Client, t *testing.T) {
473473
// RequestChannel
474+
initialPayload := payload.NewString("This is a RequestChannel initial message.", "")
474475
send := flux.Create(func(ctx context.Context, s flux.Sink) {
475476
for i := 0; i < int(channelElements); i++ {
476477
s.Next(payload.NewString(fakeData, fmt.Sprintf("%d", i)))
@@ -480,7 +481,7 @@ func testRequestChannel(ctx context.Context, cli Client, t *testing.T) {
480481

481482
var seq int
482483

483-
_, err := cli.RequestChannel(send).
484+
_, err := cli.RequestChannel(initialPayload, send).
484485
DoOnNext(func(elem payload.Payload) error {
485486
//fmt.Println(elem)
486487
m, _ := elem.MetadataUTF8()
@@ -495,6 +496,7 @@ func testRequestChannel(ctx context.Context, cli Client, t *testing.T) {
495496

496497
func testRequestChannelOneByOne(ctx context.Context, cli Client, t *testing.T) {
497498
// RequestChannel
499+
initialPayload := payload.NewString("This is a RequestChannel initial message.", "")
498500
send := flux.Create(func(ctx context.Context, s flux.Sink) {
499501
for i := 0; i < int(channelElements); i++ {
500502
s.Next(payload.NewString(fakeData, fmt.Sprintf("%d", i)))
@@ -508,7 +510,7 @@ func testRequestChannelOneByOne(ctx context.Context, cli Client, t *testing.T) {
508510

509511
var su rx.Subscription
510512

511-
cli.RequestChannel(send).
513+
cli.RequestChannel(initialPayload, send).
512514
DoFinally(func(s rx.SignalType) {
513515
assert.Equal(t, rx.SignalComplete, s, "bad signal type")
514516
close(done)
@@ -599,7 +601,7 @@ func (d delayedRSocket) RequestStream(message payload.Payload) flux.Flux {
599601
panic("implement me")
600602
}
601603

602-
func (d delayedRSocket) RequestChannel(messages flux.Flux) flux.Flux {
604+
func (d delayedRSocket) RequestChannel(initialRequest payload.Payload, messages flux.Flux) flux.Flux {
603605
panic("implement me")
604606
}
605607

0 commit comments

Comments
 (0)