import React from 'react';
import * as d3 from "d3";

import './Models.scss';

class NeuralNet extends React.Component {

    constructor(props) {
        super(props)
        var [nodes, features, outputs, nodeSize, netsize, featuresListWidth, outListWidth] = calculations(this.props.structure, this.props.featureNames, this.props.outputNames, this.props.width, this.props.height, true)
        
        this.nodes = nodes
        this.features = features
        this.outputs = outputs
        this.nodeSize = nodeSize
        this.calcProps = {
            netsize: netsize,
            featuresListWidth: featuresListWidth,
            outListWidth: outListWidth
        }

        this.svg_name = "#nnSVG"
        this.state = {
            drawn: false
        }
    }

    componentDidMount() {
        if (this.props.mlpLinks != null) {
            this.updateNodes()
            this.updateOutputs()
            this.drawNN()            
        }
    }

    componentDidUpdate(prevProps, prevState) {
        if ((this.props.width !== prevProps.width)) {
            this.updateNodes()
            this.updateOutputs()
        }
        if ((this.props.mlpLinks != null) && (!this.state.drawn)) {
            this.drawNN()
        }
        else if ((this.props.mlpLinks == null) && (prevProps.mlpLinks != null)) {
            d3.select(this.svg_name).selectAll("*").remove()
            this.setState({
                drawn: false,
            })
        }
        else if ((prevProps.mlpLinks !== this.props.mlpLinks) && this.state.drawn) {
            d3.select(this.svg_name)
                .selectAll(".link")
                .data(this.props.mlpLinks)
                .style("stroke-width", d => Math.abs(d.value) * 5)
                .attr("class", d => (d.value >= 0) ? 'link positive' : 'link negative')
                .select("title")
                .text(d => "Weight is  " + d.value)
        }
        else if ((prevProps.currentHighlight !== this.props.currentHighlight) && this.state.drawn) {
            var currentHighlight = this.props.currentHighlight

            d3.select(this.svg_name).selectAll(".node")
                .attr("class", d => (d.layer === currentHighlight) ? "node neuron highlight" : "node neuron")
        }
    }

    updateNodes() {
        var xdist = (this.props.width - this.calcProps.featuresListWidth - this.calcProps.outListWidth*2) / Object.keys(this.calcProps.netsize).length
        var featuresListWidth = this.calcProps.featuresListWidth
        this.nodes.map(function (d) {
            d["x"] = (d.layer - 0.5) * xdist + featuresListWidth
            return null
        });
    }

    updateOutputs() {
        var xdist = (this.props.width - this.calcProps.featuresListWidth - this.calcProps.outListWidth*2) / Object.keys(this.calcProps.netsize).length
        var featuresListWidth = this.calcProps.featuresListWidth
        var outListWidth = this.calcProps.outListWidth
        this.outputs.map(function (d) {
            d["x"] = featuresListWidth + (d.layer) * xdist
            return null
        });
    }
    

    handleHover(d){
        if (d.layer === this.props.currentHighlight && this.props.onHover != null){
            this.props.onHover(d.lidx - 1)
        }
    }

    handleMouseOut(d){
        if (this.props.onHover != null){
            this.props.onHover(null)
        }
    }

