import {DoseTypes} from "@/models/utils/constants";
import {AmwgSampler} from "@/plugins/mcmc.js";
import * as ss from 'simple-statistics';
import {DEBUG} from "@/main"
import {calculateAlpha, calculateBeta, calculateCompartment} from "@/models/utils/calculations_2compartment";

// Calculate the concentration at a given time (dt)
// @param doses: Array of sorted DOSE
// @param numDoses: Number of measured doses
// @param vd: vol dist
// @param cl: drug clearance
// @param dt: time at which concentration required (hours from time of c0)
export function calculateConcentrationAtTime(doses, numDoses, vd, cl, dt) {
    const k = cl / vd;

    let c = 0;
    let c0 = c;
    for (let i = 0; i < numDoses; i++) {
        c0 = c + (doses[i].amount * 1) / vd;
        const intervalToNextDose = ((doses[i + 1].datetime - doses[i].datetime) / (1000 * 60 * 60)) - (doses[i].duration * 1)
        // trough
        c = c0 * Math.exp(-k * intervalToNextDose)
        dt -= intervalToNextDose
    }
    // add remainder of dt
    c0 = c + (doses[numDoses].amount * 1) / vd;
    c = c0 * Math.exp(-k * dt)
    return c;
}

export function calculateConcentrationAtTime2Compartment(doses, numDoses, vc, vp, q, cl, dt) {
    let c = 0;
    let c0 = c;
    let Aphase = 0;
    let Bphase = 0;
    const beta = calculateBeta(q, vc, vp, cl)
    const alpha = calculateAlpha(q, vc, vp, cl, beta)
    const A = calculateCompartment('A', q, vc, vp, alpha, beta)
    const B = calculateCompartment('B', q, vc, vp, alpha, beta)

    for (let i = 0; i < numDoses; i++) {
        c0 = c
        const dose = doses[i].amount * 1
        const infusion = doses[i].durationUnit === 'hours' ? doses[i].duration * 1 : doses[i].duration / 60
        // add time to required or to trough
        const intervalToNextDose = ((doses[i + 1].datetime - doses[i].datetime) / (1000 * 60 * 60)) - infusion
        Aphase = (A / alpha) * (1 - Math.exp(-alpha * infusion)) * Math.exp(-alpha * intervalToNextDose)
        Bphase = (B / beta) * (1 - Math.exp(-beta * infusion)) * Math.exp(-beta * intervalToNextDose)
        c = (dose/infusion) * (Aphase + Bphase) + c0
        dt -= intervalToNextDose
    }
    // add remainder of dt
    const infusion_dt = doses[numDoses].durationUnit === 'hours' ? doses[numDoses].duration * 1 : doses[numDoses].duration / 60
    c0 = c + (doses[numDoses].amount * 1) / infusion_dt;
    Aphase = (A / alpha) * (1 - Math.exp(-alpha)) * Math.exp(-alpha * dt)
    Bphase = (B / beta) * (1 - Math.exp(-beta)) * Math.exp(-beta * dt)
    c = c0 * (Aphase + Bphase)

    return c;
}


export function findMeasuredDoseNums(doses) {
    const dosesSorted = doses.filter((d) => d.type === DoseTypes.DOSE)
    const dosesMeasured = doses.filter((d) => d.type === DoseTypes.CONC)
    return dosesMeasured.map((m) => {
        const mtime = (m.datetime - dosesSorted[0].datetime) / (1000 * 60 * 60)
        const match = dosesSorted.filter((d) => {
            const dtime = (d.datetime - dosesSorted[0].datetime) / (1000 * 60 * 60)
            return mtime < dtime
        })
        return dosesSorted.length - match.length - 1
    })
}

//* from distributions.js */
export function norm(x, mean, sd) {
    return -0.5 * Math.log(2 * Math.PI) - Math.log(sd) - Math.pow(x - mean, 2) / (2 * sd * sd);
}

export function unif(x, min, max) {
    return (x < min || x > max) ? -Infinity : Math.log(1 / (max - min));
}

// Get Data as object with x,y arrays where x = time since end of infusion (hrs), y is measured conc (mg/ml)
// @param doses : array of Dose with at least one dose and drug level - SORTED in ascending by datetime (removed from here bs very slow)
// @return: params object with estimates of vd (vd_est), cl (cl_est), kel (kel_est), halflife
// OR null if not enough doses provided
export const getXYData = (doses) => {
    const drugLevels = doses.filter((d) => d.type === DoseTypes.CONC)
    const dosesSorted = doses.filter((d) => d.type === DoseTypes.DOSE)
    if (dosesSorted.length >= 1 && drugLevels.length >= 1) {
        let startTime = new Date(dosesSorted[0].datetime)
        const duration_hr = dosesSorted[0].durationUnit === 'hours' ? dosesSorted[0].duration * 1 : dosesSorted[0].duration / 60
        const x_arr = []
        const y_arr = []

        for (let i = 0; i < drugLevels.length; i++) {
            let dt = ((drugLevels[i].datetime - startTime) / (1000 * 60 * 60)) - duration_hr
            if (!isNaN(dt) && dt > 0) {
                x_arr.push(dt);
                y_arr.push(drugLevels[i].amount);
            } else {
                console.error('Error in getXYData values');
                return null;
            }
        }
        return {
            x: x_arr,
            y: y_arr
        };
    } else {
        return null;
    }
}

