#!/usr/bin/python3

# EA Spend - Effective Altruism spending calculator.
# Copyright (C) 2016-2017 Gordon Irlam
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

'''
Description:

    EA Spend computes the optimal variable rate of consumption for a
    portfolio growing at a variable rate in the presence of uncertain
    life expectancy. It optimizes utility from direct donations, from
    an estate, as a result of gaining knowledge by learning over time,
    as a result of experience based on prior donations, and self
    satisfaction.

    In addition to a textual, summary the script produces a number of
    output files that describe the scenario and the optimal strategy,
    and which are suitable for plotting with GNUPlot or other plotting
    programs.  The most important of these files is
    spend-<scenario>-consume_pct.csv.  The values in this file are
    age, mean donation pct, median donation pct, 2.5 percentile
    donation pct, and 97.5 percentile donation pct, all as a
    percentage of the total mean donation.

Version:

    0.3.

Performance:

    The script is computationally intensive.  A majority of the
    scenarios compute the optimal strategy in the presence of
    experience gained based on prior donations. This takes roughly
    consume_steps times longer than when experience isn't
    present. consume_steps is typically in the range 100-500.

    It is recommended that this script be run using the pypy Python
    just in time compiler for an approcimately 9 fold performance gain
    over the standard cpython.

    Memory demands are significant, but in no way prohibitive. Running
    multiple scenarios only requires having sufficient RAM for the
    largest scenario.

    The time taken to run a single scenario is shown below.

        EC2 instance  wealth steps  consume steps  returns  scenario   interpretor     time     memory

          c4.large        500            500          10    knowledge     PyPy       8 minutes  700 Mb
          c4.large        500            500          10    knowledge  cpython 3.x  78 minutes  400 Mb

    The script currently comprises approximately 20 scenarios.
'''

from __future__ import division, print_function

from math import ceil, exp, floor, isnan, log, pi, sqrt
from random import lognormvariate, seed

trace = False # Set to true to display progress.

accurate_run = True # Set to false for a quick run.

if accurate_run:
    wealth_steps = 500 # Wealth grid for excellent consume_pct plots.
    consume_steps = 200 # Consume prev grid for excellent consume_pct plots. (Setting to 500 produces bad consume_pct plot for all_uncorr_10000; not sure why).
    num_evaluate = 100000 # Number of random sequences to generate when evaluating map.
else:
    wealth_steps = 100
    consume_steps = 100
    num_evaluate = 20000

year_steps = 1 # Number of period steps per year to use.
wealth_alone_steps = 2000 # Wealth grid when no consume_prev grid for excellent consume_pct plots.
num_returns = 10 # Number of representative return values to consider.
out_of_range = 'extrapolate' # How to handle out of range map lookups: 'truncate', 'extrapolate', or 'exception'.

def mean(s): return sum(s) / len(s)
def pctl(p, s): return sorted(s)[int(round(p * (len(s) - 1)))]
def median(s): return pctl(0.5, s)
def stdev(s):
    avg = mean(s)
    var = tuple((x - avg) ** 2 for x in s)
    return sqrt(mean(var))

class ReturnsLogNormal(object):
    '''
    We would like to simply call scipy.stats.lognorm.pdf, cdf, and ppf().

    Unfortunately pypy can't use scipy. Hence we need to compute it.
    '''

    def pdf(self, x):
        return exp(- (log(x) - self.mu) ** 2 / (2 * self.sigma ** 2)) / (x * self.sigma * sqrt(2 * pi))

    def density(self, x, year):
        return exp(- (log(x) - self.mu * year) ** 2 / (2 * year * self.sigma ** 2)) / (x * sqrt(year) * self.sigma * sqrt(2 * pi))

    def phi(self, x):
        # constants
        a1 =  0.254829592
        a2 = -0.284496736
        a3 =  1.421413741
        a4 = -1.453152027
        a5 =  1.061405429
        p  =  0.3275911

        # Save the sign of x
        sign = 1
        if x < 0:
            sign = -1
        x = abs(x) / sqrt(2)

        # A&S formula 7.1.26
        t = 1 / (1 + p * x)
        y = 1 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * exp(- x * x)

        return 0.5 * (1 + sign * y)

    def cdf(self, x):

        return self.phi((log(x) - self.mu) / self.sigma)

    # def cummulative_density(self, x, year):
    #
    #     return self.phi((log(x) - self.mu * year) / (sqrt(year) * self.sigma))

    def normal_cdf_inverse(self, p):

        def rational_approximation(t):

            # Abramowitz and Stegun formula 26.2.23.
            # The absolute value of the error should be less than 4.5e-4.
            c = [2.515517, 0.802853, 0.010328]
            d = [1.432788, 0.189269, 0.001308]
            numerator = (c[2] * t + c[1]) * t + c[0]
            denominator = ((d[2] * t + d[1]) * t + d[0]) * t + 1
            return t - numerator / denominator

        assert(0 < p < 1)

        # See article above for explanation of this section.
        if p < 0.5:
            # F^-1(p) = - G^-1(p)
            return - rational_approximation(sqrt(-2 * log(p)))
        else:
            # F^-1(p) = G^-1(1-p)
            return rational_approximation(sqrt(-2 * log(1 - p)))

    def cdf_inverse(self, p, year):
        return exp(self.mu + self.sigma * self.normal_cdf_inverse(p))

    def multiplicative_quantile(self, p, year):
        return exp(self.mu * year + self.sigma * sqrt(year) * self.normal_cdf_inverse(p))

    def __init__(self, am = 1.065, sd = 0.222, min_cdf = 1e-3):
        self.gm = am / sqrt(1 + (sd / am) ** 2)
        self.mu = log(self.gm)
        self.sigma = sqrt(log(1 + (sd / am) ** 2))
        if self.sigma == 0 or num_returns == 1:
            self.weight = (1, )
            self.sample = (self.gm ** (1 / year_steps), )
        else:
            lo = self.multiplicative_quantile(min_cdf, 1 / year_steps)
            hi = self.multiplicative_quantile(1 - min_cdf, 1 / year_steps)
            self.weight = []
            self.sample = []
            step = (hi - lo) / (num_returns - 1)
            for i in range(num_returns):
                x = lo + i * step
                self.weight.append(self.density(x, 1 / year_steps) * step)
                self.sample.append(x)
            # Commented out since for high sigma, lo - step / 2, may be less than zero.
            #self.weight[0] += self.cummulative_density(lo - step / 2, 1 / year_steps)
            #self.weight[-1] += 1 - self.cummulative_density(hi + step / 2, 1 / year_steps)
            sum_weight = sum(self.weight)
            self.weight = tuple(w / sum_weight for w in self.weight)
            self.sample = tuple(self.sample)

    def random(self):
        if self.sigma == 0:
            return self.gm ** (1 / year_steps) # Speedup.
        else:
            return lognormvariate(self.mu / year_steps, self.sigma / sqrt(year_steps))

