import os
import sys
import argparse
import numpy as np
from PIL import Image, ImageDraw
import cv2
import time
import json

# Make sure that caffe is on the python path:
caffe_root = '../../..'
sys.path.insert(0, os.path.join(caffe_root, 'python'))
import caffe

class CaffeDetection:
    def __init__(self, gpu_id, model_def, model_weights,
                 cascade=0, FPN=0, use_soft_nms=0):
        if gpu_id < 0:
            caffe.set_mode_cpu()
        else:
            caffe.set_device(gpu_id)
            caffe.set_mode_gpu()

        # Load the net in the test phase for inference, and configure input preprocessing.
        self.net = caffe.Net(model_def,      # defines the structure of the model
                             model_weights,  # contains the trained weights
                             caffe.TEST)     # use test mode (e.g., don't perform dropout)
        # input preprocessing: 'data' is the name of the input blob == net.inputs[0]
        #self.transformer = caffe.io.Transformer({'data': self.net.blobs['data'].data.shape})
        #self.transformer.set_transpose('data', (2, 0, 1))
        #self.transformer.set_mean('data', np.array([104, 117, 123])) # mean pixel
        ## the reference model operates on images in [0,255] range instead of [0,1]
        #self.transformer.set_raw_scale('data', 255)
        ## the reference model has channels in BGR order instead of RGB
        #self.transformer.set_channel_swap('data', (2, 1, 0))

        self.use_soft_nms = use_soft_nms > 0
        self.cascade = cascade > 0
        self.FPN = FPN > 0
        print cascade,FPN
        if not self.cascade:
            # baseline model
            if self.FPN:
                self.proposal_blob_names = ['proposals_to_all']
            else:
                self.proposal_blob_names = ['proposals']

            self.bbox_blob_names = ['output_bbox_1st']
            self.cls_prob_blob_names = ['cls_prob_1st']
            self.output_names = ['1st']
        else:
            # cascade-rcnn model
            if self.FPN:
                self.proposal_blob_names = ['proposals_to_all', 'proposals_to_all_2nd',
                                       'proposals_to_all_3rd', 'proposals_to_all_2nd', 'proposals_to_all_3rd']
            else:
                self.proposal_blob_names = ['proposals', 'proposals_2nd', 'proposals_3rd',
                                       'proposals_2nd', 'proposals_3rd']

            self.bbox_blob_names = ['output_bbox_1st', 'output_bbox_2nd', 'output_bbox_3rd',
                           'output_bbox_2nd', 'output_bbox_3rd']
            self.cls_prob_blob_names = ['cls_prob_1st', 'cls_prob_2nd', 'cls_prob_3rd',
                               'cls_prob_2nd_avg', 'cls_prob_3rd_avg']
            self.output_names = ['1st', '2nd', '3rd', '2nd_avg', '3rd_avg']

        self.num_outputs = len(self.proposal_blob_names)
        assert(self.num_outputs==len(self.bbox_blob_names))
        assert(self.num_outputs==len(self.cls_prob_blob_names))
        assert(self.num_outputs==len(self.output_names))
        # detection configuration
        #self.det_thr = 0.001 # threshold for testing
        self.det_thr = 0.3 # threshold for demo
        self.max_per_img = 100 # max number of detections
        self.nms_thresh = 0.5 # NMS
        if FPN:
            self.shortSize = 800
            self.longSize = 1312
        else:
            self.shortSize = 600
            self.longSize = 1000

        self.PIXEL_MEANS = np.array([104, 117, 123],dtype=np.uint8)
        self.num_cls = 80

    def detect(self, image_file):
        '''
        rcnn detection
        '''
        #image = caffe.io.load_image(image_file)
        image = cv2.imread(image_file) # BGR, default is cv2.IMREAD_COLOR 3-channel
        orgH, orgW, channel = image.shape
        #print("image shape:",image.shape)
        rzRatio = self.shortSize / float(min(orgH, orgW))
        imgH = min(rzRatio * orgH, self.longSize)
        imgW = min(rzRatio * orgW, self.longSize)
        imgH = round(imgH / 32) * 32
        imgW = round(imgW / 32) * 32 # must be the multiple of 32
        hwRatios = [imgH/orgH, imgW/orgW]
        #transformed_image = self.transformer.preprocess('data', image)
        #image = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
        resized_w = int(imgW)
        resized_h = int(imgH)
        #print 'resized -> ',(resized_w, resized_h)
        image = cv2.resize(image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR)
        image = image.astype('float32')-self.PIXEL_MEANS.astype('float32')
        #cv2.imwrite("transformed_image.jpg", image)
        transformed_image = np.transpose(image, (2,0,1)) # C H W

        # set net to batch size of 1
        self.net.blobs['data'].reshape(1, 3, resized_h, resized_w)

        #Run the net and examine the top_k results
        self.net.blobs['data'].data[...] = transformed_image.astype(np.float32, copy=False)

        start = time.time()
        # Forward pass.
        blobs_out = self.net.forward()
        #print('output_bbox_1st---',blobs_out['output_bbox_1st'].shape)
        end = time.time()
        cost_millis = int((end - start) * 1000)
        print "detection cost ms: ", cost_millis

        detect_final_boxes = []
        for nn in range(self.num_outputs):
            tmp = self.net.blobs[self.bbox_blob_names[nn]].data.copy() # if no need modify,then no need copy
            print(self.bbox_blob_names[nn], tmp.shape)
            #tmp = tmp.reshape((-1,5))
            tmp = tmp[:,:,0,0]
            tmp[:,1] /= hwRatios[1]
            tmp[:,3] /= hwRatios[1]
            tmp[:,2] /= hwRatios[0]
            tmp[:,4] /= hwRatios[0]

            # clipping bbs to image boarders
            tmp[:, 1] = np.maximum(0,tmp[:,1])
            tmp[:, 2] = np.maximum(0,tmp[:,2])
            tmp[:, 3] = np.minimum(orgW,tmp[:,3])
            tmp[:, 4] = np.minimum(orgH,tmp[:,4])
            tmp[:, 3] = tmp[:, 3] - tmp[:, 1] + 1 # w
            tmp[:, 4] = tmp[:, 4] - tmp[:, 2] + 1 # h

            output_bboxs = tmp[:,1:]

            tmp = self.net.blobs[self.cls_prob_blob_names[nn]].data
            #print(self.cls_prob_blob_names[nn], tmp.shape)
            cls_prob = tmp.reshape((-1,self.num_cls+1))

            tmp = self.net.blobs[self.proposal_blob_names[nn]].data.copy()
            #print(self.proposal_blob_names[nn], tmp.shape)
            tmp = tmp[:,1:]
            tmp[:, 2] = tmp[:, 2] - tmp[:, 0] + 1  # w
            tmp[:, 3] = tmp[:, 3] - tmp[:, 1] + 1  # h
            proposals = tmp
            keep_id = np.where((proposals[:, 2] > 0) & (proposals[:, 3] > 0))[0]
            proposals = proposals[keep_id,:]
            output_bboxs = output_bboxs[keep_id,:]
            cls_prob = cls_prob[keep_id,:]

            detect_boxes = []
            for i in range(self.num_cls):
                cls_id = i + 1
                prob = cls_prob[:, cls_id][:, np.newaxis] # 0 is background
                #print (output_bboxs.shape, prob.shape)
                bbset = np.hstack([output_bboxs, prob])
                if self.det_thr > 0:
                    keep_id = np.where(prob >= self.det_thr)[0]
                    bbset = bbset[keep_id,:]

                if self.use_soft_nms:
                    keep = self.cpu_soft_nms(bbset, sigma=0.5, Nt=0.30, threshold=0.01,method=1)
                else:
                    keep = self.cpu_nms_single_cls(bbset, self.nms_thresh)
                if len(keep) == 0: continue
                bbset = bbset[keep,:]
                cls_ids = np.array([cls_id] * len(bbset))[:, np.newaxis]
                #print "cls_ids.shape", cls_ids.shape, bbset.shape
                detect_boxes.extend(np.hstack([cls_ids, bbset]).tolist())
            print "detected box num: ", len(detect_boxes)
            detect_boxes = np.asarray(detect_boxes)
            if self.max_per_img > 0 and len(detect_boxes) > self.max_per_img:
                rank_scores = detect_boxes[:, 5].copy()[::-1]
                rank_scores.sort() # 'descend'
                print len(rank_scores),self.max_per_img
                print np.where(detect_boxes[:, 5] >= rank_scores[self.max_per_img])
                keep_id = np.where(detect_boxes[:, 5] >= rank_scores[self.max_per_img])[0]
                detect_boxes = detect_boxes[keep_id,:]
            detect_final_boxes.append(detect_boxes.tolist())

        return detect_final_boxes

    def cpu_nms_single_cls(self, dets, thresh):
        """Pure Python NMS baseline."""
        x1 = dets[:, 0]
        y1 = dets[:, 1]
        w = dets[:, 2]
        h = dets[:, 3]
        scores = dets[:, 4]

        x2 = x1 + w - 1
        y2 = y1 + h - 1
        areas = w * h
        order = scores.argsort()[::-1]

        keep = []
        while order.size > 0:
            i = order[0]
            keep.append(i)
            xx1 = np.maximum(x1[i], x1[order[1:]])
            yy1 = np.maximum(y1[i], y1[order[1:]])
            xx2 = np.minimum(x2[i], x2[order[1:]])
            yy2 = np.minimum(y2[i], y2[order[1:]])

            w = np.maximum(0.0, xx2 - xx1 + 1)
            h = np.maximum(0.0, yy2 - yy1 + 1)
            inter = w * h
            ovr = inter / (areas[i] + areas[order[1:]] - inter)

            inds = np.where(ovr <= thresh)[0]
            order = order[inds + 1]

        return keep

    def cpu_soft_nms(self, boxes, sigma=0.5, Nt=0.3, threshold=0.001, method=0):
        N = boxes.shape[0]
        pos = 0
        maxscore = 0
        maxpos = 0

        for i in range(N):
            maxscore = boxes[i, 4]
            maxpos = i

            tx1 = boxes[i,0]
            ty1 = boxes[i,1]
            tx2 = tx1 + boxes[i,2] - 1
            ty2 = ty1 + boxes[i,3] - 1
            ts = boxes[i,4]

            pos = i + 1
            # get max box
            while pos < N:
                if maxscore < boxes[pos, 4]:
                    maxscore = boxes[pos, 4]
                    maxpos = pos
                pos = pos + 1

            # add max box as a detection 
            boxes[i,0] = boxes[maxpos,0]
            boxes[i,1] = boxes[maxpos,1]
            boxes[i,2] = boxes[maxpos,2]
            boxes[i,3] = boxes[maxpos,3]
            boxes[i,4] = boxes[maxpos,4]

            # swap ith box with position of max box
            boxes[maxpos,0] = tx1
            boxes[maxpos,1] = ty1
            boxes[maxpos,2] = tx2
            boxes[maxpos,3] = ty2
            boxes[maxpos,4] = ts

            tx1 = boxes[i,0]
            ty1 = boxes[i,1]
            tx2 = boxes[i,2]
            ty2 = boxes[i,3]
            ts = boxes[i,4]

            pos = i + 1
            # NMS iterations, note that N changes if detection boxes fall below threshold
            while pos < N:
                x1 = boxes[pos, 0]
                y1 = boxes[pos, 1]
                x2 = boxes[pos, 2]
                y2 = boxes[pos, 3]
                s = boxes[pos, 4]

                area = (x2 - x1 + 1) * (y2 - y1 + 1)
                iw = (min(tx2, x2) - max(tx1, x1) + 1)
                if iw > 0:
                    ih = (min(ty2, y2) - max(ty1, y1) + 1)
                    if ih > 0:
                        ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
                        ov = iw * ih / ua #iou between max box and detection box

                        if method == 1: # linear
                            if ov > Nt: 
                                weight = 1 - ov
                            else:
                                weight = 1
                        elif method == 2: # gaussian
                            weight = np.exp(-(ov * ov)/sigma)
                        else: # original NMS
                            if ov > Nt: 
                                weight = 0
                            else:
                                weight = 1

                        boxes[pos, 4] = weight*boxes[pos, 4]
                
                        # if box score falls below threshold, discard the box by swapping with last box
                        # update N
                        if boxes[pos, 4] < threshold:
                            boxes[pos,0] = boxes[N-1, 0]
                            boxes[pos,1] = boxes[N-1, 1]
                            boxes[pos,2] = boxes[N-1, 2]
                            boxes[pos,3] = boxes[N-1, 3]
                            boxes[pos,4] = boxes[N-1, 4]
                            N = N - 1
                            pos = pos - 1

                pos = pos + 1

        keep = [i for i in range(N)]
        return keep