// run Bayesian MCMC stepper to calculate pkParams from measured concentrations
// @params pop_params : object with cl, vd
// @params doses : all doses and measurements for patient
// @returns params with pop_params and estimated params
export const runBayesianFit = (pop_params, doses) => {
    const dosesSorted = doses.filter((d) => d.type === DoseTypes.DOSE)
    const measured = findMeasuredDoseNums(doses)

    const params = {
        cl: {type: "real", init: pop_params.cl},
        vd: {type: "real", init: pop_params.vd},
        sigma: {type: "real", lower: 0}
    };

    const data = getXYData(doses);
    if (data) {
        const sample_iterations = 10;
        const samples_per_iteration = 10000;
        const samples_to_burn = 100;
        console.log(`Running Bayesian fit with ${samples_per_iteration} samples per iteration`)
        let samples;

        const arr_cl = [];
        const arr_vd = [];

        const options = {
            thin: 10,
            target_accept_rate: 0.6
        };

        const log_post = (state, data) => {
            let log_post = 0;

            // Priors
            log_post += norm(state.cl, pop_params.cl, pop_params.cl_sd);
            log_post += norm(state.vd, pop_params.vd, pop_params.vd_sd);
            log_post += unif(state.sigma, 0, 1);

            // Likelihood
            for (let i = 0; i < data.y.length; i++) {
                const mu = calculateConcentrationAtTime(dosesSorted, measured[i], state.vd, state.cl, data.x[i]);
                if (DEBUG) {
                    console.log(`mu=${mu}`)
                }
                const n = norm(data.y[i], mu, state.sigma);
                log_post += n
            }
            if (isNaN(log_post)) {
                console.error('Error - log_post is NaN', data)
                log_post = 0;
            }

            return log_post;
        };

        for (let samp_i = 0; samp_i < sample_iterations; samp_i++) {
            const sampler = new AmwgSampler(params, log_post, data, options);
            sampler.burn(samples_to_burn);
            samples = sampler.sample(samples_per_iteration);
            // Capture mean of output
            arr_cl.push(ss.mean(samples["cl"]));
            arr_vd.push(ss.mean(samples["vd"]));
        }
        const vd_est = parseFloat(ss.mean(arr_vd).toFixed(2));
        const cl_est = parseFloat(ss.mean(arr_cl).toFixed(2));
        const kel_est = cl_est / vd_est;
        const halflife_est = 0.693 / kel_est;

        return {
            ...pop_params,
            vd_est,
            cl_est,
            kel_est: parseFloat(kel_est.toFixed(4)),
            halflife_est: parseFloat(halflife_est.toFixed(4))
        };
    } else {
        return pop_params
    }
}

export const runBayesianFit2 = (pop_params, doses) => {
    const dosesSorted = doses.filter((d) => d.type === DoseTypes.DOSE)
    const measured = findMeasuredDoseNums(doses)

    const params = {
        vc: {type: "real", init: pop_params.vc},
        vp: {type: "real", init: pop_params.vp},
        sigma: {type: "real", lower: 0}
    };

    const q = pop_params.q
    const cl = pop_params.cl

    const data = getXYData(doses);
    if (data) {
        const sample_iterations = 5;
        const samples_per_iteration = 10000;
        const samples_to_burn = 100;
        console.log(`Running Bayesian fit 2 with ${samples_per_iteration} samples per iteration`)
        let samples;

        const arr_vc = [];
        const arr_vp = [];

        const options = {
            thin: 10,
            target_accept_rate: 0.6
        };

        const log_post = (state, data) => {
            let log_post = 0;

            // Priors
            log_post += norm(state.vc, pop_params.vc, pop_params.vc_sd);
            log_post += norm(state.vp, pop_params.vp, pop_params.vp_sd);
            log_post += unif(state.sigma, 0, 1);

            // Likelihood
            for (let i = 0; i < data.y.length; i++) {
                const mu = calculateConcentrationAtTime2Compartment(dosesSorted, measured[i], state.vc, state.vp, q, cl, data.x[i]);
                if (DEBUG) {
                    console.log(`mu=${mu}`)
                }
                const n = norm(data.y[i], mu, state.sigma);
                log_post += n
            }
            if (isNaN(log_post)) {
                console.error('Error - log_post is NaN', data)
                log_post = 0;
            }

            return log_post;
        };

        for (let samp_i = 0; samp_i < sample_iterations; samp_i++) {
            const sampler = new AmwgSampler(params, log_post, data, options);
            sampler.burn(samples_to_burn);
            samples = sampler.sample(samples_per_iteration);
            // Capture mean of output
            arr_vc.push(ss.mean(samples["vc"]));
            arr_vp.push(ss.mean(samples["vp"]));
        }
        const vc_est = parseFloat(ss.mean(arr_vc).toFixed(2));
        const vp_est = parseFloat(ss.mean(arr_vp).toFixed(2));
        const beta = calculateBeta(q, vc_est, vp_est, cl)
        const alpha = calculateAlpha(q, vc_est, vp_est, cl, beta)
        const a_est = calculateCompartment('A', q, vc_est, vp_est, alpha, beta)
        const b_est = calculateCompartment('B', q, vc_est, vp_est, alpha, beta)

        return {
            ...pop_params,
            vc_est,
            vp_est,
            a_est,
            b_est,
            alpha,
            beta
        };
    } else {
        return pop_params
    }
}
