import * as THREE from 'three';
import { CSS2DObject } from 'three/examples/jsm/renderers/CSS2DRenderer';
import { globals } from './globals';

import * as TWEEN from '@tweenjs/tween.js';

import { Line2 } from 'three/examples/jsm/lines/Line2';
import { LineMaterial } from 'three/examples/jsm/lines/LineMaterial';
import { LineGeometry } from 'three/examples/jsm/lines/LineGeometry';

import pako from 'pako';
import { remove_label_from_op_and_return_to_pool } from "./label_utils"


export const LINE_OBJECTS_LAYER = 3
export const ACTVOL_OBJECTS_LAYER = 4
export const OP_NODES_OBJECTS_LAYER = 5
export const TENSOR_NODES_OBJECTS_LAYER = 5
export const ACTGRID_LAYER = 6
export const ONLY_MINIMAP_LAYER = 7

let scene = globals.scene
///////////////////////////////
// constants
///////////////////////////////

export function get_edge_color(brightness_factor) {
    const edge_color = new THREE.Color(...[33*brightness_factor, 37*brightness_factor, 41*brightness_factor].map(d=>d/255))
    return edge_color
}

export const node_color = new THREE.Color(...[22, 66, 91].map(d=>d/255))
export const node_color_outline = new THREE.Color(...[7, 32, 30].map(d=>d/255))



let highlight_color = new THREE.Color(...[231, 111, 81].map(d => d/255)); // peach
// let highlight_color = new THREE.Color(...[116, 196, 118].map(d => d/255)); // green
// let highlight_color = new THREE.Color(...[199, 125, 255].map(d => d/255)); // lavender


export const node_highlight_color = highlight_color

// const scene_background_color = new THREE.Color(...[248, 249, 250].map(d => d/255));
export const plane_color = new THREE.Color(...[248, 249, 250].map(d => d/255));
export const white_color = new THREE.Color(1,1,1);
export const plane_color_darker = new THREE.Color(...[228, 229, 230].map(d => d/255));
export const plane_outline_color = new THREE.Color(...[58, 124, 165].map(d=>d/255))

// doesn't seem to be much perf effect here if any. not true. noticeable on laptop. Big diff noticed when use flat lines rather 
// than bezier
export const CURVE_N_PTS = 20 //50


export const MAX_SPHERE_SIZE = .32

//
const sphere_geometry = new THREE.CircleGeometry(1, 12);
const square_geometry = new THREE.PlaneGeometry(1, 1);

export const CLICKABLE_LAYER = 1
export const TWEEN_MS = 1200 //800
export const TWEEN_EASE = TWEEN.Easing.Linear.None

export const plane_highlight_color = highlight_color

///////////////////////////////
// viz utils
///////////////////////////////

export function get_curve_pts(pt1, pt2, n_pts) {
	let x_diff = pt2.x - pt1.x
	let z_diff = pt2.z - pt1.z
	let zd = .02
	let xd = .05
	let pts = [
			new THREE.Vector3(pt1.x, pt1.y, pt1.z),
			new THREE.Vector3(pt1.x + x_diff*xd,  pt1.y, pt1.z + z_diff*zd),
			new THREE.Vector3(pt2.x - x_diff*xd, pt2.y, pt2.z - z_diff*zd),
			new THREE.Vector3(pt2.x, pt2.y, pt2.z),
		]
	const curve = new THREE.CatmullRomCurve3(pts);
	curve.curveType = 'chordal';
	const points = curve.getPoints(n_pts);
    return points
}
export function get_pts_for_flat_line(pt1, pt2) {
    let pts = [
        new THREE.Vector3(pt1.x, pt1.y, pt1.z),
        new THREE.Vector3(pt2.x, pt2.y, pt2.z),
    ]
    return pts
}
// // normal line is constant width in most browsers despite setting thickness
// export function get_line_from_pts(pts, linewidth, color) {
// 	const line_geometry = new THREE.BufferGeometry().setFromPoints(pts);
// 	const material = new THREE.LineBasicMaterial( { color: color, linewidth:linewidth } );
// 	const lineObject = new THREE.Line(line_geometry, material);
//     lineObject.layers.set(MINIMAP_OBJECTS_LAYER)

// 	return lineObject
// }

export function pts_to_positions(pts) {
    // line2 uses flattened array
    const positions = [];
    pts.forEach(pt => {
        positions.push(pt.x, pt.y, pt.z);
    });
    return positions
}

// Line2 supports line width, which we're finding to be very helpful for understanding
export function get_line_from_pts(pts, linewidth, color) {
    // Convert the points array into a flat array of coordinates
    let positions = pts_to_positions(pts)

    // Create the LineGeometry and set the positions
    const lineGeometry = new LineGeometry();
    lineGeometry.setPositions(positions);

    color = globals.DEBUG ? (new THREE.Color(pts.length===2 ? "red" : "blue")) : color

    // Create the LineMaterial with specified color and linewidth
    const material = new LineMaterial({
        color: color,
        linewidth: linewidth,  // Line width in world units
        dashed: false,         // Optional: set to true if you want a dashed line
    });

    // Ensure material is updated before rendering
    // material.resolution.set(window.innerWidth, window.innerHeight);

    // Create the Line2 object using the geometry and material
    const lineObject = new Line2(lineGeometry, material);

    lineObject.layers.disableAll();
    lineObject.layers.enable(LINE_OBJECTS_LAYER);

    return lineObject;
}

// NOTE line2 doesn't flicker bc of frustum culling, but line does, only after tweening but not on initial create. 

/*
devlog aug 30. bugs in transitions w line2 after expansion / collapse. Everything worked w Line, but Line2 a bit buggy.
Changed to use same method of creating new line in place of old one rather than reuse line obj itself, this matches
what we're doing on expansion. Note the weird thing is we need to shift y by at least something for it to work! otherwise
get a bug sometimes where it's not visible! sd1.4, the block after the mid block, expand first resnet, then expand the mid block
and the line going into resnet disappears! Is the small y shift triggering something that's needed? some difference btwn old and new?
aye. Also note we're still confused about when lines are getting more nodes added to them? need to be more clear on this. My brain hurts,
am tired, and this is a complex part of the code. Want this to be cleaner. Need to have good perf, no more pts than needed (i think),
but also need simplicity and cleanliness in our code. It currently seems to work, but i don't like this confusion and complexity
*/

export function get_edge_pts(n0, n1) {
    // let same_y = globals.DEBUG ? n0.y_unshifted==n1.y_unshifted : n0.y==n1.y 
    let same_y = n0.y_unshifted===n1.y_unshifted
    let same_module = n0.parent_op === n1.parent_op
    let same_row = same_module && (n0.draw_order_row===n1.draw_order_row)
    let neither_is_module = n0.node_type!=="module" && n1.node_type!=="module" 
    // this will get turned into multipt curve on expansion or collapse

    // cam refactor around needing this. It's when edges going into collapse module, 
    // the one that connects at base level should have two pts but we automatically tell all to be curve
    // and yet we have issue where even when same row
    // eff. with this, we're doing way too many curved edges TODO need to get around this
    let x_dist = n1.x - n0.x
    let pt1 = {x:n0.x, y:0, z:n0.y} // n0
    let pt2 = {x:n1.x, y:0, z:n1.y} // n1
    let pts

    if (same_y) { // flat
        // lines should get the number of pts they'll ever need. If possibility of shifting to curved line, init now w enough pts.
        // if guaranteed to always be same y, can suffice w two pts. This is for perf, many of our lines will always be straight, and 
        // on laptop this makes difference. This saves complexity later bc don't have to update n_pts in the curve ever, which we were doing before
        // and works w Line but not easily w Line2, still not understood why, but this is conceptually fine, and simpler
        // if (same_row && neither_is_module) { // nodes in the same row should never be at different y position, regardless of any transition
        if (same_row) { // nodes in the same row should never be at different y position, regardless of any transition
            pts = get_pts_for_flat_line(pt1, pt2)
        } else { // this line is flat now but may change when layout changes. Init to have enough pts for transition wout having to update n_pts
            pts = get_curve_pts(pt1, pt2, CURVE_N_PTS)  
        }

    } else { // has vertical part
        let elbow_x_dist = 2

        if (x_dist > 2) { // elbow. Compound curve
            // if ((n0.respath_dist == n1.respath_dist) || n0.is_last_in_line){ // normal elbow TODO this needs work. Mark it in layout_engine. 
                if ( n0.is_last_in_line){ // normal elbow 
                let elbow = {x:n1.x-elbow_x_dist, y:0, z:n0.y}
                let flat_pts = get_pts_for_flat_line(pt1, elbow)
                let curve_pts = get_curve_pts(elbow, pt2, CURVE_N_PTS-2)
                pts = flat_pts.concat(curve_pts)    
            } else { // pre elbow // TODO pre-elbow also needs to be added to occ blocking in layout engine
                let elbow = {x:n0.x+elbow_x_dist, y:0, z:n1.y}
                let curve_pts = get_curve_pts(pt1, elbow, CURVE_N_PTS-2)
                let flat_pts = get_pts_for_flat_line(elbow, pt2)
                pts = curve_pts.concat(flat_pts)
            }
        } else { // x dist is one, single vertical curve
            pts = get_curve_pts(pt1, pt2, CURVE_N_PTS)
        }
    }
    return pts
}

export function get_node_color(n) {
    if (n.conditioning_entering_respath) {
        return new THREE.Color("green")
    } else if (n.remove_this_aux_output) {
        return new THREE.Color("orange")
    } else if (n.is_conditioning_upstream) {
        return new THREE.Color("aqua")
    } else if (n.is_conditioning) {
        return new THREE.Color("blue")
    } else if (n.is_global_input) {
        return new THREE.Color("orange")
    } else if (n.is_input && n.dns.length==0) {
        return new THREE.Color("black")
    } else if (n.is_output_global) {
        return new THREE.Color("purple")
    } else if (n.is_input) {
        return new THREE.Color("yellow")
    } else if (n.is_output) {
        return new THREE.Color("red")
    } else if (n.node_type=="mod_out") {
        return new THREE.Color("pink")
    } else if (n.node_type=="mod_in") {
        return new THREE.Color("gold")
    } else {
        return new THREE.Color("grey")
    }
}

