# test previous algorithm

actuals = pd.read_csv("https://gist.githubusercontent.com/csaid/a57c4ebaa1c7b0671cdc9692638ea4c4/raw/ad1709938834d7bc88b62ff0763733502eb6a329/shower_problem_tau_samples.csv")


DELTA = 0.1

def survival_function(t, lambda_=50., rho=1.5):
    # Assume simple Weibull model
    return np.exp(-(t/lambda_) ** rho)


def w(t1, t2):
    # equal to Pr(X = t1)
    return survival_function(t1) / (survival_function(t1) + survival_function(t2))


def determine_best_action(current_position, t1, t2):
    p1 = w(t1, t2) * (1-survival_function(t1 + DELTA) / survival_function(t1))
    p2 = (1-w(t1, t2)) * (1-survival_function(t2 + DELTA) / survival_function(t2))

    if current_position == 1:
        if p1 > p2/max(t2, 1):
            return 1
        else:
            return 2
    else:
        if p1/max(t1, 1) > p2:
            return 1
        else:
            return 2



def minimum_time_needed(actual_direction, actual_tau):
    
    explored = [0.00, 0.00] 
    time = 0.00
    
    # choose 1 initially
    current_position = 1
    explored[current_position-1] += DELTA

    
    while True:
        previous_position = current_position
        choice = determine_best_action(current_position, *explored)
        if choice == 1:
            current_position = 1
        else:
            current_position = 2
        
        explored[current_position-1] += DELTA

        if previous_position != current_position:
            # skip ahead to new region
            time += explored[current_position-1]

        time += DELTA

        if explored[int(actual_direction)] >= actual_tau:
            return time

actuals['time_spent'] = actuals.apply(lambda s: minimum_time_needed(s['direction'], s['tau']) , axis=1)
actuals['time_spent'].mean()