#pragma comment(lib, "user32")
#pragma comment(lib, "d3d11")
#pragma comment(lib, "d3dcompiler")

///////////////////////////////////////////////////////////////////////////////////////////////////

#include <windows.h>
#include <d3d11.h>
#include <d3dcompiler.h>

///////////////////////////////////////////////////////////////////////////////////////////////////

int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPSTR lpCmdLine, int nShowCmd)
{
    WNDCLASSA wndclass = { 0, DefWindowProcA, 0, 0, 0, 0, 0, 0, 0, "d7" };

    RegisterClassA(&wndclass);

    HWND window = CreateWindowExA(0, "d7", 0, 0x91000000, 0, 0, 0, 0, 0, 0, 0, 0);

    ///////////////////////////////////////////////////////////////////////////////////////////////

    D3D_FEATURE_LEVEL featurelevels[] = { D3D_FEATURE_LEVEL_11_0 };

    DXGI_SWAP_CHAIN_DESC swapchaindesc = {};
    swapchaindesc.BufferDesc.Format = DXGI_FORMAT_B8G8R8A8_UNORM; // non-srgb for simplicity here. see other minimal gists for srgb setup
    swapchaindesc.SampleDesc.Count  = 1;
    swapchaindesc.BufferUsage       = DXGI_USAGE_RENDER_TARGET_OUTPUT;
    swapchaindesc.BufferCount       = 2;
    swapchaindesc.OutputWindow      = window;
    swapchaindesc.Windowed          = TRUE;
    swapchaindesc.SwapEffect        = DXGI_SWAP_EFFECT_FLIP_DISCARD;

    IDXGISwapChain* swapchain;

    ID3D11Device* device;
    ID3D11DeviceContext* devicecontext;

    D3D11CreateDeviceAndSwapChain(nullptr, D3D_DRIVER_TYPE_HARDWARE, nullptr, D3D11_CREATE_DEVICE_BGRA_SUPPORT, featurelevels, ARRAYSIZE(featurelevels), D3D11_SDK_VERSION, &swapchaindesc, &swapchain, &device, nullptr, &devicecontext);

    swapchain->GetDesc(&swapchaindesc); // get actual dimensions (see lines 94, 98)

    ///////////////////////////////////////////////////////////////////////////////////////////////

    ID3D11Texture2D* framebuffer;

    swapchain->GetBuffer(0, __uuidof(ID3D11Texture2D), (void**)&framebuffer); // get the swapchain's frame buffer

    ID3D11RenderTargetView* framebufferRTV;

    device->CreateRenderTargetView(framebuffer, nullptr, &framebufferRTV); // make a render target [view] from it

    FLOAT clearcolor[4] = { 0.1725f, 0.1725f, 0.1725f, 1.0f }; // RGBA

    ///////////////////////////////////////////////////////////////////////////////////////////////

    ID3DBlob* vertexshaderCSO;

    D3DCompileFromFile(L"gpu.hlsl", 0, 0, "VsMain", "vs_5_0", 0, 0, &vertexshaderCSO, 0);

    ID3D11VertexShader* vertexshader;

    device->CreateVertexShader(vertexshaderCSO->GetBufferPointer(), vertexshaderCSO->GetBufferSize(), 0, &vertexshader);

    D3D11_INPUT_ELEMENT_DESC inputelementdesc[] = // maps to vertexdesc struct in gpu.hlsl via semantic names ("POS", "COL")
    {
        { "POS", 0, DXGI_FORMAT_R32G32_FLOAT,    0,                            0, D3D11_INPUT_PER_VERTEX_DATA, 0 }, // float2 position (x, y)
        { "COL", 0, DXGI_FORMAT_R32G32B32_FLOAT, 0, D3D11_APPEND_ALIGNED_ELEMENT, D3D11_INPUT_PER_VERTEX_DATA, 0 }, // float3 color (r, g, b)
    };

    ID3D11InputLayout* inputlayout;

    device->CreateInputLayout(inputelementdesc, ARRAYSIZE(inputelementdesc), vertexshaderCSO->GetBufferPointer(), vertexshaderCSO->GetBufferSize(), &inputlayout);

    ///////////////////////////////////////////////////////////////////////////////////////////////

    ID3DBlob* pixelshaderCSO;

    D3DCompileFromFile(L"gpu.hlsl", 0, 0, "PsMain", "ps_5_0", 0, 0, &pixelshaderCSO, 0);

    ID3D11PixelShader* pixelshader;

    device->CreatePixelShader(pixelshaderCSO->GetBufferPointer(), pixelshaderCSO->GetBufferSize(), 0, &pixelshader);

    ///////////////////////////////////////////////////////////////////////////////////////////////

    D3D11_RASTERIZER_DESC rasterizerdesc = { D3D11_FILL_SOLID, D3D11_CULL_NONE }; // CULL_NONE to be agnostic of triangle winding order

    ID3D11RasterizerState* rasterizerstate;

    device->CreateRasterizerState(&rasterizerdesc, &rasterizerstate);

    D3D11_VIEWPORT viewport = { 0, 0, (float)swapchaindesc.BufferDesc.Width, (float)swapchaindesc.BufferDesc.Height, 0, 1 };

    ///////////////////////////////////////////////////////////////////////////////////////////////

    float constants[2] = { 2.0f / swapchaindesc.BufferDesc.Width, -2.0f / swapchaindesc.BufferDesc.Height }; // precalc for simple screen coordinate transform in shader (instead of full-on projection matrix)

    D3D11_BUFFER_DESC constantbufferdesc = {};
    constantbufferdesc.ByteWidth = sizeof(constants) + 0xf & 0xfffffff0; // constant buffer size must be multiple of 16
    constantbufferdesc.Usage     = D3D11_USAGE_IMMUTABLE; // never updated
    constantbufferdesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;

    D3D11_SUBRESOURCE_DATA constantbufferSRD = { constants };

    ID3D11Buffer* constantbuffer;

    device->CreateBuffer(&constantbufferdesc, &constantbufferSRD, &constantbuffer);

    ///////////////////////////////////////////////////////////////////////////////////////////////

    #define MAX_VERTICES 1024 // arbitrary limit

    struct vertexdesc { float x, y, r, g, b; }; // float2 position, float3 color

    ///////////////////////////////////////////////////////////////////////////////////////////////

    D3D11_BUFFER_DESC vertexbufferdesc = {};
    vertexbufferdesc.ByteWidth      = sizeof(vertexdesc) * MAX_VERTICES;
    vertexbufferdesc.Usage          = D3D11_USAGE_DYNAMIC; // updated every frame
    vertexbufferdesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
    vertexbufferdesc.BindFlags      = D3D11_BIND_VERTEX_BUFFER;

    ID3D11Buffer* vertexbuffer;

    device->CreateBuffer(&vertexbufferdesc, nullptr, &vertexbuffer);

    UINT stride = sizeof(vertexdesc);
    UINT offset = 0;

    ///////////////////////////////////////////////////////////////////////////////////////////////

    while (true)
    {
        MSG msg;

        while (PeekMessageA(&msg, nullptr, 0, 0, PM_REMOVE))
        {
            if (msg.message == WM_KEYDOWN) return 0; // PRESS ANY KEY TO EXIT
            DispatchMessageA(&msg);
        }

        ///////////////////////////////////////////////////////////////////////////////////////////

        D3D11_MAPPED_SUBRESOURCE vertexbufferMSR;

        devicecontext->Map(vertexbuffer, 0, D3D11_MAP_WRITE_DISCARD, 0, &vertexbufferMSR);
        {
            vertexdesc* vertex = (vertexdesc*)vertexbufferMSR.pData;

            vertex[0] = { 150, 100,    1.0f, 0.0f, 0.0f }; // vertex x, y,    r, g, b
            vertex[1] = { 200, 250,    0.0f, 1.0f, 0.0f };
            vertex[2] = { 100, 200,    0.0f, 0.0f, 1.0f };
        }
        devicecontext->Unmap(vertexbuffer, 0);

        ///////////////////////////////////////////////////////////////////////////////////////////

        devicecontext->ClearRenderTargetView(framebufferRTV, clearcolor);

        devicecontext->IASetPrimitiveTopology(D3D11_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
        devicecontext->IASetInputLayout(inputlayout);
        devicecontext->IASetVertexBuffers(0, 1, &vertexbuffer, &stride, &offset);

        devicecontext->VSSetShader(vertexshader, nullptr, 0);
        devicecontext->VSSetConstantBuffers(0, 1, &constantbuffer);

        devicecontext->RSSetViewports(1, &viewport);
        devicecontext->RSSetState(rasterizerstate);

        devicecontext->PSSetShader(pixelshader, nullptr, 0);

        devicecontext->OMSetRenderTargets(1, &framebufferRTV, nullptr);

        ///////////////////////////////////////////////////////////////////////////////////////////

        devicecontext->Draw(3, 0); // draw 3 vertices

        ///////////////////////////////////////////////////////////////////////////////////////////

        swapchain->Present(1, 0);
    }
}