    drawNN() {
        this.setState({
            drawn: true,
        })
        const svg = d3.select(this.svg_name)
        var viewBox = "0 0 " + this.props.width.toString() + " " + this.props.height.toString()

        const update = svg.attr("preserveAspectRatio", "xMinYMin meet")
            .attr("viewBox", viewBox)
            .classed("svg-content", true);

        var verticalLink = d3.linkVertical()
            .x(d => d.x)
            .y(d => d.y)

        var nodes = this.nodes

        update.append("g")
            .selectAll(".link")
            .data(this.props.mlpLinks)
            .enter().append("path")
            .attr("class", "link")
            .attr("id", d => "link" + d.source.toString() + "-" + d.target.toString())
            .attr('d', function (d) {
                let datum = {
                    source: {
                        x: nodes[d.source].x,
                        y: nodes[d.source].y
                    },
                    target: {
                        x: nodes[d.target].x,
                        y: nodes[d.target].y
                    }
                };
                return verticalLink(datum, 0)
            })
            .attr("fill", "none")
            .style("stroke-width", d => Math.abs(d.value) * 5)
            .attr("class", d => (d.value >= 0) ? 'link positive' : 'link negative')
            .append("svg:title")
            .text(d => "Weight is " + d.value)

        // draw nodes
        var currentHighlight = this.props.currentHighlight
        update.selectAll(".node")
            .data(this.nodes)
            .enter().append("g")
            .attr("transform", d => "translate(" + d.x + "," + d.y + ")")
            .append("circle")
            .attr("r", this.nodeSize)
            // .style("fill", d => (d.layer === currentHighlight) ? "rgb(108, 181, 217)" : color[d.layer - 1])
            .attr("class", d => (d.layer === currentHighlight) ? "node neuron highlight" : "node neuron")
            .on("mouseover", d => this.handleHover(d))
            .on("mouseout", d => this.handleMouseOut(d))

        update.selectAll(".feat")
            .data(this.features)
            .enter()
            .append("g")
            .attr("transform", d => "translate(" + d.x + "," + d.y + ")")
            .append("text")
            .attr('class', 'label')
            .attr('text-anchor', 'end')
            .text(d => d.name)

        update.selectAll(".out")
            .data(this.outputs)
            .enter()
            .append("g")
            .attr("transform", d => "translate(" + d.x + "," + d.y + ")")
            .append("text")
            .attr('class', 'label')
            .attr('text-anchor', 'start')
            .text(d => d.name)

        // Remove old D3 elements
        update.exit()
            .remove();
    }

    render() {
        return (
            <React.Fragment>
                <svg id="nnSVG" width={this.props.width} height={this.props.height}></svg>
            </React.Fragment>
        )
    }
}



function calculations(structure, featureNames, outputNames, width, height, calcProps = false) {
    var longest = featureNames.slice().sort((a, b) => b.length - a.length)[0]
    var longestOut = outputNames.slice().sort((a, b) => b.length - a.length)[0]
    var featuresListWidth = getWidthOfText(longest.toUpperCase(), 'sans-serif', '12px')
    var outListWidth = getWidthOfText(longestOut.toUpperCase(), 'sans-serif', '12px')

    // build network structure
    var netsize = {};
    var nodes = [];
    var features = [];
    var outputs = [];
    structure.forEach(function (layerNodes, layer) {
        netsize[layer + 1] = layerNodes

        for (var n = 0; n < layerNodes; n++) {
            if (layer === 0) {
                features.push({ 'layer': layer + 1, 'lidx': n + 1, 'name': featureNames[n] });
            }
            nodes.push({ 'layer': layer + 1, 'lidx': n + 1 });
            if (layer === structure.length-1) {
                outputs.push({ 'layer': layer + 1, 'lidx': n + 1, 'name': outputNames[n] })
            }
        }
    })

    // calc distances between nodes
    var largestLayerSize = Math.max.apply(null, Object.keys(netsize).map(i => netsize[i]))

    // create node size depending on viewport
    var nodeSize = (height / largestLayerSize) * 0.35

    // create node locations (vertical)
    var xdist = (width - featuresListWidth - outListWidth*2) / Object.keys(netsize).length
    var ydist = height / largestLayerSize
    nodes.map(function (d) {
        d["y"] = (((d.lidx - 0.5) + ((largestLayerSize - netsize[d.layer]) / 2)) * ydist)
        d["x"] = (d.layer - 0.5) * xdist + featuresListWidth
        return null
    });

    features.map(function (d) {
        d["y"] = (((d.lidx - 0.5) + ((largestLayerSize - netsize[d.layer]) / 2)) * ydist)
        d["x"] = featuresListWidth
        return null
    });

    outputs.map(function (d) {
        d["y"] = (((d.lidx - 0.5) + ((largestLayerSize - netsize[d.layer]) / 2)) * ydist)
        d["x"] = featuresListWidth + (d.layer) * xdist 
        return null
    });

    if (calcProps) {
        return [nodes, features, outputs, nodeSize, netsize, featuresListWidth, outListWidth]
    } else {
        return [nodes, features, outputs, nodeSize]
    }

}


function getWidthOfText(txt, fontname, fontsize) {
    if (getWidthOfText.c === undefined) {
        getWidthOfText.c = document.createElement('canvas');
        getWidthOfText.ctx = getWidthOfText.c.getContext('2d');
    }
    getWidthOfText.ctx.font = fontsize + ' ' + fontname;
    return getWidthOfText.ctx.measureText(txt).width;
}

export default NeuralNet;