import SwiftUI
import RealityKit

struct PureSwiftWavyPlaneView: View {
    @State private var phase: Float = 0.0
    @State private var mesh: LowLevelMesh?
    @State private var timer: Timer?
    let resolution = 250
    
    var body: some View {
        RealityView { content in
            let planeEntity = try! getPlaneEntity()
            let lightEntity = try! getLightEntity()
            planeEntity.addChild(lightEntity)
            content.add(planeEntity)
        }
        .onAppear { startTimer() }
        .onDisappear { stopTimer() }
    }
    
    private func startTimer() {
        timer = Timer.scheduledTimer(withTimeInterval: 1/120.0, repeats: true) { _ in
            phase += 0.1
            updateMesh(phase: phase)
        }
    }
    
    private func stopTimer() {
        timer?.invalidate()
        timer = nil
    }
    
    func getLightEntity() throws -> Entity {
        let entity = Entity()
        let pointLightComponent = PointLightComponent( cgColor: .init(red: 1, green: 1, blue: 1, alpha: 1), intensity: 10000, attenuationRadius: 0.25 )
        entity.components.set(pointLightComponent)
        entity.position = .init(x: 0, y: 0, z: 0.125)
        return entity
    }
    
    func getPlaneEntity() throws -> Entity {
        let mesh = try createPlaneMesh()
        let resource = try! MeshResource(from: mesh)
        var material = PhysicallyBasedMaterial()
        material.baseColor.tint = .init(red: 0.0625, green: 0.125, blue: 1.0, alpha: 1.0)
        material.faceCulling = .none
        material.metallic = 0.0
        material.roughness = 0.0

        let modelComponent = ModelComponent(mesh: resource, materials: [material])

        let entity = Entity()
        entity.components.set(modelComponent)

        return entity
    }
    
    func createPlaneMesh() throws -> LowLevelMesh {
        let vertexCount = resolution * resolution
        let indexCount = (resolution - 1) * (resolution - 1) * 6
        
        var desc = MyVertexWithNormal.descriptor
        desc.vertexCapacity = vertexCount
        desc.indexCapacity = indexCount

        let mesh = try LowLevelMesh(descriptor: desc)
        self.mesh = mesh
        return mesh
    }
    
    func updateMesh(amplitude: Float = 0.1, frequency: Float = 60.0, phase: Float = 0.0) {
        guard let mesh = mesh else { return }
        let size: Float = 0.4
        let indexCount = (resolution - 1) * (resolution - 1) * 6

        mesh.withUnsafeMutableBytes(bufferIndex: 0) { rawBytes in
            let vertices = rawBytes.bindMemory(to: MyVertexWithNormal.self)

            for y in 0..<resolution {
                for x in 0..<resolution {
                    let index = y * resolution + x
                    let xPos = Float(x) / Float(resolution - 1) * size - size / 2
                    let yPos = Float(y) / Float(resolution - 1) * size - size / 2

                    // Calculate distance from center
                    let distanceFromCenter = sqrt(xPos * xPos + yPos * yPos)

                    // Calculate z using a sine wave based on distance from center
                    let z = amplitude * sin(frequency * distanceFromCenter - phase)

                    let position = SIMD3<Float>(xPos, yPos, z)

                    // Calculate normal for the wavy surface
                    let dz_dr = amplitude * frequency * cos(frequency * distanceFromCenter - phase)
                    let nx = -dz_dr * xPos / distanceFromCenter
                    let ny = -dz_dr * yPos / distanceFromCenter
                    let normal = simd_normalize(SIMD3<Float>(nx, ny, -1))

                    vertices[index] = MyVertexWithNormal(position: position, normal: normal)
                }
            }
        }

        mesh.withUnsafeMutableIndices { rawIndices in
            let indices = rawIndices.bindMemory(to: UInt32.self)
            var index = 0

            for y in 0..<(resolution - 1) {
                for x in 0..<(resolution - 1) {
                    let topLeft = UInt32(y * resolution + x)
                    let topRight = topLeft + 1
                    let bottomLeft = UInt32((y + 1) * resolution + x)
                    let bottomRight = bottomLeft + 1

                    indices[index] = topLeft
                    indices[index + 1] = bottomLeft
                    indices[index + 2] = topRight

                    indices[index + 3] = topRight
                    indices[index + 4] = bottomLeft
                    indices[index + 5] = bottomRight

                    index += 6
                }
            }
        }

        let maxZ = amplitude
        mesh.parts.replaceAll([
            LowLevelMesh.Part(
                indexCount: indexCount,
                topology: .triangle,
                bounds: BoundingBox(min: [-size/2, -size/2, -maxZ], max: [size/2, size/2, maxZ])
            )
        ])
    }
}

#Preview {
    PureSwiftWavyPlaneView()
}