class ReturnsNone(ReturnsLogNormal):

    def __init__(self):
        super(ReturnsNone, self).__init__(am = 1, sd = 0)

class Life(object):

    pass

class LifeFixed(Life):

    def __init__(self, dead = 100):

        self.dead = dead

    def q(self, age):

        if round(age * year_steps) < self.dead * year_steps - 1:
            return 0
        else:
            return 1

class LifeGM(Life):

    '''
    Gompertz-Makeham mortality.

    Adjust m to change Gompertz-Makeham life expectancy.

    b from http://www.demographic-research.org/volumes/vol32/36/32-36.pdf table 5 average of USA male/female values.

    m chosen to match average male/female cohort mortality from Social Security Administration Acuarial Study No. 120 for 25 year old in 2016 of 80.7/84.6
    computed using https://www.aacalc.com/calculators/le .
    '''

    def __init__(self, alpha = 0, m = 88.4, b = 10.6):

        self.alpha = alpha
        self.m = m
        self.b = b

    def q(self, age):

        return 1 - (1 - max(0, min(self.alpha + exp((age - self.m) / self.b) / self.b, 1))) ** (1 / year_steps)

class UtilityBase(object):

    pass

class UtilityCrra(UtilityBase):

    def __init__(self, weight = 0.1, c_uncorr = 10000, magnify = 1, gamma = 2, r = 0):

        self.weight = weight
        self.c_uncorr = c_uncorr
        self.magnify = magnify
        self.gamma = gamma
        self.r = r

        self.linear = (gamma == 0)

    def create_function(self, year, consume_max):

        factor = self.weight / (self.magnify * (1 + self.r) ** year)
        zero = 0

        def u_linear(c):

            return factor * c

        def u_log(c):

            return factor * log(self.c_uncorr + self.magnify * c) - zero

        def u(c):

            try:
                return factor * (self.c_uncorr + self.magnify * c) ** (1 - self.gamma) / (1 - self.gamma) - zero
            except OverflowError:
                raise Exception('Floating point range exceeded. Try reducing gamma.')

        if self.gamma == 0:
            f = u_linear
        elif self.gamma == 1:
            f = u_log
        else:
            f = u

        zero = f(0)

        return f

class UtilityDiscount(UtilityCrra):

    def __init__(self, weight = 1, c_uncorr = 100000, gamma = 2, r = 0.03):

        super(UtilityDiscount, self).__init__(weight = weight, c_uncorr = c_uncorr, gamma = gamma, r = r)

class UtilityLinear(UtilityDiscount):

    def __init__(self, weight = 1, r = 0.03):

        super(UtilityLinear, self).__init__(weight = weight, gamma = 0, r = r)

class UtilityNone(UtilityLinear):

    def __init__(self):

        super(UtilityNone, self).__init__(weight = 0)

class KnowledgeFactor(object):

    def __init__(self):

        self.knowledge_factor_hash = {}

    def create_function(self, year):

        if year not in self.knowledge_factor_hash:
            k = 1
            for i in range(int(year)):
                k *= self.knowledge_factor(i)
            k *= self.knowledge_factor(year) ** (year % 1)
            self.knowledge_factor_hash[year] = k

        def knowledge():

            return self.knowledge_factor_hash[year]

        return knowledge

class KnowledgeFactorDecay(KnowledgeFactor):

    def __init__(self, initial = 1.2, half_life = 5):

        super(KnowledgeFactorDecay, self).__init__()

        self.initial = initial
        self.half_life = half_life

    def knowledge_factor(self, year):

        return 1 + (self.initial - 1) * exp(- year / self.half_life * log(2))