export function _get_node_color_non_debug (n) {
    if (n.is_tensor_node) {
        return new THREE.Color(...[33, 37, 41].map(d=>d/255))
    } else {
        return node_color
    }
}

export function get_z_plane(op) {
	// return interp(op.n_ops, [2,20,800], [-.1, -.3, -.9]) // 3d
	return interp(op.depth, [0,100], [-10, -1])
}

export function get_color_from_depth(d) {
	let d_range = globals.max_depth_visible<=2 ? [0, globals.max_depth_visible] : [1, globals.max_depth_visible-1]
	let c1 = [173, 181, 189]
	let c2 = [248, 249, 250]

	c1 = c1.map(d => d/255)
	c2 = c2.map(d => d/255)
	let r = interp(d, d_range, [c1[0], c2[0]])
	let g = interp(d, d_range, [c1[1], c2[1]])
	let b = interp(d, d_range, [c1[2], c2[2]])
	const color = new THREE.Color(r, g, b)
	return color
}

export function get_plane_color(op) {
	let d = op.depth
    let color = get_color_from_depth(d)
	return color
}

let MIN_SPHERE_SIZE = .06

export function get_sphere_scale(op) {
    let v = ("n_params" in op) ? op.n_params : 0
    v += 1 // don't want sqrt of zero
	let scalar = interp(Math.sqrt(v), [0, Math.sqrt(globals.max_n_params_visible)], [MIN_SPHERE_SIZE, MAX_SPHERE_SIZE])
    return scalar
}

export function scale_sphere(sphere, op) {
    let scalar = get_sphere_scale(op)
    // if (op.activations_available) { // tensor nodes w activations available
    //     scalar *= 2
    // }
	sphere.scale.x = scalar
	sphere.scale.y = scalar
	// sphere.scale.z = scalar
}

export function get_group_label(op) {
	const div = document.createElement( 'div' );
	div.className = 'group_label';
	let s = op.name.split("-")
	let text = s[0].toUpperCase() 
	if (s.length==2) text += "-" + s[1].slice(s[1].length-4, s[1].length) + "-" + op.row_counter // 
	div.innerHTML = text
	div.style.backgroundColor = 'transparent';

	const label = new CSS2DObject( div );
    label.element.style.display = 'none'

	return label
}

let act_vol_base_color = [115, 147, 169].map(d=>d/255) // blue-grey
let actvol_front_color = new THREE.Color(...act_vol_base_color.map(d=>d*.2))
export const actvol_facing_color = new THREE.Color(...act_vol_base_color.map(d=>d*.5))
let actvol_top_color = new THREE.Color(...act_vol_base_color.map(d=>d*1.))


function get_actvol_materials() { 
    // each actvol getting own materials so we can do onHover event for color
    // TODO what are perf implications for creating all these new materials instead of reusing?
    // for onHover, can we swap material out instead?
    const materials = [
        new THREE.MeshBasicMaterial({color: actvol_front_color}), // Front
        new THREE.MeshBasicMaterial({ color: 0x00ff00 }), // Green
        new THREE.MeshBasicMaterial({color: actvol_top_color}), // Top

        new THREE.MeshBasicMaterial({ color: 0xffff00 }), // Yellow
        new THREE.MeshBasicMaterial({color: actvol_facing_color}), // Facing

        new THREE.MeshBasicMaterial({ color: 0x00ffff })  // Cyan
    ];
    return materials
}
function getRandomColor() {
    return Math.random() * 0xffffff; // Random hexadecimal color
}
let shear_to_show_top = -0.6
let shear_to_show_front = -0.6

const shearMatrix = new THREE.Matrix4().set(
    1, 0, shear_to_show_front, 0, // fake 'rotate' to show front
    0, 1, shear_to_show_top,   0, // fake 'rotate' to show top
    0, 0, 1,   0, // No shear on Z-axis
    0, 0, 0,   1  // No change in perspective
);

////////////////////////////////
// actual z, mean centered and scaled by std

// our original one
// function _rgb_from_actual_z(z_score) {
//     let neutral = .3
//     let n = 220
//     let colors = [
//         [-12, [0, 0, 120]], // blue
//         [-6, [10, 10, n]], // blue
//         [-(neutral+.01), [150, 180, n]], // lightgreen

//         [-neutral, [n,n,n]], // neutral
//         [neutral, [n,n,n]], // neutral

//         [neutral+.01, [n, 180, 150]], // orange
//         [6, [n, 10, 10]], // red
//         [12, [120, 0, 0]], // red
//     ]
//     let range = colors.map(c=>c[0])

// 	let r = interp(z_score, range, colors.map(c=>c[1][0]))
// 	let g = interp(z_score, range, colors.map(c=>c[1][1]))
// 	let b = interp(z_score, range, colors.map(c=>c[1][2]))
//     return {r, g, b}
// }
function _rgb_from_actual_z(z_score) {
    let neutral = .3
    let n = 220
    let colors = [
        [-12, [0, 0, 120]], // blue
        [-4, [10, 10, n]], // blue
        [-2, [120, 150, n]], // lightgreen

        [-1, [180, 200, n]], //
        [0, [n,n,n]], // neutral
        [1, [n, 200, 180]],

        [2, [n, 150, 120]], // orange
        [4, [n, 10, 10]], // red
        [12, [120, 0, 0]], // red
    ]
    let range = colors.map(c=>c[0])

	let r = interp(z_score, range, colors.map(c=>c[1][0]))
	let g = interp(z_score, range, colors.map(c=>c[1][1]))
	let b = interp(z_score, range, colors.map(c=>c[1][2]))
    return {r, g, b}
}
let rgb_from_actual_z_lookup = {}
for (let i=-120; i<=120; i++) {
    rgb_from_actual_z_lookup[i] = _rgb_from_actual_z(i/10.)
}
export function rgb_from_actual_z(z_score) {
    let rounded_z_score = clamp(Math.round(z_score*10), -120, 120)
    let rgb = rgb_from_actual_z_lookup[rounded_z_score]
    return rgb
}

///////////////////////////////
// psuedo-z, just divided by std, not centered at zero

function _rgb_from_z(z_score) { // not actually z score anymore, this one is for when not subtracting mean. Everything over zero shows something.
    let neutral_neg = -1.
    let n = 220
    let colors = [
        [-12, [0, 0, 120]], // blue
        [-6, [10, 10, n]], // blue
        [-3, [150, 180, n]], // lightgreen

        [0, [n,n,n]], // neutral

        [.0001, [n, 180, 150]], // orange
        [3, [n, 20, 20]], // red
        [6, [180, 10, 10]], // red
    ]
    let range = colors.map(c=>c[0])

	let r = interp(z_score, range, colors.map(c=>c[1][0]))
	let g = interp(z_score, range, colors.map(c=>c[1][1]))
	let b = interp(z_score, range, colors.map(c=>c[1][2]))
    return {r, g, b}
}
let rgb_from_z_lookup = {}
for (let i=-120; i<=120; i++) {
    rgb_from_z_lookup[i] = _rgb_from_z(i/10.)
}
export function rgb_from_z(z_score) {
    let rounded_z_score = clamp(Math.round(z_score*10), -120, 120)
    let rgb = rgb_from_z_lookup[rounded_z_score]
    return rgb
}

function _color_from_z(z_score) {

    let {r, g, b} = rgb_from_z(z_score)

	const color = new THREE.Color(r/255, g/255, b/255)
	return color
}
console.time("compile color from z lookup")
let color_from_z_lookup = {}

for (let i=-120; i<=120; i++) {
    color_from_z_lookup[i] = _color_from_z(i/10.)
}
console.timeEnd("compile color from z lookup")

function clamp(value, min, max) {
    return Math.min(Math.max(value, min), max);
  }
function color_from_z(z_score) {
    let rounded_z_score = clamp(Math.round(z_score*10), -120, 120)
    let c = color_from_z_lookup[rounded_z_score]

	return c
}

export const linear_dim_types_id = ["unknown", "spatial", "spatial", "features"].slice(1).join("-") // sam2hiera, off a linear but has nice shape
export const single_linear_dim_types_id = ["unknown", "features"].slice(1).join("-") // single features vector like eg imgnet cls

export const conv2d_dim_types_id = ["batch", "features", "spatial", "spatial"].slice(1).join("-") // conv2d

// export const linear_spatial_inferred_dim_types_id = ["unknown", "spatial_inferred", "spatial", "features"].slice(1).join("-") 




function actgrid_value_getter(volume, op, c, y, x, preceding_cls_tokens) {
    let op_dim_types_id = op.dim_types.slice(1).join("-")
    // console.log("volume", volume)
    let v
    if (op.enriched_tensor_specs.spatial_is_inferred) {
        let inferred_width = op.enriched_tensor_specs.spatial[1]
        let ix
        if (preceding_cls_tokens){ // move cls token to last position NOTE this is one of three places where we do this crap
            ix = (y*inferred_width + x) + 1 // CLS token is first, move it to bottom. this sucks. hardcoded to 1
            if (ix === volume.length) ix = 0;
        } else {
            ix = (y*inferred_width + x) 
        }
        
        if (ix < volume.length) {
            // console.log(ix, x)
            v = volume[ix][c]
        } else {
            v = PADDING_FILL_VALUE
        }
    } else if (op_dim_types_id===conv2d_dim_types_id) {
        v = volume[c][y][x]
    } else if (op_dim_types_id===linear_dim_types_id) {
        v = volume[y][x][c]
    } else if (op_dim_types_id===single_linear_dim_types_id) {
        v = volume[c]
    } else {
        console.log("don't know what to do w these dim types", op_dim_types_id)
    }
    return v
}

