import React, { useEffect, useRef, useState } from 'react';
import * as THREE from 'three';
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls';
import * as utils from './utils'
import { globals } from './globals';
import pako from 'pako';
import { Line2 } from 'three/examples/jsm/lines/Line2';
import { LineMaterial } from 'three/examples/jsm/lines/LineMaterial';
import { LineGeometry } from 'three/examples/jsm/lines/LineGeometry';

let pointer = new THREE.Vector2();
let modalScene = new THREE.Scene();
// let modalCamera = new THREE.PerspectiveCamera(
//     30,
//     window.visualViewport.width / window.visualViewport.height,
//     0.1,
//     1000
//   );

  let modalCamera = new THREE.OrthographicCamera(
    window.visualViewport.width / -2,    // left
    window.visualViewport.width / 2,     // right
    window.visualViewport.height / 2,    // top
    window.visualViewport.height / -2,   // bottom
    0.1,                                 // near
    1000                                 // far
);
modalCamera.zoom = 15
modalCamera.updateProjectionMatrix()


let modalRaycaster = new THREE.Raycaster();
let modalRenderer = new THREE.WebGLRenderer({ antialias: true });
let modalControls


///////////////////////////////////////////////////
// actgrid tooltip

let actgrid_tooltip_dim = 1 // must be odd, can be one
function get_tooltip_matrix(n, v){
  let actgrid_tooltip_data = Array(n).fill().map(() => Array(n).fill(v))
  return actgrid_tooltip_data
}
let actgrid_tooltip_data = get_tooltip_matrix(actgrid_tooltip_dim, 0)
let flatted_actgrid_rgb = utils.z_score_matrix_to_flattened_rgb_data_for_texture(actgrid_tooltip_data)

// Create texture and plane
let actgrid_tooltip_texture = new THREE.DataTexture(flatted_actgrid_rgb, actgrid_tooltip_dim, actgrid_tooltip_dim, THREE.RGBAFormat);
actgrid_tooltip_texture.flipY = true;
actgrid_tooltip_texture.needsUpdate = true;

const updateActgridTooltipTextureData = (data, texture) => {
  let flattened_actgrid_rgb = utils.z_score_matrix_to_flattened_rgb_data_for_texture(data)
  texture.image.data.set(flattened_actgrid_rgb);
  texture.needsUpdate = true;
}
let actgrid_tooltip_multiplier = 10
let _v = actgrid_tooltip_dim*globals.act_cube_size*actgrid_tooltip_multiplier

const actgrid_tooltip_geometry = new THREE.PlaneGeometry(_v, _v);
const actgrid_tooltip_material = new THREE.MeshBasicMaterial({ map: actgrid_tooltip_texture, transparent:true });

let actgrid_tooltip_mesh = new THREE.Mesh(actgrid_tooltip_geometry, actgrid_tooltip_material);
actgrid_tooltip_mesh.rotation.x = -Math.PI / 2;
actgrid_tooltip_mesh.position.y = 10

let lineColor = new THREE.Color('grey')
// 
function get_actgrid_overlays_pool(kernel_dim, n_upstream_channels, upstreamChannelSize) {
    // let overlay_multiplier = utils.interp((upstreamChannelSize), [12,100,500], [1, 10, 20])
    let overlay_multiplier = utils.interp((upstreamChannelSize), [12,100,500], [10, 10, 20]) // trying to see better
    overlay_multiplier /= kernel_dim

    // let spacer = utils.interp((upstreamChannelSize), [12,100], [.1, 1])
    let spacer = utils.interp((upstreamChannelSize), [12,100], [.3, 1])
    let padder = utils.interp((upstreamChannelSize), [12,100], [.01, .03])

    console.log("overlay multiplier", overlay_multiplier)
    console.log("spacer", spacer)

    let actgrid_overlays_pool = []
    for (let i=0; i<n_upstream_channels; i++) {
        
        // Create texture and plane
        let flattened_placeholder_rgb = utils.z_score_matrix_to_flattened_rgb_data_for_texture(get_tooltip_matrix(kernel_dim, 0))
        
        /////////////
        // weights
        let weights_texture = new THREE.DataTexture(flattened_placeholder_rgb, kernel_dim, kernel_dim, THREE.RGBAFormat);
        weights_texture.flipY = true;
        weights_texture.needsUpdate = true;
        const weights_material = new THREE.MeshBasicMaterial({ map: weights_texture, transparent:true });
    
        let _v = kernel_dim*globals.act_cube_size*overlay_multiplier
        const plane_geometry = new THREE.PlaneGeometry(_v, _v);
        

        let _weights_mesh = new THREE.Mesh(plane_geometry, weights_material);

        function get_background_mesh(){
            const background_geometry = new THREE.BoxGeometry(_v+padder, _v+padder, padder*1.8);
            const background_material = new THREE.MeshBasicMaterial({ color:utils.actvol_facing_color });
            let _background_mesh = new THREE.Mesh(background_geometry, background_material)
            _background_mesh.position.z -= padder

            return _background_mesh
        }
        let weights_background_mesh = get_background_mesh()

        let weights_mesh = new THREE.Group()
        weights_mesh.add(_weights_mesh)
        weights_mesh.add(weights_background_mesh)

        // weights_mesh.position.z += (spacer*1.5)
        weights_mesh.position.z += spacer


        // /////////////
        // // magnified
        // let flattened_placeholder_rgb_mag = utils.z_score_matrix_to_flattened_rgb_data_for_texture(get_tooltip_matrix(kernel_dim, 0))

        // let mag_texture = new THREE.DataTexture(flattened_placeholder_rgb_mag, kernel_dim, kernel_dim, THREE.RGBAFormat);
        // mag_texture.flipY = true;
        // mag_texture.needsUpdate = true;
        // const mag_material = new THREE.MeshBasicMaterial({ map: mag_texture, transparent:true });
        // let _mag_mesh = new THREE.Mesh(plane_geometry, mag_material);

        // // 
        // let mag_background_mesh = get_background_mesh()

        // let mag_mesh = new THREE.Group()
        // mag_mesh.add(_mag_mesh)
        // mag_mesh.add(mag_background_mesh)

        // mag_mesh.position.z += spacer

        //

        let group = new THREE.Group()

        ///////////
        // 
        let b = _v/2
        let a = (kernel_dim*globals.act_cube_size) / 2
        const positions0 = [
            -b, b, spacer-.01, // start position x y z
            -a, a, 0+.01  // end position x y z
        ];
        const positions1 = [
            b, b, spacer-.01, // start position x y z
            a, a, 0+.01  // end position x y z
        ];
        const positions2 = [
            b, -b, spacer-.01, // start position x y z
            a, -a, 0+.01  // end position x y z
        ];
        const positions3 = [
            -b, -b, spacer-.01, // start position x y z
            -a, -a, 0+.01  // end position x y z
        ];
        const material = new LineMaterial({
            color: lineColor,
            linewidth: .6,  // Line width in world units
        });
        const positions = [positions0, positions1, positions2, positions3]
        for (let p in positions) {
            const lineGeometry = new LineGeometry();
            lineGeometry.setPositions(positions[p]);
            const lineObject = new Line2(lineGeometry, material);
            group.add(lineObject)
        }

        //
        group.add(weights_mesh)
        // group.add(mag_mesh)
        
        group.rotation.y = Math.PI / 2
    
        // actgrid_overlays_pool.push({'mesh':group, weights_texture, mag_texture})
        actgrid_overlays_pool.push({'mesh':group, weights_texture})
    }
    return actgrid_overlays_pool
}


function remove_actgrid_overlays_pool(actgrid_overlays_pool) {
  for (let overlay of actgrid_overlays_pool) {
    modalScene.remove(overlay.mesh);
    overlay.mesh.geometry.dispose();
    overlay.mesh.material.dispose();
    overlay.overlay_texture.dispose();
  }
  actgrid_overlays_pool.length = 0;
}

