Skip to content

Commit d214e5c

Browse files
authored
Merge pull request #1 from cristipufu/feature/sse_support
Add SSE support
2 parents be9a1ac + c29d1a0 commit d214e5c

10 files changed

Lines changed: 355 additions & 1 deletion

File tree

src/Tunnelite.Sdk/HttpTunnel/HttpConnection.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,8 @@ public class WsConnection
1414
public Guid RequestId { get; set; }
1515
public string Path { get; set; }
1616
}
17+
18+
public class SseConnection : HttpConnection
19+
{
20+
public string Content { get; set; }
21+
}

src/Tunnelite.Sdk/HttpTunnel/HttpTunnelClient.cs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ public HttpTunnelClient(HttpTunnelRequest tunnel, LogLevel? logLevel)
7373
return Task.CompletedTask;
7474
});
7575

76+
Connection.On<SseConnection>("NewSseConnection", (sseConnection) =>
77+
{
78+
LogRequest?.Invoke("SSE", sseConnection.Path);
79+
80+
_ = TunnelSseConnectionAsync(sseConnection);
81+
82+
return Task.CompletedTask;
83+
});
84+
7685
Connection.Reconnected += async connectionId =>
7786
{
7887
_currentTunnel = await RegisterTunnelAsync(tunnel);
@@ -260,6 +269,84 @@ private async Task StreamOutgoingWsAsync(WebSocket localWebSocket, WsConnection
260269
}
261270
}
262271

