import React from 'react';
import Chart from 'chart.js/auto';
import { Scatter } from "react-chartjs-2";
import regression from 'regression';
import 'bootstrap/dist/js/bootstrap.bundle.min';
import 'bootstrap/dist/css/bootstrap.min.css';
import { datasets } from './datasets.js';
import './CostFunctionBuilder.css';
import ManualCostFunctionEntry from './ManualCostFunctionEntry';

const TableFormatter = (props) => {
    const { value } = props;
    var result;
    if (props.rendertype == 'percent') {
        result = value + '%'
    } else if (props.rendertype == 'dollars') {
        result = '$' + value;
    } else {
        result = value;
    }
    return (
        <React.Fragment>
            <span>{result}</span>
        </React.Fragment>
    );
};

function chartFormatter(type) {
    if (type == 'percent') {
        return function(value, index, values) { return '' + (value * 1) + '%'; };
    } else if (type == 'dollars') {
        return function(value, index, values) { return '$' + (value * 1); };
    } else { // if (type == 'number') {
        return function(value, index, values) { return '' + (value * 1); };
    }
}

export default class CostFunctionBuilder extends React.Component {

    constructor(props) {
        super(props);

        if (!props.layerId) {
            throw new Error("This component requires a layerId");
        }

        var layerIndex = -1;
        for (var i=0; i<datasets.length; i++) {
            if (datasets[i].id == props.layerId) {
                layerIndex = i;
            }
        }
        if (layerIndex == -1) {
            throw new Error("No layer " + props.layerId);
        }

        this.state = {
            layerId: props.layerId,
            layerIndex: layerIndex,
        }
        this.state = {
            ...this.state,
            ...this.regressionUpdate(layerIndex)
        }
    }

    chartOptions() {
        return {
            maintainAspectRatio: false,
            scales: {
                x: {
                    type: 'linear',
                    position: 'bottom',
                    ticks: {
                        callback: chartFormatter(datasets[this.state.layerIndex].formatter)
                    }
                },
                y: {
                    ticks: {
                        callback: function(value, index, values) {
                            return '$' + value;
                        }
                    }
                }
            }
        };
    }
    
    render() {
        return (
            <div class="cfb-manual">
                <div class="cfb-entry">
                    <ManualCostFunctionEntry questions={datasets[this.state.layerIndex].questions}
                                             onChange={() => {this.afterChange()}}
                    />
                </div>
                <div class="cfb-chart">
                    <Scatter width={'100%'}
                             height={'100%'}
                             options={this.chartOptions()} data={ {
                                 datasets: [ {
                                     label: 'Input Points',
                                     data: datasets[this.state.layerIndex].questions,
                                 }, {
                                     label: 'Cost Function',
                                     data: this.state.fn,
                                     type: 'line',
                                     backgroundColor: 'red',
                                     borderColor: 'red',
                                     fill: 'none',
                                     pointRadius: 0,
                                     //cubicInterpolationMode: 'monotone',
                                     lineTension: 0,
                                 },
                                           ] } } />
                </div>
                <div class="cfb-map" style={{'background-image': 'url("https://tiled-a.places.jackohare.com/tile/' + encodeURIComponent(this.state.equation) + '")'}}>
                </div>
                <div class="cfb-equation">{this.state.equation + " ; R²=" + this.state.r2}</div>
            </div>
        );
    }

    getFn(points) {
        var newFn = [];
        for (var i=0; i<points.length; i++) {
            newFn.push({x: points[i][0], y: points[i][1]});
        }
        return newFn;
    }

    bestFitRegression(vals, forcePositive) {
        var regressionOptions = {order: 2, precision: 5};
        var linearResult = regression.linear(vals, regressionOptions);
        var polynomialResult = regression.polynomial(vals, regressionOptions);
        var exponentialResult = regression.exponential(vals, regressionOptions);
        var powerResult = regression.power(vals, regressionOptions);
        var logResult = regression.logarithmic(vals, regressionOptions);

        var results = [linearResult, polynomialResult, exponentialResult, powerResult, logResult];
        
        if (forcePositive) {
            for (var i=0; i<results.length; i++) {
                for (var j=0; j<results[i].points.length; j++) {
                    // Special case hack to avoid negative cost functions
                    if (results[i].points[j][1] < 0) {
                        results[i].r2 = 0;
                    }
                }
            }
        }
        
        var result = linearResult;
        result.type = 'linear';
        if (result.equation[0] == 0) {
            result.subtype = 'flat';
            result.string = '' + result.equation[1];
        } else {
            result.subtype = 'notflat';
            if (result.equation[1] == 0) {
                if (result.equation[0] == 1) {
                    result.string = '𝕩';
                } else {
                    result.string = result.equation[0] + '*𝕩';
                }
            } else {
                if (result.equation[0] == 1) {
                    result.string = result.equation[1] + '+' + '𝕩';
                } else {
                    result.string = result.equation[1] + '+' + result.equation[0] + '*𝕩';
                }
            }
        }
        if (exponentialResult.r2 > result.r2) {
            result = exponentialResult;
            result.type = 'exponential';
            result.string = result.equation[0] + '*exp(' + result.equation[1] + '*𝕩)';
        }
        if (powerResult.r2 > result.r2) {
            result = powerResult;
            result.type = 'power';
            result.string = result.equation[0] + '*𝕩^' + result.equation[1];
        }
        if (logResult.r2 > result.r2) {
            result = logResult;
            result.type = 'log';
            result.string = result.equation[0] + '+' + result.equation[1] + '*ln(𝕩)';
        }
        if (polynomialResult.r2 > result.r2) {
            result = polynomialResult;
            result.type = 'polynomial';
            result.string = '';
            for (var i=0; i<result.equation.length; i++) {
                if (i != 0) {
                    result.string += '+';
                }
                result.string += result.equation[i];
                if (i+2 == result.equation.length) {
                    result.string += '*𝕩';
                } else if (i+1 == result.equation.length) {
                    // Nothing for last one - it's just a constant
                } else {
                    result.string += '*𝕩^' + (result.equation.length-i-1);
                }
            }
        }
        
        return result;
    }

