using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEditor;
using UnityEngine;
using UnityEngine.Playables;

namespace Needle.AnimationUtils
{
    public class VisualizeMotionTrajectories : MonoBehaviour
    {
        public Animator target;
        
        public float fps = 15f;
        public float velocityScale = 0.1f;
        public float velocityPower = 2f;
        public float minOpacity = 0.04f;
        public float length = 12f;
        
        // builds a line topology mesh with vertex colors.
        // requires a vertex color capable shader to look nice, otherwise it will just be a single color
        [ContextMenu("Build Mesh")]
        void BuildMesh()
        {
            var go = target.gameObject;
            if (!go) return;
            var g = target ? target.playableGraph : default;
            
            var driver = ScriptableObject.CreateInstance<AnimationModeDriver>();
            
            AnimationMode.StartAnimationMode(driver);
            AnimationMode.BeginSampling();
            var t = 0f;
            var len = length;
            var dt = 1f / fps;

            var vertices = new Dictionary<Transform, List<Vector3>>();

            var ts = go.GetComponentsInChildren<Transform>();
            foreach (var tr in ts)
            {
                // TODO for Playable Director we need to collect all outputs and then iterate their children here
                var bindings= AnimationUtility.GetAnimatableBindings(tr.gameObject, go);
                for (int i = 0; i < bindings.Length; i++)
                {
                    var animated = AnimationUtility.GetAnimatedObject(go, bindings[i]);
                    if (animated is Transform animatedTransform)
                    {
                        vertices.Add(animatedTransform, new List<Vector3>());
                        break;
                    }
                }
            }

            var originalPos = vertices.Keys.ToDictionary(x => x, x => x.transform.position);
            var bounds = new Bounds(originalPos.First().Value, Vector3.zero);
            foreach (var kvp in originalPos)
            {
                bounds.Encapsulate(kvp.Value);
            }

            // Debug.Log("Animated objects:\n" + string.Join("\n", vertices.Keys));

            void Sample()
            {
                foreach (var kvp in vertices)
                {
                    var pos = kvp.Key.position;
                    kvp.Value.Add(pos);
                }
            }

            var outputCount = g.GetOutputCount();

            int channel = 0;
            while (t < len)
            {
                AnimationMode.SamplePlayableGraph(g, channel, t);
                Sample();
                t += dt;
            }

            // last frame
            AnimationMode.SamplePlayableGraph(g, channel, len);
            Sample();
            
            AnimationMode.EndSampling();
            AnimationMode.StopAnimationMode(driver);

            if (!gameObject.TryGetComponent<MeshFilter>(out var mf))
                mf = gameObject.AddComponent<MeshFilter>();

            var m = new Mesh();
            
            var vs = vertices.Values.SelectMany(x => x).ToList();
            var cs = new Color[vs.Count];
            
            m.SetVertices(vs);
            var indices = new List<int>();
            var lastEnd = 0;
            foreach (var kvp in vertices)
            {
                var originalP = originalPos[kvp.Key];
                var xColor = Mathf.InverseLerp(bounds.min.x, bounds.max.x, originalP.x);
                var yColor = Mathf.InverseLerp(bounds.min.y, bounds.max.y, originalP.y);
                
                var list = kvp.Value;
                for (int j = 0; j < list.Count - 1; j++)
                {
                    indices.Add(lastEnd++);
                    indices.Add(lastEnd);
                    var velocity = list[j + 1] - list[j];
                    var mag = velocity.magnitude * velocityScale;
                    mag = Mathf.Pow(mag, velocityPower);
                    cs[lastEnd] = new Color(xColor, yColor, 0, mag + minOpacity);
                }

                lastEnd++;
            }

            m.SetColors(cs);
            m.SetIndices(indices.ToList(), MeshTopology.Lines, 0);
            mf.sharedMesh = m;
        }
    }
}