#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <openssl/sha.h>  // For SHA-1 (link with -lcrypto)
#include <openssl/bio.h>  // For Base64 encoding
#include <openssl/evp.h>
#include <openssl/buffer.h>

#define PORT 8080
#define BUFFER_SIZE 1024
#define MAGIC_STRING "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

// Compute Sec-WebSocket-Accept header value (unchanged from previous example)
char* compute_accept_key(const char* client_key) {
    char combined[100];
    snprintf(combined, sizeof(combined), "%s%s", client_key, MAGIC_STRING);

    unsigned char sha1[SHA_DIGEST_LENGTH];
    SHA1((unsigned char*)combined, strlen(combined), sha1);

    BIO* b64 = BIO_new(BIO_f_base64());
    BIO* mem = BIO_new(BIO_s_mem());
    b64 = BIO_push(b64, mem);
    BIO_set_flags(b64, BIO_FLAGS_BASE64_NO_NL);
    BIO_write(b64, sha1, SHA_DIGEST_LENGTH);
    BIO_flush(b64);

    BUF_MEM* buffer;
    BIO_get_mem_ptr(b64, &buffer);
    char* accept_key = (char *)malloc(buffer->length + 1);
    memcpy(accept_key, buffer->data, buffer->length);
    accept_key[buffer->length] = '\0';

    BIO_free_all(b64);
    return accept_key;
}

// Send HTTP handshake response (unchanged)
void send_handshake(int client_fd, const char* client_key) {
    char* accept_key = compute_accept_key(client_key);
    char response[BUFFER_SIZE];
    snprintf(response, sizeof(response),
             "HTTP/1.1 101 Switching Protocols\r\n"
             "Upgrade: websocket\r\n"
             "Connection: Upgrade\r\n"
             "Sec-WebSocket-Accept: %s\r\n\r\n",
             accept_key);
    write(client_fd, response, strlen(response));
    free(accept_key);
}

// Extract Sec-WebSocket-Key from HTTP request (unchanged)
char* extract_key(const char* request) {
    const char* key_start = strstr(request, "Sec-WebSocket-Key: ");
    if (!key_start) return NULL;
    key_start += 19;
    char* key_end = (char *)strstr(key_start, "\r\n");
    if (!key_end) return NULL;
    int key_len = key_end - key_start;
    char* key = (char *)malloc(key_len + 1);
    strncpy(key, key_start, key_len);
    key[key_len] = '\0';
    return key;
}

// Decode a masked payload
void decode_payload(char* payload, unsigned char* mask_key, int payload_len) {
    for (int i = 0; i < payload_len; i++) {
        payload[i] = payload[i] ^ mask_key[i % 4];
    }
}

int main() {
    int server_fd, client_fd;
    struct sockaddr_in address;
    int addrlen = sizeof(address);
    char buffer[BUFFER_SIZE] = {0};

    // Socket setup (unchanged)
    if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) {
        perror("Socket failed");
        exit(EXIT_FAILURE);
    }
    address.sin_family = AF_INET;
    address.sin_addr.s_addr = INADDR_ANY;
    address.sin_port = htons(PORT);
    if (bind(server_fd, (struct sockaddr*)&address, sizeof(address)) < 0) {
        perror("Bind failed");
        exit(EXIT_FAILURE);
    }
    if (listen(server_fd, 3) < 0) {
        perror("Listen failed");
        exit(EXIT_FAILURE);
    }
    printf("Server listening on port %d\n", PORT);

    // Accept client
    if ((client_fd = accept(server_fd, (struct sockaddr*)&address, (socklen_t*)&addrlen)) < 0) {
        perror("Accept failed");
        exit(EXIT_FAILURE);
    }

    // Handle handshake
    read(client_fd, buffer, BUFFER_SIZE);
    char* client_key = extract_key(buffer);
    if (client_key) {
        send_handshake(client_fd, client_key);
        free(client_key);
    } else {
        printf("Invalid WebSocket request\n");
        close(client_fd);
        close(server_fd);
        return 1;
    }

    // WebSocket frame handling loop
    while (1) {
        memset(buffer, 0, BUFFER_SIZE);
        int bytes_read = read(client_fd, buffer, BUFFER_SIZE);
        if (bytes_read <= 0) break;

        // Parse frame
        unsigned char first_byte = buffer[0];
        unsigned char second_byte = buffer[1];
        int fin = (first_byte & 0x80) >> 7;    // FIN bit
        int opcode = first_byte & 0x0F;        // Opcode (e.g., 0x1 for text)
        int mask = (second_byte & 0x80) >> 7;  // Mask bit
        int payload_len = second_byte & 0x7F;  // Payload length (basic case)

        if (opcode == 0x1 && fin == 1) {  // Text frame, FIN set
            char* payload;
            unsigned char* mask_key;

            if (mask == 1) {
                // Masked frame: mask key follows second byte
                mask_key = (unsigned char*)buffer + 2;
                payload = buffer + 6;  // Payload starts after mask key
                decode_payload(payload, mask_key, payload_len);
            } else {
                // Unmasked frame (shouldn’t happen from client, but handle for robustness)
                payload = buffer + 2;
            }

            printf("Received: %.*s\n", payload_len, payload);

            // Echo back (unmasked, FIN=1, opcode=0x1)
            char frame[BUFFER_SIZE] = {(char)0x81, (char)payload_len};
            memcpy(frame + 2, payload, payload_len);
            write(client_fd, frame, 2 + payload_len);
        } else {
            printf("Unsupported or fragmented frame\n");
        }
    }

    close(client_fd);
    close(server_fd);
    return 0;
}