    penalty(result) {
        if (result == null) {
            return 0;
        }
        if (result.type=='polynomial') {
            return 6;
        } else if (result.type=='exponential') {
            return 5;
        } else if (result.type=='log') {
            return 4;
        } else if (result.type=='power') {
            return 3;
        } else if (result.subtype=='notflat') {
            return 2;
        } else {
            return 1;
        }
    }

    functionString(str, variableName) {
        return str.replace(/𝕩/g, variableName);
    }

    bestFitMultipartRegression(vals, variableName, forcePositive) {
        var bestR2 = 0;
        var bestOverfitPenalty = 999999;
        var bestResult = null;
        vals = [...vals]; // clone array
        //vals = expandInputPoints(vals, 8);
        vals.sort(function(a,b) {return a[0]-b[0]});
        
        if (forcePositive) {
            for (var i=0; i<vals.length; i++) {
                if (vals[i][1] < 0) {
                    vals[i][1] = 0;
                }
            }
        }
        
        for (var i=0; i<=vals.length; i++) {
            for (var j=i; j<=vals.length; j++) {
                var vals1 = [];
                var vals2 = [];
                var vals3 = [];
                var pivot1 = i==vals.length?Infinity:vals[i][0];
                var pivot2 = j==vals.length?Infinity:vals[j][0];
                
                for (var k=0; k<vals.length; k++) {
                    var val = vals[k][0];
                    if (val == pivot1 && val == pivot2) {
                        vals1.push(vals[k]);
                        vals2.push(vals[k]);
                        vals3.push(vals[k]);
                    } else if (val == pivot1) {
                        vals1.push(vals[k]);
                        vals2.push(vals[k]);
                    } else if (val == pivot2) {
                        vals2.push(vals[k]);
                        vals3.push(vals[k]);
                    } else if (val < pivot1) {
                        vals1.push(vals[k]);
                    } else if (val > pivot1 && val < pivot2) {
                        vals2.push(vals[k]);
                    } else if (val > pivot2) {
                        vals3.push(vals[k]);
                    }
                }
                
                if (pivot1 == pivot2) {
                    vals2 = [];
                }
                
                var result1 = vals1.length > 0 ? this.bestFitRegression(vals1, forcePositive) : null;
                var result2 = (vals2.length > 0 && pivot1 != pivot2) ? this.bestFitRegression(vals2, forcePositive) : null;
                var result3 = vals3.length > 0 ? this.bestFitRegression(vals3, forcePositive) : null;
                
                var ssres = this.SSres(vals1, result1, forcePositive) + this.SSres(vals2, result2, forcePositive) + this.SSres(vals3, result3, forcePositive);
                var sstot = this.SStot(vals1, result1, forcePositive) + this.SStot(vals2, result2, forcePositive) + this.SStot(vals3, result3, forcePositive);
                var r2 = 1.0 - (ssres / sstot);
                if (sstot == 0 && ssres == 0) {
                    r2 = 1.0;
                }
                
                var overfitPenalty = this.penalty(result1) + this.penalty(result2) + this.penalty(result3);
                
                if (bestResult == null || r2 > bestR2 || (r2 == bestR2 && overfitPenalty < bestOverfitPenalty)) {
                    bestR2 = r2;
                    bestOverfitPenalty = overfitPenalty;
                    bestResult = {
                        parts: [],
                        r2: r2,
                    };
                    if (vals2.length == 0 && vals3.length == 0) {
                        bestResult.parts.push({
                            result: result1
                        });
                        bestResult.points = result1.points;
                        bestResult.string = this.functionString(result1.string, variableName);
                    } else if (vals2.length == 0) {
                        bestResult.parts.push({
                            condition: variableName + ' <= ' + pivot1,
                            result: result1
                        });
                        bestResult.parts.push({
                            condition: variableName + ' > ' + pivot1,
                            result: result3
                        });
                        bestResult.points = result1.points.concat([pivot1[0]+0.00000001, NaN]).concat(result3.points); // TODO: +0.00000001 isn't that robust...
                        bestResult.string = 'ifelse( ' + bestResult.parts[0].condition + ' , ' + this.functionString(result1.string, variableName) + ' , ' + this.functionString(result3.string, variableName) + ' )';
                        bestResult.pivot = pivot1;
                    } else {
                        bestResult.parts.push({
                            condition: variableName +' <= ' + pivot1,
                            result: result1
                        });
                        bestResult.parts.push({
                            condition: 'and(' + variableName +' > ' + pivot1 + ', ' + variableName + ' <= ' + pivot2 + ')',
                            result: result2
                        });
                        bestResult.parts.push({
                            condition: variableName +' > ' + pivot2,
                            result: result3
                        });
                        bestResult.points = result1.points.concat([pivot1[0]+0.00000001, NaN]).concat(result2.points).concat([pivot2[0]+0.00000001, NaN]).concat(result3.points);
                        bestResult.string = 'ifelse(' + bestResult.parts[0].condition + ', ' + this.functionString(result1.string, variableName) + ', ifelse(' + bestResult.parts[1].condition + ', ' + this.functionString(result2.string, variableName) + ', ' + this.functionString(result3.string, variableName) + '))';
                        bestResult.pivot1 = pivot1;
                        bestResult.pivot2 = pivot2;
                    }
                }
            }
        }
        if (forcePositive) {
            var needMax = false;
            bestResult.hasMax = false;
            for (var i=0; i<bestResult.points.length; i++) {
                if (bestResult.points[i][1] < 0) {
                    needMax = true;
                    bestResult.points[i][1] = 0;
                }
            }
            if (needMax) {
                bestResult.string = "max(0, " + bestResult.string + ")";
                bestResult.hasMax = true;
            }
        }
        return bestResult;
    }
    