// claude
export const PADDING_FILL_VALUE = 200

// putting in fn bc layout engine needs this info to reserve spacing
export function get_padding_size_from_channel_width(width){
    return Math.round(interp(width, [1, 16], [1, 3])) // less padding at smaller channel dims
}

function reshapeToGrid(input, op, preceding_cls_tokens) {

    let op_dim_types_id = op.dim_types.slice(1).join("-")
    let channels, height, width
    if (op.enriched_tensor_specs.spatial_is_inferred) {
        let inferred_width = op.enriched_tensor_specs.spatial[1]
        let inferred_height = op.enriched_tensor_specs.spatial[0]
        width = inferred_width
        height = inferred_height
        channels = input[0].length
    } else if (op_dim_types_id===conv2d_dim_types_id) {
        channels = input.length
        height = input[0].length
        width = input[0][0].length
    } else if (op_dim_types_id===linear_dim_types_id) {
        height = input.length
        width = input[0].length
        channels = input[0][0].length
    } else if (op_dim_types_id===single_linear_dim_types_id) {
        channels = input.length
        height = 1
        width = 1
    }

    const gridWidth = Math.ceil(Math.sqrt(channels));
    const gridHeight = Math.ceil(channels / gridWidth);

    // const paddingSize = 3;
    let paddingSize = get_padding_size_from_channel_width(width)

    
    const outputHeight = height * gridHeight + (gridHeight + 1) * paddingSize;
    const outputWidth = width * gridWidth + (gridWidth + 1) * paddingSize;
    const output = Array(outputHeight).fill().map(() => 
      Array(outputWidth).fill(PADDING_FILL_VALUE)
    );
    
    for (let c = 0; c < channels; c++) {
      const gridRow = Math.floor(c / gridWidth);
      const gridCol = c % gridWidth;
      
      for (let y = 0; y < height; y++) {
        for (let x = 0; x < width; x++) {
          const outY = (gridRow * height) + y + paddingSize * (gridRow + 1);
          const outX = (gridCol * width) + x + paddingSize * (gridCol + 1);
        //   output[outY][outX] = input[c][y][x];
          output[outY][outX] = actgrid_value_getter(input, op, c,y,x, preceding_cls_tokens);
        }
      }
    }
    
    return {
            gridData: output,
            channel_height: height,
            channel_width: width,
            channels, 
            paddingSize
            }; // h x w
}



export function z_score_matrix_to_flattened_rgb_data_for_texture(gridData, mean_center) {
    // takes in h x w data of values, outputs flattened array of rgba for datatexture
    const height = gridData.length;
    const width = gridData[0].length;
    let color_fn = mean_center ? rgb_from_actual_z : rgb_from_z
    // overlays are already mean-centered if necessary

    const data = new Uint8Array(4 * width * height);
    for (let y = 0; y < height; y++) {
      for (let x = 0; x < width; x++) {
        const value = gridData[y][x];
        let rgb
        if (value===PADDING_FILL_VALUE) {
            rgb = {r:0,g:0,b:0,a:0}
        } else {
            rgb = color_fn(value/10) 
            rgb.a = 255
        }
        const idx = (y * width + x) * 4;
        
        data[idx] = rgb.r;   // R
        data[idx + 1] = rgb.g; // G
        data[idx + 2] = rgb.b; // B
        data[idx + 3] = rgb.a; // A
      }
    }
    return data
}

function get_actvol(tensor_id) {
    let BATCH_IX = 0
    return globals.tensor_trace[tensor_id][BATCH_IX]
}

// claude
const mapArrays = (arr1, arr2, fn) => 
    Array.isArray(arr1) 
        ? arr1.map((val, i) => mapArrays(val, arr2[i], fn))
        : fn(arr1, arr2);

const zerosLike = arr => 
    Array.isArray(arr) 
        ? arr.map(val => zerosLike(val))
        : 0;

function swap_acts_w_actgrids_if_necessary(actvol, tensor_id) {
    ///////
    if (globals.acts_color_by==="grads") {
        if (tensor_id in globals.grads) {
            let BATCH_IX = 0
            let grads = globals.grads[tensor_id]["grads"][BATCH_IX] // -1, 0, 1

            function actgrad_fn(act,grad){
                // return (Math.abs(act) * grad)
                return (act * Math.abs(grad))
            }
            let actgrads = mapArrays(actvol, grads, actgrad_fn)

            actvol = actgrads
        } else {
            actvol = zerosLike(actvol)
        }
    } 
    return actvol
}

const neutral_actgrid_color = color_from_z(0)
const actgridPlaneBackgroundMaterial = new THREE.MeshBasicMaterial({ color: actvol_facing_color, wireframe: false });
const backgroundMaterialNeutral = new THREE.MeshBasicMaterial({ color: neutral_actgrid_color, wireframe: false });


// hack to get around aliasing. Btwn our channel slices, was getting strong aliasing, to the point of distracting. 
// this looks 'extra', but i thinks it's in fact simpler than eg making separate LODs. We're transitioning from a light color
// at a distance (less contrast w channel slices) to darker when up close, to give us the contrast we actually want btwn channel
// slices
function interp_color_from_zoom(z) {
    let close_rgb = {
        r: actvol_facing_color.r,
        g: actvol_facing_color.g,
        b: actvol_facing_color.b
    }
    let far_rgb = {
        r: neutral_actgrid_color.r,
        g: neutral_actgrid_color.g,
        b: neutral_actgrid_color.b
    }
    let z_range = [12, 52]
    let in_btwn_color = new THREE.Color(
        interp(z, z_range, [far_rgb.r, close_rgb.r]),
        interp(z, z_range, [far_rgb.g, close_rgb.g]),
        interp(z, z_range, [far_rgb.b, close_rgb.b])
    )
    return in_btwn_color
}
export function on_controls_end(){
    let z = globals.camera.zoom
    actgridPlaneBackgroundMaterial.color = interp_color_from_zoom(z)
}



////////////////////////////////////////////////////
// highlight and selected channel planes
// highlight plane in abs coords, selected planes placed in coords relative to parent
// actgrid so that can transition w them
// coords here confusing zs and ys all tangled

function _get_highlight_plane(){
    const channel_slice_highlight_geometry = new THREE.PlaneGeometry(1, 1);
    const channel_slice_highlight_material = new THREE.MeshBasicMaterial({ color: highlight_color });
    const channel_slice_highlight_plane = new THREE.Mesh(channel_slice_highlight_geometry, channel_slice_highlight_material);
    channel_slice_highlight_plane.rotation.x = -Math.PI / 2;
    channel_slice_highlight_plane.position.y += 1; // move towards camera
    channel_slice_highlight_plane.userData.is_added_to_scene = false
    channel_slice_highlight_plane.visible = false
    return channel_slice_highlight_plane
}
//
export function make_channel_highlight_plane() {

    ////////////////////////////
    // hovered highlight plane
    globals.channel_slice_highlight_plane = _get_highlight_plane()


    ////////////////////////////
    // selected indicators, plane and arrow
    let p = _get_highlight_plane()

    let selected_indication_group = new THREE.Group()

    p.material.color = new THREE.Color(.05, .05, .05)
    // globals.channel_slice_selected_plane = p

    // p.position.z -= 10
    p.visible = true

    selected_indication_group.add(p)
    selected_indication_group.the_plane = p

    let arrow_indicator = get_selected_indicator()
    selected_indication_group.add(arrow_indicator)
    selected_indication_group.the_arrow = arrow_indicator

    selected_indication_group.rotation.x = Math.PI / 2; // opposite of op mesh group. weird. don't like. simplify TODO


    globals.selected_indication_group = selected_indication_group

    // globals.selected_indicator = get_selected_indicator()
}

let selected_channel_rgb = [60,60,60]

function get_selected_indicator() {
    const div = document.createElement('div');
    div.innerHTML = '<i class="fa-solid fa-arrow-down fa-2x"></i>';
    div.style.backgroundColor = 'transparent';
    div.style.padding = '4px';
    div.style.color = `rgb(${selected_channel_rgb[0]},${selected_channel_rgb[1]},${selected_channel_rgb[2]})`;

    const label = new CSS2DObject(div);
    label.center.set( .5, 1); // centered horizontally w node, aligns w bottom of text to node center vertically

    return label;
}

export function make_plane_visible(highlight_plane){
    // add to scene if necessary. Make visible
    if (!highlight_plane.userData.is_added_to_scene) {
        globals.scene.add(highlight_plane)
        highlight_plane.userData.is_added_to_scene = true
    }
    highlight_plane.visible = true
}
export function make_plane_invisible(highlight_plane) {
    if (highlight_plane) {
      highlight_plane.visible = false
    }
}

export function set_selected_feature_plane(op){
    // Don't really like how all this works, w highlight plane and selected planes using different approaches, and w selected
    // plane requiring highlight plane to be active to get pos from it. Strange and backwards.

    let highlight_plane = globals.channel_slice_highlight_plane // this must be currently active, it should be bc we're hovered when calling this
    make_plane_invisible(highlight_plane) // so selected plane shows immediately

    let selected_indication_group = globals.selected_indication_group
    make_plane_visible(selected_indication_group)

    // op used so we can add to group, so it stays with group
    let target_absolute_x = highlight_plane.position.x
    let target_absolute_z = highlight_plane.position.z

    let x_relative_to_op = target_absolute_x - op.mesh.position.x
    let z_relative_to_op = target_absolute_z - op.mesh.position.z


    // clone position and scale from highlight plane
    // but set in relative terms to parent op mesh
    selected_indication_group.position.x = x_relative_to_op // highlight_plane.position.x
    selected_indication_group.position.y = -z_relative_to_op //highlight_plane.position.y
    selected_indication_group.position.z = -highlight_plane.position.y - .03 // wth this hardcoded

    // move arrow up to top of channel
    let channel_height = op.actgrid_group.the_plane.userData.channel_height * globals.act_cube_size
    selected_indication_group.the_arrow.position.z = -channel_height/2

    op.mesh.add(selected_indication_group)
    
    let p = .0
    selected_indication_group.the_plane.scale.x = highlight_plane.scale.x + p
    selected_indication_group.the_plane.scale.y = highlight_plane.scale.y + p
    selected_indication_group.the_plane.scale.z = highlight_plane.scale.z + p

    globals.feature_sidebar_op = op
}


