"""
1. Record an Off-Beat Jazz rehearsal session with an Sony A7C camera
2. Plug in an SD card to your Mac
3. Run this script, which will automatically segment the rehearsal session and render
   to lower-quality, shareable, Google Drive-able H265-encoded video files
4. Upload to Google Drive
5. Get better at playing Jazz

Requirements:
python3 -m pip install tqdm ffmpeg-python numpy ffmpeg matplotlib

@psobot 2023-09-07
"""

import os
import pickle
import argparse
import subprocess
import platformdirs
import hashlib
import inspect
from functools import wraps
from tqdm import tqdm
from io import BytesIO
from glob import glob
from typing import Iterable

import ffmpeg
import numpy as np
from pedalboard import PeakFilter
from pedalboard.io import AudioFile
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker


DEFAULT_GLOB_EXPR = "/Volumes/*/PRIVATE/M4ROOT/CLIP/"
EXTENSIONS = ["*.MP4", "*.MTS"]
APPNAME, APPAUTHOR = "autoclip", "psobot"

QUIET_THRESHOLD = 0.06
SEARCH_DISTANCE_SECONDS = 60
MIN_QUIET_SECONDS = 3
THRESHOLD = 0.35

def cache_on_disk(fun):
    @wraps(fun)
    def inner(*args, **kwargs):
        cache_dir = platformdirs.user_cache_dir(APPNAME, APPAUTHOR)
        os.makedirs(cache_dir, exist_ok=True)
        fun_key = hashlib.md5(inspect.getsource(fun).encode("utf-8")).hexdigest()
        key = hashlib.md5(
            ''.join((repr(arg) for arg in (args, kwargs))).encode("utf-8")
        ).hexdigest()
        filename = os.path.join(cache_dir, fun_key + "." + key + ".pkl")
        try:
            with open(filename, "rb") as f:
                return pickle.load(f)
        except OSError:
            pass
        result = fun(*args, **kwargs)
        with open(filename, "wb") as f:
            pickle.dump(result, f)
        return result
    
    return inner


def scan_for_files(directory: str | None = None) -> Iterable[str]:
    if directory:
        return sum((glob(os.path.join(directory, extension)) for extension in EXTENSIONS), [])
    return sum((glob(os.path.join(DEFAULT_GLOB_EXPR, extension)) for extension in EXTENSIONS), [])