function getTooltipCrop(matrix, centerX, centerY, cropSize) {

  const radius = Math.floor(cropSize / 2);
  const result = Array(cropSize).fill().map(() => Array(cropSize).fill(0));
  
  // Loop through cropSize x cropSize area
  for (let y = -radius; y <= radius; y++) {
      for (let x = -radius; x <= radius; x++) {
          const sourceX = centerX + x;
          const sourceY = centerY + y;
          
          // Check if we're within bounds of source matrix
          if (sourceX >= 0 && sourceX < matrix[0].length && 
              sourceY >= 0 && sourceY < matrix.length) {
              // Map to 0 to cropSize-1 range for our result matrix
              result[y + radius][x + radius] = matrix[sourceY][sourceX];
          }
      }
  }
  
  return result;
}

let is_loading_weights = false
let modal_is_pinned = false
const Conv2dModal = ({ 
    op_nid,
    onClose,
    setFeatureTooltipObject
}) => {
  const modalSceneRef = useRef(null);
  const modalRendererRef = useRef(null);

  useEffect(() => {

    // target conv op
    let nid = op_nid
    let op = globals.nodes_lookup[nid]

    let input_tensor_node = globals.nodes_lookup[op.uns[0]]
    let output_tensor_node = globals.nodes_lookup[op.dns[0]]

    let input_actgrid = utils.createGridOfChannelSlicesWithDataTexture(input_tensor_node)
    let output_actgrid = utils.createGridOfChannelSlicesWithDataTexture(output_tensor_node)

    let d_input = input_actgrid.the_plane.userData
    let d_output = output_actgrid.the_plane.userData

    d_input.layer_type = "input"
    d_output.layer_type = "output"

    // Modal scene setup

    modalCamera.layers.enable(utils.ACTGRID_LAYER)
    modalCamera.layers.enable(utils.ACTVOL_OBJECTS_LAYER)

    modalRenderer.setSize(window.visualViewport.width, window.visualViewport.height);   
    modalSceneRef.current.appendChild(modalRenderer.domElement);
    modalRendererRef.current = modalRenderer;

    // controls
    modalControls = new OrbitControls(modalCamera, modalRenderer.domElement);
    
    input_actgrid.rotation.x = 0
    output_actgrid.rotation.x = 0

    input_actgrid.rotation.y = Math.PI / 2
    output_actgrid.rotation.y = Math.PI / 2

    input_actgrid.position.z -= (d_input.grid_width*globals.act_cube_size/2) 
    output_actgrid.position.z -= (d_output.grid_width*globals.act_cube_size/2) 
    input_actgrid.position.y += (d_input.grid_height*globals.act_cube_size/2) 
    output_actgrid.position.y += (d_output.grid_height*globals.act_cube_size/2) 

    let max_height = Math.max(d_input.grid_height, d_output.grid_height)
    // let spacer = utils.interp((max_height), [60, 700, 1000], [2, 6, 8])
    let spacer = utils.interp((max_height), [60, 700, 1000], [4, 12, 24])
    console.log('spacer', spacer)

    input_actgrid.position.x -= spacer
    output_actgrid.position.x += spacer

    modalScene.add(input_actgrid)
    modalScene.add(output_actgrid)

    modalScene.add(actgrid_tooltip_mesh)

    ///////////////////////////
    let kernel_dim
    if (op.name==="conv2d"){
        kernel_dim = op.fn_metadata.kernel_size // eg "(3, 3)"
        kernel_dim = kernel_dim.replace(/[()]/g, '').split(',').map(num => parseInt(num.trim()));
        kernel_dim = kernel_dim[0] // TODO, this assumes square, which is not always the case
    } else if (op.name==="linear") {
        kernel_dim = 1
    }

    console.log("kernel dim", kernel_dim)

    let BATCH_IX = 0
    let input_actvol = globals.tensor_trace[input_tensor_node.tensor_id][BATCH_IX] // conv (32, 112, 112); 
    console.log("input actvol", input_actvol) // 257 x 384 (spatial, channels) for linear; (channels, spatial, spatial) for conv
    let n_upstream_channels
    if (op.name==="conv2d") {
        n_upstream_channels = input_actvol.length
    } else if (op.name==="linear") { 
        n_upstream_channels = input_actvol[0].length // TODO this isn't safe, as linear can be arbitary shape, w features last
    }

    let d = input_actgrid.the_plane.userData
    let base_pos_x = input_actgrid.position.x //
    let base_pos_y = input_actgrid.position.y
    let base_pos_z = input_actgrid.position.z

    const upstreamChannelSize = d.channel_width; // height and width always the same for now
    console.log("upstreamChannelSize", upstreamChannelSize)

    const paddingSize = 3; // TODO should take from userData also

    let actgrid_overlays_pool = get_actgrid_overlays_pool(kernel_dim, n_upstream_channels, upstreamChannelSize)
    
    modalScene.background = new THREE.Color(...[248, 249, 250].map(d => d/255))

    modalCamera.position.z = 60;
    modalCamera.position.y = -5;
    modalCamera.rotation.y = Math.PI / 10;

    const multiplyArray = (arr, scalar) => {
        return arr.map(item => 
          Array.isArray(item) ? multiplyArray(item, scalar) : item * scalar
        );
      };

    // load the weights from server if not already stored
    if (!(op_nid in globals.model_weights)){
        is_loading_weights = true // TODO need to check this later 
        let model_name = globals.nn.trace_metadata.name

        fetch(`/data/weights/${model_name}/${op_nid}.json.gz`)
            .then(response => response.arrayBuffer())
            .then(arrayBuffer => {
                const uint8Array = new Uint8Array(arrayBuffer);
                const decompressed = pako.ungzip(uint8Array, { to: 'string' });
                let weights = JSON.parse(decompressed);
                console.log("loaded weights", weights)
                for (let k in weights) {
                    weights[k] = multiplyArray(weights[k], 2) // quick hack to adjust color scale on the weights
                }
                globals.model_weights[op_nid] = weights
                is_loading_weights = false
            })
    } else {
        console.log("weights already loaded")
    }

    modalRaycaster.layers.disableAll()
    modalRaycaster.layers.enable(utils.CLICKABLE_LAYER)

    function togglePinned(is_pinned) {
        modal_is_pinned = is_pinned
        console.log("modal_is_pinned", modal_is_pinned)
    }
    window.addEventListener('keydown', function(event) {
        if (event.ctrlKey && event.key === 'l') {
            event.preventDefault(); 
            togglePinned(true)
        }
        if (event.ctrlKey && event.key === 'u') {
            event.preventDefault(); 
            togglePinned(false)
        }
      });


    function modalOnMouseMove(event) {
      pointer.x = ((event.pageX) / (window.visualViewport.width)) * 2 - 1;
      pointer.y = -(event.pageY / window.visualViewport.height) * 2 + 1;
      modalRaycaster.setFromCamera(pointer, modalCamera);

      const intersects = modalRaycaster.intersectObjects(modalScene.children, true);
  
      if (intersects.length > 0) {
        let obj = intersects[0].object
        if (obj.is_actgrid_plane) {
            let plane = intersects[0].object
            let intx_pt = intersects[0].point
            const uv = intersects[0].uv;

            const x = Math.floor(uv.x * plane.userData.grid_width);
            const y = plane.userData.grid_height - Math.floor(uv.y * plane.userData.grid_height);
            
            // Calculate channel information
            let channels = plane.userData.channels
            const gridSize = Math.ceil(Math.sqrt(channels));
            const channelSize = plane.userData.channel_width; // height and width equal always for now
            const paddingSizeDownstream = 3;
            
            const gridX = Math.floor((x - paddingSizeDownstream) / (channelSize + paddingSizeDownstream));
            const gridY = Math.floor((y - paddingSizeDownstream) / (channelSize + paddingSizeDownstream));
            
            const localX = (x - paddingSizeDownstream) % (channelSize + paddingSizeDownstream);
            const localY = (y - paddingSizeDownstream) % (channelSize + paddingSizeDownstream);
            
            const channel_ix = gridY * gridSize + gridX;

            let data = plane.userData.gridData
            
            if (y<data.length && x<data[0].length) {
                let v = data[y][x]
                if (v===utils.PADDING_FILL_VALUE) {
                    // tensor background
                    // clear_highlight_on_prev_intersected()
                    setFeatureTooltipObject(null)

                    // setTooltipPosition({ left: currentMouseCoords.x, //screen_coords.clientX, 
                    //                       top: currentMouseCoords.y-10, //screen_coords.clientY 
                    //                     });
                    // setTooltipObject(op);
                    // hovered_op = op
                    // setHelpInformation(`Double-click to close activations grid. Right-click for more options`)
                    
                    console.log("tensor background")
                } else {
                    // channel slice
                    console.log('channel', channel_ix)

                    // // the small tooltip over cursor
                    // let tooltip_data = getTooltipCrop(data, x, y, actgrid_tooltip_dim)
                    // updateActgridTooltipTextureData(tooltip_data, actgrid_tooltip_texture)
                    // // actgrid_tooltip_mesh.position.x = intx_pt.x // isn't positioned right
                    // // actgrid_tooltip_mesh.position.z = intx_pt.z
                    
                    //////////////////
                    // the upstream layer

                    const sumAbsKernels = (arr) => {
                        return arr.map(kernel => 
                          kernel.reduce((sum, row) => 
                            sum + row.reduce((rowSum, val) => rowSum + Math.abs(val), 0)
                          , 0)
                        );
                      };
                    const nthPercentile = (arr, n) => {
                        const sorted = [...arr].sort((a, b) => a - b);
                        const index = Math.ceil((n / 100) * sorted.length) - 1;
                        return sorted[index];
                    };
                    // let percentile_thresh = utils.interp(n_upstream_channels, [3,32,128], [1,60,90])
                    let percentile_thresh = utils.interp(n_upstream_channels, [3,32,128,1024], [1,80,95,99])
                    console.log('n_upstream_channels', n_upstream_channels)

                    let s = Math.ceil(Math.sqrt(d.channels)) // the size, s x s,  of the upstream grid
                    
                    // only update overlays if not pinned
                    if (!modal_is_pinned) {
                        if (op_nid in globals.model_weights) {
                            // weights are loaded and ready
                            let weights = globals.model_weights[op_nid] // needs to have been put there TODO fix conv for this case
                            weights = weights[1] // just taking weights for now, bias is ix 2

                            console.log(weights)

                            let channel_is_important
                            if (op.name==="conv2d") {
                                let incoming_channels_strengths = sumAbsKernels(weights[channel_ix]) 
                                // weights is conv: shape (output_channels, input_channels, kernel_height, kernel_width)
                                // linear: (output_channels, input_channels)
    
                                let t = nthPercentile(incoming_channels_strengths, percentile_thresh)
                                channel_is_important = incoming_channels_strengths.map(s=>s>=t)
                            } else if (op.name==="linear") {
                                let w = weights[channel_ix].map(v => Math.abs(v))
                                let t = nthPercentile(w, percentile_thresh)
                                channel_is_important = w.map(s=>s>=t)
                            }
                            console.log("channel is important", channel_is_important)

    
                            // place overlays on upstream layer
                            for (let c=0; c<n_upstream_channels; c++) {
                                const channel_x = (c%s) * (upstreamChannelSize + paddingSize) * globals.act_cube_size
                                const channel_y = Math.floor(c/s) * (upstreamChannelSize + paddingSize) * globals.act_cube_size
                                
                                // TODO TODO we need to ensure we're completely correctly aligned here
                                let r = upstreamChannelSize / channelSize
                                let upstreamLocalX = Math.floor(localX*r)
                                let upstreamLocalY = Math.floor(localY*r)
    
                                // let channel_tooltip_crop = getTooltipCrop(input_actvol[c], upstreamLocalX, upstreamLocalY, kernel_dim)
                                
                                let px = base_pos_x
                                let py = base_pos_y - channel_y - (upstreamLocalY*globals.act_cube_size)
                                let top_left = (base_pos_z + d.grid_width * globals.act_cube_size)
                                let pz = top_left - channel_x - (upstreamLocalX*globals.act_cube_size)
            
                                let kernel = weights[channel_ix][c] // weights is shape (output_channels, input_channels, kernel_height, kernel_width)
                                if (op.name=="linear"){
                                    kernel = [[kernel]]
                                }
                                let o = actgrid_overlays_pool[c]
                                o.mesh.position.x = px 
                                o.mesh.position.y = py
                                o.mesh.position.z = pz
                                updateActgridTooltipTextureData(kernel, o.weights_texture)  
                                // updateActgridTooltipTextureData(channel_tooltip_crop, o.mag_texture)

                                if (channel_is_important[c]) { // it's these that slow things down. invisible still just as slow. 
                                    // o.mesh.visible = true
                                    modalScene.add(o.mesh)

                                } else {
                                    // o.mesh.visible = false
                                    modalScene.remove(o.mesh)
                                }
    
                            }
    
                        }
                    }


                    /////
                    // feature tooltip
                    let plane_op = plane.userData.actual_op
                    let tensor_id = plane_op.tensor_id // can be input or output tensor
                    let featurespace = globals.featurespace[tensor_id]
                    let top_5 = featurespace["top_5s"][channel_ix]
                    let bottom_5 = featurespace["bottom_5s"][channel_ix]
                    let dataset = globals.nn.trace_metadata.dataset
                    let dataset_name = dataset//=="imagenet"?"imagenet_val":dataset // dumb
                    let image_paths_pos = top_5.map(ix => {
                      return `/data/${dataset_name}/image_${ix}.png`
                    })
                    let image_paths_neg = bottom_5.map(ix => {
                      return `/data/${dataset_name}/image_${ix}.png`
                    })
                    let overlaysPath = `/data/featurespace_overlays/${globals.nn.trace_metadata.name}/${tensor_id}_${channel_ix}.json.gz`
                
                    setFeatureTooltipObject({image_paths_pos, image_paths_neg, channel_ix, tensor_id, overlaysPath})// NOTE just turning off for dev of mouseover mag

                    
                }
            }
        }
      }



    //             }
            
    //         }
    //         /////////////////////              
    //         /////////////////////////////


    }
    window.addEventListener('mousemove', modalOnMouseMove, false);
    
    // Animation loop
    const animateModal = () => {
      requestAnimationFrame(animateModal);
      modalRenderer.render(modalScene, modalCamera);
      modalControls.update()
    };
    animateModal();

    // Cleanup
    return () => {
      window.removeEventListener('mousemove', modalOnMouseMove, false);
      if (modalRendererRef.current) {
        modalRendererRef.current.dispose();
      }
    };
  }, []);
  const closeModal = ()=>{
    // cleanup scene here TODO ensure this is thorough
    utils.clear_scene(modalScene)
    onClose()
  }

  return (
    <div style={modalStyles.overlay}>
      <div style={modalStyles.modal}>
          <button onClick={closeModal} style={modalStyles.closeButton}>
            ✕
          </button>
        <div ref={modalSceneRef} style={modalStyles.sceneContainer} />
      </div>
    </div>
  );
};

const modalStyles = {
    overlay: {
        position: 'fixed',
        top: 0,
        left: 0,
        right: 0,
        bottom: 0,
        backgroundColor: 'rgba(0, 0, 0, 0.5)',
        display: 'flex',
        alignItems: 'center',
        justifyContent: 'center', 
        zIndex: 5,
    },
    modal: {
        backgroundColor: 'white',
        width: '100%',
        height: '100%'
    },
    closeButton: {
        border: 'none',
        background: 'none',
        fontSize: '20px',
        cursor: 'pointer'
    },
    sceneContainer: {
        width: '100%',
        height: '100%'
    }
};


export default Conv2dModal;