export function createGridOfChannelSlicesWithDataTexture(op) {
    let specs = op.enriched_tensor_specs
    let tensor_id = op.tensor_id

    let BATCH_IX = 0
    // let actvol = example_actvol //
    let actvol = globals.tensor_trace[tensor_id][BATCH_IX] // (c,h,w)

    let featurespace = globals.most_activated_ixs[tensor_id] // preceding_cls_tokens
    let preceding_cls_tokens = featurespace.capture_config["preceding_cls_tokens"]
    let mean_center = featurespace.capture_config["mean_center"]
    // preceding_cls_token should be stored in actvol something something, not here in microscope, as it's relevent 
    // even when no microscope. Right now we're putting this in during microscope, but should be further upstream. TODO

    actvol = swap_acts_w_actgrids_if_necessary(actvol, tensor_id) // updating the data itself, before any reshaping. Keeping in mind target range -127 to 127

    let gridDataPackage = reshapeToGrid(actvol, op, preceding_cls_tokens) // flatted to grid
    let gridData = gridDataPackage.gridData

    let grid_height = gridData.length
    let grid_width = gridData[0].length

    // Create RGBA data
    let data = z_score_matrix_to_flattened_rgb_data_for_texture(gridData, mean_center)

    // Create texture and plane
    const texture = new THREE.DataTexture(data, grid_width, grid_height, THREE.RGBAFormat);
    texture.flipY = true;
    // texture.minFilter = THREE.LinearFilter; // doesn't do anything. 
    texture.needsUpdate = true;

    let h = grid_height * globals.act_cube_size
    let w = grid_width * globals.act_cube_size

    const geometry = new THREE.PlaneGeometry(w, h); // NOTE this is width, height
    const material = new THREE.MeshBasicMaterial({ map: texture, transparent:true });
    const plane = new THREE.Mesh(geometry, material);

    const backgroundPlane = new THREE.Mesh(geometry, actgridPlaneBackgroundMaterial);
    const backgroundPlaneNeutral = new THREE.Mesh(geometry, backgroundMaterialNeutral); // only for minimap

    // set layers
    plane.layers.disableAll()
    plane.layers.enable(ACTVOL_OBJECTS_LAYER)
    plane.layers.enable(CLICKABLE_LAYER)

    backgroundPlaneNeutral.layers.disableAll()
    backgroundPlaneNeutral.layers.enable(ONLY_MINIMAP_LAYER)

    plane.userData.grid_width = grid_width
    plane.userData.grid_height = grid_height
    plane.userData.channels = gridDataPackage.channels
    plane.userData.channel_width = gridDataPackage.channel_width
    plane.userData.channel_height = gridDataPackage.channel_height
    plane.userData.gridData = gridData
    plane.userData.paddingSize = gridDataPackage.paddingSize
    plane.userData.actual_op = op

    function update_texture_data(){
        // call this whenever actvol data changes, e.g. load new trace
        let updated_actvol = get_actvol(tensor_id) // (c,h,w)

        updated_actvol = swap_acts_w_actgrids_if_necessary(updated_actvol, tensor_id)
        
        let gridDataPackage = reshapeToGrid(updated_actvol, op) // flattened to grid
        let gridData = gridDataPackage.gridData
        let data = z_score_matrix_to_flattened_rgb_data_for_texture(gridData, mean_center)
        texture.image.data.set(data);
        texture.needsUpdate = true;
        plane.userData.gridData = gridData
    }
    plane.userData.update_texture_data = update_texture_data

    plane.position.x -= w/2
    plane.position.y -= h/2

    backgroundPlane.position.x -= w/2
    backgroundPlane.position.y -= h/2
    backgroundPlane.position.z -= .02 // slightly back from main plane
    backgroundPlaneNeutral.position.x -= w/2
    backgroundPlaneNeutral.position.y -= h/2
    
    plane.is_actgrid_plane = true

    // Create a group to hold both the planes and the channels
    const group = new THREE.Group();
    group.add(plane);
    group.add(backgroundPlane);
    group.add(backgroundPlaneNeutral);


    group.rotation.x = -Math.PI / 2;
    group.position.y += 1 //0.1; // Shift towards the camera so it doesn't overlap with edges

    op.actgrid_group = group
    group.the_plane = plane

    globals.plane_to_node_lookup[tensor_id] = op 
    // to easily find which node is currently showing the actgrid (as same actgrid will often be associated w multiple nodes)

    return group;
}


export function get_activation_volume(n){

    let act_vol_materials = get_actvol_materials()
    let actvol_dims = n.enriched_tensor_specs.actvol_dims

    // // Calculate the scaling factor to adjust the width
    // const scalingFactor = Math.sqrt(1 + shear_to_show_front ** 2); // Hypotenuse of the shear angle
    // const adjustedWidth = specs.width / scalingFactor; // Adjust the width based on scaling
    // // Create the box geometry with adjusted width
    // const box_geometry = new THREE.BoxGeometry(specs.depth, specs.height, adjustedWidth);

    const box_geometry = new THREE.BoxGeometry(actvol_dims.depth, actvol_dims.height, actvol_dims.width*.5); // total hack estimated value for scalar

    box_geometry.translate(-actvol_dims.depth/2, 0, 0) // origin on the right side so box ends where tensor nodes used to be
    // Apply the shear transformation to the geometry
    box_geometry.applyMatrix4(shearMatrix);
    
    let actvol_mesh = new THREE.Mesh( box_geometry, act_vol_materials )
    
    actvol_mesh.layers.disableAll();
    actvol_mesh.layers.enable(ACTVOL_OBJECTS_LAYER);
    actvol_mesh.layers.enable(CLICKABLE_LAYER);

    actvol_mesh.is_actvol_mesh = true

    actvol_mesh.rotation.x = -Math.PI / 2; // Rotate 90 degrees to make it face upward
    actvol_mesh.position.y += .1 // shift towards camera so doesn't overlap w edges

    let group = new THREE.Group();
    group.add(actvol_mesh)

    group.children.forEach(c => c.actual_node = n) // required for onHover, click events

    return group
}

export function get_mesh_for_op(op) {
    let mesh
    if (op.tensor_node_display_type==="volume"){
        mesh = get_activation_volume(op)
    } else if (op.tensor_node_display_type==="grid") {
        if (op.enriched_tensor_specs.is_attn_matrices) { // attn matrix. Note these attn heads can also be displayed as flatted 
            // mesh = createGridOfAttnHeads(op)
            mesh = createGridOfChannelSlicesWithDataTexture(op)

        } else { // normal channels
            // mesh = createGridOfChannelSlices(op)
            mesh = createGridOfChannelSlicesWithDataTexture(op)
        }
    } else {
        mesh = get_sphere_group(op)
    }
    return mesh
}



export function get_sphere_group(n){
    
    let sphere
    let color = get_node_color(n)
    if (n.node_type=="function" || n.node_type=="module") {
        sphere = new THREE.Mesh( sphere_geometry, new THREE.MeshBasicMaterial( { color: color } ) )
        sphere.layers.disableAll()
        sphere.layers.enable(OP_NODES_OBJECTS_LAYER)
        // sphere.layers.enable(CLICKABLE_LAYER) // using the larger background sphere
    } else {
        // tensor square
        sphere = new THREE.Mesh( square_geometry, new THREE.MeshBasicMaterial( { color: color } ) )
        sphere.layers.disableAll()
        sphere.layers.enable(TENSOR_NODES_OBJECTS_LAYER)
        // sphere.layers.enable(CLICKABLE_LAYER)
    }

    sphere.rotation.x = -Math.PI / 2; // Rotate 90 degrees to make it face upward
    sphere.position.y += .1 // shift towards camera so doesn't overlap w edges

    scale_sphere(sphere, n)

    let group = new THREE.Group();
    group.add(sphere)

    // Create a larger sphere for click events
    let largerSphere = new THREE.Mesh(sphere_geometry,
        new THREE.MeshBasicMaterial({ color: color, transparent: true, opacity: 0 })); // color doesn't matter
    largerSphere.rotation.x = -Math.PI / 2; // Rotate 90 degrees to make it face upward
    largerSphere.position.y += 0
    
    let s = 3
    let larger_sphere_scale = Math.min(sphere.scale.x*s, MAX_SPHERE_SIZE) // don't need our large spheres to have any extra for clicking
    largerSphere.scale.x = larger_sphere_scale
    largerSphere.scale.y = larger_sphere_scale
    largerSphere.scale.z = larger_sphere_scale
    largerSphere.layers.enable(CLICKABLE_LAYER);
    largerSphere.smaller_sphere = sphere
    group.add(largerSphere);

    // add outline to modules
    if (n.node_type=="module") {
        let _sphere = new THREE.Mesh( sphere_geometry, new THREE.MeshBasicMaterial( { color: node_color_outline } ) )
        _sphere.rotation.x = -Math.PI / 2; // Rotate 90 degrees to make it face upward
        _sphere.position.y += .09
        let s = sphere.scale.x+.05 // slightly bigger than inner circle
        _sphere.scale.x = s
        _sphere.scale.y = s
        _sphere.scale.z = s
        group.add(_sphere)
        largerSphere.outline_sphere = _sphere
    }

    //
    group.children.forEach(c => c.actual_node = n) // required for onHover, click events

    group.children.forEach(o => {
        o.visible = true
    })
    return group
}