272+
private async Task TunnelSseConnectionAsync(SseConnection sseConnection)
273+
{
274+
using var cts = new CancellationTokenSource();
275+
276+
try
277+
{
278+
// Send the request to the local server
279+
using var request = new HttpRequestMessage(new HttpMethod(sseConnection.Method), sseConnection.Path);
280+
281+
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
282+
283+
request.Content = new StringContent(sseConnection.Content);
284+
if (sseConnection.ContentType != null)
285+
{
286+
request.Content.Headers.ContentType = MediaTypeHeaderValue.Parse(sseConnection.ContentType);
287+
}
288+
289+
using var response = await LocalHttpClient.SendAsync(
290+
request,
291+
HttpCompletionOption.ResponseHeadersRead,
292+
cts.Token);
293+
294+
response.EnsureSuccessStatusCode();
295+
296+
// Stream the SSE data from local server to the public tunnel
297+
var outgoingTask = StreamOutgoingSseAsync(response, sseConnection, cts.Token);
298+
299+
// Wait for the streaming to complete
300+
await outgoingTask;
301+
}
302+
catch (Exception ex)
303+
{
304+
LogFailedRequest?.Invoke("SSE", sseConnection.Path);
305+
LogException?.Invoke(ex);
306+
}
307+
finally
308+
{
309+
await cts.CancelAsync();
310+
311+
Log?.Invoke($"[SSE] Connection {sseConnection.RequestId} closed.");
312+
}
313+
}
314+
315+
private async Task StreamOutgoingSseAsync(HttpResponseMessage response, SseConnection sseConnection, CancellationToken cancellationToken)
316+
{
317+
await Connection.InvokeAsync(
318+
"StreamOutgoingSseAsync",
319+
StreamLocalSseAsync(response, sseConnection, cancellationToken),
320+
sseConnection,
321+
cancellationToken: cancellationToken);
322+
}
323+
324+
private async IAsyncEnumerable<ReadOnlyMemory<byte>> StreamLocalSseAsync(
325+
HttpResponseMessage response,
326+
SseConnection sseConnection,
327+
[EnumeratorCancellation] CancellationToken cancellationToken)
328+
{
329+
const int chunkSize = 32 * 1024;
330+
byte[] buffer = ArrayPool<byte>.Shared.Rent(chunkSize);
331+
332+
try
333+
{
334+
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
335+
336+
int bytesRead;
337+
while ((bytesRead = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken)) > 0)
338+
{
339+
yield return new ReadOnlyMemory<byte>(buffer, 0, bytesRead);
340+
}
341+
}
342+
finally
343+
{
344+
Log?.Invoke($"[SSE] Reading data from connection {sseConnection.RequestId} finished.");
345+
346+
ArrayPool<byte>.Shared.Return(buffer);
347+
}
348+
}
349+
263350
private async Task<HttpTunnelResponse?> RegisterTunnelAsync(HttpTunnelRequest tunnel)
264351
{
265352
tunnel.Subdomain = _currentTunnel?.Subdomain;

src/Tunnelite.Server/Extensions.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Tunnelite.Server.HttpTunnel;
2+
using Tunnelite.Server.SseTunnel;
23
using Tunnelite.Server.TcpTunnel;
34
using Tunnelite.Server.WsTunnel;
45

@@ -10,6 +11,7 @@ public static void AddHttpTunneling(this WebApplicationBuilder builder)
1011
{
1112
builder.Services.AddSingleton<HttpTunnelStore>();
1213
builder.Services.AddSingleton<HttpRequestsQueue>();
14+
builder.Services.AddSingleton<SseRequestsQueue>();
1315
builder.Services.AddSingleton<WsRequestsQueue>();
1416
}
1517

src/Tunnelite.Server/HttpTunnel/HttpTunnelHub.cs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
using Microsoft.AspNetCore.SignalR;
22
using System.Buffers;
33
using System.Net.WebSockets;
4+
using Tunnelite.Server.SseTunnel;
45
using Tunnelite.Server.WsTunnel;
56

67
namespace Tunnelite.Server.HttpTunnel;
78

8-
public class HttpTunnelHub(HttpTunnelStore httpTunnelStore, WsRequestsQueue wsRequestsQueue, ILogger<HttpTunnelHub> logger) : Hub
9+
public class HttpTunnelHub(HttpTunnelStore httpTunnelStore, WsRequestsQueue wsRequestsQueue, SseRequestsQueue sseRequestsQueue, ILogger<HttpTunnelHub> logger) : Hub
910
{
1011
private readonly HttpTunnelStore _httpTunnelStore = httpTunnelStore;
1112
private readonly WsRequestsQueue _wsRequestsQueue = wsRequestsQueue;
13+
private readonly SseRequestsQueue _sseRequestsQueue = sseRequestsQueue;
14+
1215
private readonly ILogger _logger = logger;
1316

1417
public override Task OnConnectedAsync()
@@ -112,6 +115,46 @@ public async Task StreamOutgoingWsAsync(WsConnection wsConnection, IAsyncEnumera
112115
}
113116
}
114117

118+
public async Task StreamOutgoingSseAsync(SseConnection sseConnection, IAsyncEnumerable<ReadOnlyMemory<byte>> stream)
119+
{
120+
var clientId = GetClientId(Context);
121+
122+
var context = _sseRequestsQueue.GetHttpContext(clientId, sseConnection.RequestId);
123+
124+
if (context == null)
125+
{
126+
return;
127+
}
128+
129+
try
130+
{
131+
await foreach (var chunk in stream.WithCancellation(Context.ConnectionAborted))
132+
{
133+
await context.Response.Body.WriteAsync(chunk, Context.ConnectionAborted);
134+
await context.Response.Body.FlushAsync(Context.ConnectionAborted);
135+
}
136+
}
137+
catch (OperationCanceledException)
138+
{
139+
// ignore
140+
}
141+
catch (Exception ex) when (ex.Message == "Stream canceled by client.")
142+
{
143+
// ignore
144+
}
145+
catch (Exception ex)
146+
{
147+
_logger.LogError(ex, "An unexpected error occurred while streaming SSE data for {RequestId}", sseConnection.RequestId);
148+
}
149+
finally
150+
{
151+
_logger.LogInformation("Done writing.. SSE Connection {RequestId}", sseConnection.RequestId);
152+
153+
// Complete the SSE request
154+
await _sseRequestsQueue.CompleteAsync(clientId, sseConnection.RequestId);
155+
}
156+
}
157+
115158
public override async Task OnDisconnectedAsync(Exception? exception)
116159
{
117160
var clientId = GetClientId(Context);
@@ -124,6 +167,7 @@ public override async Task OnDisconnectedAsync(Exception? exception)
124167
}
125168

126169
await _wsRequestsQueue.CompleteAsync(clientId);
170+
await _sseRequestsQueue.CompleteAsync(clientId);
127171

128172
await base.OnDisconnectedAsync(exception);
129173
}

