#    This file is part of EAP.
#
#    EAP is free software: you can redistribute it and/or modify
#    it under the terms of the GNU Lesser General Public License as
#    published by the Free Software Foundation, either version 3 of
#    the License, or (at your option) any later version.
#
#    EAP is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#    GNU Lesser General Public License for more details.
#
#    You should have received a copy of the GNU Lesser General Public
#    License along with EAP. If not, see <http://www.gnu.org/licenses/>.

import operator
import math
import random

from deap import algorithms
from deap import base
from deap import creator
from deap import tools
from deap import gp

# Define new functions
def safeDiv(left, right):
    try:
        return left / right
    except ZeroDivisionError:
        return 0

pset = gp.PrimitiveSet("MAIN", 1)
pset.addPrimitive(operator.add, 2)
pset.addPrimitive(operator.sub, 2)
pset.addPrimitive(operator.mul, 2)
pset.addPrimitive(safeDiv, 2)
pset.addPrimitive(operator.neg, 1)
pset.addPrimitive(math.cos, 1)
pset.addPrimitive(math.sin, 1)
pset.addEphemeralConstant(lambda: random.randint(-1,1))
pset.renameArguments({"ARG0" : "x"})

creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin, pset=pset)

toolbox = base.Toolbox()
toolbox.register("expr", gp.genRamped, pset=pset, min_=1, max_=2)
toolbox.register("individual", tools.initIterate, creator.Individual, toolbox.expr)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("lambdify", gp.lambdify, pset=pset)

def evalSymbReg(individual):
    # Transform the tree expression in a callable function
    func = toolbox.lambdify(expr=individual)
    # Evaluate the sum of squared difference between the expression
    # and the real function : x**4 + x**3 + x**2 + x
    values = (x/10. for x in xrange(-10,10))
    diff_func = lambda x: (func(x)-(x**4 + x**3 + x**2 + x))**2
    diff = sum(map(diff_func, values))
    return diff,

def selDoubleTournament(individuals, k, fitTournSize, sizeTournSize):
    # Implements a do-size-first double tournament
    
    # A nested function which will be used for the size tournament
    def _sizeTournament(individuals, tournamentSize):
        chosen = []
        
        # Randomly select two individuals
        aspirant1 = random.choice(individuals)
        aspirant2 = random.choice(individuals)
        s1, s2 = aspirant1.size, aspirant2.size
        
        # If size1 < size2 then aspirant1 is selected with a probability of tournamentSize/2
        if s1 < s2:
            return aspirant1 if random.random() < tournamentSize / 2. else aspirant2
        elif s1 > s2:
            return aspirant2 if random.random() < tournamentSize / 2. else aspirant1
        else:
            return random.choice([aspirant1, aspirant2])
        
    chosen = []
    # While we have not selected k individuals... 
    for i in xrange(k):
        # We select the first participant (which is the winner of a size tournament)
        chosen.append(_sizeTournament(individuals, sizeTournSize))
        for j in xrange(fitTournSize - 1):
            # We select another participant from size tournament; if its fitness
            # is better than the fitness of the individual currently selected
            # (the aspirant), then it becomes the new aspirant
            aspirant = _sizeTournament(individuals, sizeTournSize)
            if aspirant.fitness > chosen[i].fitness:
                chosen[i] = aspirant
    # Return the list of selected individuals
    return chosen
    

    
    
def staticLimitCrossover(ind1, ind2, heightLimit):
    # Store a backup of the original individuals
    keepInd1, keepInd2 = toolbox.clone(ind1), toolbox.clone(ind2)
    
    # Mate the two individuals
    # If using STGP (like spambase), replace this line by gp.cxTypedOnePoint(ind1, ind2)
    gp.cxUniformOnePoint(ind1, ind2)
    
    # If a child is higher than the maximum allowed, then
    # it is replaced by one of its parent
    if ind1.height > heightLimit:
        ind1[:] = keepInd1
    if ind2.height > heightLimit:
        ind2[:] = keepInd2
    
def staticLimitMutation(individual, expr, heightLimit):
    # Store a backup of the original individual
    keepInd = toolbox.clone(individual)
    
    # Mutate the individual
    # If using STGP (like spambase), replace this line by gp.mutTypedUniform(individual,expr)
    gp.mutUniform(individual, expr)
    
    # If the mutation set the individual higher than the maximum allowed,
    # return the original individual
    if individual.height > heightLimit:
        individual[:] = keepInd

toolbox.register("evaluate", evalSymbReg)
toolbox.register("select", selDoubleTournament, fitTournSize=5, sizeTournSize=1.4)
toolbox.register("mate", staticLimitCrossover, heightLimit=17)
toolbox.register("expr_mut", gp.genFull, min_=0, max_=2)
toolbox.register('mutate', staticLimitMutation, expr=toolbox.expr_mut, heightLimit=17)


def main():
    random.seed(318)

    pop = toolbox.population(n=300)
    hof = tools.HallOfFame(1)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("Avg", tools.mean)
    stats.register("Std", tools.std)
    stats.register("Min", min)
    stats.register("Max", max)
    algorithms.eaSimple(toolbox, pop, 0.5, 0.1, 40, stats, halloffame=hof)

    return pop, stats, hof

if __name__ == "__main__":
    main()