export function remove_dom_el_bv_class_name (class_name) { // chatgpt
    var elements = document.getElementsByClassName(class_name);
    // Use a while loop to remove elements because getElementsByClassName returns a live HTMLCollection
    while (elements.length > 0) {
        elements[0].parentNode.removeChild(elements[0]);
    }
}

export function clear_scene(sceneToClear) {
    sceneToClear.traverse(function(object) { // chatgpt
        if (object.isMesh) {
            if (object.geometry) {
                object.geometry.dispose();
            }

            if (object.material) {
                if (object.material.isMaterial) {
                    cleanMaterial(object.material);
                } else {
                    // An array of materials
                    for (const material of object.material) cleanMaterial(material);
                }
            }
        }
    });

    function cleanMaterial(material) {
        material.dispose();

        // Dispose textures
        for (const key of Object.keys(material)) {
            const value = material[key];
            if (value && typeof value === 'object' && 'minFilter' in value) {
                value.dispose();
            }
        }
    }
    while(sceneToClear.children.length > 0){
        sceneToClear.remove(sceneToClear.children[0]);
    }

    sceneToClear.clear()
    let names = ["group_label", "label"]
    names.forEach(n => remove_dom_el_bv_class_name(n))
}

export function get_plane_specs(op){
    
    // let w = op.w
    let w = op.plane_info.max_x - op.plane_info.min_x
    let h = op.plane_info.max_y - op.plane_info.min_y
    let target_y = get_z_plane(op) // based on depth
    let target_x = op.plane_info.min_x + w/2
    let target_z = op.plane_info.min_y + h/2

    const target_pos = { x: target_x, y: target_y, z: target_z };

    return [h,w,target_pos]
}

export function remove_sphere(op) {
    if (op.mesh != undefined) {
        remove_label_from_op_and_return_to_pool(op)
    }
    scene.remove(op.mesh)
    op.mesh = undefined
}

export function scale_to_zero_and_shift_to_location_then_remove(mesh, target_position, onComplete) {

    new TWEEN.Tween(mesh.scale)
        .to({x:0, y:0, z:0}, TWEEN_MS) 
        .easing(TWEEN_EASE)
        .start();
        
    new TWEEN.Tween(mesh.position)
            .to(target_position, TWEEN_MS) 
            .easing(TWEEN_EASE)
            .onComplete(() => {
                onComplete()
            })
            .start();
}

function deleteCSS2DLabel(label) { // chatgpt
    if (label && label.element && label.parent) {
        // Remove the label from the scene or its parent
        // label.parent.remove(label);
        scene.remove(label)

        // Dispose of the label's DOM element
        if (label.element.parentNode) {
            label.element.parentNode.removeChild(label.element);
        }
    }
}

export function remove_all_meshes(op, target_position) {
    if (op.mesh != undefined) { // node

        // // check if actgrid has selected channel
        // if (("actgrid_group" in op) && op.actgrid_group.has_selected_channel) {
        //     console.log("clearing selected channel")
        //     // clear_selected_channel()
        // }

        function onComplete(){
            remove_sphere(op)
        }
        scale_to_zero_and_shift_to_location_then_remove(op.mesh, target_position, onComplete)

    } else if (op.expanded_plane_mesh != undefined) { // plane

        // remove label immediately
        deleteCSS2DLabel(op.expanded_plane_label)
        op.expanded_plane_label = undefined

        let planes = [op.expanded_plane_mesh, op.expanded_plane_background_mesh] // NOTE if do forEach directly on this array rather than declare it first, then preceding code needs semicolon
        planes.forEach(plane => { // NOTE TODO we're doing the onComplete twice, fix
            new TWEEN.Tween(plane.position)
                .to(target_position, TWEEN_MS) 
                .easing(TWEEN_EASE)
                .start();

            new TWEEN.Tween(plane.scale)
                .to({x:0, y:0, z:0}, TWEEN_MS) 
                .easing(TWEEN_EASE)
                .onComplete(() => {
                    scene.remove(op.expanded_plane_mesh)
                    scene.remove(op.expanded_plane_background_mesh)

                    op.expanded_plane_mesh = undefined
                    op.expanded_plane_background_mesh = undefined
                })
                .start();
        })

    }	
    op.children.forEach(c => {remove_all_meshes(c, target_position)})
}

export function nice_name(op) {
	return op.name.slice(0, 10) + "-" + op.node_id.slice(op.node_id.length-4, op.node_id.length)
}

///////////////////////////////
// utils
///////////////////////////////

export function get_main_window_position() {
    const h_width = globals.camera.right / globals.camera.zoom;
    const h_height = globals.camera.top / globals.camera.zoom;

    let cx = globals.camera.position.x
    let cz = globals.camera.position.z
    
    return [h_width, h_height, cx, cz];
  }

export function mark_attr(op, attr, value) {
    // mark this op and all its children
    op[attr] = value
    op.children.forEach(c => {
        mark_attr(c, attr, value)
    })
}

export function interp(xPoint, breakpoints, values) { // like numpy interpolate. from chatgpt
	// Find the first breakpoint larger than the xPoint
    // has to be arranged from lowest to highest
	const upperIndex = breakpoints.findIndex(breakpoint => breakpoint > xPoint);
	if (upperIndex === -1) {
		return values[values.length - 1]; // Return the last value if xPoint is beyond the range
	}
	if (upperIndex === 0) {
		return values[0]; // Return the first value if xPoint is before the range
	}

	// Perform linear interpolation
	const lowerIndex = upperIndex - 1;
	const lowerBreakpoint = breakpoints[lowerIndex];
	const upperBreakpoint = breakpoints[upperIndex];
	const lowerValue = values[lowerIndex];
	const upperValue = values[upperIndex];

	const t = (xPoint - lowerBreakpoint) / (upperBreakpoint - lowerBreakpoint);
	return lowerValue + t * (upperValue - lowerValue);
}

// edges
export function get_ns(op, uns_or_dns) {
    let ns = op[uns_or_dns].map(nid => globals.nodes_lookup[nid])
    ns = ns.filter(n => n != undefined) // was getting lots of undefineds from removing aux outputs
    return ns
}
// get nodes fns dominate timing
export function get_downstream_peer_nodes(base_op) {
    let all_dns = get_ns(base_op, "dns")
    // let just_peer_dns = all_dns.filter(dn => dn.parent_op.name==base_op.parent_op.name)
    // BUG REPORT i am an idiot. check out that filter, what if have same name ??? eg Sequential??? why did i do it that way? 2.5 hr bug.
    let just_peer_dns = all_dns.filter(dn => dn.parent_op==base_op.parent_op)
    return just_peer_dns
}
export function get_upstream_peer_nodes(base_op) {
    let all_uns = get_ns(base_op, "uns")
    // let just_peer_uns = all_uns.filter(un => un.parent_op.name==base_op.parent_op.name)
    let just_peer_uns = all_uns.filter(un => un.parent_op==base_op.parent_op)
    return just_peer_uns
}

export function get_downstream_nodes_from_group(base_op, ops) {
    return ops.filter(o => base_op.dns.includes(o.node_id))
}
export function get_upstream_nodes_from_group(base_op, ops) {
    return ops.filter(o => base_op.uns.includes(o.node_id))
}

// TODO consolidate these
export function mark_all_mods_of_family_as_collapsed(op, family, to_remove_container){
    if (op.node_type=="module" && op.name==family && !op.collapsed) {
      op.collapsed = true
      to_remove_container.push(op)
    }
    op.children.forEach(c => mark_all_mods_of_family_as_collapsed(c, family, to_remove_container))
}
export function mark_all_mods_of_family_as_expanded(op, family, to_expand_container){
    if (op.node_type=="module" && op.name==family && op.collapsed) {
      op.collapsed = false
      to_expand_container.push(op)
    }
    op.children.forEach(c => mark_all_mods_of_family_as_expanded(c, family, to_expand_container))
}


export function mark_all_mods_past_depth_as_collapsed(level){
    let to_collapse_container = []
    let to_expand_container = []
    function _mark_all_mods_past_depth_as_collapsed(o){
        if (o.node_type=="module") {
            if (o.depth>=level && !o.collapsed) {
                o.collapsed = true
                to_collapse_container.push(o)
            } else if (o.depth<level && o.collapsed) {
                o.collapsed = false
                to_expand_container.push(o)
            }
            o.children.forEach(c => _mark_all_mods_past_depth_as_collapsed(c))
        } 
    }
    _mark_all_mods_past_depth_as_collapsed(globals.nn)

    return [to_collapse_container, to_expand_container]
}

// stats, tooltips number formatting
export function formatNumParams(num) {
    if (num >= 1e9) {
      return (num / 1e9).toFixed(1) + 'b';
    } else if (num >= 1e6) {
      return (num / 1e6).toFixed(1) + 'm';
    } else if (num >= 1e3) {
      return (num / 1e3).toFixed(1) + 'k';
    } else {
      return num.toFixed(1).toString();
    }
}
export function formatMemorySize(numBytes) {
    const ONE_KB = 1024;
    const ONE_MB = 1024 * ONE_KB;
    const ONE_GB = 1024 * ONE_MB;

    if (numBytes >= ONE_GB) {
        return (numBytes / ONE_GB).toFixed(1) + ' GB';
    } else if (numBytes >= ONE_MB) {
        return (numBytes / ONE_MB).toFixed(1) + ' MB';
    } else if (numBytes >= ONE_KB) {
        return (numBytes / ONE_KB).toFixed(1) + ' KB';
    } else {
        return Math.round(numBytes) + ' bytes';
    }
}

export function formatLatency(ms) {
    if (ms >= 1000) {
      // If the time is more than a second, format it as seconds with one decimal
      return (ms / 1000).toFixed(1) + 's';
    } else if (ms >= 1) {
      // If the time is more than a millisecond, return in milliseconds
      return Math.round(ms) + 'ms';
    } else {
      // If the time is less than a millisecond, return in microseconds
      return Math.round(ms * 1000) + 'µs';
    }
  }
  
