#!/usr/bin/env python
import sys
import os
import errno
from argparse import ArgumentParser
from collections import defaultdict
import pickle
import logging
import thread
import threading
from Queue import Queue, Empty
import time
import random
import json

from swift.common.storage_policy import POLICIES
from swift.common.ring import Ring
from swift.obj.diskfile import get_async_dir
from swift.common.utils import RateLimitedIterator, split_path

# fix monkey-patch lp bug #1380815
logging.threading = threading
logging.thread = thread
logging._lock = threading.RLock()


parser = ArgumentParser()
parser.add_argument('devices', help='root of devices tree for node',
                    nargs='*', default=['/srv/node'])
parser.add_argument('--policy-index', help='the policy index',
                    type=int, default=0)
parser.add_argument('--limit', help='max number of asyncs to check per disk',
                    default=None, type=int)
parser.add_argument('--updates-per-second', default=250.0, type=float,
                    help='max number of asyncs to check per second')
parser.add_argument('--top-stats', help='display N top account & container',
                    default=10, type=int)
parser.add_argument('--workers', help='number of workers', type=int,
                    default=24, )
parser.add_argument('--verbose', help='log at debug', action='store_true')
parser.add_argument('--swift-dir', help='y u no use /etc/swift',
                    default='/etc/swift')
parser.add_argument('--json', action='store_true', help='dump raw json stats')


class AtomicStats(object):

    def __init__(self):
        self.stats = defaultdict(int)
        self.lock = threading.RLock()

    def incr(self, key, amount=1):
        with self.lock:
            self.stats[key] += amount

    def __iter__(self):
        return iter(self.stats.items())


STATS = AtomicStats()


def handle_update(update_path, container_ring, args):
    with open(update_path) as f:
        update_data = pickle.load(f)
    if args.verbose:
        logging.debug('Found %s\n%s' % (update_path, json.dumps(
            update_data, indent=2)))
    num_success = len(update_data.get('successes', []))
    container_path = update_data.get('container_path')
    if container_path:
        account, container = split_path('/' + container_path, minsegs=2)
    else:
        account, container = \
            update_data['account'], update_data['container']
    _part, nodes = container_ring.get_nodes(account, container)
    bad_devs = [n['device'] for n in nodes
                if n['id'] not in update_data.get('successes', [])]
    if len(bad_devs) == 1:
        logging.debug('Notice %r waiting on update to %s',
                      update_path, ','.join(bad_devs))
    return {
        'op': update_data['op'],
        'account': account,
        'container': container,
        'num_success': num_success,
        'bad_devs': bad_devs,
    }


def consumer(q, args, ring):
    while True:
        update_path = q.get()
        if update_path is None:
            return
        STATS.incr('count')
        update_data = handle_update(update_path, ring, args)
        update_stats(STATS, update_data)


def update_stats(stats, update):
    stats.incr('op_%s' % update['op'])
    stats.incr('acct_%s' % update['account'])
    key = 'cont_%s/%s' % (update['account'], update['container'])
    stats.incr(key)
    key = 'success_%s' % update['num_success']
    stats.incr(key)
    for dev in update['bad_devs']:
        key = 'dev_%s' % dev
        stats.incr(key)


def _display_stats(stats, args):
    accounts = []
    containers = []
    success_counts = []
    ops = []
    devs = []
    logging.info('=' * 50)
    for k, v in stats:
        if k.startswith('acct_'):
            accounts.append((v, k[5:]))
        elif k.startswith('cont_'):
            containers.append((v, k[5:]))
        elif k.startswith('success_'):
            success_counts.append((k, v))
        elif k.startswith('op_'):
            ops.append((k[3:], v))
        elif k.startswith('dev_'):
            devs.append((v, k[4:]))
        else:
            logging.info('%-9s: %s', k, v)
    for k, v in ops:
        logging.info('%-9s: %s' % (k, v))
    success_counts.sort()
    for k, v in success_counts:
        logging.info('%s: %s', k, v)
    logging.info('-' * 50)
    accounts.sort(reverse=True)
    for v, k in accounts[:args.top_stats]:
        logging.info('%s: %s', k, v)
    containers.sort(reverse=True)
    for v, k in containers[:args.top_stats]:
        logging.info('%s: %s', k, v)
    devs.sort(reverse=True)
    for v, k in devs[:args.top_stats]:
        logging.info('%s: %s', k, v)


def display_stats(q, args):
    while True:
        try:
            q.get(block=False)
        except Empty:
            _display_stats(STATS, args)
            time.sleep(1.0)
        else:
            return


def iter_update_paths(device_path, args):
    policy = POLICIES[args.policy_index]
    asyncdir = get_async_dir(policy)

    num_updates = 0
    async_path = os.path.join(device_path, asyncdir)
    try:
        suffixes = os.listdir(async_path)
    except OSError as e:
        if e.errno == errno.ENOENT:
            return
        else:
            raise
    random.shuffle(suffixes)
    for suffix in suffixes:
        try:
            int(suffix, 16)
        except ValueError:
            continue
        suffix_path = os.path.join(async_path, suffix)
        updates = os.listdir(suffix_path)
        random.shuffle(updates)
        for update in updates:
            num_updates += 1
            if args.limit and num_updates >= args.limit:
                return
            update_path = os.path.join(suffix_path, update)
            yield update_path


def feed_queue(q, device_dir, args):
    update_path_iter = iter_update_paths(device_dir, args)
    for update_path in RateLimitedIterator(
            update_path_iter, args.updates_per_second):
        q.put(update_path)


def main():
    args = parser.parse_args()

    if args.verbose:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logging.basicConfig(level=level)

    container_ring = Ring(os.path.join(args.swift_dir, 'container.ring.gz'))

    stats_kill_q = Queue(1)
    stats_worker = threading.Thread(target=display_stats, args=(
        stats_kill_q, args))
    stats_worker.start()

    q = Queue(1000)
    workers = []
    feeders = []
    try:
        for i in range(args.workers):
            t = threading.Thread(target=consumer, args=(
                q, args, container_ring))
            t.start()
            workers.append(t)
        for device_root in args.devices:
            device_dirs = os.listdir(device_root)
            for device_dir in device_dirs:
                device_path = os.path.join(device_root, device_dir)
                u = threading.Thread(target=feed_queue, args=(
                    q, device_path, args))
                u.start()
                feeders.append(u)
        for u in feeders:
            u.join()
    finally:
        logging.info('queue finished')
        for t in workers:
            q.put(None)
        for t in workers:
            t.join()
    logging.info('workers finished')

    stats_kill_q.put(None)
    stats_worker.join()
    if args.json:
        json.dump(STATS.stats, sys.stdout)
    else:
        _display_stats(STATS, args)


if __name__ == "__main__":
    sys.exit(main())