class KnowledgeNone(KnowledgeFactorDecay):

    def __init__(self):

        super(KnowledgeNone, self).__init__(initial = 1)

class ExperiencePower(object):

    '''
    Common sense dictates that we want utility(c) * experience(c, c_prev) to be monotone in c.

    This is true for this class if utility is linear, but is not guaranteed to work if utility is not.
    '''

    def __init__(self, min_factor = 0, max_factor = 1.2, capacity_increase = 0.3, power = -0.2):

        self.min_factor = min_factor
        self.max_factor = max_factor
        self.power = power

        assert(min_factor <= max_factor)
        assert(power <= 0) # No increasing returns to scale. Binary search would break.

        self.none = (min_factor == max_factor == 1)
        self.ratio = True

        # Binary search to find self.growth so that experience(1 + capacity_increase, 1) = 1.
        lo = 0
        hi = 1 + capacity_increase
        for _ in range(100):
            mid = (lo + hi) / 2
            self.growth = mid ** (1 / year_steps)
            try:
                self.c_join = self.growth * ((max_factor - min_factor) / (1 - min_factor)) ** (1 / power)
            except ZeroDivisionError:
                self.c_join = 0
                self.join_diff = 0
            else:
                u1_join = max_factor * self.c_join
                try:
                    u2_join = ((max_factor - min_factor) / (1 + power) + min_factor) * self.c_join
                except ZeroDivisionError:
                    u2_join = ((max_factor - min_factor) * log(self.c_join) + min_factor) * self.c_join
                self.join_diff = u2_join - u1_join
            if self.create_function(0)((1 + capacity_increase) ** (1 / year_steps), 1) >= 1:
                hi = mid
            else:
                lo = mid

    def create_function(self, year):

        def experience_power(c, c_prev):

            try:
                c_ratio = c / c_prev
            except ZeroDivisionError:
                return self.min_factor
            try:
                c_ratio_power = (self.growth / c_ratio) ** - self.power # c = 0: max_factor; power = 0: 1.
            except ZeroDivisionError:
                return self.max_factor
            if (1 - self.min_factor) * (c_ratio_power - 1) >= (self.max_factor - 1):
                return self.max_factor
            else:
                return 1 + (1 - self.min_factor) * (c_ratio_power / (1 + self.power) - 1) - self.join_diff / c_ratio

        def experience_log(c, c_prev):

            try:
                c_ratio = c / c_prev
            except ZeroDivisionError:
                return self.min_factor
            try:
                c_ratio_power = (self.growth / c_ratio) ** - self.power # c = 0: max_factor
            except ZeroDivisionError:
                return self.max_factor
            if (1 - self.min_factor) * (c_ratio_power - 1) >= (self.max_factor - 1):
                return self.max_factor
            else:
                return 1 + (1 - self.min_factor) * (c_ratio_power * log(c_ratio) - 1) - self.join_diff / c_ratio

        if self.power == -1:
            return experience_log
        else:
            return experience_power

class ExperienceNone(ExperiencePower):

    def __init__(self):

        super(ExperienceNone, self).__init__(min_factor = 1, max_factor = 1)

class LookupUnknown(Exception):
    pass

class Element:

    def total_utility(self, consume_fraction):
        '''Compute utility of consuming consume_fraction in this interval and the optimal consumption levels in the remaining intervals.'''

        consume = consume_fraction * self.wealth
        wealth = self.wealth - consume
        u_current = self.current_period.utilities.u(wealth, consume, self.consume_prev)
        u_next = 0
        if self.next_period:
            returns = self.current_period.map.returns
            for i in range(len(returns.sample)):
                wealth_next = wealth * returns.sample[i]
                u_sample = self.next_period.lookup(wealth_next, consume, 'utility')
                u_next += returns.weight[i] * u_sample
        u = u_current + u_next

        return u

    # Because we are limited to interpolating on a finite grid total_utility() isn't guaranteed to be monotone.
    # In particular total_utility() contains a very small amount of noise, and is also very flat near the maximum.
    # This makes optimization difficult. We make sure search_tolerance and search_delta are large so as to avoid getting stuck needlessly.
    # Changing search tolerance from 1e-4 to 1e-3 results in a 1 - 3e-6 factor decline in expected utility, which is more than acceptable.
    # However, we set them to 1e-4 since this produces smoother donation amount plots.

    search_tolerance = 1e-4
    search_expand = True
    search_delta = 0.49 * search_tolerance # Ensure expand search meets search tolerance if hint is best.
    search_expansion = 2

    def search(self, f, hint):
        '''Possibly starting at hint search for maximum of f on [0, 1].'''

        if self.search_expand:

            # Expand search around hint.
            f_hint = f(hint)
            a_prev = hint
            f_a_prev = f_hint
            b = None
            delta = self.search_delta
            while True:
                a = max(0, hint - delta)
                if hint == 0:
                    f_a = f_hint
                else:
                    f_a = f(a)
                if a == 0 or f_a < f_a_prev:
                    break
                b = a_prev
                f_b = f_a_prev
                a_prev = a
                f_a_prev = f_a
                delta *= self.search_expansion
            if b is None:
                b_prev = hint
                f_b_prev = f_hint
                delta = self.search_delta
                while True:
                    b = min(1, hint + delta)
                    if hint == 1:
                        f_b = f_hint
                    else:
                        f_b = f(b)
                    if f_b <= f_b_prev:
                        break
                    a = b_prev
                    f_a = f_b_prev
                    b_prev = b
                    f_b_prev = f_b
                    delta *= self.search_expansion

        else:

            a = 0
            f_a = f(a)
            b = 1
            f_b = f(b)

        # Golden segment search for maximum of f on [a, b].
        g = (1 + sqrt(5)) / 2
        while (b - a) >= self.search_tolerance:
            c = b - (b - a) / g
            d = a + (b - a) / g
            f_c = f(c)
            f_d = f(d)
            if f_c >= f_d:
                b = d
                f_b = f_d
            else:
                a = c
                f_a = f_c
        if self.search_expand and f_hint >= max(f_a, f_b):
            found = hint
            f_found = f_hint
        elif f_a > f_b:
            found = a
            f_found = f_a
        else:
            found = b
            f_found = f_b

        return found, f_found

    def __init__(self, wealth, consume_prev, current_period, next_period, hint):

        self.wealth = wealth
        self.consume_prev = consume_prev
        self.current_period = current_period
        self.next_period = next_period

        self.check(hint, False)

    def check(self, hint, recheck = True):

        if self.current_period.utilities.alive > 0:

            try:

                if not recheck or abs(hint - self.consume_fraction) >= self.search_tolerance and self.total_utility(hint) > self.utility:

                    def f(cf):
                        return self.total_utility(cf)

                    report_recheck = False
                    if recheck and report_recheck:
                        with open(self.current_period.map.prefix + self.current_period.map.name + '-utility-' +
                                  str(self.current_period.age) + '-' + str(self.wealth) + '-' + str(self.consume_prev) + '.csv',
                                  'w') as out:
                            for i in range(1000 + 1):
                                consume_fraction = i / 1000
                                out.write('%g,%g\n' % (consume_fraction, f(consume_fraction)))

                    self.consume_fraction, self.utility = self.search(f, hint)

                    return True

            except LookupUnknown:

                self.utility = self.consume_fraction = float('nan')

        else:

            self.utility = 0
            self.consume_fraction = 1

        return False