export function save_current_state() {

    let properties_to_copy = ["name", "collapsed", "mod_identifier"]

    let expanded_ops = globals.ops_of_visible_planes.map(op => {
        let copy = {}
        properties_to_copy.forEach(p => copy[p]=op[p])
        return copy
    })

    let trace_name = globals.nn.trace_metadata.name

    let saved_settings = {
        "name":trace_name,
        "expanded_ops":expanded_ops
    }
    saveCompressedJSON(saved_settings, trace_name)
}

export function load_saved_settings(nn, saved_settings) {
    Object.keys(globals.nodes_lookup).forEach(nid => {
        let op = globals.nodes_lookup[nid]
        op.collapsed = true
    })

    nn.collapsed = false // root
    saved_settings.expanded_ops.forEach(expanded_op => {
        let mod = globals.modules_lookup_by_identifier[expanded_op.mod_identifier] // root not in there
        if (mod) {
            mod.collapsed = false
        }
    })
}

function saveCompressedJSON(jsonObject, trace_name) { // chatgpt
    let filename = `darkspark_defaults_${trace_name}.json.gz`

    // Step 1: Convert the JSON object to a string
    const jsonString = JSON.stringify(jsonObject);

    // Step 2: Compress the JSON string using pako.gzip
    const compressed = pako.gzip(jsonString);

    // Step 3: Create a Blob from the compressed data
    const blob = new Blob([compressed], { type: 'application/gzip' });

    // Step 4: Create a download link and trigger it
    const link = document.createElement('a');
    link.href = URL.createObjectURL(blob);
    link.download = filename;
    document.body.appendChild(link); // Needed for Firefox
    link.click();
    document.body.removeChild(link); // Clean up
}


export function saveMinimapAsImage(renderer, camera) {
    // Render the current scene from the camera's perspective
    globals.minimap_window_plane.visible = false
    renderer.render(scene, camera);
    
    // Get the data URL of the canvas
    const imgData = renderer.domElement.toDataURL("image/png");
  
    // Create a temporary link element to trigger the download
    const link = document.createElement('a');
    link.href = imgData;
    let trace_name = globals.nn.trace_metadata.name
    link.download = `darkspark_thumbnail_${trace_name}.png`;
    link.click();

    globals.minimap_window_plane.visible = true
    renderer.render(scene, camera);

  }

export const transformers_str_w_emoji = '\u{1F917} Transformers';
export const diffusers_str_w_emoji = '\u{1F917} Diffusers';

export const base_help_text = "Left-click and drag to pan scene. Scroll to zoom."

/////////////////////////////////////////////
// color by


/////////////

////////////////// Simple by op type
let tensor_ops = ["reshape*", "cat", "__getitem__"]

function op_type_to_color(op) {
    if (op.is_tensor_node || tensor_ops.includes(op.name)) {
        return new THREE.Color('grey')
    } else {
        return node_color
    } 
}

const color_by_lookup = {
    "none": op_type_to_color,
    "debug": get_node_color,
}


let continuous_color_bys = ["latency", "n_params", "incremental_memory_usage", "max_memory_allocated"]
export function update_node_colors() {
    let colorby_attr = globals.nodes_color_by
    let color_by_fn

    if (continuous_color_bys.includes(colorby_attr)) {
        // continuous colorBy, compile on the fly based on scale of visible nodes
        // TODO perf can improve here, sort then take min max rather than Math.min etc
        let colorby_values = []

        globals.ops_of_visible_nodes.forEach(op => {
            if (colorby_attr in op) {
                colorby_values.push(op[colorby_attr])
            }
        })
        function value_adjuster(v) {
            // return Math.sqrt(s) 
            return v 
        }
        colorby_values = colorby_values.map(s => {
            return value_adjuster(s)
        })
        let _max = Math.max(...colorby_values)
        let _min = Math.min(...colorby_values)
        function normalize_colorby(v) {
            let normalized = (value_adjuster(v) - _min) / (_max - _min)
            return normalized
        }
        function colorby_continuous(op) {
            if (colorby_attr in op) {
                let l = normalize_colorby(op[colorby_attr])
                return new THREE.Color(l, .3, .3)
            } else {
                return new THREE.Color('grey')
            }
        }
        color_by_fn = colorby_continuous
    } else { // categorical colorBy, get from lookup
        color_by_fn = color_by_lookup[colorby_attr]
    }

    globals.ops_of_visible_nodes.forEach(op => {
        if ("mesh" in op) {
            if (op.tensor_is_expanded){
                if (op.tensor_node_display_type==="grid") {
                    if (globals.actgrids_need_texture_update) {
                        op.actgrid_group.the_plane.userData.update_texture_data()
                    }
                    // color_individual_activations(op.mesh, op)
                } else {
                    // is actvol, no need to update color
                }
            } else {
                let node = op.mesh.children[0]
                
                node.material.color = color_by_fn(op)
            }
        }
    })
    globals.actgrids_need_texture_update = false
}

let base_acts_display_options = [
    {label:'nodes', value:'collapsed', tooltip:"collapse all activation tensors into nodes"}, 
    {label: 'volumes', value: 'volumes', tooltip:"expand activation tensors into volumes when shape information is known"},
]
                    
let acts_display_options_w_acts = base_acts_display_options.concat(
    [{label:"full activations", value:'expanded', tooltip:"show actual activation values where available"}]
)


export async function load_model(model_path){
    ///////////////
    // model arch graph
    let response = await fetch(model_path)
    let arrayBuffer = await response.arrayBuffer();

    const uint8Array = new Uint8Array(arrayBuffer);
    const decompressed = pako.ungzip(uint8Array, { to: 'string' });
    const _nn = JSON.parse(decompressed);
    globals.nn = _nn
}

export async function load_feature_descriptions(){

    let metadata = globals.nn.trace_metadata
    let descriptions_path = get_file_path(`data/feature_descriptions/${metadata.name}.json`)
    let feature_descriptions
    console.log(descriptions_path)
    let response = await fetch(descriptions_path)
    try {
        feature_descriptions = await response.json() // not compressed for now
        console.log("feature descriptions DO exist for this model")
    } catch(e) {
        if (globals.nn.trace_metadata.name==="dino-v2-large-microscope") alert("Feature descriptions don't exist for dino?"); // TODO remove
        console.log("feature descriptions don't yet exist")
        feature_descriptions = {}
    }

    globals.feature_descriptions = feature_descriptions
}

export function save_description_text(tensor_id, channel_ix, text) {
    // if no tensor, init
    if (!(tensor_id in globals.feature_descriptions)) {
        globals.feature_descriptions[tensor_id] = {}
    }
    // if no channel, init
    if (!(channel_ix in globals.feature_descriptions[tensor_id])) {
        globals.feature_descriptions[tensor_id][channel_ix] = {}
    }
    globals.feature_descriptions[tensor_id][channel_ix]["body"] = text

    let model_name = globals.nn.trace_metadata.name
    let feature_descriptions = globals.feature_descriptions

    let body = { model_name, feature_descriptions }

    fetch('http://127.0.0.1:5000/save-feature-descriptions', {
        method: 'POST',
        headers: {
          'Content-Type': 'application/json',
        },
        body: JSON.stringify(body),
      })
      .then(response => response.json())
      .then(data => {
        console.log('Success:', data)
      })
      .catch(error => console.error('Error:', error));
}

export async function load_toc() {
    let metadata = globals.nn.trace_metadata
    let toc_path = get_file_path(`data/toc/${metadata.name}.json`)
    await fetch(toc_path) // pause execution until this is done loading
        .then(response => response.json())
        .then(toc_entry => {
            globals.toc_entry = toc_entry
            console.log("loaded toc entry", toc_entry)

            if (toc_entry.has_featurespace) {

                console.time("register available acts")
                // TODO could move to only check this on mouseover node, though this loop only takes 1.5ms in dino-v2-large
                let tensors_w_acts = toc_entry.layers_with_activations
                globals.attn_matrix_tids = []
                Object.keys(globals.nodes_lookup).forEach(nid => { // marking all nodes, not just those currently shown
                    let op = globals.nodes_lookup[nid]
                    op.activations_available = false // default can't show actgrid
                    ///////
                    if (op.should_draw && // otherwise all the extra debug nodes will also be expanded. We may want to limit one node per tensor bc they are somewhat costly to draw and it's a bit confusing when same tensor is drawn in multiple locations, though sometimes that is also what we want eg when they're far apart
                        tensors_w_acts.includes(op.tensor_id)) { // assumes also in featurespace
                        let op_dim_types_id = op.dim_types.slice(1).join("-")
                        // if (can_do_dim_types.includes(op_dim_types_id) ||
                        //     op.enriched_tensor_specs?.spatial_is_inferred // 
                        //     ){ // if we currently support the dim types arrangement
                        // if we're saving acts, can assume we support the dim types
                            op.activations_available = true
                            
                            // awkward way of identifying attn matrix ops. Hitching ride here. Ideally we'd have a datastructure containing
                            // unique tensors, same as backend, which would have this info
                            if (attn_matrix_op_types.includes(op.created_by_fn)) {
                                globals.attn_matrix_tids.push(op.tensor_id)
                            }
                        // }
                    }
                })
                console.timeEnd("register available acts")
        
                if (toc_entry.trace_ids) {

                    // get individual trace
                    globals.setTracedImgsList(toc_entry.trace_ids)
                    
                    let dataset = globals.nn.trace_metadata.dataset

                    let default_img = dataset==="imagenet" ? "zebra_00" : "mustache2"
                    let trace_id = toc_entry.trace_ids.includes(default_img) ? default_img : toc_entry.trace_ids[0]
                    globals.setInputImage(trace_id)

                    globals.setTensorTraceId(trace_id)
                    globals.current_trace = trace_id

                    globals.setActsDisplayOptions(acts_display_options_w_acts)
                    globals.acts_display_options = acts_display_options_w_acts.map(e=>e.value)
                } else {
                    globals.setActsDisplayOptions(base_acts_display_options)
                    globals.acts_display_options = base_acts_display_options.map(e=>e.value)

                }
        
            } else {
                globals.setActsDisplayOptions(base_acts_display_options)
                globals.acts_display_options = base_acts_display_options.map(e=>e.value)

            }
    }) 
}