    expandInputPoints(vals, n) {
        var output = [];
        for (var j=0; j<vals.length*n-(n-1); j++) {
            if (j%n==0) {
                output.push(vals[j/n]);
            } else {
                var x = vals[Math.floor(j/n)][0] + (j%n) * ((vals[Math.floor(j/n)+1][0] - vals[Math.floor(j/n)][0]) / n);
                var y = vals[Math.floor(j/n)][1] + (j%n) * ((vals[Math.floor(j/n)+1][1] - vals[Math.floor(j/n)][1]) / n);
                output.push([x, y]);
            }
        }
        return output;
    }
    
    expandFnPoints(result, n, forcePositive) {
        var output = [];
        for (var i=0; i<result.parts.length; i++) {
            var part = result.parts[i];
            var points = part.result.points;
            for (var j=0; j<points.length*n-(n-1); j++) {
                if (j%n==0) {
                    output.push(points[j/n]);
                } else {
                    var x = points[Math.floor(j/n)][0] + (j%n) * ((points[Math.floor(j/n)+1][0] - points[Math.floor(j/n)][0]) / n);
                    var y = part.result.predict(x)[1];
                    if (forcePositive && y < 0) {
                        y = 0;
                    }
                    output.push([x, y]);
                }
            }
        }
        return output;
    }


    regressionUpdate(layerIndex) {
        var vals = new Array(datasets[layerIndex].questions.length);
        for (var i=0; i<datasets[layerIndex].questions.length; i++) {
            vals[i] = [datasets[layerIndex].questions[i]['x'], datasets[layerIndex].questions[i]['y']];
        }
        
        var forcePositive = false;
        var result = this.bestFitMultipartRegression(vals, this.state.layerId, forcePositive);
        var fnPoints = this.expandFnPoints(result, 8, forcePositive);

        var newFn = this.getFn(fnPoints);
        return {equation: result.string, r2: result.r2, fn: newFn};
    }

    
    SSres(inputData, regressionResult, forcePositive) {
        if (inputData == null) {
            return 0;
        }
        var sum = 0;
        for (var i=0; i<inputData.length; i++) {
            var x = inputData[i][0];
            var y = inputData[i][1];
            if (forcePositive && y < 0) {
                y = 0;
            }
            var prediction = regressionResult.predict(x)[1];
            sum += Math.pow(y - prediction, 2);
            
            // Special case hack to avoid negative cost functions
            //if (prediction < 0) {
            //    return Infinity;
            //}
        }
        return sum;
    }

    
    SStot(inputData, regressionResult, forcePositive) {
        if (inputData == null) {
            return 0;
        }
        var sum = 0;
        for (var i=0; i<inputData.length; i++) {
            var y = inputData[i][1];
            if (forcePositive && y < 0) {
                y = 0;
            }
            sum += y;
        }
        var mean = 1.0*sum / inputData.length;
        sum = 0;
        for (var i=0; i<inputData.length; i++) {
            var y = inputData[i][1];
            if (forcePositive && y < 0) {
                y = 0;
            }
            sum += Math.pow(y - mean, 2);
        }
        
        return sum;
    }

    afterChange() {
        var newState = this.regressionUpdate(this.state.layerIndex);
        this.setState(newState);
        if (this.props.onEquationChange) {
            this.props.onEquationChange(newState.equation);
        }
    }
    
}