def to_hhmmssms(num_seconds: float | None) -> str:
    """
    Format a value in seconds as FFMPEG's preferred time input format, hh:mm:ss.ms.
    """
    if num_seconds is None:
        return "None"
    ms = int(1000 * (num_seconds % 1))
    seconds = int(num_seconds % 60)
    minutes = int((num_seconds // 60) % 60)
    hours = int(num_seconds // 3600)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{ms:03d}"



def segment_with_backwards_search(
    profile: np.ndarray,
    min_quiet_seconds: float = MIN_QUIET_SECONDS,
    quiet_threshold: float = QUIET_THRESHOLD,
    search_distance_seconds: float = SEARCH_DISTANCE_SECONDS,
) -> Iterable[tuple[int, int]]:
    plt.plot(profile, label="profile")
    for start, end in segment_loudness_profile(profile):
        # Search back from "start" until a small minimum is found:
        consecutive_seconds = 0
        for i in range(int(start), max(0, int(start - search_distance_seconds)), -1):
            if profile[i] > quiet_threshold:
                consecutive_seconds = 0
            else:
                consecutive_seconds += 1
            if consecutive_seconds > min_quiet_seconds:
                yield i, end
                break
        else:
            yield start, end

def segment_loudness_profile(
    profile: np.ndarray,
    window_size: int = 30,
    threshold: float = THRESHOLD,
    min_length: int = 60,
    down_time: float = 60,
) -> Iterable[tuple[int, int]]:
    smoothed = np.convolve(profile, np.ones(window_size))
    smoothed /= np.amax(smoothed)
    plt.plot(smoothed, label="smoothed")

    buffer = window_size / 2
    start_index = None
    end_index = None
    last_end = None
    for index, value in enumerate(smoothed):
        if last_end is not None and index - buffer < last_end:
            continue
        if value > threshold:
            if start_index is None:
                start_index = index
                end_index = None
            elif start_index is not None and end_index is not None:
                end_index = None
        if value < threshold:
            if start_index is not None:
                if end_index is None:
                    end_index = index
                if index - end_index > down_time:
                    if end_index - start_index >= min_length:
                        yield (start_index - buffer, end_index - buffer / 2)
                        last_end = end_index - buffer / 2
                    start_index = None
                    end_index = None
    

@cache_on_disk
def measure_loudness_profile(filename: str) -> np.ndarray:
    # Have FFMPEG decode to low-bitrate, low-sample-rate MP3 in-memory
    # so we can take a rough loudness profile.
    #
    # On my M1 MacBook Air, this runs about 90x real-time and requires
    # about 14MB of memory per hour of video.
    file_duration = max(float(stream["duration"]) for stream in ffmpeg.probe(filename)["streams"])
    bitrate = 32000
    bytes_per_second = bitrate / 8
    expected_mp3_size = float(file_duration * bytes_per_second)
    process = (
        ffmpeg.input(filename)
            .audio
            .output("pipe:", format="mp3", audio_bitrate=bitrate, ar=8000)
            .run_async(pipe_stdout=True, pipe_stderr=True)
    )

    buf = BytesIO()
    title = f"Scanning {os.path.basename(filename)}..."
    with tqdm(desc=title, unit='B', unit_scale=True, total=expected_mp3_size) as pbar:
        while True:
            chunk = process.stdout.read(1024 * 16)
            pbar.update(len(chunk))
            if len(chunk) == 0:
                break
            buf.write(chunk)
    process.wait()

    # Read in one-second chunks and take the loudness profile:
    with AudioFile(buf) as f:
        # Boost the bass frequencies to make energy detection easier:
        filter = PeakFilter(cutoff_frequency_hz=100, gain_db=40, q=4)
        loudness_per_second = np.zeros(int(f.duration))
        title = f"Measuring loudness of {os.path.basename(filename)}..."
        for i in tqdm(range(int(f.duration)), desc=title, total=int(f.duration)):
            loudness = np.amax(np.abs(filter(f.read(f.samplerate), f.samplerate)))
            loudness_per_second[i] = loudness

        # Normalize the loudness curve:
        loudness_per_second /= np.amax(loudness_per_second)
    return loudness_per_second


def identify_clips(filename: str) -> Iterable[tuple[float, float]]:
    loudness_per_second = measure_loudness_profile(filename)
    yield from segment_with_backwards_search(loudness_per_second)
        


def render_clip(
    filename: str,
    segment: tuple[float, float],
    output_filename: str,
    draft: bool = False
) -> list[str]:
    input = ffmpeg.input(filename, ss=segment[0])
    audio = input.audio
    if draft:
        video = input.video.filter("scale", "360x640")
        return ffmpeg.output(
            audio,
            #video,
            output_filename,
            pix_fmt='yuv420p',
            crf=35,
            t=segment[1] - segment[0],
            preset="ultrafast",
            ac=1,
            r=5,
            **{'c:v': 'libx264', 'c:a': 'aac', 'b:a': 96000}
        ).overwrite_output().get_args()
    else:
        video = input.video.filter("scale", "1080x1920")
        return ffmpeg.output(
            audio,
            video,
            output_filename,
            pix_fmt='yuv420p',
            crf=30,
            t=segment[1] - segment[0],
            ac=1,
            **{'c:v': 'libx265', 'c:a': 'aac', 'b:a': 256000, 'tag:v': "hvc1"}
        ).overwrite_output().get_args()


def main():
    parser = argparse.ArgumentParser(
        description="Automatically slice and transcode band rehearsal videos based on audio."
    )
    parser.add_argument(
        "--input-directory",
        help=(
            f"The input directory to search for .MP4 or .MTS files. "
            f"If not provided, all files matching {DEFAULT_GLOB_EXPR} will be used."
        )
    )
    parser.add_argument(
        "--output-directory",
        help="The input directory to search for .MP4 or .MTS files.",
        default=".",
    )
    parser.add_argument(
        "--ffmpeg-args",
        help="A sequence of video encoding args to pass to FFMPEG, passed as a single string.",
        default="",
    )
    parser.add_argument(
        "--draft",
        action="store_true",
        help="If passed, render low-quality outputs for testing."
    )
    parser.add_argument(
        "--run",
        action="store_true",
        help="If passed, actually call FFMPEG instead of just printing commands."
    )
    parser.add_argument(
        "--graph",
        action="store_true",
        help="If passed, graph the loudness contour that will be used."
    )
    args = parser.parse_args()

    for filename in scan_for_files(args.input_directory):
        plt.clf()
        segments = list(identify_clips(filename))
        plt.gca().xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: to_hhmmssms(x)))
        for segment in segments:
            print(
                f"Rendering {filename!r} from {to_hhmmssms(segment[0])} "
                f"to {to_hhmmssms(segment[1])}..."
            )
            plt.axvspan(segment[0], segment[1], alpha=0.25, color='red')
        plt.axhline(THRESHOLD, color="green")
        plt.legend()
        if args.graph:
            plt.show()
        for i, segment in enumerate(segments):
            output_filename = os.path.join(
                args.output_directory, f"{os.path.basename(filename)}-{i}.mp4"
            )
            command = ["ffmpeg"] + render_clip(filename, segment, output_filename, args.draft)
            if args.run:
                subprocess.Popen(command).wait()
            else:
                print(' '.join(command))


if __name__ == "__main__":
    main()