class Utilities:

    def __init__(self, year, alive, q, consume_max, estate_years, estate_capacity_increase, donate, estate, satisfaction, knowledge, experience):

        self.year = year
        self.alive = alive
        self.q = q
        self.estate_years = estate_years
        self.estate_capacity_increase = estate_capacity_increase
        self.donate = donate.create_function(year, consume_max)
        self.estate = estate.create_function(year, consume_max)
        self.satisfaction = satisfaction.create_function(year, consume_max)
        self.knowledge = knowledge.create_function(year)()
        self.experience = experience.create_function(year)

        self.consume_prev = not experience.none
        self.experience_ratio = experience.ratio
        if self.experience_ratio:
            self.experience_estate_init = self.experience(1, 0)
            self.experience_estate = self.experience(1 + self.estate_capacity_increase, 1)
        self.estate_linear = estate.linear
        if self.estate_linear:
            self.estate_factor =self.estate(1)

    def u(self, wealth, c, c_prev, return_components = False):

        c *= year_steps
        c_prev *= year_steps
        weight = self.alive / year_steps
        knowledge = self.knowledge
        experience_donate = self.experience(c, c_prev)
        donate = weight * knowledge * self.donate(c)
        donate_experience = (experience_donate - 1) * donate
        try:
            wealth_estate_factor = ((1 + self.estate_capacity_increase) ** self.estate_years - 1) / self.estate_capacity_increase
        except ZeroDivisionError:
            wealth_estate_factor = self.estate_years
        wealth_estate = wealth / wealth_estate_factor
        if self.estate_linear and self.experience_ratio:
            # 60% speedup.
            estate = self.estate_factor * wealth
            estate_experience = self.experience_estate_init - 1
            estate_experience += (wealth_estate_factor - 1) * (self.experience_estate - 1)
            estate_experience *= self.estate_factor * wealth_estate
        else:
            estate_experience = 0
            estate = 0
            wealth_estate_prev = 0
            for i in range(self.estate_years):
                estate_annual = self.estate(wealth_estate)
                estate += estate_annual
                estate_experience += estate_annual * (self.experience(wealth_estate, wealth_estate_prev) - 1)
                wealth_estate_prev = wealth_estate
                wealth_estate *= 1 + self.estate_capacity_increase
        estate *= weight * knowledge * self.q
        estate_experience *= weight * knowledge * self.q
        satisfaction = weight * self.satisfaction(c)

        if return_components:
            return {
                'donate': donate,
                'donate_experience': donate_experience,
                'estate': estate,
                'estate_experience': estate_experience,
                'satisfaction': satisfaction,
            }
        else:
            return donate + donate_experience + estate + estate_experience + satisfaction

