Skip to content

Instantly share code, notes, and snippets.

@p-i-
Created October 8, 2024 15:48
Show Gist options
  • Save p-i-/02e26d89444efe51ca6d16ed28a1fa92 to your computer and use it in GitHub Desktop.
Save p-i-/02e26d89444efe51ca6d16ed28a1fa92 to your computer and use it in GitHub Desktop.
(.venv)
pi@πlocal ~/code/2024/kit/kit/Gin master
> git diff
diff --git a/modules/gin_network/network/gin_asyncwebsocket.cpp b/modules/gin_network/network/gin_asyncwebsocket.cpp
index 15be54251e..f3a8f057ac 100644
--- a/modules/gin_network/network/gin_asyncwebsocket.cpp
+++ b/modules/gin_network/network/gin_asyncwebsocket.cpp
@@ -6,8 +6,8 @@
==============================================================================*/
//==============================================================================
-AsyncWebsocket::AsyncWebsocket (const juce::URL url_)
- : Thread ("websocket"), url (url_)
+AsyncWebsocket::AsyncWebsocket(const juce::URL url_, const juce::StringPairArray& headers)
+ : Thread("websocket"), url(url_), customHeaders(headers)
{
}
@@ -46,20 +46,20 @@ void AsyncWebsocket::process()
using MM = juce::MessageManager;
juce::WeakReference<AsyncWebsocket> weakThis = this;
- ws->dispatch ([this, weakThis] (const juce::MemoryBlock message, bool isBinary)
+ ws->dispatch([this, weakThis](const juce::MemoryBlock message, bool isBinary)
{
auto messageCopy = message;
- MM::callAsync ([this, weakThis, messageCopy, isBinary]
+ MM::callAsync([this, weakThis, messageCopy, isBinary]
{
if (weakThis != nullptr)
{
- // if we are receiving data we don't need to ping
+ // If we are receiving data, we don't need to ping
lastPing = juce::Time::getMillisecondCounterHiRes() / 1000;
if (isBinary && onBinary)
- onBinary (messageCopy);
- else if (! isBinary && onText)
- onText (juce::String::fromUTF8 ((char*)messageCopy.getData(), int (messageCopy.getSize())));
+ onBinary(messageCopy);
+ else if (!isBinary && onText)
+ onText(juce::String::fromUTF8((char*)messageCopy.getData(), int(messageCopy.getSize())));
}
});
});
@@ -68,15 +68,15 @@ void AsyncWebsocket::process()
//==============================================================================
auto processOutgoingData = [this](std::unique_ptr<WebSocket>& ws)
{
- juce::ScopedLock sl (lock);
+ juce::ScopedLock sl(lock);
for (auto& data : outgoingQueue)
{
if (data.type == pingMsg)
ws->sendPing();
else if (data.type == binaryMsg)
- ws->sendBinary (data.data);
+ ws->sendBinary(data.data);
else if (data.type == textMsg)
- ws->send (data.text);
+ ws->send(data.text);
else
jassertfalse;
}
@@ -87,18 +87,20 @@ void AsyncWebsocket::process()
using MM = juce::MessageManager;
juce::WeakReference<AsyncWebsocket> weakThis = this;
- auto ws = std::unique_ptr<WebSocket> (WebSocket::fromURL (url.toString (true).toStdString()));
+ // Use custom headers when creating WebSocket connection
+ auto ws = std::unique_ptr<WebSocket>(WebSocket::fromURL(url.toString(true), {}, customHeaders));
+
if (ws != nullptr)
{
- MM::callAsync ([this, weakThis]
+ MM::callAsync([this, weakThis]
{
- if (weakThis != nullptr && onConnect)
- onConnect();
+ if (weakThis != nullptr && onConnect)
+ onConnect();
});
- while (! threadShouldExit())
+ while (!threadShouldExit())
{
- ws->poll (50);
+ ws->poll(50);
if (ws->getReadyState() == WebSocket::CLOSED)
break;
@@ -109,25 +111,23 @@ void AsyncWebsocket::process()
lastPing = now;
}
- processIncomingData (ws);
- processOutgoingData (ws);
+ processIncomingData(ws);
+ processOutgoingData(ws);
}
if (ws->getReadyState() != WebSocket::CLOSED)
{
ws->close();
- ws->poll (0);
+ ws->poll(0);
}
}
- MM::callAsync ([this, weakThis]
+ MM::callAsync([this, weakThis]
{
if (weakThis != nullptr && onDisconnect)
{
- // The disconnect callback may cause the websocket to get deleted
- // Dangerous to use the lambda that is a member function, so make a copy
- auto safeCallback = onDisconnect;
- safeCallback ();
+ auto safeCallback = onDisconnect;
+ safeCallback();
}
});
}
diff --git a/modules/gin_network/network/gin_asyncwebsocket.h b/modules/gin_network/network/gin_asyncwebsocket.h
index 5d846a8e7a..3d10b81d41 100644
--- a/modules/gin_network/network/gin_asyncwebsocket.h
+++ b/modules/gin_network/network/gin_asyncwebsocket.h
@@ -15,7 +15,7 @@ class AsyncWebsocket : public juce::Thread
{
public:
//==============================================================================
- AsyncWebsocket (const juce::URL url);
+ AsyncWebsocket(const juce::URL url, const juce::StringPairArray& customHeaders = {});
~AsyncWebsocket() override;
//==============================================================================
@@ -41,6 +41,7 @@ private:
void process();
juce::URL url;
+ juce::StringPairArray customHeaders;
//==============================================================================
enum MessageType
diff --git a/modules/gin_network/network/gin_websocket.cpp b/modules/gin_network/network/gin_websocket.cpp
index b4f71403ff..bc22d2184e 100644
--- a/modules/gin_network/network/gin_websocket.cpp
+++ b/modules/gin_network/network/gin_websocket.cpp
@@ -495,6 +495,180 @@ WebSocket* WebSocket::fromURL (const juce::String& url, bool useMask, const juce
return new WebSocket (std::move (socket), useMask);
}
+WebSocket* WebSocket::fromURL(const juce::String& url, const juce::String& origin, const juce::StringPairArray& customHeaders)
+{
+ char path[512] = {0};
+ char host[512] = {0};
+ int port = 0;
+
+ // Ensure URL and origin length limits are respected
+ if (url.length() >= 512)
+ {
+ fprintf(stderr, "ERROR: URL size limit exceeded: %s\n", url.toRawUTF8());
+ return nullptr;
+ }
+ if (origin.length() >= 200)
+ {
+ fprintf(stderr, "ERROR: Origin size limit exceeded: %s\n", origin.toRawUTF8());
+ return nullptr;
+ }
+
+ bool secure = false;
+
+ // Parse URL for wss:// or ws://
+ if (sscanf(url.toRawUTF8(), "wss://%[^:/]:%d/%s", host, &port, path) == 3)
+ {
+ secure = true;
+ }
+ else if (sscanf(url.toRawUTF8(), "wss://%[^:/]/%s", host, path) == 2)
+ {
+ secure = true;
+ port = 443;
+ }
+ else if (sscanf(url.toRawUTF8(), "wss://%[^:/]:%d", host, &port) == 2)
+ {
+ secure = true;
+ path[0] = '\0';
+ }
+ else if (sscanf(url.toRawUTF8(), "wss://%[^:/]", host) == 1)
+ {
+ secure = true;
+ port = 443;
+ path[0] = '\0';
+ }
+ else if (sscanf(url.toRawUTF8(), "ws://%[^:/]:%d/%s", host, &port, path) == 3)
+ {
+ // non-secure ws://
+ }
+ else if (sscanf(url.toRawUTF8(), "ws://%[^:/]/%s", host, path) == 2)
+ {
+ port = 80;
+ }
+ else if (sscanf(url.toRawUTF8(), "ws://%[^:/]:%d", host, &port) == 2)
+ {
+ path[0] = '\0';
+ }
+ else if (sscanf(url.toRawUTF8(), "ws://%[^:/]", host) == 1)
+ {
+ port = 80;
+ path[0] = '\0';
+ }
+ else
+ {
+ fprintf(stderr, "ERROR: Could not parse WebSocket URL: %s\n", url.toRawUTF8());
+ return nullptr;
+ }
+
+ // Create secure or non-secure socket
+ auto socket = std::make_unique<gin::SecureStreamingSocket>(secure);
+ if (!socket->connect(host, port))
+ {
+ fprintf(stderr, "Unable to connect to %s:%d\n", host, port);
+ return nullptr;
+ }
+
+ int sockfd = socket->getRawSocketHandle();
+
+ // Begin WebSocket handshake
+ {
+ char line[1024] = {0};
+ int status;
+ int i;
+
+ snprintf(line, 1024, "GET /%s HTTP/1.1\r\n", path);
+ socket->write(line, int(strlen(line)));
+
+ if (port == 80)
+ {
+ snprintf(line, 1024, "Host: %s\r\n", host);
+ socket->write(line, int(strlen(line)));
+ }
+ else
+ {
+ snprintf(line, 1024, "Host: %s:%d\r\n", host, port);
+ socket->write(line, int(strlen(line)));
+ }
+
+ // Standard WebSocket headers
+ snprintf(line, 1024, "Upgrade: websocket\r\n");
+ socket->write(line, int(strlen(line)));
+
+ snprintf(line, 1024, "Connection: Upgrade\r\n");
+ socket->write(line, int(strlen(line)));
+
+ if (origin.isNotEmpty())
+ {
+ snprintf(line, 1024, "Origin: %s\r\n", origin.toRawUTF8());
+ socket->write(line, int(strlen(line)));
+ }
+
+ // Send Sec-WebSocket headers
+ snprintf(line, 1024, "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n");
+ socket->write(line, int(strlen(line)));
+
+ snprintf(line, 1024, "Sec-WebSocket-Version: 13\r\n");
+ socket->write(line, int(strlen(line)));
+
+ // // Custom Headers
+ // for (auto& header : customHeaders)
+ // {
+ // snprintf(line, 1024, "%s: %s\r\n", header.first.toRawUTF8(), header.second.toRawUTF8());
+ // socket->write(line, int(strlen(line)));
+ // }
+
+ // Custom Headers
+ const juce::StringArray& headerKeys = customHeaders.getAllKeys();
+ const juce::StringArray& headerValues = customHeaders.getAllValues();
+
+ for (int i = 0; i < headerKeys.size(); ++i)
+ {
+ snprintf(line, 1024, "%s: %s\r\n", headerKeys[i].toRawUTF8(), headerValues[i].toRawUTF8());
+ socket->write(line, int(strlen(line)));
+ }
+
+ // End the header section
+ snprintf(line, 1024, "\r\n");
+ socket->write(line, int(strlen(line)));
+
+ // Read server's handshake response
+ for (i = 0; i < 2 || (i < 1023 && line[i-2] != '\r' && line[i-1] != '\n'); ++i)
+ {
+ if (socket->read(line + i, 1, true) == 0)
+ return nullptr;
+ }
+
+ line[i] = 0;
+ if (i == 1023)
+ {
+ fprintf(stderr, "ERROR: Got invalid status line connecting to: %s\n", url.toRawUTF8());
+ return nullptr;
+ }
+
+ if (sscanf(line, "HTTP/1.1 %d", &status) != 1 || status != 101)
+ {
+ fprintf(stderr, "ERROR: Got bad status connecting to %s: %s", url.toRawUTF8(), line);
+ return nullptr;
+ }
+
+ // Validate response headers
+ while (true)
+ {
+ for (i = 0; i < 2 || (i < 1023 && line[i-2] != '\r' && line[i-1] != '\n'); ++i)
+ if (socket->read(line + i, 1, true) == 0)
+ return nullptr;
+
+ if (line[0] == '\r' && line[1] == '\n')
+ break;
+ }
+ }
+
+ // Disable Nagle's algorithm
+ int flag = 1;
+ setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(flag));
+
+ // Return WebSocket object
+ return new WebSocket(std::move(socket), true);
+}
WebSocket* WebSocket::fromURL (const juce::String& url, const juce::String& origin)
{
return WebSocket::fromURL (url, true, origin);
diff --git a/modules/gin_network/network/gin_websocket.h b/modules/gin_network/network/gin_websocket.h
index c70643813a..6849f59f4a 100644
--- a/modules/gin_network/network/gin_websocket.h
+++ b/modules/gin_network/network/gin_websocket.h
@@ -44,6 +44,8 @@ class WebSocket
static WebSocket* fromURLNoMask (const juce::String& url, const juce::String& origin = {});
static WebSocket* fromURL (const juce::String& url, bool useMask, const juce::String& origin);
+ static WebSocket* fromURL (const juce::String& url, const juce::String& origin, const juce::StringPairArray& customHeaders);
+
~WebSocket();
bool readIncoming();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment