import React, { useLayoutEffect, useRef, useState } from 'react'
import { timeFormat } from 'd3-time-format'
import { select, event as d3_event } from 'd3-selection'
import { scaleTime, scaleLinear } from 'd3-scale'
import { axisBottom, axisLeft } from 'd3-axis'
import { drag } from 'd3-drag'
import { line } from 'd3-shape'
import { easeLinear } from 'd3-ease'
import { csv } from 'd3-fetch'
import 'd3-svg-annotation'

import styles from './TeleworkabilityVisualization.module.scss'
import Button from 'components/common/Button'

const d3 = {
  timeFormat,
  select,
  scaleTime,
  scaleLinear,
  axisBottom,
  axisLeft,
  line,
  easeLinear,
  drag,
  csv
}

const data_root = '/datasets/teleworkability-visualization/'

const TeleworkabilityVisualization = () =>{
  const refVis = useRef(null)
  const refPlayButton = useRef(null)
  const refBase = useRef(null)
  const refFunctions = useRef({})

  const [selected, setSelected] = useState('nonmedical')
  const [playing, setPlaying] = useState(false)

  useLayoutEffect(() => {
    var formatDate = d3.timeFormat("%Y-%m-%d")
    var formatDate_long = d3.timeFormat("%a, \n%B %-d")
    var formatDate_ticks = d3.timeFormat("%B %-d")

    var startDate = new Date("2020-03-08")
    var endDate = new Date("2020-06-29")

    var margin = { top: 60, right: 50, bottom: 0, left: 50 }
    var width = 500 - margin.left - margin.right
    var height = 500 - margin.top - margin.bottom

    var vertical_pad = 60

    var accent_color = '#C15353'
    var main_color = 'darkgray'
    var dot_label_color = 'black'
    var scatterdot_r = 4

    var hide_y_axis = false
    var hide_lin_reg = false

    // Define the div for the tooltip
    var div = d3.select(refBase.current)
      .append("div")
      .attr("class", "tooltip")
      .style("opacity", 0)

    var svg = d3.select(refVis.current)
      .append("svg")
        .attr("width", width + margin.left + margin.right)
        .attr("height", height + margin.top + margin.bottom)
    svg.append("text")
      .attr("x", 60)
      .attr("y", 0 + (margin.top + 40))
      .attr("text-anchor", "left")
      .attr("font-weight", 300)
      // .text("Insured Unemployment Rate vs Foot Traffic in ")
      .text(function () { return  (!hide_y_axis ? "Insured Unemployment Rate vs Foot Traffic in " : "Foot Traffc in ") } )
      .attr("font-size", "12px")
      .append("tspan")
      .attr("font-weight", 700)
      .text("Low Telework States")
      .attr("font-size", "12px")

    var svg_HIGH = d3.select(refVis.current)
      .append("svg")
        .attr("width", width + margin.left + margin.right)
        .attr("height", height + margin.top + margin.bottom)
        .attr("transform", "translate(0," + (vertical_pad) + ")")

    svg_HIGH.append("text")
      .attr("x", 60)
      .attr("y", 0 + margin.top - 15)
      .attr("text-anchor", "left")
      .attr("font-weight", 300)
      // .text("Insured Unemployment Rate vs Foot Traffic in ")
      .text(function () { return  (!hide_y_axis ? "Insured Unemployment Rate vs Foot Traffic in " : "Foot Traffc in ") } )
      .attr("font-size", "12px")
      .append("tspan")
      .attr("font-weight", 700)
      .text("High Telework States")
      .attr("font-size", "12px")

    ////////// slider //////////
    var axisLabelColor = '#929292'
    // var moving = false
    var currentValue = 0
    var targetValue = width

    var x_slider = d3.scaleTime()
      .domain([startDate, endDate])
      .range([0, targetValue])
      .clamp(true)

    var red_rect_w = 267

    var max_x
    var min_x
    if (hide_y_axis){
      max_x = 1.5
      min_x = 0
    } else {
      max_x = 1.5
      min_x = 0
    }
    // Add X axis
    var x_dots = d3.scaleLinear()
      .domain([min_x, max_x])
      .range([ 0, width ])

    svg.append("g")
      .attr("transform", "translate(" + margin.left + "," + (height) + ")")
      .attr("class", "axisGray")
      .call(d3.axisBottom(x_dots))

    svg_HIGH.append("g")
      .attr("transform", "translate(" + margin.left + "," + (height-vertical_pad) + ")")
      .attr("class", "axisGray")
      .call(d3.axisBottom(x_dots))

    // text label for the x axis
    svg.append("text")
      .attr("transform",
            "translate(" + (width/2 + 40) + " ," +
                          (height + 40) + ")")
      .style("text-anchor", "middle")
      .style("fill", axisLabelColor)
      .text("Foot traffic index")

    svg_HIGH.append("text")
      .attr("transform",
            "translate(" + (width/2 + 40) + " ," +
                          (height - 20) + ")")
      .style("text-anchor", "middle")
      .style("fill", axisLabelColor)
      .text("Foot traffic index")

    // Add Y axis
    var y = d3.scaleLinear()
      .domain([30, 0])
      .range([ vertical_pad, height-vertical_pad])

    svg.append("g")
      .attr("transform", "translate(" + margin.left + "," + (vertical_pad) + ")")
      .attr("class", "axisGray")
      .call(d3.axisLeft(y))

    svg_HIGH.append("g")
      .attr("transform", "translate(" + margin.left + ", 0)")
      .attr("class", "axisGray")
      .call(d3.axisLeft(y))

    // text label for the y axis
    svg.append("text")
      .attr("transform", "rotate(-90)")
      .attr("y", 0)
      .attr("x", 0 - (height/1.57))
      .attr("dy", "1em")
      .style("text-anchor", "middle")
      // .text("Insured unemployment rate")
      .text("IUR")
      .style("fill", axisLabelColor)

    svg_HIGH.append("text")
      .attr("transform", "rotate(-90)")
      .attr("y", 0)
      .attr("x", 0 - (height/2))
      .attr("dy", "1em")
      .style("text-anchor", "middle")
      // .text("Insured unemployment rate")
      .text("IUR")
      .style("fill", axisLabelColor)

    // rectangle annotation
    svg.append("rect")
      .attr("x", x_dots(min_x) + 50)
      .attr("y", y(0) - 260)
      .attr("width", red_rect_w)
      .attr("height", 320)
      .attr("opacity", 0.06)
      .style("fill", "red")

    svg_HIGH.append("rect")
      .attr("x", x_dots(min_x) + 50)
      .attr("y", y(0) - 320)
      .attr("width", red_rect_w)
      .attr("height", 320)
      .attr("opacity", 0.06)
      .style("fill", "red")

    var slider = svg.append("g")
      .attr("class", "slider")
      .attr("transform", "translate(" + margin.left + "," + margin.top/1.2 + ")")

    slider.append("line")
      .attr("class", "track")
      .attr("x1", x_slider.range()[0])
      .attr("x2", x_slider.range()[1])
    .select(function() { return this.parentNode.appendChild(this.cloneNode(true)) })
      .attr("class", "track-inset")
    .select(function() { return this.parentNode.appendChild(this.cloneNode(true)) })
      .attr("class", "track-overlay")
      .call(d3.drag()
          .on("start.interrupt", function() { slider.interrupt() })
          .on("start drag", function() {
            currentValue = d3_event.x
            update(x_slider.invert(currentValue))
          })
      )

    slider.insert("g", ".track-overlay")
      .attr("class", "ticks")
      .attr("transform", "translate(0," + -18 + ")")
    .selectAll("text")
      .data(x_slider.ticks(5.5))
      .enter()
      .append("text")
      .attr("x", x_slider)
      .attr("y", 10)
      .attr("font-size", "8px")
      .style("fill", "black")
      .attr("text-anchor", "middle")
      .text(function(d) { return formatDate_ticks(d) })

    var handle = slider.insert("circle", ".track-overlay")
      .attr("class", "handle")
      .attr("r", 6)

    var label = slider.append("text")
      .attr("class", "label")
      .attr("text-anchor", "middle")
      //.text(formatDate(startDate))
      .text(formatDate_long(startDate))
      .style("fill", "black")
      .attr("transform", "translate("+15+", "+ (-25) + ")")

    // gridlines in x axis function
    function make_x_gridlines () {
      return d3.axisBottom(x_dots)
        .ticks(5)
    }

    // gridlines in y axis function
    function make_y_gridlines () {
      return d3.axisLeft(y)
        .ticks(5)
    }

    // add the X gridlines
    svg.append("g")
      .attr("class", "grid")
      .attr("transform", "translate(" + (margin.left) + "," + (height) + ")")
      .call(make_x_gridlines()
        .tickSize(-height+vertical_pad*2)
        .tickFormat("")
      )
    svg_HIGH.append("g")
      .attr("class", "grid")
      .attr("transform", "translate(" + margin.left + "," + (height-vertical_pad) + ")")
      .call(make_x_gridlines()
        .tickSize(-height+vertical_pad*2)
        .tickFormat("")
      )

    // add the Y gridlines
    svg.append("g")
      .attr("class", "grid")
      .attr("transform", "translate(" + (margin.left) + "," + (60) + ")")
      .call(make_y_gridlines()
        .tickSize(-width)
        .tickFormat("")
      )
    svg_HIGH.append("g")
      .attr("class", "grid")
      .attr("transform", "translate(" + margin.left + ", 0)")
      .call(make_y_gridlines()
        .tickSize(-width)
        .tickFormat("")
      )

    ////////// plot //////////

    var dataset
    // var stay_at_home = []
    // var state_opacity = 1

    var ploty = 0

    var plot = svg.append("g")
      .attr("class", "plot")
      .attr("transform", "translate(" + margin.left + "," + vertical_pad + ")")

    var plot_HIGH = svg_HIGH.append("g")
      .attr("class", "plot")
      .attr("id", "plot_HIGH")
      .attr("transform", "translate(" + margin.left + "," + (ploty) + ")")
    // } // makeChart end

    function lin_regression(dataset){
      // regression line
      let n = dataset.length
      let x_mean = 0
      let x_total = 0
      let term1 = 0
      let y_mean = 0
      let y_total = 0
      let term2 = 0

      let x_reg = dataset.map(x=>x.FT_change)
      let y_reg = dataset.map(x=>x.IU_rate)

      // calculate mean x and y
      for(var i = 0; i < n; i++) {
        x_total = x_total + x_reg[i]
        y_total += y_reg[i]
      }

      x_mean = x_total / n
      y_mean = y_total / n

      // calculate coefficients
      let xr = 0
      let yr = 0
      for (i = 0; i < x_reg.length; i++) {
        xr = x_reg[i] - x_mean
        yr = y_reg[i] - y_mean
        term1 += xr * yr
        term2 += xr * xr
      }
      let b1 = term1 / term2
      let b0 = y_mean - (b1 * x_mean)
      // perform regression
      let yhat = []
      // fit line using coeffs
      for (i = 0; i < x_reg.length; i++) {
        yhat.push(b0 + (x_reg[i] * b1))
      }

      var data_reg = []
      for (i = 0; i < y_reg.length; i++) {
        data_reg.push({
            "yhat": yhat[i],
            "y": y_reg[i],
            "x": x_reg[i]
        })
      }
      return data_reg
    }

    function filter_data(_date, data){
      // Returns data for a certain date
      // var dates = [...new Set(data.map(x=>x.date))]
      // var date_ind = dates.indexOf(_date) + 1
      var data_filt = data.filter(x => x.date == _date) // eslint-disable-line eqeqeq
      var data_obj = data_filt.map( function (x) {
        return  {
          'telework': +x.mean_teleworkable,
          'TW_bin' : x.TW_bin,
          'FT_change': +x.rel_visits_dowg,
          'IU_rate' : +x.insured_unemployment_rate_DOL,
          'state': x.stabr,
          'date' : x.date,
          'key' : x.stabr
        }
      })
      return data_obj
    }

    function key (d) { return d.key }

    var regressionLine = d3.line()
      .x(function (d) {
        return x_dots(d.x)
      })
      .y(function (d) {
        return y(d.yhat)
      })

    var timer
    var isPlaying = false
    function new_csv (data) {
      dataset = data
      formatDate = d3.timeFormat("%Y-%m-%d")
      var cur_date = formatDate(x_slider.invert(currentValue))
      var data_filt = filter_data(cur_date, dataset)

      drawPlot(data_filt)

      var data_reg = lin_regression(data_filt.filter(x => x.TW_bin === 'low'))
      var data_reg_HIGH = lin_regression(data_filt.filter(x => x.TW_bin === 'high'))
      if (!hide_lin_reg){
      svg.append("path")
        .datum(data_reg)
        .attr("class", "line")
        .attr("transform", "translate(" + (margin.left) + "," + vertical_pad + ")")
        .attr("d", regressionLine)

      svg_HIGH.append("path")
        .datum(data_reg_HIGH)
        .attr("class", "line")
        .attr("transform", "translate(" + (margin.left) + ", 0)")
        .attr("d", regressionLine)
      }

      function play () {
        clearInterval(timer)
        if (isPlaying) {
          isPlaying = false
          setPlaying(false)
        } else {
          timer = setInterval(step, 200)
          isPlaying = true
          setPlaying(true)
        }
      }

      refFunctions.current.play = play
    }

    refFunctions.current.new_csv = new_csv

    function step () {
      update(x_slider.invert(currentValue))

      currentValue = parseInt(currentValue + Math.floor(targetValue/50))
      if (currentValue > targetValue) {
        currentValue = 0
        clearInterval(timer)
        isPlaying = false
        setPlaying(false)
      }
    }

    function drawPlot (dataset) {
      var dots = plot.selectAll(".scatterdot")
        .data(dataset.filter(x => x.TW_bin === 'low'), key)

      var dots_HIGH = plot_HIGH.selectAll(".scatterdot")
        .data(dataset.filter(x => x.TW_bin === 'high'), key)

      svg.selectAll(".dodo").exit().remove()
      svg_HIGH.selectAll(".dodo").exit().remove()

      svg.selectAll(".dodo")
        .data(dataset.filter(x => x.TW_bin === 'low'), key)
      .enter().append("text")
        .attr("class", "dodo")
          .attr("x", function (d, i) { return x_dots(d.FT_change)+45 } )
          .attr("y", function (d) { return y(d.IU_rate) } )
          .style("font-size",  "8px")
          .text(function(d) { return d.state })
          .attr("transform", "translate(0," + vertical_pad + ")")

      svg_HIGH.selectAll(".dodo")
        .data(dataset.filter(x => x.TW_bin === 'high'), key)
      .enter().append("text")
        .attr("class", "dodo")
          .attr("x", function (d, i) { return x_dots(d.FT_change)+45 } )
          .attr("y", function (d) { return y(d.IU_rate) } )
          .style("font-size",  "8px")
          .text(function(d) { return d.state })

      // if filtered dataset has more circles than already existing, transition new ones in
      dots.enter()
        .append("circle")
          .attr("class", function (d) { return ("scatterdot "+d.state) })
          .attr("r", scatterdot_r)
          .style('fill', function (d, i) { return d.FT_change<1 ? accent_color : main_color} )
          .style("opacity", 0.8)
          .attr("cx", function (d, i) { return x_dots(d.FT_change) } )
          .attr("cy", function (d, i) { return hide_y_axis ? y(0) : y(d.IU_rate) } )
        .on("mouseover", function(d) {
          d3.select(this)
            .attr("r", scatterdot_r*3)
          plot.selectAll(".scatterdot")
            .style("opacity", .2)
            .attr("opacity", .2)
          plot.selectAll(".scatterdot."+d.key)
            .style("opacity", 1)
            .attr("opacity", 1)
          svg.selectAll(".dodo")
            .style("opacity", .2)
            .attr("opacity", .2)
          svg.selectAll(".dodo."+d.key)
            .style("opacity", 1)
            .attr("opacity", 1)
          div.transition()
            .duration(50)
            .style("opacity", .95)
          div.html(d.state)  // div .html
            .style("left", (d3_event.pageX) + "px")
            .style("top", (d3_event.pageY - 15) + "px")
        })
        .on("mouseout", function(d) {
          d3.select(this).transition(300)
            .attr("r", scatterdot_r)
          plot.selectAll(".scatterdot")
            .style("opacity", 1)
            .attr("opacity", 1)
          svg.selectAll(".dodo")
            .style("opacity", 1)
            .attr("opacity", 1)
          div.transition()
            .duration(800)
            .style("opacity", 0)
        })

      //update all circles to new positions
      dots.transition()
        .duration(200)
        .ease(d3.easeLinear)
        .attr("cy", function (d) { return y(d.IU_rate) } )
        .attr("cx", function (d) { return x_dots(d.FT_change) } )
        .style('fill', function (d, i) { return d.FT_change<1 ? accent_color : main_color} )

      svg.selectAll(".dodo").transition()
        .duration(200)
        .ease(d3.easeLinear)
        .attr("x", function (d) { return x_dots(d.FT_change)+45 } )
        .attr("y", function (d) { return y(d.IU_rate)-7 } )
        .style('fill', function (d, i) { return d.FT_change<1 ? accent_color : dot_label_color} )

      // HIGH tw
      dots_HIGH.enter()
        .append("circle")
          .attr("class", function (d) { return ("scatterdot "+d.state) })
          .attr("r", scatterdot_r)
          .style('fill', function (d, i) { return d.FT_change<1 ? accent_color : main_color} )
          .style("opacity", 0.8)
          .attr("cx", function (d, i) { return x_dots(d.FT_change) } )
          .attr("cy", function (d, i) { return hide_y_axis ? y(0) : y(d.IU_rate) } )
        .on("mouseover", function(d) {
          d3.select(this)
            .attr("r", scatterdot_r*3)
          plot_HIGH.selectAll(".scatterdot")
            .style("opacity", .2)
            .attr("opacity", .2)
          plot_HIGH.selectAll(".scatterdot." + d.key)
            .style("opacity", 1)
            .attr("opacity", 1)
          svg_HIGH.selectAll(".dodo")
            .style("opacity", .2)
            .attr("opacity", .2)
          svg_HIGH.selectAll(".dodo." + d.key)
            .style("opacity", 1)
            .attr("opacity", 1)
          div.transition()
            .duration(50)
            .style("opacity", .95)
          div.html(d.state) // div .html
            .style("left", (d3_event.pageX) + "px")
            .style("top", (d3_event.pageY - 15) + "px")
        })
        .on("mouseout", function(d) {
          d3.select(this).transition(300)
            .attr("r", scatterdot_r)
          plot_HIGH.selectAll(".scatterdot")
            .style("opacity", 1)
            .attr("opacity", 1)
          svg_HIGH.selectAll(".dodo")
            .style("opacity", 1)
            .attr("opacity", 1)
          div.transition()
            .duration(800)
            .style("opacity", 0)
        })

      //update all circles to new positions
      dots_HIGH.transition()
        .duration(200)
        .ease(d3.easeLinear)
        .attr("cx", function (d) { return x_dots(d.FT_change) } )
        .attr("cy", function (d) { return y(d.IU_rate) } )
        .style('fill', function (d, i) { return d.FT_change<1 ? accent_color : main_color} )

      svg_HIGH.selectAll(".dodo").transition()
        .duration(200)
        .ease(d3.easeLinear)
        .attr("x", function (d) { return x_dots(d.FT_change)+45 } )
        .attr("y", function (d) { return y(d.IU_rate)-7 } )
        .style('fill', function (d, i) { return d.FT_change<1 ? accent_color : dot_label_color} )

      var data_reg = lin_regression(dataset.filter(x => x.TW_bin === 'low'))
      var data_reg_HIGH = lin_regression(dataset.filter(x => x.TW_bin === 'high'))

      svg.selectAll('.line')
        .datum(data_reg)
        .transition()
        .duration(200)
        .ease(d3.easeLinear)
        .attr("d", regressionLine)
        .attr("transform", "translate(" + (margin.left) + "," + vertical_pad + ")")

      svg_HIGH.selectAll('.line')
        .datum(data_reg_HIGH)
        .transition()
        .duration(200)
        .ease(d3.easeLinear)
        .attr("d", regressionLine)

        dots.exit().remove()
        dots_HIGH.exit().remove()
        svg.selectAll(".dodo").exit().remove()
        svg_HIGH.selectAll(".dodo").exit().remove()

      if (hide_y_axis) {
        dots_HIGH
        .transition()
        .duration(200)
        .ease(d3.easeLinear)
          .attr("cy", function (d, i) { return y(0)} )
        dots
        .transition()
        .duration(200)
        .ease(d3.easeLinear)
          .attr("cy", function (d, i) { return y(0)} )
        svg.selectAll(".dodo")
        .transition()
          .attr("y", function (d) { return y(0)-7 } )
        svg_HIGH.selectAll(".dodo")
        .transition()
          .attr("y", function (d) { return y(0)-7 } )
      }
    }

    function update (h) {
      // update position and text of label according to slider scale
      handle.attr("cx", x_slider(h))
      label
        .attr("x", x_slider(h))
        //.text(formatDate(h))
        .text(formatDate_long(h))
      var d = formatDate(h).toString()

      // filter data set and redraw plot
      var new_data = filter_data(d, dataset)
      if (new_data.length > 0){
        drawPlot(new_data)
      }
    }

    // Start with entertainment
    d3.csv(data_root+"foot_traffic_NONMEDICAL_insured_unemployment_rate_TW_binned.csv").then(function (data) {
      new_csv(data)
    })
  }, [])

  async function entertainment_csv () {
    const data = await d3.csv(data_root + 'foot_traffic_ENTERTAINMENT_insured_unemployment_rate_TW_binned.csv')
    refFunctions.current.new_csv(data)
  }
  async function retail_csv () {
    const data = await d3.csv(data_root + 'foot_traffic_RETAIL_insured_unemployment_rate_TW_binned.csv')
    refFunctions.current.new_csv(data)
  }
  async function restaurant_csv () {
    const data = await d3.csv(data_root + 'foot_traffic_RESTAURANT_insured_unemployment_rate_TW_binned.csv')
    refFunctions.current.new_csv(data)
  }
  async function nonmedical_csv () {
    const data = await d3.csv(data_root + 'foot_traffic_NONMEDICAL_insured_unemployment_rate_TW_binned.csv')
    refFunctions.current.new_csv(data)
  }

  return <div className={styles.base} ref={refBase}>
    <h1>Explore how foot traffic impacts unemployment across the course of the pandemic by industry</h1>
    <h3>Sources: Department of Labor, SafeGraph, U.S. Bureau of Labor Statistics</h3>
    <div>
      <p>Filter foot traffic by sector</p>
    </div>
    <div id="mobility-buttons">
      <Button id="nonmedical-button"
        buttonStyle={selected=== 'nonmedical' ? 'filled' : 'bordered'}
        onClick={() => { nonmedical_csv(); setSelected('nonmedical') }}>General non-medical foot traffic</Button>
      <Button id="entertainment-button"
        buttonStyle={selected=== 'entertainment' ? 'filled' : 'bordered'}
        onClick={() => { entertainment_csv(); setSelected('entertainment') }}>Foot traffic to Entertainment</Button>
      <Button id="retail-button"
        buttonStyle={selected=== 'retail' ? 'filled' : 'bordered'}
        onClick={() => { retail_csv(); setSelected('retail') }}>Foot traffic to Retail</Button>
      <Button id="restaurant-button"
        buttonStyle={selected=== 'restaurant' ? 'filled' : 'bordered'}
        onClick={() => { restaurant_csv(); setSelected('restaurant') }}>Foot traffic to Restaurant</Button>
      <Button id="play-button" shape="square" ref={refPlayButton}
        onClick={() => { refFunctions.current.play() }}>{playing ? '❚❚' : '▶'}</Button>
    </div>
    <br />
    {/* <p id='title'> Foot Traffic vs Insured Employment Rate in <b>High Telework Areas</b> </p> */}
    <div ref={refVis} id="vis" />
  </div>
}

export default TeleworkabilityVisualization