class Period:

    def __init__(self, map, utilities, age, wealth_min, wealth_max, consume_min, consume_max, next_period):

        self.map = map
        self.utilities = utilities
        self.age = age
        self.wealth_min = wealth_min
        self.wealth_max = wealth_max
        self.consume_min = consume_min
        self.consume_max = consume_max

        if utilities.consume_prev:
            self.w_steps = wealth_steps
            self.c_steps = consume_steps
        else:
            self.w_steps = wealth_alone_steps
            self.c_steps = 0
        self.wealth_scale = (wealth_max / wealth_min) ** (1 / self.w_steps)
        self.log_wealth_scale = log(self.wealth_scale)
        try:
            self.consume_scale = (consume_max / consume_min) ** (1 / self.c_steps)
            self.log_consume_scale = log(self.consume_scale)
        except ZeroDivisionError:
            self.consume_scale = 0
            self.log_consume_scale = 0

        element = []
        for i in range(self.w_steps + 2):
            element_row = []
            for j in range(self.c_steps + 1):
                if 0 < i < self.w_steps + 1:
                    hint = element[i - 1][j].consume_fraction
                    if isnan(hint):
                        hint = 1
                else:
                    hint = 1
                if i < self.w_steps:
                    wealth = wealth_min * self.wealth_scale ** i
                elif i == self.w_steps:
                    wealth = wealth_max # Avoid fp out of range.
                else:
                    wealth = 0 # Store wealth == 0 at index self.w_steps + 1. Helps produce good strategy plots where wealth close to zero.
                if self.c_steps == 0:
                    consume_prev = 0
                elif j < self.c_steps:
                    consume_prev = consume_min * self.consume_scale ** j
                else:
                    consume_prev = consume_max
                e = Element(wealth, consume_prev, self, next_period, hint)
                element_row.append(e)
            element.append(element_row)

        def check_neighbour(s, i, j, di, dj):
            try:
                hint = element[i + di][j + dj].consume_fraction
                if isnan(hint):
                    hint = 1
            except IndexError:
                pass
            else:
                e = element[i][j]
                old_utility = e.utility
                old_consume_fraction = e.consume_fraction
                if e.check(hint):
                    try:
                        improve = (e.utility - old_utility) / old_utility
                    except ZeroDivisionError:
                        improve = float('inf')
                    print('Improved:', s, e.wealth, e.consume_prev, abs(e.consume_fraction - old_consume_fraction), improve)

        smooth_consume_fraction = False
        if smooth_consume_fraction:
            # Attempt to fix non-monitonicity in consume fraction, by attempting a few other hints.
            # This makes plots of consume fraction look smoother, but it doesn't appreciably alter the expected or realized utility.
            for i in range(self.w_steps + 1):
                for j in range(self.c_steps + 1):
                    check_neighbour('wealth below', i, j, -1, 0)
                    check_neighbour('wealth above', i, j, 1, 0)
                    check_neighbour('c_prev below', i, j, 0, -1)
                    check_neighbour('c_prev above', i, j, 0, 1)

        # Conserve RAM by allowing element to be deleted; just save what we need in an array.
        self.utility = tuple(tuple(e.utility for e in element_wealth) for element_wealth in element)
        self.consume_fraction = tuple(tuple(e.consume_fraction for e in element_wealth) for element_wealth in element)

    def lookup(self, wealth, consume_prev, what):

        try:
            windex = log(wealth / self.wealth_min) / self.log_wealth_scale
        except ValueError:
            windex = -1
        if windex < 0:
            windex0 = self.w_steps + 1 # Location for wealth == 0.
            windex1 = 0
            wremain = wealth / self.wealth_min
        elif windex >= self.w_steps:
            windex0 = self.w_steps - 1
            windex1 = self.w_steps
            if wealth <= self.wealth_max: # Handles fp rounding errors.
                wremain = 1
            elif out_of_range == 'truncate':
                wremain = 1
            elif out_of_range == 'extrapolate':
                wremain = 1 + (wealth / self.wealth_max - 1) / (1 - 1 / self.wealth_scale)
            else:
                wremain = float('nan')
        else:
            windex0 = int(floor(windex))
            windex1 = windex0 + 1
            wremain = (self.wealth_scale ** (windex % 1) - 1) / (self.wealth_scale - 1)

        try:
            cindex = log(consume_prev / self.consume_min) / self.log_consume_scale
        except (ValueError, ZeroDivisionError):
            cindex = -1
        if cindex < 0:
            cindex0 = 0
            cindex1 = cindex0
            cremain = 0
        elif cindex >= self.c_steps:
            cindex0 = self.c_steps - 1
            cindex1 = self.c_steps
            if consume_prev <= self.consume_max:
                cremain = 1
            elif out_of_range == 'truncate':
                cremain = 1
            elif out_of_range == 'extrapolate':
                cremain = 1 + (consume_prev / self.consume_max - 1) / (1 - 1 / self.consume_scale)
            else:
                cremain = float('nan')
        else:
            cindex0 = int(floor(cindex))
            cindex1 = cindex0 + 1
            cremain = (self.consume_scale ** (cindex % 1) - 1) / (self.consume_scale - 1)

        if what == 'utility':

            if self.utility:
                utility0 = (1 - cremain) * self.utility[windex0][cindex0] + cremain * self.utility[windex0][cindex1]
                utility1 = (1 - cremain) * self.utility[windex1][cindex0] + cremain * self.utility[windex1][cindex1]
                utility = (1 - wremain) * utility0 + wremain * utility1
                if isnan(utility):
                    raise LookupUnknown('Need to increase wealth_max.')
            else:
                utility = float('nan')

            return utility

        if what == 'consume_fraction':

            if windex0 == self.w_steps + 1 and wealth > 0:
                # Using consume_fraction rather than consume allows for more accurate calculations, except at wealth == 0, for which consume_fraction is ill-defined.
                windex0 = 0
            consume_fraction0 = (1 - cremain) * self.consume_fraction[windex0][cindex0] + cremain * self.consume_fraction[windex0][cindex1]
            consume_fraction1 = (1 - cremain) * self.consume_fraction[windex1][cindex0] + cremain * self.consume_fraction[windex1][cindex1]
            consume_fraction = (1 - wremain) * consume_fraction0 + wremain * consume_fraction1
            consume_fraction = max(0, min(consume_fraction, 1))
            if isnan(consume_fraction):
                raise LookupUnknown('Need to increase consume_max.')

            return consume_fraction

    def conserve_ram(self):

        self.utility = None

