-
-
Save mmozeiko/c0dfcc8fec527a90a02145d2cc0bfb6d to your computer and use it in GitHub Desktop.
| #define WIN32_LEAN_AND_MEAN | |
| #include <winsock2.h> | |
| #include <windows.h> | |
| #define SECURITY_WIN32 | |
| #include <security.h> | |
| #include <schannel.h> | |
| #include <shlwapi.h> | |
| #include <assert.h> | |
| #include <stdio.h> | |
| #pragma comment (lib, "ws2_32.lib") | |
| #pragma comment (lib, "secur32.lib") | |
| #pragma comment (lib, "shlwapi.lib") | |
| #define TLS_MAX_PACKET_SIZE (16384+512) // payload + extra over head for header/mac/padding (probably an overestimate) | |
| typedef struct { | |
| SOCKET sock; | |
| CredHandle handle; | |
| CtxtHandle context; | |
| SecPkgContext_StreamSizes sizes; | |
| int received; // byte count in incoming buffer (ciphertext) | |
| int used; // byte count used from incoming buffer to decrypt current packet | |
| int available; // byte count available for decrypted bytes | |
| char* decrypted; // points to incoming buffer where data is decrypted inplace | |
| char incoming[TLS_MAX_PACKET_SIZE]; | |
| } tls_socket; | |
| // returns 0 on success or negative value on error | |
| static int tls_connect(tls_socket* s, const char* hostname, unsigned short port) | |
| { | |
| // initialize windows sockets | |
| WSADATA wsadata; | |
| if (WSAStartup(MAKEWORD(2, 2), &wsadata) != 0) | |
| { | |
| return -1; | |
| } | |
| // create TCP IPv4 socket | |
| s->sock = socket(AF_INET, SOCK_STREAM, 0); | |
| if (s->sock == INVALID_SOCKET) | |
| { | |
| WSACleanup(); | |
| return -1; | |
| } | |
| char sport[64]; | |
| wnsprintfA(sport, sizeof(sport), "%u", port); | |
| // connect to server | |
| if (!WSAConnectByNameA(s->sock, hostname, sport, NULL, NULL, NULL, NULL, NULL, NULL)) | |
| { | |
| closesocket(s->sock); | |
| WSACleanup(); | |
| return -1; | |
| } | |
| // initialize schannel | |
| { | |
| SCHANNEL_CRED cred = | |
| { | |
| .dwVersion = SCHANNEL_CRED_VERSION, | |
| .dwFlags = SCH_USE_STRONG_CRYPTO // use only strong crypto alogorithms | |
| | SCH_CRED_AUTO_CRED_VALIDATION // automatically validate server certificate | |
| | SCH_CRED_NO_DEFAULT_CREDS, // no client certificate authentication | |
| .grbitEnabledProtocols = SP_PROT_TLS1_2, // allow only TLS v1.2 | |
| }; | |
| if (AcquireCredentialsHandleA(NULL, UNISP_NAME_A, SECPKG_CRED_OUTBOUND, NULL, &cred, NULL, NULL, &s->handle, NULL) != SEC_E_OK) | |
| { | |
| closesocket(s->sock); | |
| WSACleanup(); | |
| return -1; | |
| } | |
| } | |
| s->received = s->used = s->available = 0; | |
| s->decrypted = NULL; | |
| // perform tls handshake | |
| // 1) call InitializeSecurityContext to create/update schannel context | |
| // 2) when it returns SEC_E_OK - tls handshake completed | |
| // 3) when it returns SEC_I_INCOMPLETE_CREDENTIALS - server requests client certificate (not supported here) | |
| // 4) when it returns SEC_I_CONTINUE_NEEDED - send token to server and read data | |
| // 5) when it returns SEC_E_INCOMPLETE_MESSAGE - need to read more data from server | |
| // 6) otherwise read data from server and go to step 1 | |
| CtxtHandle* context = NULL; | |
| int result = 0; | |
| for (;;) | |
| { | |
| SecBuffer inbuffers[2] = { 0 }; | |
| inbuffers[0].BufferType = SECBUFFER_TOKEN; | |
| inbuffers[0].pvBuffer = s->incoming; | |
| inbuffers[0].cbBuffer = s->received; | |
| inbuffers[1].BufferType = SECBUFFER_EMPTY; | |
| SecBuffer outbuffers[1] = { 0 }; | |
| outbuffers[0].BufferType = SECBUFFER_TOKEN; | |
| SecBufferDesc indesc = { SECBUFFER_VERSION, ARRAYSIZE(inbuffers), inbuffers }; | |
| SecBufferDesc outdesc = { SECBUFFER_VERSION, ARRAYSIZE(outbuffers), outbuffers }; | |
| DWORD flags = ISC_REQ_USE_SUPPLIED_CREDS | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM; | |
| SECURITY_STATUS sec = InitializeSecurityContextA( | |
| &s->handle, | |
| context, | |
| context ? NULL : (SEC_CHAR*)hostname, | |
| flags, | |
| 0, | |
| 0, | |
| context ? &indesc : NULL, | |
| 0, | |
| context ? NULL : &s->context, | |
| &outdesc, | |
| &flags, | |
| NULL); | |
| // after first call to InitializeSecurityContext context is available and should be reused for next calls | |
| context = &s->context; | |
| if (inbuffers[1].BufferType == SECBUFFER_EXTRA) | |
| { | |
| MoveMemory(s->incoming, s->incoming + (s->received - inbuffers[1].cbBuffer), inbuffers[1].cbBuffer); | |
| s->received = inbuffers[1].cbBuffer; | |
| } | |
| else | |
| { | |
| s->received = 0; | |
| } | |
| if (sec == SEC_E_OK) | |
| { | |
| // tls handshake completed | |
| break; | |
| } | |
| else if (sec == SEC_I_INCOMPLETE_CREDENTIALS) | |
| { | |
| // server asked for client certificate, not supported here | |
| result = -1; | |
| break; | |
| } | |
| else if (sec == SEC_I_CONTINUE_NEEDED) | |
| { | |
| // need to send data to server | |
| char* buffer = outbuffers[0].pvBuffer; | |
| int size = outbuffers[0].cbBuffer; | |
| while (size != 0) | |
| { | |
| int d = send(s->sock, buffer, size, 0); | |
| if (d <= 0) | |
| { | |
| break; | |
| } | |
| size -= d; | |
| buffer += d; | |
| } | |
| FreeContextBuffer(outbuffers[0].pvBuffer); | |
| if (size != 0) | |
| { | |
| // failed to fully send data to server | |
| result = -1; | |
| break; | |
| } | |
| } | |
| else if (sec != SEC_E_INCOMPLETE_MESSAGE) | |
| { | |
| // SEC_E_CERT_EXPIRED - certificate expired or revoked | |
| // SEC_E_WRONG_PRINCIPAL - bad hostname | |
| // SEC_E_UNTRUSTED_ROOT - cannot vertify CA chain | |
| // SEC_E_ILLEGAL_MESSAGE / SEC_E_ALGORITHM_MISMATCH - cannot negotiate crypto algorithms | |
| result = -1; | |
| break; | |
| } | |
| // read more data from server when possible | |
| if (s->received == sizeof(s->incoming)) | |
| { | |
| // server is sending too much data instead of proper handshake? | |
| result = -1; | |
| break; | |
| } | |
| int r = recv(s->sock, s->incoming + s->received, sizeof(s->incoming) - s->received, 0); | |
| if (r == 0) | |
| { | |
| // server disconnected socket | |
| return 0; | |
| } | |
| else if (r < 0) | |
| { | |
| // socket error | |
| result = -1; | |
| break; | |
| } | |
| s->received += r; | |
| } | |
| if (result != 0) | |
| { | |
| DeleteSecurityContext(context); | |
| FreeCredentialsHandle(&s->handle); | |
| closesocket(s->sock); | |
| WSACleanup(); | |
| return result; | |
| } | |
| QueryContextAttributes(context, SECPKG_ATTR_STREAM_SIZES, &s->sizes); | |
| return 0; | |
| } | |
| // disconnects socket & releases resources (call this even if tls_write/tls_read function return error) | |
| static void tls_disconnect(tls_socket* s) | |
| { | |
| DWORD type = SCHANNEL_SHUTDOWN; | |
| SecBuffer inbuffers[1]; | |
| inbuffers[0].BufferType = SECBUFFER_TOKEN; | |
| inbuffers[0].pvBuffer = &type; | |
| inbuffers[0].cbBuffer = sizeof(type); | |
| SecBufferDesc indesc = { SECBUFFER_VERSION, ARRAYSIZE(inbuffers), inbuffers }; | |
| ApplyControlToken(&s->context, &indesc); | |
| SecBuffer outbuffers[1]; | |
| outbuffers[0].BufferType = SECBUFFER_TOKEN; | |
| SecBufferDesc outdesc = { SECBUFFER_VERSION, ARRAYSIZE(outbuffers), outbuffers }; | |
| DWORD flags = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM; | |
| if (InitializeSecurityContextA(&s->handle, &s->context, NULL, flags, 0, 0, &outdesc, 0, NULL, &outdesc, &flags, NULL) == SEC_E_OK) | |
| { | |
| char* buffer = outbuffers[0].pvBuffer; | |
| int size = outbuffers[0].cbBuffer; | |
| while (size != 0) | |
| { | |
| int d = send(s->sock, buffer, size, 0); | |
| if (d <= 0) | |
| { | |
| // ignore any failures socket will be closed anyway | |
| break; | |
| } | |
| buffer += d; | |
| size -= d; | |
| } | |
| FreeContextBuffer(outbuffers[0].pvBuffer); | |
| } | |
| shutdown(s->sock, SD_BOTH); | |
| DeleteSecurityContext(&s->context); | |
| FreeCredentialsHandle(&s->handle); | |
| closesocket(s->sock); | |
| WSACleanup(); | |
| } | |
| // returns 0 on success or negative value on error | |
| static int tls_write(tls_socket* s, const void* buffer, int size) | |
| { | |
| while (size != 0) | |
| { | |
| int use = min(size, s->sizes.cbMaximumMessage); | |
| char wbuffer[TLS_MAX_PACKET_SIZE]; | |
| assert(s->sizes.cbHeader + s->sizes.cbMaximumMessage + s->sizes.cbTrailer <= sizeof(wbuffer)); | |
| SecBuffer buffers[3]; | |
| buffers[0].BufferType = SECBUFFER_STREAM_HEADER; | |
| buffers[0].pvBuffer = wbuffer; | |
| buffers[0].cbBuffer = s->sizes.cbHeader; | |
| buffers[1].BufferType = SECBUFFER_DATA; | |
| buffers[1].pvBuffer = wbuffer + s->sizes.cbHeader; | |
| buffers[1].cbBuffer = use; | |
| buffers[2].BufferType = SECBUFFER_STREAM_TRAILER; | |
| buffers[2].pvBuffer = wbuffer + s->sizes.cbHeader + use; | |
| buffers[2].cbBuffer = s->sizes.cbTrailer; | |
| CopyMemory(buffers[1].pvBuffer, buffer, use); | |
| SecBufferDesc desc = { SECBUFFER_VERSION, ARRAYSIZE(buffers), buffers }; | |
| SECURITY_STATUS sec = EncryptMessage(&s->context, 0, &desc, 0); | |
| if (sec != SEC_E_OK) | |
| { | |
| // this should not happen, but just in case check it | |
| return -1; | |
| } | |
| int total = buffers[0].cbBuffer + buffers[1].cbBuffer + buffers[2].cbBuffer; | |
| int sent = 0; | |
| while (sent != total) | |
| { | |
| int d = send(s->sock, wbuffer + sent, total - sent, 0); | |
| if (d <= 0) | |
| { | |
| // error sending data to socket, or server disconnected | |
| return -1; | |
| } | |
| sent += d; | |
| } | |
| buffer = (char*)buffer + use; | |
| size -= use; | |
| } | |
| return 0; | |
| } | |
| // blocking read, waits & reads up to size bytes, returns amount of bytes received on success (<= size) | |
| // returns 0 on disconnect or negative value on error | |
| static int tls_read(tls_socket* s, void* buffer, int size) | |
| { | |
| int result = 0; | |
| while (size != 0) | |
| { | |
| if (s->decrypted) | |
| { | |
| // if there is decrypted data available, then use it as much as possible | |
| int use = min(size, s->available); | |
| CopyMemory(buffer, s->decrypted, use); | |
| buffer = (char*)buffer + use; | |
| size -= use; | |
| result += use; | |
| if (use == s->available) | |
| { | |
| // all decrypted data is used, remove ciphertext from incoming buffer so next time it starts from beginning | |
| MoveMemory(s->incoming, s->incoming + s->used, s->received - s->used); | |
| s->received -= s->used; | |
| s->used = 0; | |
| s->available = 0; | |
| s->decrypted = NULL; | |
| } | |
| else | |
| { | |
| s->available -= use; | |
| s->decrypted += use; | |
| } | |
| } | |
| else | |
| { | |
| // if any ciphertext data available then try to decrypt it | |
| if (s->received != 0) | |
| { | |
| SecBuffer buffers[4]; | |
| assert(s->sizes.cBuffers == ARRAYSIZE(buffers)); | |
| buffers[0].BufferType = SECBUFFER_DATA; | |
| buffers[0].pvBuffer = s->incoming; | |
| buffers[0].cbBuffer = s->received; | |
| buffers[1].BufferType = SECBUFFER_EMPTY; | |
| buffers[2].BufferType = SECBUFFER_EMPTY; | |
| buffers[3].BufferType = SECBUFFER_EMPTY; | |
| SecBufferDesc desc = { SECBUFFER_VERSION, ARRAYSIZE(buffers), buffers }; | |
| SECURITY_STATUS sec = DecryptMessage(&s->context, &desc, 0, NULL); | |
| if (sec == SEC_E_OK) | |
| { | |
| assert(buffers[0].BufferType == SECBUFFER_STREAM_HEADER); | |
| assert(buffers[1].BufferType == SECBUFFER_DATA); | |
| assert(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER); | |
| s->decrypted = buffers[1].pvBuffer; | |
| s->available = buffers[1].cbBuffer; | |
| s->used = s->received - (buffers[3].BufferType == SECBUFFER_EXTRA ? buffers[3].cbBuffer : 0); | |
| // data is now decrypted, go back to beginning of loop to copy memory to output buffer | |
| continue; | |
| } | |
| else if (sec == SEC_I_CONTEXT_EXPIRED) | |
| { | |
| // server closed TLS connection (but socket is still open) | |
| s->received = 0; | |
| return result; | |
| } | |
| else if (sec == SEC_I_RENEGOTIATE) | |
| { | |
| // server wants to renegotiate TLS connection, not implemented here | |
| return -1; | |
| } | |
| else if (sec != SEC_E_INCOMPLETE_MESSAGE) | |
| { | |
| // some other schannel or TLS protocol error | |
| return -1; | |
| } | |
| // otherwise sec == SEC_E_INCOMPLETE_MESSAGE which means need to read more data | |
| } | |
| // otherwise not enough data received to decrypt | |
| if (result != 0) | |
| { | |
| // some data is already copied to output buffer, so return that before blocking with recv | |
| break; | |
| } | |
| if (s->received == sizeof(s->incoming)) | |
| { | |
| // server is sending too much garbage data instead of proper TLS packet | |
| return -1; | |
| } | |
| // wait for more ciphertext data from server | |
| int r = recv(s->sock, s->incoming + s->received, sizeof(s->incoming) - s->received, 0); | |
| if (r == 0) | |
| { | |
| // server disconnected socket | |
| return 0; | |
| } | |
| else if (r < 0) | |
| { | |
| // error receiving data from socket | |
| result = -1; | |
| break; | |
| } | |
| s->received += r; | |
| } | |
| } | |
| return result; | |
| } | |
| int main() | |
| { | |
| const char* hostname = "www.google.com"; | |
| //const char* hostname = "badssl.com"; | |
| //const char* hostname = "expired.badssl.com"; | |
| //const char* hostname = "wrong.host.badssl.com"; | |
| //const char* hostname = "self-signed.badssl.com"; | |
| //const char* hostname = "untrusted-root.badssl.com"; | |
| const char* path = "/"; | |
| tls_socket s; | |
| if (tls_connect(&s, hostname, 443) != 0) | |
| { | |
| printf("Error connecting to %s\n", hostname); | |
| return -1; | |
| } | |
| printf("Connected!\n"); | |
| // send request | |
| char req[1024]; | |
| int len = sprintf(req, "GET / HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", hostname); | |
| if (tls_write(&s, req, len) != 0) | |
| { | |
| tls_disconnect(&s); | |
| return -1; | |
| } | |
| // write response to file | |
| FILE* f = fopen("response.txt", "wb"); | |
| int received = 0; | |
| for (;;) | |
| { | |
| char buf[65536]; | |
| int r = tls_read(&s, buf, sizeof(buf)); | |
| if (r < 0) | |
| { | |
| printf("Error receiving data\n"); | |
| break; | |
| } | |
| else if (r == 0) | |
| { | |
| printf("Socket disconnected\n"); | |
| break; | |
| } | |
| else | |
| { | |
| fwrite(buf, 1, r, f); | |
| fflush(f); | |
| received += r; | |
| } | |
| } | |
| fclose(f); | |
| printf("Received %d bytes\n", received); | |
| tls_disconnect(&s); | |
| } |
Thanks for bringing it up anyways @never-unsealed
I am trying to use SChannel and stumbled upon this gist, but I have a question:
SECURITY_STATUS sec = InitializeSecurityContextA(
&s->handle,
context,
context ? NULL : (SEC_CHAR*)hostname,
flags,
0,
0,
context ? &indesc : NULL,
0,
context ? NULL : &s->context,
&outdesc,
&flags,
NULL);
//.....
int r = recv(s->sock, s->incoming + s->received, sizeof(s->incoming) - s->received, 0);
if (r == 0)
{
// server disconnected socket
return 0;
}In this snipped (or this line) - if the server disconnects, then we can't send/recv any data, so shouldn't the tls_connect return -1 in this case?
Sorry if this is a dumb question, I am learning.
Do you have plans to update your example to support TLS1.3 ?
i.e. Update to SCH_CREDENTIALS and handle a Renegotiate from DecryptMessage
John
Here are the basic 1.3 changes https://gist.github.com/mlt/694a4db9875d1c9f848204654dd1b636/revisions#diff-599e5710bf78734b1a90a4538e52d8f80efb77c89b1a777967f235d75bd5357f
These SO posts helped greatly https://stackoverflow.com/a/78833887/673826 and https://stackoverflow.com/a/78393548/673826
thanks a lot, that helped me to write custom TLS/SSL socket for my bot.
Since they updated the docs for all callers I assumed it also affected UM, but you're right, if even curl and Qt do it that way it's likely just a problem with the KM implementation...