def bbox2cocoVec(image_id, results, catIds):
    bbox_list = []
    for item in results[len(results)-1]:# the 3rd_avg result
        xmin = round(item[1])
        ymin = round(item[2])
        cls_id = int(item[0])
        cat_id = catIds[cls_id]
        score = item[5]
        bbox = [image_id, xmin, ymin, item[3], item[4], score, cat_id]
        bbox_list.append(bbox)
    return bbox_list

def demo(args):
    
    detection = CaffeDetection(args.gpu_id,
                               args.model_def, args.model_weights,
                               cascade=args.cascade, FPN=args.FPN)
    results = detection.detect(args.image_file)

    img = Image.open(args.image_file)
    draw = ImageDraw.Draw(img)
    width, height = img.size
    for item in results[len(results)-1]:# the 3rd_avg result
        xmin = int(round(item[1]))
        ymin = int(round(item[2]))
        xmax = int(round(item[1] + item[3] - 1))
        ymax = int(round(item[2] + item[4] - 1))
        cls_id = int(item[0])
        draw.rectangle([xmin, ymin, xmax, ymax], outline=(255, 0, 0))
        draw.text([xmin, ymin], str(cls_id), (0, 0, 255))
        print [cls_id, xmin, ymin, xmax, ymax, round(item[-1]*1000)/1000]

    img.save('detect_result.jpg')