class Map:

    def report_fn(self, name, f, lo, hi):

        delta = 1e-6 * (hi - lo)

        with open(self.prefix + self.name + '-' + name + '.csv', 'w') as file:
            for i in range(1000 + 1):
                x = lo + i / 1000 * (hi - lo)
                if x + delta <= hi:
                    try:
                        prime = (f(x + delta) - f(x)) / delta
                    except ZeroDivisionError:
                        prime = float('nan')
                else:
                    # Don't evaluate outside range.
                    prime = (f(x) - f(x - delta)) / delta
                file.write('%g,%g,%g\n' % (x, f(x), prime))

    def __init__(self, prefix = 'spend-', name = 'default', start = 0, limit = 120, wealth_init = 0, consume_init = 0,
                 returns = ReturnsNone(), life = LifeFixed(), estate_years = 1, estate_capacity_increase = 0,
                 donate = UtilityNone(), estate = UtilityNone(), satisfaction = UtilityNone(), knowledge = KnowledgeNone(), experience = ExperienceNone(),
                 wealth_range_init = 1000, consume_range_init = 100000,
                 map_quantile = 0.99999, wealth_min = None, wealth_max = float('inf'), consume_min = None, consume_max = float('inf'), wealth_report_max = None):

        self.prefix = prefix
        self.name = name
        self.start = start
        self.limit = limit
        self.wealth_init = wealth_init
        self.consume_init = consume_init
        self.returns = returns
        self.life = life
        self.estate_years = estate_years
        self.estate_capacity_increase = estate_capacity_increase
        self.donate = donate
        self.estate = estate
        self.satisfaction = satisfaction
        self.knowledge = knowledge
        self.experience = experience

        print('Generating map:', name)

        if wealth_min != None:
            min_wealth = wealth_min
        else:
            min_wealth = wealth_init / wealth_range_init if wealth_init != 0 else 1
        if consume_min != None:
            min_consume = consume_min
        else:
            min_consume = wealth_init / consume_range_init if wealth_init != 0 else 1

        print('Geometric mean return: %.2f%%' % ((returns.gm - 1) * 100))

        # Precompute from youngest age to oldest.
        alive = [1]
        max_wealth_period = [wealth_init]
        wealth_max_generatable = wealth_init
        for i in range((limit - start) * year_steps):
            year = (i + 1) / year_steps
            prev_age = start + i / year_steps
            alive.append(alive[-1] * (1 - life.q(prev_age)))
            wealth_max_generatable *= returns.sample[-1]
            max_wealth_period.append(min(wealth_init * returns.multiplicative_quantile(map_quantile, year), wealth_max_generatable, wealth_max))
        print('Total life expectancy: %.1f' % (start + (sum(alive[1:]) + 0.5) / year_steps))

        # Report functions for plotting.
        self.report_fn('q', lambda age: 1 - (1 - life.q(age)) ** year_steps, start, limit)
        def l(age):
            i = (age - start) * year_steps
            return (1 - (i % 1)) * alive[int(floor(i))] + (i % 1) * alive[int(ceil(i))]
        self.report_fn('l', l, start, limit)
        max_wealth_report = wealth_report_max if wealth_report_max != None else wealth_init * returns.multiplicative_quantile(map_quantile, limit - start)
        utilities = Utilities(0, 1, 0, max_wealth_report, estate_years, estate_capacity_increase, donate, estate, satisfaction, knowledge, experience)
        self.report_fn('donate', utilities.donate, 0, max_wealth_report)
        self.report_fn('estate', utilities.estate, 0, max_wealth_report)
        self.report_fn('satisfaction', utilities.satisfaction, 0, max_wealth_report)
        try:
            self.report_fn('knowledge_factor', knowledge.knowledge_factor, 0, limit - start)
        except AttributeError:
            pass
        self.report_fn('knowledge', lambda y: knowledge.create_function(y)(), 0, limit - start)
        self.report_fn('experience', lambda c: utilities.experience(c, 1), 0, 30)

        # Solve from oldest age to youngest.
        self.map = []
        future_period = None
        for i in range((limit - start) * year_steps, -1, -1):
            year = i / year_steps
            age = start + year
            current_alive = alive[i]
            death = life.q(age) if age < limit else 1
            max_wealth = max_wealth_period[i]
            try:
                max_consume = min(max_wealth_period[i - 1], consume_max)
            except IndexError:
                max_consume = consume_init
            utilities = Utilities(year, current_alive, death, max_wealth, estate_years, estate_capacity_increase, donate, estate, satisfaction, knowledge, experience)
            period = Period(self, utilities, age, min_wealth, max_wealth, min_consume, max_consume, future_period)

            report_ages = False
            if report_ages:
                with open(self.prefix + self.name + '-age-' + str(age) + '.csv', 'w') as f:
                    for i in range(100):
                        wealth = i / 100 * max_wealth
                        for j in range(100):
                            consume_prev = j / 100 * max_consume
                            utility = period.lookup(wealth, consume_prev, 'utility')
                            consume_fraction = period.lookup(wealth, consume_prev, 'consume_fraction')
                            f.write('%g,%g,%g,%g\n' % (wealth, consume_prev, utility, consume_fraction))
                        f.write('\n')

            self.map.insert(0, period)
            if trace:
                wealth_sample = wealth_init
                consume_sample = wealth_init / 10
                utility = period.lookup(wealth_sample, consume_sample, 'utility')
                consume_fraction = period.lookup(wealth_sample, consume_sample, 'consume_fraction')
                print(age, '-', consume_fraction, utility)
            if future_period:
                future_period.conserve_ram()
            future_period = period

        print()

    def evaluate(self, name = None, start = None, limit = None, wealth_init = None, consume_init = None,
                 returns = None, life = None, estate_years = None, estate_capacity_increase = None,
                 donate = None, estate = None, satisfaction = None, knowledge = None, experience = None, **kwargs):

        same_map = (name == None or name == self.name)
        if same_map:
            fullname = self.name
        else:
            fullname = self.name + '-using-' + name
        if start == None:
            start = self.start
        if limit == None:
            limit = self.limit
        assert(self.start <= start <= limit)
        assert(limit <= self.limit)
        assert(start <= limit)
        if wealth_init == None:
            wealth_init = self.wealth_init
        if consume_init == None:
            consume_init = self.consume_init
        if returns == None:
            returns = self.returns
        if life == None:
            life = self.life
        if estate_years == None:
            estate_years = self.estate_years
        if estate_capacity_increase == None:
            estate_capacity_increase = self.estate_capacity_increase
        if donate == None:
            donate = self.donate
        if estate == None:
            estate = self.estate
        if satisfaction == None:
            satisfaction = self.satisfaction
        if knowledge == None:
            knowledge = self.knowledge
        if experience == None:
            experience = self.experience

        print('Evaluating map:', fullname)

        seed(0)

        alive = 1
        wealths = [wealth_init] * num_evaluate
        consume_prevs = [consume_init] * num_evaluate
        utilities = [0] * num_evaluate
        utility_components = {}
        means = {'wealth': [], 'consume': [], 'utility': []}
        medians = {'wealth': [], 'consume': [], 'utility': []}
        pctl_low = {'wealth': [], 'consume': [], 'utility': []}
        pctl_high = {'wealth': [], 'consume': [], 'utility': []}
        for period in self.map[(start - self.start) * year_steps : (limit - self.start) * year_steps + 1]:
            year = period.age - start
            death = life.q(period.age) if period.age < limit else 1
            utility_fn = Utilities(year, alive, death, period.wealth_max, estate_years, estate_capacity_increase, donate, estate, satisfaction, knowledge, experience)
            wealth_values = []
            consume_values = []
            utility_values = []
            for i, (wealth, consume_prev) in enumerate(zip(wealths, consume_prevs)):
                consume_fraction = period.lookup(wealth, consume_prev, 'consume_fraction')
                consume = consume_fraction * wealth
                wealth_values.append(wealth)
                consume_values.append(consume)
                wealth -= consume
                utility = utility_fn.u(wealth, consume, consume_prev, return_components = True)
                utility_sum = sum(utility.values())
                utility_values.append(utility_sum)
                wealths[i] = wealth * returns.random()
                consume_prevs[i] = consume
                utilities[i] += utility_sum
                for key, value in utility.items():
                    try:
                        utility_components[key] += value / num_evaluate
                    except KeyError:
                        utility_components[key] = value / num_evaluate
            means['wealth'].append(mean(wealth_values))
            means['consume'].append(mean(consume_values))
            means['utility'].append(mean(utility_values))
            medians['wealth'].append(pctl(0.5, wealth_values))
            medians['consume'].append(pctl(0.5, consume_values))
            medians['utility'].append(pctl(0.5, utility_values))
            confidence_interval = 0.95
            low = (1 - confidence_interval) / 2
            high = 1 - low
            pctl_low['wealth'].append(pctl(low, wealth_values))
            pctl_low['consume'].append(pctl(low, consume_values))
            pctl_low['utility'].append(pctl(low, utility_values))
            pctl_high['wealth'].append(pctl(high, wealth_values))
            pctl_high['consume'].append(pctl(high, consume_values))
            pctl_high['utility'].append(pctl(high, utility_values))
            alive *= 1 - death
        mean_utility = mean(utilities)
        median_utility = median(utilities)
        stdev_utility = stdev(utilities)
        stderr_utility = stdev_utility / sqrt(num_evaluate)
        
        def strategy(name, what = None, scale = 1):
            if not what:
                what = name
            with open(self.prefix + fullname + '-' + name + '.csv', 'w') as f:
                for i in range(len(means[what])):
                    age = start + i / year_steps
                    f.write('%g,%g,%g,%g,%g\n' % (age, means[what][i] / scale, medians[what][i] / scale, pctl_low[what][i] / scale, pctl_high[what][i] / scale))

        strategy('wealth')
        strategy('consume')
        strategy('utility')
        scale = sum(means['consume'])
        strategy('consume_pct', 'consume', scale if scale > 0 else 1)

        period = self.map[(start - self.start) * year_steps]
        expected = period.lookup(wealth_init, consume_init, 'utility')
        consume_fraction = period.lookup(wealth_init, consume_init, 'consume_fraction') * year_steps
        consume = consume_fraction * wealth_init
        print('Initial consume rate: %g (%.1f%%)' % (consume, consume_fraction * 100))
        if same_map:
            try:
                gain = (expected - mean_utility) / mean_utility
            except ZeroDivisionError:
                gain = float('inf') if mean_utililty != 0 else 0
            print('Mean utility expected: %g (diff. %f%%)' % (expected, gain * 100))
        try:
            stderr_fract = stderr_utility / mean_utility
        except ZeroDivisionError:
            stderr_fract = float('inf') if stderr_utility != 0 else 0
        print('Mean utility realized: %g +/- %g%%' % (mean_utility, stderr_fract * 100))
        for key in sorted(utility_components.keys()):
            try:
                u_ratio = utility_components[key] / mean_utility
            except ZeroDivisionError:
                u_ratio = float('inf') if utility_components[key] != 0 else 0
            print('  %s: %g (%.1f%%)' % (key, utility_components[key], u_ratio * 100))
        print('Median utility realized: %g' % median_utility)

        print()

initial = {'name': 'initial', 'start': 25, 'limit': 105, 'wealth_init': 200000, 'consume_init': 0,
           'returns': ReturnsLogNormal(), 'life': LifeGM(), 'estate_years': 5, 'estate_capacity_increase': 0.3,
           'donate': UtilityLinear(), 'estate': UtilityNone(), 'satisfaction': UtilityNone(), 'knowledge': KnowledgeNone(), 'experience': ExperienceNone(),
           'wealth_report_max': 100000}
donate = dict(initial, name = 'donate', donate = UtilityDiscount())
estate = dict(donate, name = 'estate', estate = UtilityDiscount(weight = 0.8))
experience = dict(estate, name = 'experience', experience = ExperiencePower())
knowledge = dict(experience, name = 'knowledge', knowledge = KnowledgeFactorDecay())
all = dict(knowledge, name = 'all')

all_r_0_1 = dict(all, name = 'r-0.1', donate = UtilityDiscount(r = 0.1), estate = UtilityDiscount(weight = 0.8, r = 0.1))
estate_satisfaction_r_sub = dict(estate, name = 'estate-satisfaction-r-sub', satisfaction = UtilityCrra(r = 0.040))
all_experience_extreme = dict(all, name = 'experience-extreme', experience = ExperiencePower(max_factor = 4, capacity_increase = 200, power = -0.8))
all_estate_weight_0_5 = dict(all, name = 'estate-weight-0.5', estate = UtilityDiscount(weight = 0.5))
all_returns_lo = dict(all, name = 'returns-lo', returns = ReturnsLogNormal(am = 1.029, sd = 0.105))
all_returns_gm = dict(all, name = 'returns-gm', returns = ReturnsLogNormal(am = ReturnsLogNormal().gm, sd = 0))
all_uncorr_10000 = dict(all, name = 'uncorr-10000',
                        donate = UtilityCrra(weight = 1, c_uncorr = 10000, gamma = 2, r = 0.03),
                        estate = UtilityCrra(weight = 0.8, c_uncorr = 10000, gamma = 2, r = 0.03))
