Created
October 8, 2024 15:48
-
-
Save p-i-/02e26d89444efe51ca6d16ed28a1fa92 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(.venv) | |
pi@πlocal ~/code/2024/kit/kit/Gin master | |
> git diff | |
diff --git a/modules/gin_network/network/gin_asyncwebsocket.cpp b/modules/gin_network/network/gin_asyncwebsocket.cpp | |
index 15be54251e..f3a8f057ac 100644 | |
--- a/modules/gin_network/network/gin_asyncwebsocket.cpp | |
+++ b/modules/gin_network/network/gin_asyncwebsocket.cpp | |
@@ -6,8 +6,8 @@ | |
==============================================================================*/ | |
//============================================================================== | |
-AsyncWebsocket::AsyncWebsocket (const juce::URL url_) | |
- : Thread ("websocket"), url (url_) | |
+AsyncWebsocket::AsyncWebsocket(const juce::URL url_, const juce::StringPairArray& headers) | |
+ : Thread("websocket"), url(url_), customHeaders(headers) | |
{ | |
} | |
@@ -46,20 +46,20 @@ void AsyncWebsocket::process() | |
using MM = juce::MessageManager; | |
juce::WeakReference<AsyncWebsocket> weakThis = this; | |
- ws->dispatch ([this, weakThis] (const juce::MemoryBlock message, bool isBinary) | |
+ ws->dispatch([this, weakThis](const juce::MemoryBlock message, bool isBinary) | |
{ | |
auto messageCopy = message; | |
- MM::callAsync ([this, weakThis, messageCopy, isBinary] | |
+ MM::callAsync([this, weakThis, messageCopy, isBinary] | |
{ | |
if (weakThis != nullptr) | |
{ | |
- // if we are receiving data we don't need to ping | |
+ // If we are receiving data, we don't need to ping | |
lastPing = juce::Time::getMillisecondCounterHiRes() / 1000; | |
if (isBinary && onBinary) | |
- onBinary (messageCopy); | |
- else if (! isBinary && onText) | |
- onText (juce::String::fromUTF8 ((char*)messageCopy.getData(), int (messageCopy.getSize()))); | |
+ onBinary(messageCopy); | |
+ else if (!isBinary && onText) | |
+ onText(juce::String::fromUTF8((char*)messageCopy.getData(), int(messageCopy.getSize()))); | |
} | |
}); | |
}); | |
@@ -68,15 +68,15 @@ void AsyncWebsocket::process() | |
//============================================================================== | |
auto processOutgoingData = [this](std::unique_ptr<WebSocket>& ws) | |
{ | |
- juce::ScopedLock sl (lock); | |
+ juce::ScopedLock sl(lock); | |
for (auto& data : outgoingQueue) | |
{ | |
if (data.type == pingMsg) | |
ws->sendPing(); | |
else if (data.type == binaryMsg) | |
- ws->sendBinary (data.data); | |
+ ws->sendBinary(data.data); | |
else if (data.type == textMsg) | |
- ws->send (data.text); | |
+ ws->send(data.text); | |
else | |
jassertfalse; | |
} | |
@@ -87,18 +87,20 @@ void AsyncWebsocket::process() | |
using MM = juce::MessageManager; | |
juce::WeakReference<AsyncWebsocket> weakThis = this; | |
- auto ws = std::unique_ptr<WebSocket> (WebSocket::fromURL (url.toString (true).toStdString())); | |
+ // Use custom headers when creating WebSocket connection | |
+ auto ws = std::unique_ptr<WebSocket>(WebSocket::fromURL(url.toString(true), {}, customHeaders)); | |
+ | |
if (ws != nullptr) | |
{ | |
- MM::callAsync ([this, weakThis] | |
+ MM::callAsync([this, weakThis] | |
{ | |
- if (weakThis != nullptr && onConnect) | |
- onConnect(); | |
+ if (weakThis != nullptr && onConnect) | |
+ onConnect(); | |
}); | |
- while (! threadShouldExit()) | |
+ while (!threadShouldExit()) | |
{ | |
- ws->poll (50); | |
+ ws->poll(50); | |
if (ws->getReadyState() == WebSocket::CLOSED) | |
break; | |
@@ -109,25 +111,23 @@ void AsyncWebsocket::process() | |
lastPing = now; | |
} | |
- processIncomingData (ws); | |
- processOutgoingData (ws); | |
+ processIncomingData(ws); | |
+ processOutgoingData(ws); | |
} | |
if (ws->getReadyState() != WebSocket::CLOSED) | |
{ | |
ws->close(); | |
- ws->poll (0); | |
+ ws->poll(0); | |
} | |
} | |
- MM::callAsync ([this, weakThis] | |
+ MM::callAsync([this, weakThis] | |
{ | |
if (weakThis != nullptr && onDisconnect) | |
{ | |
- // The disconnect callback may cause the websocket to get deleted | |
- // Dangerous to use the lambda that is a member function, so make a copy | |
- auto safeCallback = onDisconnect; | |
- safeCallback (); | |
+ auto safeCallback = onDisconnect; | |
+ safeCallback(); | |
} | |
}); | |
} | |
diff --git a/modules/gin_network/network/gin_asyncwebsocket.h b/modules/gin_network/network/gin_asyncwebsocket.h | |
index 5d846a8e7a..3d10b81d41 100644 | |
--- a/modules/gin_network/network/gin_asyncwebsocket.h | |
+++ b/modules/gin_network/network/gin_asyncwebsocket.h | |
@@ -15,7 +15,7 @@ class AsyncWebsocket : public juce::Thread | |
{ | |
public: | |
//============================================================================== | |
- AsyncWebsocket (const juce::URL url); | |
+ AsyncWebsocket(const juce::URL url, const juce::StringPairArray& customHeaders = {}); | |
~AsyncWebsocket() override; | |
//============================================================================== | |
@@ -41,6 +41,7 @@ private: | |
void process(); | |
juce::URL url; | |
+ juce::StringPairArray customHeaders; | |
//============================================================================== | |
enum MessageType | |
diff --git a/modules/gin_network/network/gin_websocket.cpp b/modules/gin_network/network/gin_websocket.cpp | |
index b4f71403ff..bc22d2184e 100644 | |
--- a/modules/gin_network/network/gin_websocket.cpp | |
+++ b/modules/gin_network/network/gin_websocket.cpp | |
@@ -495,6 +495,180 @@ WebSocket* WebSocket::fromURL (const juce::String& url, bool useMask, const juce | |
return new WebSocket (std::move (socket), useMask); | |
} | |
+WebSocket* WebSocket::fromURL(const juce::String& url, const juce::String& origin, const juce::StringPairArray& customHeaders) | |
+{ | |
+ char path[512] = {0}; | |
+ char host[512] = {0}; | |
+ int port = 0; | |
+ | |
+ // Ensure URL and origin length limits are respected | |
+ if (url.length() >= 512) | |
+ { | |
+ fprintf(stderr, "ERROR: URL size limit exceeded: %s\n", url.toRawUTF8()); | |
+ return nullptr; | |
+ } | |
+ if (origin.length() >= 200) | |
+ { | |
+ fprintf(stderr, "ERROR: Origin size limit exceeded: %s\n", origin.toRawUTF8()); | |
+ return nullptr; | |
+ } | |
+ | |
+ bool secure = false; | |
+ | |
+ // Parse URL for wss:// or ws:// | |
+ if (sscanf(url.toRawUTF8(), "wss://%[^:/]:%d/%s", host, &port, path) == 3) | |
+ { | |
+ secure = true; | |
+ } | |
+ else if (sscanf(url.toRawUTF8(), "wss://%[^:/]/%s", host, path) == 2) | |
+ { | |
+ secure = true; | |
+ port = 443; | |
+ } | |
+ else if (sscanf(url.toRawUTF8(), "wss://%[^:/]:%d", host, &port) == 2) | |
+ { | |
+ secure = true; | |
+ path[0] = '\0'; | |
+ } | |
+ else if (sscanf(url.toRawUTF8(), "wss://%[^:/]", host) == 1) | |
+ { | |
+ secure = true; | |
+ port = 443; | |
+ path[0] = '\0'; | |
+ } | |
+ else if (sscanf(url.toRawUTF8(), "ws://%[^:/]:%d/%s", host, &port, path) == 3) | |
+ { | |
+ // non-secure ws:// | |
+ } | |
+ else if (sscanf(url.toRawUTF8(), "ws://%[^:/]/%s", host, path) == 2) | |
+ { | |
+ port = 80; | |
+ } | |
+ else if (sscanf(url.toRawUTF8(), "ws://%[^:/]:%d", host, &port) == 2) | |
+ { | |
+ path[0] = '\0'; | |
+ } | |
+ else if (sscanf(url.toRawUTF8(), "ws://%[^:/]", host) == 1) | |
+ { | |
+ port = 80; | |
+ path[0] = '\0'; | |
+ } | |
+ else | |
+ { | |
+ fprintf(stderr, "ERROR: Could not parse WebSocket URL: %s\n", url.toRawUTF8()); | |
+ return nullptr; | |
+ } | |
+ | |
+ // Create secure or non-secure socket | |
+ auto socket = std::make_unique<gin::SecureStreamingSocket>(secure); | |
+ if (!socket->connect(host, port)) | |
+ { | |
+ fprintf(stderr, "Unable to connect to %s:%d\n", host, port); | |
+ return nullptr; | |
+ } | |
+ | |
+ int sockfd = socket->getRawSocketHandle(); | |
+ | |
+ // Begin WebSocket handshake | |
+ { | |
+ char line[1024] = {0}; | |
+ int status; | |
+ int i; | |
+ | |
+ snprintf(line, 1024, "GET /%s HTTP/1.1\r\n", path); | |
+ socket->write(line, int(strlen(line))); | |
+ | |
+ if (port == 80) | |
+ { | |
+ snprintf(line, 1024, "Host: %s\r\n", host); | |
+ socket->write(line, int(strlen(line))); | |
+ } | |
+ else | |
+ { | |
+ snprintf(line, 1024, "Host: %s:%d\r\n", host, port); | |
+ socket->write(line, int(strlen(line))); | |
+ } | |
+ | |
+ // Standard WebSocket headers | |
+ snprintf(line, 1024, "Upgrade: websocket\r\n"); | |
+ socket->write(line, int(strlen(line))); | |
+ | |
+ snprintf(line, 1024, "Connection: Upgrade\r\n"); | |
+ socket->write(line, int(strlen(line))); | |
+ | |
+ if (origin.isNotEmpty()) | |
+ { | |
+ snprintf(line, 1024, "Origin: %s\r\n", origin.toRawUTF8()); | |
+ socket->write(line, int(strlen(line))); | |
+ } | |
+ | |
+ // Send Sec-WebSocket headers | |
+ snprintf(line, 1024, "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"); | |
+ socket->write(line, int(strlen(line))); | |
+ | |
+ snprintf(line, 1024, "Sec-WebSocket-Version: 13\r\n"); | |
+ socket->write(line, int(strlen(line))); | |
+ | |
+ // // Custom Headers | |
+ // for (auto& header : customHeaders) | |
+ // { | |
+ // snprintf(line, 1024, "%s: %s\r\n", header.first.toRawUTF8(), header.second.toRawUTF8()); | |
+ // socket->write(line, int(strlen(line))); | |
+ // } | |
+ | |
+ // Custom Headers | |
+ const juce::StringArray& headerKeys = customHeaders.getAllKeys(); | |
+ const juce::StringArray& headerValues = customHeaders.getAllValues(); | |
+ | |
+ for (int i = 0; i < headerKeys.size(); ++i) | |
+ { | |
+ snprintf(line, 1024, "%s: %s\r\n", headerKeys[i].toRawUTF8(), headerValues[i].toRawUTF8()); | |
+ socket->write(line, int(strlen(line))); | |
+ } | |
+ | |
+ // End the header section | |
+ snprintf(line, 1024, "\r\n"); | |
+ socket->write(line, int(strlen(line))); | |
+ | |
+ // Read server's handshake response | |
+ for (i = 0; i < 2 || (i < 1023 && line[i-2] != '\r' && line[i-1] != '\n'); ++i) | |
+ { | |
+ if (socket->read(line + i, 1, true) == 0) | |
+ return nullptr; | |
+ } | |
+ | |
+ line[i] = 0; | |
+ if (i == 1023) | |
+ { | |
+ fprintf(stderr, "ERROR: Got invalid status line connecting to: %s\n", url.toRawUTF8()); | |
+ return nullptr; | |
+ } | |
+ | |
+ if (sscanf(line, "HTTP/1.1 %d", &status) != 1 || status != 101) | |
+ { | |
+ fprintf(stderr, "ERROR: Got bad status connecting to %s: %s", url.toRawUTF8(), line); | |
+ return nullptr; | |
+ } | |
+ | |
+ // Validate response headers | |
+ while (true) | |
+ { | |
+ for (i = 0; i < 2 || (i < 1023 && line[i-2] != '\r' && line[i-1] != '\n'); ++i) | |
+ if (socket->read(line + i, 1, true) == 0) | |
+ return nullptr; | |
+ | |
+ if (line[0] == '\r' && line[1] == '\n') | |
+ break; | |
+ } | |
+ } | |
+ | |
+ // Disable Nagle's algorithm | |
+ int flag = 1; | |
+ setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(flag)); | |
+ | |
+ // Return WebSocket object | |
+ return new WebSocket(std::move(socket), true); | |
+} | |
WebSocket* WebSocket::fromURL (const juce::String& url, const juce::String& origin) | |
{ | |
return WebSocket::fromURL (url, true, origin); | |
diff --git a/modules/gin_network/network/gin_websocket.h b/modules/gin_network/network/gin_websocket.h | |
index c70643813a..6849f59f4a 100644 | |
--- a/modules/gin_network/network/gin_websocket.h | |
+++ b/modules/gin_network/network/gin_websocket.h | |
@@ -44,6 +44,8 @@ class WebSocket | |
static WebSocket* fromURLNoMask (const juce::String& url, const juce::String& origin = {}); | |
static WebSocket* fromURL (const juce::String& url, bool useMask, const juce::String& origin); | |
+ static WebSocket* fromURL (const juce::String& url, const juce::String& origin, const juce::StringPairArray& customHeaders); | |
+ | |
~WebSocket(); | |
bool readIncoming(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment