/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.montecarlo.interestrate;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import net.finmath.exception.CalculationException;
import net.finmath.marketdata.model.AnalyticModelInterface;
import net.finmath.marketdata.model.curves.DiscountCurveFromForwardCurve;
import net.finmath.marketdata.model.curves.DiscountCurveInterface;
import net.finmath.marketdata.model.curves.ForwardCurveInterface;
import net.finmath.montecarlo.interestrate.LIBORMarketModelInterface;
import net.finmath.montecarlo.interestrate.LIBORModelInterface;
import net.finmath.montecarlo.interestrate.modelplugins.ShortRateVolailityModelInterface;
import net.finmath.montecarlo.model.AbstractModel;
import net.finmath.montecarlo.process.AbstractProcessInterface;
import net.finmath.stochastic.RandomVariableInterface;
import net.finmath.time.TimeDiscretizationInterface;

public class HullWhiteModelWithShiftExtension
extends AbstractModel
implements LIBORModelInterface {
    private final DriftFormula driftFormula;
    private final TimeDiscretizationInterface liborPeriodDiscretization;
    private String forwardCurveName;
    private AnalyticModelInterface curveModel;
    private ForwardCurveInterface forwardRateCurve;
    private DiscountCurveInterface discountCurve;
    private DiscountCurveInterface discountCurveFromForwardCurve;
    private final ConcurrentHashMap<Integer, RandomVariableInterface> numeraires;
    private AbstractProcessInterface numerairesProcess = null;
    private final ShortRateVolailityModelInterface volatilityModel;

    public HullWhiteModelWithShiftExtension(TimeDiscretizationInterface liborPeriodDiscretization, AnalyticModelInterface analyticModel, ForwardCurveInterface forwardRateCurve, DiscountCurveInterface discountCurve, ShortRateVolailityModelInterface volatilityModel, Map<String, ?> properties) {
        this.liborPeriodDiscretization = liborPeriodDiscretization;
        this.curveModel = analyticModel;
        this.forwardRateCurve = forwardRateCurve;
        this.discountCurve = discountCurve;
        this.volatilityModel = volatilityModel;
        this.discountCurveFromForwardCurve = new DiscountCurveFromForwardCurve(forwardRateCurve);
        this.numeraires = new ConcurrentHashMap();
        this.driftFormula = properties != null && properties.containsKey("driftFormula") ? DriftFormula.valueOf((String)properties.get("driftFormula")) : DriftFormula.DISCRETE;
    }

    @Override
    public int getNumberOfComponents() {
        return 1;
    }

    @Override
    public RandomVariableInterface applyStateSpaceTransform(int componentIndex, RandomVariableInterface randomVariable) {
        return randomVariable;
    }

    @Override
    public RandomVariableInterface applyStateSpaceTransformInverse(int componentIndex, RandomVariableInterface randomVariable) {
        return randomVariable;
    }

    @Override
    public RandomVariableInterface[] getInitialState() {
        RandomVariableInterface zero = this.getProcess().getStochasticDriver().getRandomVariableForConstant(0.0);
        return new RandomVariableInterface[]{zero};
    }

    @Override
    public RandomVariableInterface getNumeraire(double time) throws CalculationException {
        RandomVariableInterface numeraire;
        if (time == this.getTime(0)) {
            RandomVariableInterface one = this.getProcess().getStochasticDriver().getRandomVariableForConstant(1.0);
            return one;
        }
        int timeIndex = this.getProcess().getTimeIndex(time);
        if (timeIndex < 0) {
            int previousTimeIndex = this.getProcess().getTimeIndex(time);
            if (previousTimeIndex < 0) {
                previousTimeIndex = -previousTimeIndex - 1;
            }
            double previousTime = this.getProcess().getTime(--previousTimeIndex);
            RandomVariableInterface rate = this.getShortRate(previousTimeIndex);
            RandomVariableInterface integratedRate = rate.mult(time - previousTime);
            return this.getNumeraire(previousTime).mult(integratedRate.exp());
        }
        if (this.getProcess() != this.numerairesProcess) {
            this.numeraires.clear();
            this.numerairesProcess = this.getProcess();
        }
        if ((numeraire = this.numeraires.get(timeIndex)) == null) {
            RandomVariableInterface zero;
            RandomVariableInterface integratedRate = zero = this.getProcess().getStochasticDriver().getRandomVariableForConstant(0.0);
            for (int i = 0; i < timeIndex; ++i) {
                RandomVariableInterface rate = this.getShortRate(i);
                double dt = this.getProcess().getTimeDiscretization().getTimeStep(i);
                integratedRate = integratedRate.addProduct(rate, dt);
                numeraire = integratedRate.exp();
                this.numeraires.put(i + 1, numeraire);
            }
        }
        if (this.discountCurve != null) {
            double deterministicNumeraireAdjustment = numeraire.invert().getAverage() / this.discountCurve.getDiscountFactor(this.curveModel, time);
            numeraire = numeraire.mult(deterministicNumeraireAdjustment);
        }
        return numeraire;
    }

    @Override
    public RandomVariableInterface[] getDrift(int timeIndex, RandomVariableInterface[] realizationAtTimeIndex, RandomVariableInterface[] realizationPredictor) {
        double time = this.getProcess().getTime(timeIndex);
        double timeNext = this.getProcess().getTime(timeIndex + 1);
        int timeIndexVolatility = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexVolatility < 0) {
            timeIndexVolatility = -timeIndexVolatility - 2;
        }
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexVolatility);
        double meanReversionEffective = meanReversion * this.getB(time, timeNext) / (timeNext - time);
        return new RandomVariableInterface[]{realizationAtTimeIndex[0].mult(-meanReversionEffective)};
    }

    @Override
    public RandomVariableInterface getRandomVariableForConstant(double value) {
        return this.getProcess().getStochasticDriver().getRandomVariableForConstant(value);
    }

    @Override
    public RandomVariableInterface[] getFactorLoading(int timeIndex, int componentIndex, RandomVariableInterface[] realizationAtTimeIndex) {
        double time = this.getProcess().getTime(timeIndex);
        double timeNext = this.getProcess().getTime(timeIndex + 1);
        int timeIndexVolatility = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexVolatility < 0) {
            timeIndexVolatility = -timeIndexVolatility - 2;
        }
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexVolatility);
        double scaling = Math.sqrt((1.0 - Math.exp(-2.0 * meanReversion * (timeNext - time))) / (2.0 * meanReversion * (timeNext - time)));
        double volatilityEffective = scaling * this.volatilityModel.getVolatility(timeIndexVolatility);
        RandomVariableInterface factorLoading = this.getProcess().getStochasticDriver().getRandomVariableForConstant(volatilityEffective);
        return new RandomVariableInterface[]{factorLoading};
    }

    @Override
    public RandomVariableInterface getLIBOR(double time, double periodStart, double periodEnd) throws CalculationException {
        return this.getZeroCouponBond(time, periodStart).div(this.getZeroCouponBond(time, periodEnd)).sub(1.0).div(periodEnd - periodStart);
    }

    @Override
    public RandomVariableInterface getLIBOR(int timeIndex, int liborIndex) throws CalculationException {
        return this.getZeroCouponBond(this.getProcess().getTime(timeIndex), this.getLiborPeriod(liborIndex)).div(this.getZeroCouponBond(this.getProcess().getTime(timeIndex), this.getLiborPeriod(liborIndex + 1))).sub(1.0).div(this.getLiborPeriodDiscretization().getTimeStep(liborIndex));
    }

    @Override
    public TimeDiscretizationInterface getLiborPeriodDiscretization() {
        return this.liborPeriodDiscretization;
    }

    @Override
    public int getNumberOfLibors() {
        return this.liborPeriodDiscretization.getNumberOfTimeSteps();
    }

    @Override
    public double getLiborPeriod(int timeIndex) {
        return this.liborPeriodDiscretization.getTime(timeIndex);
    }

    @Override
    public int getLiborPeriodIndex(double time) {
        return this.liborPeriodDiscretization.getTimeIndex(time);
    }

    @Override
    public AnalyticModelInterface getAnalyticModel() {
        return this.curveModel;
    }

    @Override
    public DiscountCurveInterface getDiscountCurve() {
        return this.discountCurve;
    }

    @Override
    public ForwardCurveInterface getForwardRateCurve() {
        return this.forwardRateCurve;
    }

    @Override
    public LIBORMarketModelInterface getCloneWithModifiedData(Map<String, Object> dataModified) throws CalculationException {
        throw new UnsupportedOperationException();
    }

    private RandomVariableInterface getShortRate(int timeIndex) throws CalculationException {
        double zeroRate;
        double time = this.getProcess().getTime(timeIndex);
        RandomVariableInterface value = this.getProcess().getProcessValue(timeIndex, 0);
        double dt = this.getProcess().getTimeDiscretization().getTimeStep(timeIndex);
        double alpha = zeroRate = -Math.log(this.discountCurveFromForwardCurve.getDiscountFactor(time + dt) / this.discountCurveFromForwardCurve.getDiscountFactor(time)) / dt;
        if (this.driftFormula == DriftFormula.DISCRETE) {
            alpha += this.getIntegratedDriftAdjustment(timeIndex);
        } else if (this.driftFormula == DriftFormula.ANALYTIC) {
            alpha += this.getDV(0.0, time);
        }
        value = value.add(alpha);
        return value;
    }

    private RandomVariableInterface getZeroCouponBond(double time, double maturity) throws CalculationException {
        int timeIndex = this.getProcess().getTimeIndex(time);
        RandomVariableInterface shortRate = this.getShortRate(timeIndex);
        double A = this.getA(time, maturity);
        double B = this.getB(time, maturity);
        return shortRate.mult(-B).exp().mult(A);
    }

    private double getIntegratedDriftAdjustment(int timeIndex) {
        double integratedDriftAdjustment = 0.0;
        for (int i = 1; i <= timeIndex; ++i) {
            double t = this.getProcess().getTime(i - 1);
            double t2 = this.getProcess().getTime(i);
            int timeIndexVolatilityModel = this.volatilityModel.getTimeDiscretization().getTimeIndex(t);
            if (timeIndexVolatilityModel < 0) {
                timeIndexVolatilityModel = -timeIndexVolatilityModel - 2;
            }
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndexVolatilityModel);
            integratedDriftAdjustment += this.getShortRateConditionalVariance(0.0, t) * this.getB(t, t2) / (t2 - t) * (t2 - t) - integratedDriftAdjustment * meanReversion * (t2 - t) * this.getB(t, t2) / (t2 - t);
        }
        return integratedDriftAdjustment;
    }

    private double getA(double time, double maturity) {
        double timeStep;
        int timeIndex = this.getProcess().getTimeIndex(time);
        double dt = timeStep = this.getProcess().getTimeDiscretization().getTimeStep(timeIndex);
        double zeroRate = -Math.log(this.discountCurveFromForwardCurve.getDiscountFactor(time + dt) / this.discountCurveFromForwardCurve.getDiscountFactor(time)) / dt;
        double B = this.getB(time, maturity);
        double lnA = Math.log(this.discountCurveFromForwardCurve.getDiscountFactor(maturity) / this.discountCurveFromForwardCurve.getDiscountFactor(time)) + B * zeroRate - 0.5 * this.getShortRateConditionalVariance(0.0, time) * B * B;
        return Math.exp(lnA);
    }

    private double getMRTime(double time, double maturity) {
        double timeNext;
        int timeIndexEnd;
        int timeIndexStart = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexStart < 0) {
            timeIndexStart = -timeIndexStart - 1;
        }
        if ((timeIndexEnd = this.volatilityModel.getTimeDiscretization().getTimeIndex(maturity)) < 0) {
            timeIndexEnd = -timeIndexEnd - 2;
        }
        double integral = 0.0;
        double timePrev = time;
        for (int timeIndex = timeIndexStart + 1; timeIndex <= timeIndexEnd; ++timeIndex) {
            timeNext = this.volatilityModel.getTimeDiscretization().getTime(timeIndex);
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndex - 1);
            integral += meanReversion * (timeNext - timePrev);
            timePrev = timeNext;
        }
        timeNext = maturity;
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexEnd);
        return integral += meanReversion * (timeNext - timePrev);
    }

    private double getB(double time, double maturity) {
        double timeNext;
        int timeIndexEnd;
        int timeIndexStart = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexStart < 0) {
            timeIndexStart = -timeIndexStart - 1;
        }
        if ((timeIndexEnd = this.volatilityModel.getTimeDiscretization().getTimeIndex(maturity)) < 0) {
            timeIndexEnd = -timeIndexEnd - 2;
        }
        double integral = 0.0;
        double timePrev = time;
        for (int timeIndex = timeIndexStart + 1; timeIndex <= timeIndexEnd; ++timeIndex) {
            timeNext = this.volatilityModel.getTimeDiscretization().getTime(timeIndex);
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndex - 1);
            integral += (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / meanReversion;
            timePrev = timeNext;
        }
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexEnd);
        timeNext = maturity;
        return integral += (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / meanReversion;
    }

    private double getV(double time, double maturity) {
        double timeNext;
        int timeIndexEnd;
        if (time == maturity) {
            return 0.0;
        }
        int timeIndexStart = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexStart < 0) {
            timeIndexStart = -timeIndexStart - 1;
        }
        if ((timeIndexEnd = this.volatilityModel.getTimeDiscretization().getTimeIndex(maturity)) < 0) {
            timeIndexEnd = -timeIndexEnd - 2;
        }
        double integral = 0.0;
        double timePrev = time;
        for (int timeIndex = timeIndexStart + 1; timeIndex <= timeIndexEnd; ++timeIndex) {
            timeNext = this.volatilityModel.getTimeDiscretization().getTime(timeIndex);
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndex - 1);
            double volatility = this.volatilityModel.getVolatility(timeIndex - 1);
            integral += volatility * volatility * (timeNext - timePrev) / (meanReversion * meanReversion);
            integral -= volatility * volatility * 2.0 * (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / (meanReversion * meanReversion * meanReversion);
            integral += volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion * meanReversion * meanReversion);
            timePrev = timeNext;
        }
        timeNext = maturity;
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexEnd);
        double volatility = this.volatilityModel.getVolatility(timeIndexEnd);
        integral += volatility * volatility * (timeNext - timePrev) / (meanReversion * meanReversion);
        integral -= volatility * volatility * 2.0 * (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / (meanReversion * meanReversion * meanReversion);
        return integral += volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion * meanReversion * meanReversion);
    }

    private double getDV(double time, double maturity) {
        double timeNext;
        int timeIndexEnd;
        if (time == maturity) {
            return 0.0;
        }
        int timeIndexStart = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexStart < 0) {
            timeIndexStart = -timeIndexStart - 1;
        }
        if ((timeIndexEnd = this.volatilityModel.getTimeDiscretization().getTimeIndex(maturity)) < 0) {
            timeIndexEnd = -timeIndexEnd - 2;
        }
        double integral = 0.0;
        double timePrev = time;
        for (int timeIndex = timeIndexStart + 1; timeIndex <= timeIndexEnd; ++timeIndex) {
            timeNext = this.volatilityModel.getTimeDiscretization().getTime(timeIndex);
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndex - 1);
            double volatility = this.volatilityModel.getVolatility(timeIndex - 1);
            integral += volatility * volatility * (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / (meanReversion * meanReversion);
            integral -= volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion * meanReversion);
            timePrev = timeNext;
        }
        timeNext = maturity;
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexEnd);
        double volatility = this.volatilityModel.getVolatility(timeIndexEnd);
        integral += volatility * volatility * (Math.exp(-this.getMRTime(timeNext, maturity)) - Math.exp(-this.getMRTime(timePrev, maturity))) / (meanReversion * meanReversion);
        return integral -= volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion * meanReversion);
    }

    public double getShortRateConditionalVariance(double time, double maturity) {
        double timeNext;
        int timeIndexEnd;
        int timeIndexStart = this.volatilityModel.getTimeDiscretization().getTimeIndex(time);
        if (timeIndexStart < 0) {
            timeIndexStart = -timeIndexStart - 1;
        }
        if ((timeIndexEnd = this.volatilityModel.getTimeDiscretization().getTimeIndex(maturity)) < 0) {
            timeIndexEnd = -timeIndexEnd - 2;
        }
        double integral = 0.0;
        double timePrev = time;
        for (int timeIndex = timeIndexStart + 1; timeIndex <= timeIndexEnd; ++timeIndex) {
            timeNext = this.volatilityModel.getTimeDiscretization().getTime(timeIndex);
            double meanReversion = this.volatilityModel.getMeanReversion(timeIndex - 1);
            double volatility = this.volatilityModel.getVolatility(timeIndex - 1);
            integral += volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion);
            timePrev = timeNext;
        }
        timeNext = maturity;
        double meanReversion = this.volatilityModel.getMeanReversion(timeIndexEnd);
        double volatility = this.volatilityModel.getVolatility(timeIndexEnd);
        return integral += volatility * volatility * (Math.exp(-2.0 * this.getMRTime(timeNext, maturity)) - Math.exp(-2.0 * this.getMRTime(timePrev, maturity))) / (2.0 * meanReversion);
    }

    public double getIntegratedBondSquaredVolatility(double time, double maturity) {
        return this.getShortRateConditionalVariance(0.0, time) * this.getB(time, maturity) * this.getB(time, maturity);
    }

    private static enum DriftFormula {
        ANALYTIC,
        DISCRETE;

    }
}

