import RealityKit
import SwiftUI

struct FireballCirclesAddBlendView: View {
    @State var rootEntity: Entity?
    @State var sphereTargets: [Entity: SIMD3<Float>] = [:]
    @State private var rotationAngles: SIMD3<Float> = [0, 0, 0]
    @State private var modulationTimer: Timer?
    @State private var lastRotationUpdateTime = CACurrentMediaTime()

    var body: some View {
        RealityView { content in
            let rootEntity = await FireballCirclesAddBlendView.createRootEntity()
            content.add(rootEntity)
            initializeTargets(for: rootEntity)
            self.rootEntity = rootEntity
        }
        .onAppear { startTimer() }
        .onDisappear { stopTimer() }
    }

    private func startTimer() {
        modulationTimer = Timer.scheduledTimer(withTimeInterval: 1/120.0, repeats: true) { _ in
            moveChildEntities()
            rotateRootEntity()
        }
    }
    
    private func stopTimer() {
        modulationTimer?.invalidate()
        modulationTimer = nil
    }
    
    private func moveChildEntities() {
        guard let rootEntity else { return }
        let movementSpeed: Float = 0.00025
        
        for child in rootEntity.children {
            if let target = sphereTargets[child] {
                let direction = normalize(target - child.position)
                child.position += direction * movementSpeed
                
                if distance(child.position, target) < movementSpeed {
                    sphereTargets[child] = FireballCirclesAddBlendView.generateRandomPosition(bound: 0.025)
                }
            }
        }
    }
    
    private func rotateRootEntity() {
        let currentTime = CACurrentMediaTime()
        let frameDuration = currentTime - lastRotationUpdateTime
        
        // Rotate along all axis at different rates for a wonky rotation effect
        let scaleFactor = 1.0
        rotationAngles.x += Float(frameDuration * 1.5 * scaleFactor)
        rotationAngles.y += Float(frameDuration * 0.75 * scaleFactor)
        rotationAngles.z += Float(frameDuration * 0.5 * scaleFactor)
        
        let rotationX = simd_quatf(angle: rotationAngles.x, axis: [1, 0, 0])
        let rotationY = simd_quatf(angle: rotationAngles.y, axis: [0, 1, 0])
        let rotationZ = simd_quatf(angle: rotationAngles.z, axis: [0, 0, 1])
        rootEntity?.transform.rotation = rotationX * rotationY * rotationZ
        
        lastRotationUpdateTime = currentTime
    }
    
    static func createRootEntity() async -> Entity {
        let rootEntity = Entity()
        
        // create model component for reuse with all circles
        let radius: Float = 0.01
        let sphereMesh = try! MeshResource.generateSpecificSphere(radius: radius, latitudeBands: 6, longitudeBands: 10)
        let material = await generateAddMaterial(color: .orange)
        let modelComponent = ModelComponent(mesh: sphereMesh, materials: [material])
        
        // create many circles
        let entityCount = 400
        for _ in 0..<entityCount {
            let entity = Entity()
            entity.position = generateRandomPosition(bound: 0.1)
            entity.components.set(modelComponent)
            rootEntity.addChild(entity)
        }
        
        return rootEntity
    }
    
    static func generateAddMaterial(color: UIColor) async -> UnlitMaterial {
        var descriptor = UnlitMaterial.Program.Descriptor()
        descriptor.blendMode = .add
        let prog = await UnlitMaterial.Program(descriptor: descriptor)
        var material = UnlitMaterial(program: prog)
        material.color = UnlitMaterial.BaseColor(tint: color)
        material.blending = .transparent(opacity: 0.075)
        return material
    }
    
    private func initializeTargets(for rootEntity: Entity) {
        for child in rootEntity.children {
            sphereTargets[child] = FireballCirclesAddBlendView.generateRandomPosition(bound: 0.025)
        }
    }
    