let can_do_dim_types = [linear_dim_types_id, 
                        // linear_spatial_inferred_dim_types_id, // don't do it this way TODO
                        single_linear_dim_types_id, 
                        conv2d_dim_types_id]

let attn_matrix_op_types = ["matmul", "softmax"]


///////////////////////////////////////////////////////////////

// general enriched tensor specs, used by both actgrid and volumes
// static, won't change after compiling for first time
function _get_enriched_tensor_specs(op) {
    let specs = {
        'features':[],
        'spatial':[]
    }
    if ("dim_types" in op) {
        op.dim_types.forEach((d,i)=>{
            let s = op.shape[i]
            if (d in specs) {
                specs[d].push(s)
            }
        })
    }
    if (specs.features.length==1 && specs.spatial.length==2) { // normal standard volume
        return specs
    } else if (specs.spatial.length==2 && specs.features.length==0) { // one channel is implied, add it in, eg depth output
        specs.features.push(1)
        return specs
    } else if (specs.spatial.length==0 && specs.features.length==1) { // single feature vector, add ones for spatial to show
        if (op.activations_available) { 
            specs.spatial.push(1) 
            specs.spatial.push(1)
            return specs
        } else {
            return undefined     // don't like how they look. too long and thin, take up too much space
        }

    } else if (specs.spatial.length==1 && specs.features.length==1) { 
        // sequence, eg text, pad width w one and sequence shown vertically. Or flattened img

        // specs.spatial.push(1)
        // return specs
        
        ///////////////////////////////////////////////
        ///////////////////////////////////////////////
        // NOTE hardcoded manual area
        if (globals.nn.trace_metadata.name==="vit_base_patch16_siglip_224-microscope") {
            let hw = specs.spatial[0]
            if (hw===196) {
                specs.spatial = [14, 14]
                specs.spatial_is_inferred = true
                return specs
            }
        } else if (globals.nn.trace_metadata.name==="dino-v2-large-microscope") { //
            let hw = specs.spatial[0]
            if (hw===257) {
                specs.spatial = [17,16]
                specs.spatial_is_inferred = true
                return specs
            }
        } else if (globals.nn.trace_metadata.name==="dino-v2-giant-microscope") { //dino-v2-large-microscope
            let hw = specs.spatial[0]
            if (hw===257) {
                specs.spatial = [17,16]
                specs.spatial_is_inferred = true
                return specs
            }
        }else if (globals.nn.trace_metadata.name==="vit_base_patch16_clip_224-microscope") {
            let hw = specs.spatial[0]
            if (hw===197) {
                specs.spatial = [15,14]
                specs.spatial_is_inferred = true
                return specs
            }
        } else if (globals.nn.trace_metadata.name==="vit_base_patch32_224-microscope") { 
            let hw = specs.spatial[0]
            if (hw===50) {
                specs.spatial = [8,7]
                specs.spatial_is_inferred = true
                return specs
            }
        } else if (globals.nn.trace_metadata.name==="depth-anything-v2-microscope") {
            let hw = specs.spatial[0]
            if (hw===257) {
                specs.spatial = [17,16]
                specs.spatial_is_inferred = true
                return specs
            }
        } else if (globals.nn.trace_metadata.name==="segformer-b0-finetuned-ade-microscope") {
            let hw = specs.spatial[0]
            if ((Math.sqrt(hw) - Math.floor(Math.sqrt(hw)))==0) { // if square, take side
                let side = Math.sqrt(hw)
                specs.spatial = [side, side]
                specs.spatial_is_inferred = true
                return specs
            }
        }


        ////////////////////////////////////////////////
        ////////////////////////////////////////////////
        

        return undefined // visually don't like, for now. Just doing CV for actvols and microscope for now
    } else {
        return undefined
    }
}

// actvol dimensions. used only by actvol
// static, won't change after getting for first time
function get_actvol_dims(enriched_tensor_specs) {

    let actvol_dims = {}
    let channels_scalar = globals.actvol_base_scalar
    let spatial_scalar = channels_scalar * 1

    actvol_dims.height = enriched_tensor_specs.height_n_acts * spatial_scalar
    actvol_dims.width = enriched_tensor_specs.width_n_acts * spatial_scalar
    actvol_dims.depth = enriched_tensor_specs.depth_n_acts * channels_scalar

    let MIN_SPATIAL = .04
    actvol_dims.width += MIN_SPATIAL
    actvol_dims.height += MIN_SPATIAL

    let MAX_SPATIAL = 10 //3
    actvol_dims.width = Math.min(actvol_dims.width, MAX_SPATIAL) // TODO have to indicate overflow here also 
    actvol_dims.height = Math.min(actvol_dims.height, MAX_SPATIAL)


    actvol_dims.depth /= 4 // don't like this, as volume should be consistent across spatial and channels, they are equivalent 

    let MAX_DEPTH = 16
    actvol_dims.depth = Math.min(actvol_dims.depth, MAX_DEPTH)
    if (actvol_dims.depth===MAX_DEPTH) actvol_dims.depth_is_overflowed = true;

    return actvol_dims
}

function get_enriched_tensor_specs(op){
    let specs = _get_enriched_tensor_specs(op)
    if (specs) { // specify volume dimensions, used in both grid and volume
        specs.height_n_acts = specs.spatial[0]
        specs.width_n_acts = specs.spatial[1]
        specs.depth_n_acts = specs.features[0]

        specs.actvol_dims = get_actvol_dims(specs) // when drawn as actvol, these are the dimensions

        if (["matmul", "softmax"].includes(op.created_by_fn)) {
            specs.is_attn_matrices = true
        }
        return specs
    }
}

// enrich with specs to be used for actvols and actgrids
// For simplicity, just determine all specs for all tensor ops though many (most) won't ever be used. 
// Could do this on demand but not spending much perf here anyways (i don't think)
// These are constant, won't change. Just do once on load nn
export function add_tensor_specs() {
    // assumes globals.nn is in place

    function add_activation_volume_specs(op) {
        if (op.is_tensor_node) { 
            let enriched_tensor_specs = get_enriched_tensor_specs(op)
            if (enriched_tensor_specs !== undefined) { // have sufficient information to draw actgrids or volumes
                op.enriched_tensor_specs = enriched_tensor_specs
                op.has_enriched_tensor_specs = true
            }
        } 
        op.children.forEach(c=>add_activation_volume_specs(c))
    }
    add_activation_volume_specs(globals.nn)
}

// lets the spinner start. otherwise doesn't start. Wrap any fn in this to start spinner beforehand.
// don't like this. If could just set the effing spinner that would be nice. 
export function thinkingFn(fn, thinkingAbout){
    function wrappedFn(...args) {
        globals.setIsThinking(true)
        globals.setHelpInformation(thinkingAbout)
        setTimeout(() => {
        fn(...args)
        }, 1);
    }
    return wrappedFn
}

export async function load_trace_layer(tensor_id, onComplete){
    // loads a single trace layer, then does something on the completion of it. 
    // if layer already stored in globals, just do the onComplete, otherwise load it from server then do the onComplete
    // eg on click the tensor, this will open it up, thus calling from the backend the tensor if needed
    let model_name = globals.nn.trace_metadata.name
    let current_trace = globals.current_trace
    let trace_layer_path = `data/tensor_traces/${model_name}/${current_trace}/${tensor_id}.json.gz`
    trace_layer_path = get_file_path(trace_layer_path)
    // console.log("Grabbing trace layer", trace_layer_path)
    if (tensor_id in globals.tensor_trace) {
    //   console.log("activations already fetched for this layer")
      onComplete()
    } else {
        // don't have activations for this img and this layer yet, fetch them
        // console.log("don't have activations for this layer. Fetching...")
        await fetch(trace_layer_path)
            .then(response => response.arrayBuffer())
            .then(arrayBuffer => {
                // decompress gzip
                const uint8Array = new Uint8Array(arrayBuffer);
                const decompressed = pako.ungzip(uint8Array, { to: 'string' });
                const trace_layer = JSON.parse(decompressed);
                globals.tensor_trace[tensor_id] = trace_layer
                onComplete()

        }) 
    }
}

export async function load_most_activated_ixs(tensor_id) {
    let model_name = globals.nn.trace_metadata.name
    let most_activated_ixs_path = `data/most_activated_ixs/${model_name}/${tensor_id}.json.gz`
    most_activated_ixs_path = get_file_path(most_activated_ixs_path)
    // console.log("grabbing most_activated_ixs_path", most_activated_ixs_path)
    if (tensor_id in globals.most_activated_ixs) {
    //   console.log("most_activated_ixs already fetched for this layer")
    } else {
        // don't have most_activated_ixs for this img and this layer yet, fetch them
        // console.log("don't have most_activated_ixs for this layer. Fetching...")

        await fetch(most_activated_ixs_path)
                .then(response => response.arrayBuffer())
                .then(arrayBuffer => {
                    // decompress gzip
                    const uint8Array = new Uint8Array(arrayBuffer);
                    const decompressed = pako.ungzip(uint8Array, { to: 'string' });
                    const most_activated_ixs = JSON.parse(decompressed);
                    // console.log("got featurespace", most_activated_ixs)
                    globals.most_activated_ixs[tensor_id] = most_activated_ixs
                    globals.most_activated_ixs_are_loaded = true
        }) 
    }
}
// node is correct type to be potential actvol
export let always_show_act_vol_fns = ["conv2d", "conv_transpose2d", "linear", "max_pool2d", "cat", "mean", "interpolate", 
    "avg_pool2d", "adaptive_avg_pool2d", "adaptive_avg_pool1d",
    "add" // i like to see it coming to the respath, but maybe take out
] // TODO this is brittle. Has to be anything that might have an actgrid 
export let show_act_vol_if_shape_changes = ["__getitem__", "chunk", "split", "unfold", "stack"]