src/Tunnelite.Server/Program.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Tunnelite.Server;
22
using Tunnelite.Server.HttpTunnel;
3+
using Tunnelite.Server.SseTunnel;
34
using Tunnelite.Server.TcpTunnel;
45
using Tunnelite.Server.WsTunnel;
56

@@ -21,6 +22,8 @@
2122

2223
app.UseWsTunneling();
2324

25+
app.UsSseTunneling();
26+
2427
app.UseHttpTunneling();
2528

2629
app.UseTcpTunneling();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
namespace Tunnelite.Server.SseTunnel;
2+
3+
public static class SseAppExtensions
4+
{
5+
public static void UsSseTunneling(this WebApplication app)
6+
{
7+
app.UseMiddleware<SseTunnelMiddleware>();
8+
}
9+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#nullable disable
2+
using Tunnelite.Server.HttpTunnel;
3+
4+
namespace Tunnelite.Server.SseTunnel;
5+
6+
public class SseConnection : HttpConnection
7+
{
8+
public string Content { get; set; }
9+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
namespace Tunnelite.Server.SseTunnel;
2+
3+
public class SseDeferredRequest
4+
{
5+
public HttpContext? HttpContext { get; set; }
6+
public Guid RequestId { get; set; }
7+
public TaskCompletionSource? TaskCompletionSource { get; set; }
8+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using System.Collections.Concurrent;
2+
namespace Tunnelite.Server.SseTunnel;
3+
4+
public class SseRequestsQueue
5+
{
6+
// client, [requestId, SseDeferredRequest]
7+
private readonly ConcurrentDictionary<Guid, ConcurrentDictionary<Guid, SseDeferredRequest>> PendingRequests = new();
8+
9+
public Task WaitForCompletionAsync(Guid clientId, Guid requestId, HttpContext context)
10+
{
11+
SseDeferredRequest request = new()
12+
{
13+
HttpContext = context,
14+
RequestId = requestId,
15+
TaskCompletionSource = new TaskCompletionSource(),
16+
};
17+
18+
PendingRequests.AddOrUpdate(
19+
clientId,
20+
_ => new ConcurrentDictionary<Guid, SseDeferredRequest> { [requestId] = request },
21+
(_, requests) =>
22+
{
23+
requests[requestId] = request;
24+
return requests;
25+
});
26+
27+
return request.TaskCompletionSource.Task;
28+
}
29+
30+
public virtual HttpContext? GetHttpContext(Guid clientId, Guid requestId)
31+
{
32+
if (!PendingRequests.TryGetValue(clientId, out var requests))
33+
{
34+
return null;
35+
}
36+
37+
requests.TryGetValue(requestId, out var request);
38+
39+
return request?.HttpContext;
40+
}
41+
42+
public virtual Task CompleteAsync(Guid clientId, Guid requestId)
43+
{
44+
if (!PendingRequests.TryGetValue(clientId, out var requests))
45+
{
46+
return Task.CompletedTask;
47+
}
48+
49+
if (!requests.TryRemove(requestId, out var request))
50+
{
51+
return Task.CompletedTask;
52+
}
53+
54+
if (!request.TaskCompletionSource!.Task.IsCompleted)
55+
{
56+
// Try to complete the task
57+
if (request.TaskCompletionSource?.TrySetResult() == false)
58+
{
59+
// The request was canceled
60+
}
61+
}
62+
else
63+
{
64+
// The request was canceled while pending
65+
}
66+
67+
return Task.CompletedTask;
68+
}
69+
70+
public virtual async Task CompleteAsync(Guid clientId)
71+
{
72+
if (!PendingRequests.TryRemove(clientId, out var requests))
73+
{
74+
return;
75+
}
76+
77+
foreach (var request in requests)
78+
{
79+
await CompleteAsync(clientId, request.Key);
80+
}
81+
}
82+
}

0 commit comments

Comments
 (0)