    static func generateRandomPosition(bound: Float) -> SIMD3<Float> {
        let x = Float.random(in: -bound...bound)
        let y = Float.random(in: -bound...bound)
        let z = Float.random(in: -bound...bound)
        return SIMD3<Float>(x, y, z)
    }
}

#Preview {
    FireballCirclesAddBlendView()
}

extension MeshResource {
    static func generateSpecificSphere(radius: Float, latitudeBands: Int = 10, longitudeBands: Int = 10) throws -> MeshResource {
        let vertexCount = (latitudeBands + 1) * (longitudeBands + 1)
        let indexCount = latitudeBands * longitudeBands * 6
        
        var desc = MyVertexWithNormal.descriptor
        desc.vertexCapacity = vertexCount
        desc.indexCapacity = indexCount
        
        let mesh = try LowLevelMesh(descriptor: desc)

        mesh.withUnsafeMutableBytes(bufferIndex: 0) { rawBytes in
            let vertices = rawBytes.bindMemory(to: MyVertexWithNormal.self)
            var vertexIndex = 0
            
            for latNumber in 0...latitudeBands {
                let theta = Float(latNumber) * Float.pi / Float(latitudeBands)
                let sinTheta = sin(theta)
                let cosTheta = cos(theta)
                
                for longNumber in 0...longitudeBands {
                    let phi = Float(longNumber) * 2 * Float.pi / Float(longitudeBands)
                    let sinPhi = sin(phi)
                    let cosPhi = cos(phi)
                    
                    let x = cosPhi * sinTheta
                    let y = cosTheta
                    let z = sinPhi * sinTheta
                    let position = SIMD3<Float>(x, y, z) * radius
                    let normal = -SIMD3<Float>(x, y, z).normalized()
                    vertices[vertexIndex] = MyVertexWithNormal(position: position, normal: normal)
                    vertexIndex += 1
                }
            }
        }
        
        mesh.withUnsafeMutableIndices { rawIndices in
            let indices = rawIndices.bindMemory(to: UInt32.self)
            var index = 0
            
            for latNumber in 0..<latitudeBands {
                for longNumber in 0..<longitudeBands {
                    let first = (latNumber * (longitudeBands + 1)) + longNumber
                    let second = first + longitudeBands + 1
                    
                    indices[index] = UInt32(first)
                    indices[index + 1] = UInt32(second)
                    indices[index + 2] = UInt32(first + 1)
                    
                    indices[index + 3] = UInt32(second)
                    indices[index + 4] = UInt32(second + 1)
                    indices[index + 5] = UInt32(first + 1)
                    
                    index += 6
                }
            }
        }
        
        let meshBounds = BoundingBox(min: [-radius, -radius, -radius], max: [radius, radius, radius])
        mesh.parts.replaceAll([
            LowLevelMesh.Part(
                indexCount: indexCount,
                topology: .triangle,
                bounds: meshBounds
            )
        ])
        
        return try MeshResource(from: mesh)
    }
}

struct MyVertexWithNormal {
    var position: SIMD3<Float> = .zero
    var normal: SIMD3<Float> = .zero
    
    static var vertexAttributes: [LowLevelMesh.Attribute] = [
        .init(semantic: .position, format: .float3, offset: MemoryLayout<Self>.offset(of: \.position)!),
        .init(semantic: .normal, format: .float3, offset: MemoryLayout<Self>.offset(of: \.normal)!),
    ]

    static var vertexLayouts: [LowLevelMesh.Layout] = [
        .init(bufferIndex: 0, bufferStride: MemoryLayout<Self>.stride)
    ]

    static var descriptor: LowLevelMesh.Descriptor {
        var desc = LowLevelMesh.Descriptor()
        desc.vertexAttributes = MyVertexWithNormal.vertexAttributes
        desc.vertexLayouts = MyVertexWithNormal.vertexLayouts
        desc.indexType = .uint32
        return desc
    }
}