import * as reg from "d3-regression";
import * as Plot from "@observablehq/plot";
import * as d3 from "d3";
import type { ChannelTransform, ChannelValue, Data, Transformed } from "@observablehq/plot";
import * as stats from "../../jstat";

type RegressionType =
    | 'lin'
    | 'quad'
    | 'poly'
    | 'pow'
    | 'exp'
    | 'log'
    | 'loess';

interface IInput {
    type?: RegressionType;
    [x: string]: any;
}

export const plotRegression = function <T>({ x, y, type, bandwidth, order, ...options }: T & IInput): Transformed<T> {
    const regressor =
        type === "quad"
            ? reg.regressionQuad()
            : type === "poly"
                ? reg.regressionPoly()
                : type === "pow"
                    ? reg.regressionPow()
                    : type === "exp"
                        ? reg.regressionExp()
                        : type === "log"
                            ? reg.regressionLog()
                            : type === "loess"
                                ? reg.regressionLoess()
                                : type === "lin" ?
                                    reg.regressionLinear()
                                    : reg.regressionLinear();
    if (bandwidth && regressor.bandwidth) regressor.bandwidth(bandwidth);
    if (order && regressor.order) regressor.order(order);

    const z = options.z || options.stroke; // maybeZ
    return Plot.transform(options, function (data: any[], facets: any[]) {
        const X = Plot.valueof(data, x)!;
        const Y = Plot.valueof(data, y)!;
        const Z = Plot.valueof(data, z)!;
        regressor.x((i: number) => X[i]).y((i: number) => Y[i]);

        const regFacets = [];
        const points = [];
        for (const facet of facets) {
            const regFacet = [];
            for (const I of Z ? d3.group(facet, ((i: number) => Z[i]) as any).values() : [facet]) {
                const reg = regressor(I);
                for (const d of reg) {
                    const j = points.push(d) - 1;
                    if (z) d[z] = Z[I[0]];
                    regFacet.push(j);
                }
            }
            regFacets.push(regFacet);
        }

        return { data: points, facets: regFacets };
    }) as Transformed<T>;
};

export function getChannelName<T>(options: T, channelKey: keyof T & string): string {
    const channel = options[channelKey] as unknown as ChannelValue;
    if (typeof channel == 'string') {
        return channel;
    } else if (typeof channel == 'function') {
        return channelKey;
    } else if (typeof channel == "object" && (channel as ChannelTransform)?.label) {
        return (channel as ChannelTransform)?.label ?? channelKey;
    } else {
        return channelKey;
    }
}

// Plot: linearRegressions.js
export function linearRegressionBand(I: number[], X: number[], Y: number[]) {
    const { ci, precision } = { ci: 0.95, precision: 4 };
    const [x1, x2] = d3.extent(I, (i) => X[i]);
    const f = linearRegressionF(I, X, Y);
    const g = confidenceIntervalF(I, X, Y, (1 - ci) / 2, f);

    const data = d3.range(x1!, x2! - precision / 2, precision).concat(x2!);
    return data.map(x => ({ x, y: f(x), y1: g(x, -1), y2: g(x, +1) }));
}

// Plot: linearRegressions.js
export function linearRegressionF(I: number[], X: number[], Y: number[]) {
    let sumX = 0,
        sumY = 0,
        sumXY = 0,
        sumX2 = 0;
    for (const i of I) {
        const xi = X[i];
        const yi = Y[i];
        sumX += xi;
        sumY += yi;
        sumXY += xi * yi;
        sumX2 += xi * xi;
    }
    const n = I.length;
    const slope = (n * sumXY - sumX * sumY) / (n * sumX2 - sumX * sumX);
    const intercept = (sumY - slope * sumX) / n;
    return (x: number) => slope * x + intercept;
}

// Plot: linearRegressions.js
export function confidenceIntervalF(I: number[], X: number[], Y: number[], p: number, f: (x: number) => number) {
    const mean = d3.sum(I, (i: number) => X[i]) / I.length;
    let a = 0,
        b = 0;
    for (const i of I) {
        a += (X[i] - mean) ** 2;
        b += (Y[i] - f(X[i])) ** 2;
    }
    const sy = Math.sqrt(b / (I.length - 2));
    const t = stats.qt(p, I.length - 2);
    return (x: number, k: number) => {
        const Y = f(x);
        const se = sy * Math.sqrt(1 / I.length + (x - mean) ** 2 / a);
        return Y + k * t * se;
    };
}