export function on_load_model_prep() {
    // things that only need to be run once after model loads
    let nn = globals.nn

    function copy_dims(op) {
        op.x_relative_original = op.x_relative
        op.y_relative_original = 0 //op.y_relative no longer capturing this, just row order BUG REPORT this took 1.5 hrs to find this: there were nans here so in the base row case y_relative was never being marked! nans were only breaking things on depth change. After removing y_relative from our python backend, i should have been more careful to test everything
        op.children.forEach(c => copy_dims(c))
    }
    copy_dims(nn)

    // mark parentage, convenience
    function mark_parentage(op) {
      op.children.forEach(c => {
        c.parent_op = op
        mark_parentage(c)
      })
    }
    mark_parentage(nn)

    globals.nodes_lookup = {}
    function add_to_nodes_lookup(op) { // modules and ops
      globals.nodes_lookup[op["node_id"]] = op
      op.children.forEach(c => add_to_nodes_lookup(c))
    }
    add_to_nodes_lookup(nn)

    //
    globals.modules_lookup_by_identifier = {}
    function add_to_mods_lookup(op) { // modules and ops
      if (op.node_type=="module") {
        if (op.mod_identifier in globals.modules_lookup_by_identifier) {
          console.log("duplicate mod identifier already in lookup? XXXXXXXXXXXXXXXXXXXXX shouldn't happen", op.mod_identifier)
          console.log("node_ids", globals.modules_lookup_by_identifier[op.mod_identifier]["node_id"], op["node_id"])
        } // TODO perf
        globals.modules_lookup_by_identifier[op.mod_identifier] = op
        op.children.forEach(c => add_to_mods_lookup(c))
      }
    }
    add_to_mods_lookup(nn)
    //

    // adding actual upstream nodes, for convenience
    function link_upstream_nodes(op){
      op.upstream_nodes = op.uns.map(nid => globals.nodes_lookup[nid])
      op.children.forEach(c => link_upstream_nodes(c))
    }
    link_upstream_nodes(nn)

    // set max depth, used for scales
    globals.max_depth = 0
    function set_max_depth(op) {
      globals.max_depth = Math.max(globals.max_depth, (op.depth ? op.depth : 0))
        if (!op.collapsed){
            op.children.forEach(c => set_max_depth(c))
        }
    }
    set_max_depth(nn)

    ///////////////////
    // mark extraneous fn_outs
    // extraneous_io means it's fn_output and it stacks up w others eg mod_out to not need to be shown
    // these will be eg fn_out directly before an output node
    // doing these here bc don't want to assign them as volumes bc they're going to disappear anyways
    console.time("mark extraneous fn_outs")
    function mark_extraneous_io(op){
        op.children.forEach(o => {
            if (o.is_output) {   
                let uns = get_upstream_peer_nodes(o)
                if ((uns.length==1) && ["fn_out", "mod_out"].includes(uns[0].node_type)) {
                    if (!always_show_act_vol_fns.includes(uns[0].created_by_fn) && // if fn_out from important op, keep it
                            !show_act_vol_if_shape_changes.includes(uns[0].created_by_fn) // TODO will need to check for shape changes here
                            ) { 
                        uns[0].node_is_extraneous_io = true
                    }
                }
            } 
            mark_extraneous_io(o)
        })
    }
    mark_extraneous_io(nn)
    console.timeEnd("mark extraneous fn_outs")


    //////////////////////////////////
    // set should_draw. Used to hide nodes used for structural layout, eg elbows, inputs, etc, but to show them 
    // when eg debug
    for (let op_id in globals.nodes_lookup) {
        let op = globals.nodes_lookup[op_id]
        if (((op.node_type=="function" || 
            op.node_type=="module" || 
            op.is_global_input || 
            op.node_type=="fn_out" || 
            op.node_type=="mod_out") &&
            !op.node_is_extraneous_io) || globals.DEBUG //|| op.tensor_is_expanded // why did we have that? breaks things. 
            ) {
                op.should_draw = true
        } else {
            op.should_draw = false
        }
    }
    ///////////////////////////////////////////////////////////////////////////////////
    ///////////////////////////////////////////

}

let bucketName = "darkspark-83550.firebasestorage.app"
let DATA_VERSION = 'preview'

export function get_file_path(path) {
    let is_local = window.location.hostname === 'localhost';
    if (is_local) {
        let fileUrl = "/"+path
        return fileUrl
    } else {
        path = `${DATA_VERSION}/${path}`
        const fileUrl = `https://firebasestorage.googleapis.com/v0/b/${bucketName}/o/${encodeURIComponent(path)}?alt=media`
        return fileUrl
    }
}
export function get_image_path(path) {
    path = `datasets/${path}`
    let is_local = window.location.hostname === 'localhost';
    if (is_local) {
        let fileUrl = "/"+path
        return fileUrl
    } else {
        // path = `${DATA_VERSION}/${path}`
        const fileUrl = `https://firebasestorage.googleapis.com/v0/b/${bucketName}/o/${encodeURIComponent(path)}?alt=media`
        return fileUrl
    }
}

export function get_feature_object(op, channel_ix) {
    // convenience fn to compile the featureTooltipObject or featureSidebarObject.
    // TODO should prob take in tensor_id and channel_ix instead, as that makes most sense. Sidebar is associated w tensor-channel, 
    // regardless of which node it's shown in
    let tensor_id = op.tensor_id
    let featurespace = globals.most_activated_ixs[tensor_id]

    let is_zero_thresh = ["gelu", "relu", "silu"].includes(op.created_by_fn) // TODO should come in on config
    let showNegImgs = !is_zero_thresh

    let dataset_name = globals.nn.trace_metadata.dataset

    let percentiles = featurespace["percentiles"][channel_ix]
    let hist_bins = featurespace["hist_bins"][channel_ix]

    let fired_rate = featurespace["fired_rate"][channel_ix]
    let channel_mean = featurespace["channel_means"][channel_ix]
    let channel_stds = featurespace["channel_stds"]
    if (typeof channel_stds === 'number') channel_stds = [channel_stds]; // rare case eg depthanything last layer, single channel, this single channel got squeezed out in np 
    let channel_std = channel_stds[channel_ix]


    let top_5 = featurespace["top_5s"][channel_ix]

    let image_paths_pos = top_5.map(ix => {
      let path = `${dataset_name}/image_${ix}.png`
      return get_image_path(path)
    })
    let image_paths_mid = featurespace["mids"][channel_ix].map(ix => {
      let path = `${dataset_name}/image_${ix}.png`
      return get_image_path(path)
    })

    let image_paths_neg = []
    if (showNegImgs) {
      let bottom_5 = featurespace["bottom_5s"][channel_ix]
      image_paths_neg = bottom_5.map(ix => {
        // return `/data/${dataset_name}/image_${ix}.png`
        let path = `${dataset_name}/image_${ix}.png`
        return get_image_path(path)

      })
    }
    let overlaysPath = `data/featurespace_overlays/${globals.nn.trace_metadata.name}/${tensor_id}_${channel_ix}.json.gz`
    overlaysPath = get_file_path(overlaysPath)

    let node_id = op.node_id // this shouldn't be needed. Remove after change to tensor_id
    let capture_config = featurespace.capture_config

    let feature_tooltip_obj = {image_paths_pos, image_paths_mid, image_paths_neg, channel_ix, 
      tensor_id, overlaysPath, showNegImgs, percentiles, hist_bins, fired_rate,
      channel_mean, channel_std, node_id, capture_config}

    return feature_tooltip_obj
}

export function set_highlight_plane_at_channel(op, channel_ix) {
    let plane = op.actgrid_group.the_plane


    let channels = plane.userData.channels
    const gridSize = Math.ceil(Math.sqrt(channels));
    const channel_width = plane.userData.channel_width; // height and width equal always for now
    const channel_height = plane.userData.channel_height; // height and width equal always for now
    const paddingSize = plane.userData.paddingSize;

    const channel_x = (channel_ix%gridSize) * (channel_width + paddingSize) * globals.act_cube_size
    const channel_y = Math.floor(channel_ix/gridSize) * (channel_height + paddingSize) * globals.act_cube_size
    
    let px = op.actgrid_group.position.x - (plane.userData.grid_width*globals.act_cube_size) // top left of actgrid
    let py = op.actgrid_group.position.z 
    px += channel_x // channel pos, including padding
    py += channel_y
    // 
    px += (paddingSize)*globals.act_cube_size
    py += (paddingSize)*globals.act_cube_size // now at the top left of the channel, excluding padding, 
    // 
    let p = Math.ceil(paddingSize/2) // amount in grid units highlight plane should extend out
    px -= p*globals.act_cube_size
    py -= p*globals.act_cube_size // now adding some margin around the highlight plane

    let highlight_plane_height = (channel_height+p*2)*globals.act_cube_size
    let highlight_plane_width = (channel_width+p*2)*globals.act_cube_size
    let highlight_plane = globals.channel_slice_highlight_plane

    highlight_plane.position.x = px + highlight_plane_width/2
    highlight_plane.position.z = py + highlight_plane_height/2
    highlight_plane.position.y = op.actgrid_group.position.y-.01 // slides btwn the main plane and the background plane

    highlight_plane.scale.x = highlight_plane_width
    highlight_plane.scale.y = highlight_plane_height
    
    make_plane_visible(highlight_plane)
}