all_gamma_1 = dict(all, name = 'gamma-1',
                   donate = UtilityCrra(weight = 1, c_uncorr = 100000, gamma = 1, r = 0.03),
                   estate = UtilityCrra(weight = 0.8, c_uncorr = 100000, gamma = 1, r = 0.03))

satisfaction = dict(all, name = 'satisfaction', satisfaction = UtilityCrra())
all_returns_lo_r_0_1 = dict(all_returns_lo, name = 'returns-lo-r-0.1', donate = UtilityDiscount(r = 0.1), estate = UtilityDiscount(weight = 0.8, r = 0.1))

if __name__ == '__main__':

    Map(**initial).evaluate()
    Map(**donate).evaluate()
    m = Map(**estate)
    m.evaluate()
    m.evaluate(**estate_satisfaction_r_sub)
    m = None # Conserve RAM.
    Map(**experience).evaluate()
    Map(**knowledge).evaluate()

    Map(**all_r_0_1).evaluate()
    Map(**all_estate_weight_0_5).evaluate()
    Map(**all_experience_extreme).evaluate()
    Map(**all_returns_lo).evaluate()
    Map(**all_returns_gm).evaluate()
    Map(**all_uncorr_10000).evaluate()
    Map(**all_gamma_1).evaluate()

    Map(**satisfaction).evaluate()
    m = Map(**estate_satisfaction_r_sub)
    m.evaluate()

    m.evaluate(**all)
    m.evaluate(**all_r_0_1)
    m = None # Conserve RAM.
    Map(**all_returns_lo).evaluate()
    Map(**all_returns_lo_r_0_1).evaluate()
