using System.Buffers;
using System.Net.WebSockets;

var builder = WebApplication.CreateBuilder(args);
var app = builder.Build();

app.MapGet("/ws", async (HttpContext context) =>
{
    const int MaxMessageSize = 1024 * 1024;

    var ws = await context.WebSockets.AcceptWebSocketAsync();
    var incomingMessages = new MessagePipe<WebSocketMessageType>(MaxMessageSize);
    var outgoingMessages = new MessagePipe<(WebSocketMessageType messageType, bool endOfMessage)>(MaxMessageSize);

    _ = Task.Run(async () =>
    {
        while (true)
        {
            var message = await outgoingMessages.ReadAsync();
            var result = message.Result;
            var buffer = result.Buffer;
            var (messageType, endOfMessage) = message.Metadata;

            if (buffer.IsSingleSegment)
            {
                await ws.SendAsync(buffer.First, messageType, endOfMessage, cancellationToken: default);
            }
            else
            {
                var position = buffer.Start;
                // Get a segment before the loop so we can be one segment behind while writing
                // This allows us to do a non-zero byte write for the endOfMessage = true send
                buffer.TryGet(ref position, out var prevSegment);
                while (buffer.TryGet(ref position, out var segment))
                {
                    await ws.SendAsync(prevSegment, messageType, endOfMessage: false, default);
                    prevSegment = segment;
                }

                // End of message frame
                await ws.SendAsync(prevSegment, messageType, endOfMessage, default);
            }

            outgoingMessages.AdvanceReader();
        }
    });

    _ = Task.Run(async () =>
    {
        while (true)
        {
            var message = await incomingMessages.ReadAsync();
            var result = message.Result;
            var buffer = result.Buffer;
            var messageType = message.Metadata;

            if (!buffer.IsEmpty)
            {
                await ProcessMessageAsync(outgoingMessages, result.Buffer, messageType);
            }

            if (result.IsCompleted)
            {
                break;
            }

            incomingMessages.AdvanceReader();
        }

        await incomingMessages.CompleteReaderAsync();
    });

    while (true)
    {
        var result = await ws.ReceiveAsync(incomingMessages.GetMemory(512), default);

        if (result.MessageType == WebSocketMessageType.Close)
        {
            break;
        }

        incomingMessages.AdvanceWriter(result.Count);

        if (result.EndOfMessage)
        {
            await incomingMessages.FlushMessageAsync(result.MessageType);
        }
    }

    await incomingMessages.CompleteWriterAsync();
});

Task ProcessMessageAsync(MessagePipe<(WebSocketMessageType, bool)> outgoingMessages, ReadOnlySequence<byte> buffer, WebSocketMessageType message)
{
    // Process the message here
    return Task.CompletedTask;
}

app.Run();