def test_coco(args):
    # local import
    from pycocotools.coco import COCO
    coco_catIds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 
              23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
              46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
              65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
    dataDir='coco'
    dataType='val2017'
    annFile='{}/annotations/instances_{}.json'.format(dataDir,dataType)
    image_base_path = '{}/images/{}'.format(dataDir,dataType)

    # initialize COCO api for instance annotations
    cocoGT = COCO(annFile)
    imgIds = cocoGT.getImgIds()

    # model define
    detection = CaffeDetection(args.gpu_id,
                               args.model_def, args.model_weights,
                               cascade=args.cascade, FPN=args.FPN)

    res_list = []
    i = 0
    for imgId in sorted(imgIds):
        img = cocoGT.loadImgs(imgId)[0]
        img_name = img['file_name']
        img_path = image_base_path + '/' + img_name

        # inference
        results = detection.detect(img_path)
        res_list = res_list + bbox2cocoVec(imgId, results, coco_catIds)
        i = i + 1
        if i % 100 == 0:
            print('--------------- ' + str(i) + ' ---------------')
        #if i ==10:
        #    break
    with open(args.out_file, 'w') as f:
        json.dump(cocoGT.loadNumpyAnnotations(np.asarray(res_list)), f)

def parse_args():
    '''parse args'''
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
    parser.add_argument('--model_def',
                        default='models/deploy.prototxt')
    parser.add_argument('--cascade', default=0, type=int)
    parser.add_argument('--FPN', default=0, type=int)
    parser.add_argument('--model_weights',
                        default='models/models_iter_120000.caffemodel')
    parser.add_argument('--image_file', default='')
    parser.add_argument('--out_file', default='cascadercnn_coco_result.json')
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    if args.image_file != '':
        demo(args)
    else:
        test_coco(args)