import { globals } from './globals';
import * as utils from './utils'

///////////////////////////////
// Layout engine
///////////////////////////////

export default function recompute_layout() {
    console.time("compute layout")
    let nn = globals.nn
    
    /////////////////////
    // freeriding
    globals.setFeatureTooltipObject(null)
    globals.setTooltipObject(null)

    globals.is_loading = true

    ////////////////////////////////////////////////////
    // getting flat array of visible nodes

    function mark_not_visible(op) {
        op.is_currently_visible_node = false
        op.children.forEach(c => mark_not_visible(c))
    }
    mark_not_visible(nn)
    let draw_order = 0
    
    globals.ops_of_visible_planes = []
    globals.ops_of_visible_nodes = []
    function draw_op(op) {
        if (op.collapsed) { // Nodes
            op.draw_order_global = draw_order; draw_order += 1
            globals.ops_of_visible_nodes.push(op)
            op.is_currently_visible_node = true
        } else { // Planes
            globals.ops_of_visible_planes.push(op)
        
            // Op is expanded, draw children
            op.children.sort((a,b) => {return a.draw_order - b.draw_order})
            op.children.forEach(c => {
                draw_op(c)
            })
        }
    }
    nn.children.sort((a,b) => {return a.draw_order - b.draw_order})
    draw_op(nn)

    ////////////////
    // visible max depth
    let depths = globals.ops_of_visible_nodes.map(op => (isNaN(op.depth) ? 0 : op.depth))
    globals.max_depth_visible = Math.max(...depths)

    ////////////
    // visible max depth
    let n_params = globals.ops_of_visible_nodes.map(op => op.n_params).filter(op => op)

    n_params.sort((a,b)=>a-b)
    let n_params_at_upper_percentile = n_params[parseInt(n_params.length*.95)]
    globals.max_n_params_visible = n_params_at_upper_percentile //Math.max(...n_params)

    // not actually max, capping at 95th percentile so very large outliers don't destroy scale
    // TODO need to ensure scales are now always updating during all transitions
    ////////////////////////////////////////////////////


    ///////////////////////////////////////////
    // cache prev positions, node types for use in transitions

    function cache_prev(op) {
        // op.prev_pos = {x:op.x, y:op.y}
        op.prev_pos = {x:op.x, y:op.y_unshifted} // TODO NOTE on first pass, these will be unset

        if (op.is_tensor_node) {
            op.tensor_node_prev_display_type = op.tensor_node_display_type
        }
        op.children.forEach(c => cache_prev(c))
    }
    cache_prev(nn)

    ///////////////////////////////////////////

    // decide tensor node display types: node, volume, grid

    function is_actvol_candidate(op){
        if (op.is_global_input || op.is_output_global) {
            return true
        } else if (op.node_is_extraneous_io) {
            return false
        } else if (op.activations_available) {
            return true // gelu, relu etc
        } else if (op.node_type=="mod_out"){
            let dispatching_module_is_collapsed = globals.nodes_lookup[op.from_module_nid].collapsed
            return dispatching_module_is_collapsed // don't show when mod is expanded
        } else if (op.node_type=="fn_out"){
            if (utils.always_show_act_vol_fns.includes(op.created_by_fn)) {
                return true
            } else if (utils.show_act_vol_if_shape_changes.includes(op.created_by_fn)) {
                return true // TODO only return if shape changes. Will need to get actvol specs first
            } 
        }
        return false
    }
    function actvol_is_too_big(op) {
        let actvol_dims = op.enriched_tensor_specs.actvol_dims
        // let DEPTH_MAX = op.activations_available ? 42 : 16 // world units
        // // let DEPTH_MAX = 16
        // let too_big = actvol_dims.depth > DEPTH_MAX
        // return too_big
        if (op.activations_available) {
            // if acts available, show it, it may be capped at max depth or not
            return false
        } else if (actvol_dims.depth_is_overflowed){
            // don't show if no acts and is capped at max depth
            return true
        }
        return false // default ok. These are no-acts, but under max depth limit
    }
    globals.setActivationsShowing(false)
    function mark_tensor_node_display_type(op) {
        if (op.is_tensor_node) { 
            if (globals.SHOW_ACTIVATION_VOLUMES) {
                if (op.has_enriched_tensor_specs) { // has the specs to be a volume or grid
                    if (op.should_be_actgrid) { // actgrid takes precedence 
                        globals.setActivationsShowing(true)
                        op.tensor_node_display_type = "grid"
                    } else { // can be a volume
                        // if (is_actvol_candidate(op) && !actvol_is_too_big(op)) {
                        if (is_actvol_candidate(op) && !actvol_is_too_big(op)) {
                            op.tensor_node_display_type = "volume"
                        } else {
                            op.tensor_node_display_type = "node"
                        }
                    }
                } else { // no enriched tensor specs, can only be a node
                    op.tensor_node_display_type = "node"
                }
            } else { // all tensors collapsed
                op.tensor_node_display_type = "node"
            }
        } 

        op.children.forEach(c=>mark_tensor_node_display_type(c))
    }
    mark_tensor_node_display_type(nn)   


    /////////////////////////////////
    // don't show the same tensor as expanded more than once
    // demotes from grid or volume -> node
    // this will work across multiple levels, whereas below only works within a row in a module
    let grids_and_volumes = globals.ops_of_visible_nodes.filter(op=>["volume", "grid"].includes(op.tensor_node_display_type))
    // only showing one actgrid or volume per tensor so as to avoid many in a row
    grids_and_volumes.sort((a,b)=> b.depth - a.depth) // show innermost, ie closest to creator op. These are visible ops, so this won't hide anything
    let shown_tids = {}
    grids_and_volumes.forEach(op => {
      if (shown_tids[op.tensor_id]) {
        // already shown as expanded. demote to node
        op.tensor_node_display_type = "node"
      } else {
        // not yet shown, remain grid or volume
        shown_tids[op.tensor_id] = true
      }
    })

    ////////////////////////
    // demote actvols to nodes when they're repetitive. 
    // Switches tensor_node_display_type "volume" -> "node"
    // this will hide them even when different tensor ids
    function demote_excess_actvols(op) {
        if (!op.collapsed) {
            // compile rows lookup
            let rows = {}
            op.children.forEach(c => {
                if (!(c.draw_order_row in rows)) {
                    rows[c.draw_order_row] = {
                        "nodes":[],
                        "draw_order_row":c.draw_order_row,
                    }
                };
                rows[c.draw_order_row].nodes.push(c)
            })
            // Prune actvols, revert to node when repetitive
            for (let rid in rows) {
                let row = rows[rid]
                let row_actvols = row.nodes.filter(n => n.tensor_node_display_type==="volume")
                row_actvols.sort((a,b) => a.x_relative - b.x_relative)
                for (let i = 1; i<row_actvols.length; i++) {
                    let prev_actvol = row_actvols[i-1]
                    let this_actvol = row_actvols[i]
                    let s0 = prev_actvol.enriched_tensor_specs.actvol_dims // TODO fix this 
                    let s1 = this_actvol.enriched_tensor_specs.actvol_dims
                    let no_dims_change = (s0.width==s1.width) && (s0.height==s1.height) && (s0.depth==s1.depth) // TODO should prob be eg depth_n_acts
                    let pretty_close = (this_actvol.x_relative_original - prev_actvol.x_relative_original) < 6
                    if (no_dims_change && !prev_actvol.is_global_input && pretty_close && !prev_actvol.activations_available) {
                        prev_actvol.tensor_node_display_type = "node" // revert to node
                    }
                }
            }
            // 
            op.children.forEach(c => demote_excess_actvols(c))
        }
    }
    demote_excess_actvols(nn)

    /////////////////////////////
    // determine bounding dims of actvols and grids now that we've determined actvol types.
    // Not making them yet but layout engine needs this for spacing

    function mark_bbs_for_expanded_tensors(op) { // volumes and grids count as expanded tensors
        if (op.is_tensor_node) {
            let x_span, y_span_half
            if (op.tensor_node_display_type==="volume") {
                let actvol_dims = op.enriched_tensor_specs.actvol_dims
                y_span_half = actvol_dims.height*.5 + actvol_dims.width*.15
                x_span = actvol_dims.depth + actvol_dims.width*.15
    
                op.expanded_tensor_spans = {x_span, y_span_half}
                op.tensor_is_expanded = true
            } else if (op.tensor_node_display_type==="grid") {
    
                let numChannels = op.enriched_tensor_specs.depth_n_acts

                const gridWidth = Math.ceil(Math.sqrt(numChannels))
                const gridHeight = Math.ceil(numChannels / gridWidth);
                
                // NOTE we're having to calculate what will be the sizing of these, though we haven't made them yet.
                // This will have to stay up to date w the actual drawing fns. Awkward. 
                let channel_width = op.enriched_tensor_specs.width_n_acts
                let paddingSize = utils.get_padding_size_from_channel_width(channel_width)
                x_span = gridWidth * (channel_width+paddingSize) * globals.act_cube_size
                x_span += paddingSize*globals.act_cube_size // a single extra padding

                let channel_height = op.enriched_tensor_specs.height_n_acts
                // y_span_half = gridHeight * op.enriched_tensor_specs.height_n_acts * globals.act_cube_size + (gridHeight*.014)+(2*.06)//
                y_span_half = gridHeight * (channel_height+paddingSize) * globals.act_cube_size // caled 'half' bc it extends from midway to the bottom, while actvols also extend upwards
                y_span_half += paddingSize*globals.act_cube_size // a single extra padding

                op.expanded_tensor_spans = {x_span, y_span_half}
                op.tensor_is_expanded = true
            } else {
                op.tensor_is_expanded = false
            }
        }
        op.children.forEach(c => mark_bbs_for_expanded_tensors(c))
    }
    mark_bbs_for_expanded_tensors(nn)

    /////////////////////////////




    function reset_dims(op) {
        op.x_relative = op.x_relative_original
        // op.x_relative /= 100 
        // So can have less dist for unimportant ops. adds to compute_layout time by 12ms -> 30ms for sd1.4 i think maybe. but also maybe not.
        
        op.y_relative = op.y_relative_original

        op.history_js = []
        op.children.forEach(c => reset_dims(c))
    }
    reset_dims(nn)


    // Updating relative xy coords bc of expansions
    // just consider two levels at a time: an op and its children ops
    // the children ops x_relative and y_relative are defined in terms of their parent op's
    // frame of reference, but their h and w dims are actual values

    function update_op_h_w(op) {

        if (op.collapsed || op.children.length==0) {
            op.h = 0; op.w = 0
            return
        }

        op.children.forEach(c => update_op_h_w(c))
        // now each child has h w as a result of the expansion and arrangement of its children. 
        // Imagine all the boxes expanded, but not yet shifted around, so there is overlap. We now 
        // have to shift the boxes around to eliminate overlap. 
        // Relative coords haven't been updated based on the h w and expansions of the peer subops. 
        // That's what we're doing below.
        
        // relative X 


        // this is the same as is implemented in the fronted, first a normal pass for standard inputs which cling to the left,
        // then a costlier pass only for freshies, to allow them to not cling to left, to attach branch where required 
        // freshie pass is possible to use for all inputs, but is almost 100x more costly in terms of latency, i believe bc of all the 
        // frequent bringing forward of branches ie looping through all peers, though this hypths is not tested
        let P = 0 // 1e6 // don't think we actually need this to be so negative. Trying w zero and works...
        // TODO this needs attn, better way to fail when things don't work. Below we're setting x_relative to this value, so if
        // eg 1e6 that puts it way out can't see

        op.children.forEach(o => {
            o.x_nudge_traversed = false
            o.x_relative = -P
            o.y_relative = 0
            o.x_relative_fully_marked = false
        })

        let input_ops = op.children.filter(o => o.is_input); input_ops.sort((a,b)=>a.input_priority - b.input_priority)
        
        let input_ops_global = input_ops.filter(o => o.is_global_input)
        let input_ops_standard = input_ops.filter(o => !o.is_global_input)


        ///////////////////
        // old version. When no orig fns, applies to all except global inputs, none on first level bc those are all global
        function nudge_forward_dns(op_whose_dns_to_nudge) {
            op_whose_dns_to_nudge.x_nudge_traversed = true

            let dns = utils.get_downstream_peer_nodes(op_whose_dns_to_nudge)
            
            ///////////////////////////
            // if any dns are not same row, false. True only when continuing within a row.
            let is_same_row = true
            dns.forEach(dn => {
                let _is_same_row = (dn.draw_order_row===op_whose_dns_to_nudge.draw_order_row) && (dn.parent_op===op_whose_dns_to_nudge.parent_op)
                is_same_row = _is_same_row && is_same_row

                // if (dn.uns.length>1) is_same_row = false; // if dn has multiple incoming, one of them has to be different row
            })
            let x_amount = is_same_row ? 1 : 2

            // // 
            // let shorter_x_amount = true
            // let important_ops = ["conv2d", "linear", "max_pool2d", "cat", "mean", "interpolate", 
            //      "avg_pool2d", "adaptive_avg_pool2d", "adaptive_avg_pool1d", "matmul", "bmm"]
            // if (important_ops.includes(op_whose_dns_to_nudge.name) || 
            //     important_ops.includes(op_whose_dns_to_nudge.created_by_fn) ||
            //     op_whose_dns_to_nudge.node_type=="module") {
            //     shorter_x_amount = false
            // }
            // if (is_same_row && shorter_x_amount) {
            //     x_amount = .5
            // } 
            // NOTE this works now, but bc of occupancy pegged at ints, not seamless. 

            // keeping 2 everywhere else. 
            /////////////////////////

            dns.forEach(dn => {

                let x_threshold = op_whose_dns_to_nudge.x_relative + op_whose_dns_to_nudge.w
                if (dn.tensor_is_expanded) {
                    let actvol_x_span = dn.expanded_tensor_spans.x_span
                    x_threshold += Math.round(actvol_x_span) 
                    // occ uses array, so int for ix. Can undo the round when occ is more flexible TODO NOTE NOTE
                }
                if (dn.x_relative <= x_threshold) {
                    dn.x_relative = x_threshold + x_amount
                    dn.x_relative_fully_marked = true // needed for below, not when used in isolation
                    dn.history_js.push("JS x nudged forward by "+utils.nice_name(op_whose_dns_to_nudge)+" "+x_threshold+" "+dn.x_relative)
                    nudge_forward_dns(dn)
                } else if (!dn.x_nudge_traversed) { // NOTE densenet was giving bug here, taking forever, not recursed but so many retreading. This prevents from retreading ground unless necessary
                    nudge_forward_dns(dn)
                }
            })
        }
        input_ops_standard.forEach(o => nudge_forward_dns(o))
        ///////////////////

        ///////////////////
        // new version
        // marking forward then when hitting already marked, shifts entire branch up to meet it, clumps to the right
        // see backend for full comments, as this is same as there. Only applies to global inputs.
        function mark_next_x_pos(op_whose_dns_to_nudge) {
            let dns = utils.get_downstream_peer_nodes(op_whose_dns_to_nudge, op.children)

            ///////////////////////////
            // if any dns are not same row, false. True only when continuing within a row.
            // same as above. could consolidate
            let is_same_row = true
            dns.forEach(dn => {
                let _is_same_row = (dn.draw_order_row===op_whose_dns_to_nudge.draw_order_row) && (dn.parent_op===op_whose_dns_to_nudge.parent_op)
                is_same_row = _is_same_row && is_same_row

                // if (dn.uns.length>1) is_same_row = false; // if dn has multiple incoming, one of them has to be different row
            })
            let x_amount = is_same_row ? 1 : 2

            // keeping 2 everywhere else. 
            /////////////////////////

            dns.forEach(dn => {
                if (dn.x_relative_fully_marked) {
                    let to_shift = dn.x_relative - op_whose_dns_to_nudge.x_relative
                    to_shift -= x_amount
                    to_shift -= op_whose_dns_to_nudge.w
                    if (dn.tensor_is_expanded){
                        let actvol_x_span = op.expanded_tensor_spans.x_span
                        to_shift -= Math.round(actvol_x_span) 
                    }
                    to_shifts.push(to_shift)
                } else {
                    let x_threshold = op_whose_dns_to_nudge.x_relative + op_whose_dns_to_nudge.w
                    if (dn.tensor_is_expanded) {
                        let actvol_x_span = dn.expanded_tensor_spans.x_span
                        x_threshold += Math.round(actvol_x_span) 
                        // occ uses array, so int for ix. Can undo the round when occ is more flexible TODO NOTE NOTE
                    }
                    if (dn.x_relative <= x_threshold) {
                        dn.x_relative = x_threshold + x_amount
                        nodes_in_this_input_group.push(dn)
                        dn.history_js.push("JS x nudged forward by "+utils.nice_name(op_whose_dns_to_nudge)+" "+x_threshold+" "+dn.x_relative)
                        mark_next_x_pos(dn)
                    } else if (!dn.x_nudge_traversed) { // NOTE densenet was giving bug here, taking forever, not recursed but so many retreading. This prevents from retreading ground unless necessary
                        mark_next_x_pos(dn)
                    }
                }
            })
        }
        let to_shifts, nodes_in_this_input_group
        input_ops_global.forEach(o => {
            to_shifts = []; nodes_in_this_input_group = [o] // to be filled out in mark_next_x_pos
            // when stranded input global op, eg after pruning sd 1.5 we got this, single op, the gets pushed way far over 
            // then way stretched out. Will need to deal w this better
            o.x_relative = -P
            mark_next_x_pos(o)
            let to_shift = to_shifts.length==0 ? P : Math.min(...to_shifts) // on first pass through module will be no diffs, so align to zero 
            
            nodes_in_this_input_group.forEach(n => n.marked_for_this_round = false)
            nodes_in_this_input_group.forEach(n => {
                if (!n.marked_for_this_round) {
                    n.x_relative += to_shift
                    n.x_relative_fully_marked = true
                    n.marked_for_this_round = true // same node may be pushed in multiple times
                }
            })
        })
        ////////////////////////


        ////
        // shift all to start at zero
        let input_min_x = Math.min(...input_ops.map(o=>o.x_relative))
        op.children.forEach(c => {
            c.x_relative -= input_min_x
        })
        // make sure all standard inputs are on the left. Global inputs remain where they were
        input_ops.forEach(o => {
            if (!o.is_global_input) {
                o.x_relative = 0
            }
        })
        


        // Moving all output nodes to the right edge of expanded box
        // but only if they're not a module, ie if they're one of the single output ops we created ourselves
        let max_x = Math.max(...op.children.map(o => o.x_relative))
        let output_nodes = op.children.filter(o => o.is_output)
        output_nodes = output_nodes.filter(o => o.node_type=="output") // ie not modules. ie the nodes we created manually.
        output_nodes.forEach(o => o.x_relative = max_x)

        //////
        // Keep extension / branch nodes one behind their target downstream node, as that is their purpose
        op.children.filter(o=>["extension", "elbow"].includes(o.node_type)).forEach(branch_node => {
            if (!branch_node.pre_elbow) {
                let branch_dns = branch_node.dn_ids.map(nid => globals.nodes_lookup[nid])
                let min_x = Math.min(...branch_dns.map(o => o.x_relative))
                branch_node.x_relative = min_x - 2
                branch_node.history_js.push("moving extension to stay one less than dn")
            } else { // is pre elbow
                let nid = branch_node.uns[0] // will only be one
                let un = globals.nodes_lookup[nid] 
                branch_node.x_relative = un.x_relative + 2
                branch_node.history_js.push("moving pre-elbow to stay one more than un") 
            }
        }) 
        // TODO do this each time we nudge a node forward. It should bring it's preceding elbow / ext with it
        // so that other nodes can also respond




        ////////////////////////////////////////////////
        // Relative y

        // NOTE ensure these all use ints, for occ grid

        // compile rows lookup
        let rows = {}
        op.children.forEach(c => {
            if (!(c.draw_order_row in rows)) {
                rows[c.draw_order_row] = {
                    "nodes":[],
                    "draw_order_row":c.draw_order_row,
                }
            };
            rows[c.draw_order_row].nodes.push(c)
        })
        // calc information for each row
        for (let rid in rows) {
            let row = rows[rid]
            row.nodes.sort((a,b) => a.x_relative - b.x_relative)
            let first_op = row.nodes[0]

            // no, can't do back to uns and down to dns bc then get double counting, overlap eg in sd 1.4 midblock. We've been doing
            // trace to dn for awhile and i'm good w that, don't use this trace back to un then.
            let uns = utils.get_upstream_peer_nodes(first_op) // all the way from dispatching node, relevent in eg stylegan when it is input
            if (false){ //(uns.length==1 && uns[0].node_type=="input") { 
                // Total hack for now. Only affects stylegan as far as i know. TODO NOTE NOTE don't know if this affects non inputs. See notes below
                // let un_max_x = Math.max(...uns.map(un => un.x_relative)) // when will there be multiple? should this be min? NOTE NOTE pay attn
                // console.log(un_max_x, first_op.x_relative)
                uns.sort((a,b) => a.x_relative - b.x_relative) // when will there be multiple? should this be min? NOTE NOTE pay attn
                let max_un = uns[uns.length-1]
                row.starts_at_x = max_un.x_relative + max_un.w + 1 // one after the right edge of un, width is zero when collapsed
            } else {
                row.starts_at_x = first_op.x_relative
            }
            // row.starts_at_x = un_max_x + 1
            // // dunno about this. Doing it for case of stylegan where dispatching node is input, doesn't have fn out etc to move upwards
            // // so row was starting way far down, at the elbow, but want it to start earlier bc of occupancy
            // This is not quite right, bc we're also extending forward, so get double counting when extend both. Coat mini has this, leading
            // to extraneous up shift. Restricting to only un==input for now to affect only stylegan
            // will the move the elbows in JS fix this? we need each row to always start right after the dispatching node and end 
            // right before the terminating node. 

            // row.starts_at_x = first_op.x_relative


            let last_op = row.nodes[row.nodes.length-1]
            let dns = utils.get_downstream_peer_nodes(last_op) // all the way till terminating node, ie not just row itself
            let dn_max_x = Math.max(...dns.map(dn => dn.x_relative))
            last_op.is_last_in_row = true

            let until = last_op.x_relative + last_op.w
            row.ends_at_x = Math.max(until, dn_max_x-2) //NOTE this 2 

            row.y_relative = 0 //row.nodes[0].y_relative // all have same y_relative

            // 
            row.is_only_tensors = true
            row.nodes.forEach(n => {
                if (["function", "module"].includes(n.node_type)) { // can remove the elbow, ext check bc next run making those tensors
                    row.is_only_tensors = false
                }
                n.n_peer_row_nodes = row.nodes.length
            })
            row.nodes.forEach(n=>n.row=row)

            // 
            let expanded_tensor_nodes = row.nodes.filter(n => n.tensor_is_expanded)

            let row_actvol_hheights = expanded_tensor_nodes.map(n => n.expanded_tensor_spans.hheight) // TODO this needs to be better
            if (row_actvol_hheights.length>0){
                row.actvol_hheight = Math.max(...row_actvol_hheights)
            } else {
                row.actvol_hheight = 0
            }
            row.nodes.forEach(n=>n.row_actvol_hheight=row.actvol_hheight)

            //
            let important_ops = ["conv2d", "linear", "max_pool2d", "cat", "mean", "interpolate", 
                        "avg_pool2d", "adaptive_avg_pool2d", "adaptive_avg_pool1d", "matmul", "bmm"]
            let is_primary_row = false
            row.nodes.forEach(n => {
                if (n.n_params >0 || n.tensor_is_expanded || (important_ops.includes(n.name)) ||
                    (n.is_respath_row) // this should be if is respath row, some models have multiple    
                ) {
                    is_primary_row = true
                    return
                }
            })

            let base_pad = is_primary_row ? 1. : .2
            row.pad = base_pad //Math.max(row.actvol_hheight, base_pad)

            

        }
        // Set y_relative for row and all nodes in row. should only increment up
        function set_row_y(row, new_y_value) {
            row.nodes.forEach(o => o.y_relative = new_y_value)
            row.y_relative = new_y_value
        }

        ////////
        let row_queue = [] // not actually using this as a queue, ie not ever adding back to the end, just cycling through
        Object.keys(rows).forEach(rid => row_queue.push(rows[rid])) // rows dict in list form
        row_queue.sort((a,b)=>{
            return a.draw_order_row - b.draw_order_row
        })

        let occupancy = new Array(3000).fill(-1)
        function block_occupancy(from, until, value) { // NOTE must be int
            for (let i=from; i<=until; i++) { // includes 'until'. May be float bc eg act vol
                occupancy[i] = Math.max(occupancy[i], value) // why can't just block at y directly, when will this come in below?
            }
        }

        while (row_queue.length > 0) {
            

            // moving other rows up in response to this row
            // this row's y_relative has already been set
            let row = row_queue.shift() // take at ix 0 and shift rest one to the left
            
            //////////////////
            // block occupancy for row line

            block_occupancy(row.starts_at_x, row.ends_at_x, row.y_relative+row.pad)

            // Block occ for all expanded ops in the row
            row.nodes.forEach(o => {
                 if (o.tensor_is_expanded) { // actvol or actgrid
                    let actvol_spans = o.expanded_tensor_spans // TODO make this better
                    block_occupancy(Math.floor(o.x_relative-actvol_spans.x_span), o.x_relative, o.y_relative+actvol_spans.y_span_half) 
                    // BUG REPORT didn't have the o.y_relative as base, so wasn't correctly blocking occupancy
                } else if (!o.collapsed) { // expanded box within the row
                    let top = o.y_relative + o.h + 1. //.5 //.6 NOTE NOTE this hardcoding
                    let right = o.x_relative + o.w
                    block_occupancy(o.x_relative, right, top)
                }
            })

            ///////////////////
            // Shifting input rows up
            let queue_row_ids = row_queue.map(r => r.draw_order_row)

            row.nodes.forEach(o => {
                if (!o.collapsed) { // expanded box within the row. bring its input nodes up to its frame of reference
                    let c_sub_inputs = o.children.filter(cc => cc.is_input)
                    c_sub_inputs.forEach(cc => { // y_relative_grandparent is the subops y_relative value in the current frame of reference
                        cc.x_relative_grandparent = o.x_relative + cc.x_relative
                        cc.y_relative_grandparent = o.y_relative + cc.y_relative
                    })
                    c_sub_inputs.sort((a,b) => a.y_relative_grandparent - b.y_relative_grandparent) 
                    c_sub_inputs.forEach(input_node => {
                        let uns = utils.get_upstream_nodes_from_group(input_node, op.children) // the upstream op back in the current peer op group
                        if (uns.length==1) {

                            let id_of_incoming_row = uns[0].draw_order_row
                            let incoming_row = rows[id_of_incoming_row]

                            if (incoming_row.y_relative < input_node.y_relative_grandparent) { 
                                if (queue_row_ids.includes(id_of_incoming_row)) { // if the incoming row is not yet fixed
                                    // if the incoming row is below the target height of the input node of the expanded box
                                    set_row_y(incoming_row, input_node.y_relative_grandparent)
                                    incoming_row.has_been_moved_up_w_expanding_box = true
                                }
                            }
                        }
                    })
                }
            })
            // similar to above, but if is collapsed. Now w variable row heights these benefit in same way as do expanded boxes.
            // it's nice for things to align. Can also do for outgoing. Should consolidate this functionality as it's all very similar
            let first_node_in_row = row.nodes[0]
            // if (first_node_in_row.node_type==="module") { // can also do for non-modules
            if (true) { // can also do for non-modules
                let uns = utils.get_upstream_nodes_from_group(first_node_in_row, op.children)
                uns.forEach(un => {
                    if (un.is_last_in_row) {
                        let id_of_incoming_row = un.draw_order_row
                        let incoming_row = rows[id_of_incoming_row]
        
                        if (incoming_row.y_relative < first_node_in_row.y_relative) { 
                            if (queue_row_ids.includes(id_of_incoming_row)) { // if the incoming row is not yet fixed
                                // if the incoming row is below the target height of the input node of the expanded box
                                set_row_y(incoming_row, first_node_in_row.y_relative)
                                incoming_row.has_been_moved_up_w_expanding_box = true
                            }
                        } 
                    }
                })
            }
            
            // OUTPUTS
            row.nodes.forEach(o => { // TODO if we like this, refactor into one fn, only diff is get_uns and .is_output. otherwise identical to above
                if (!o.collapsed) { // expanded box within the row. bring its input nodes up to its frame of reference
                    let c_sub_inputs = o.children.filter(cc => cc.is_output)
                    c_sub_inputs.forEach(cc => { // y_relative_grandparent is the subops y_relative value in the current frame of reference
                        cc.x_relative_grandparent = o.x_relative + cc.x_relative
                        cc.y_relative_grandparent = o.y_relative + cc.y_relative
                    })
                    c_sub_inputs.sort((a,b) => a.y_relative_grandparent - b.y_relative_grandparent) 
                    c_sub_inputs.forEach(input_node => {
                        let uns = utils.get_downstream_nodes_from_group(input_node, op.children) // the upstream op back in the current peer op group
                        if (uns.length==1) {

                            let id_of_incoming_row = uns[0].draw_order_row
                            let incoming_row = rows[id_of_incoming_row]

                            if (incoming_row.y_relative < input_node.y_relative_grandparent) { 
                                if (queue_row_ids.includes(id_of_incoming_row)) { // if the incoming row is not yet fixed
                                    // if the incoming row is below the target height of the input node of the expanded box
                                    set_row_y(incoming_row, input_node.y_relative_grandparent)
                                    incoming_row.has_been_moved_up_w_expanding_box = true
                                }
                            }
                        }
                    })
                }
            })

            //////////////////////////////////
            // move remaining rows up to evade occupancy
            row_queue.forEach(_row => {
                // console.log(_row.nodes[0].node_id, _row.nodes[0].name, _row.starts_at_x, _row.ends_at_x)
                let occ = Math.max(...occupancy.slice(_row.starts_at_x, _row.ends_at_x+1))
                let new_row_y =  occ+_row.pad
                
                // TODO needs to account for h when eg expanded and tensors inside

                let actvols = _row.nodes.filter(n=>n.tensor_node_display_type==="volume") // TODO upgrade // NOTE this doesn't include grids bc those don't extend upwards like volumes do
                actvols.forEach(o => {

                    let actvol_spans = o.expanded_tensor_spans
                    let s = Math.floor(o.x_relative-actvol_spans.x_span)

                    let actvol_occ = Math.max(...occupancy.slice(s, o.x_relative+1))
                    
                    new_row_y = Math.max(new_row_y, actvol_occ+actvol_spans.y_span_half)
                })

                if (_row.has_been_moved_up_w_expanding_box && _row.y_relative>new_row_y) { 
                    // if moved up w expanding box and is higher than new value, let it stay. 
                } else {
                    set_row_y(_row, new_row_y)
                }
            })
    
        }

        ///////////////////////////
        if (!globals.DEBUG) {
            let input_nodes_can_be_removed = true
            op.children.forEach(o => {
                if (o.is_output) { 
                    
                    let uns = utils.get_upstream_peer_nodes(o)
                    if ((uns.length==1) && uns[0].node_is_extraneous_io) {
                        o.x_relative -= 2 //1.8 //.9
                        uns[0].x_relative -= 1 //.9
                    } else {
                        o.x_relative -= 1 //1.8 //.9
                    }
                } 
                else if (o.is_input) {

                    let dns = utils.get_downstream_peer_nodes(o) 

                    if (dns.length>1 || o.is_global_input) {
                        input_nodes_can_be_removed = false
                    } else if (dns.length == 1) {
                        if (dns[0].y_relative != o.y_relative) {
                            input_nodes_can_be_removed = false
                        }
                    }
                }
            })

            if (input_nodes_can_be_removed) {
                op.children.forEach(o => {
                    if (!o.is_input) {
                        o.x_relative -= 1 //.9
                    }
                })
            }
        }

        // Now that all children ops have their dims, and have been shifted bc of expansions, we can ascertain
        // the dimensions of the parent op
        op.w = Math.max(...op.children.map(c => c.x_relative+c.w))

        // op.h = Math.max(...op.children.map(c => c.y_relative+c.h))
        op.h = Math.max(...occupancy)

        // console.log("op hw", op.name, op.node_id, op.h, op.w)
    }
    update_op_h_w(nn)
    


    // Get absolute coords from nested relative coords
    function set_op_children_absolute_coords(op) {
        op.children.forEach(c => {
            c.x = op.x + c.x_relative
            c.y = op.y + c.y_relative
            set_op_children_absolute_coords(c)
        })
    }
    nn.x = 0; nn.y = 0
    set_op_children_absolute_coords(nn)

    // debugging
    function random_shift(op) {
        op.y_unshifted = op.y
        op.y += Math.random()*.02 //.001 // WTF this affects display, like need this or sometimes line doesn't display? happened after Line2. Only happens in some cases, after tweens
        // NOTE i actually like the look of higher random shift, but would need it to be more autocorrelated so no stepwise jumps. 
        // more random gives it a handrawn, organic look which i like, almost makes the longer horizontal lines easier to read?
        if (globals.DEBUG) { op.y += Math.random()*.2 }
        op.children.forEach(c => random_shift(c))
    }
    random_shift(nn)

    // Mark plane specs
    // let PLANE_BUFFER = {top:.05, bottom:.15, left:.05, right:.05}
    let PLANE_BUFFER = {top:.1, bottom:.15, left:.1, right:.1}
    function mark_plane_specs(op) {
        op.children.forEach(c => {
            if (c.collapsed){
                if (c.tensor_node_display_type==="grid") { 
                    // kindof don't like encompassing vols, throws of 3d feel. Also not sure for grids.
                    // though those i'm more amenable to
                    let {x_span, y_span_half} = c.expanded_tensor_spans
                    c.plane_info = {}
                    c.plane_info.min_x = c.x - PLANE_BUFFER.left
                    c.plane_info.max_x = c.x + PLANE_BUFFER.right
                    c.plane_info.min_y = c.y - PLANE_BUFFER.top //- y_span_half
                    c.plane_info.max_y = c.y + PLANE_BUFFER.bottom // + y_span_half
                } else {
                    c.plane_info = {}
                    c.plane_info.min_x = c.x - PLANE_BUFFER.left
                    c.plane_info.max_x = c.x + PLANE_BUFFER.right
                    c.plane_info.min_y = c.y - PLANE_BUFFER.top
                    c.plane_info.max_y = c.y + PLANE_BUFFER.bottom
                }

            } else {
                mark_plane_specs(c)
            }
        })
        op.plane_info = {}
        // x
        let children_min_xs = op.children.map(c => c.plane_info.min_x).filter(v=>isFinite(v)) // NOTE filter isFinite. volo had some nans or something, broken upstream but filtering here allows us to still see the nn which is helpful for debugging, and honestly i can't see what was wrong, maybe a stranded node somewhere?
        op.plane_info.min_x = Math.min(...children_min_xs) - PLANE_BUFFER.left

        let children_max_xs = op.children.map(c => c.plane_info.max_x).filter(v=>isFinite(v))
        op.plane_info.max_x = Math.max(...children_max_xs) + PLANE_BUFFER.right

        // y
        let children_min_ys = op.children.map(c => c.plane_info.min_y).filter(v=>isFinite(v))
        op.plane_info.min_y = Math.min(...children_min_ys) - PLANE_BUFFER.top

        let children_max_ys = op.children.map(c => c.plane_info.max_y).filter(v=>isFinite(v))
        op.plane_info.max_y = Math.max(...children_max_ys) + PLANE_BUFFER.bottom

    }
    mark_plane_specs(nn)





    /////////////////////////////
    // Set scene min max coords
    let xs = globals.ops_of_visible_nodes.map(op => op.x).filter(v => isFinite(v))
    let ys = globals.ops_of_visible_nodes.map(op => op.y).filter(v => isFinite(v)) // in case of bugs, this allows us to at least view nn
    globals.scene_bb = {
        x_min: Math.min(...xs), 
        x_max: Math.max(...xs), 
        y_min: Math.min(...ys),
        y_max: Math.max(...ys), 
    }
    globals.scene_bb.hheight = (globals.scene_bb.y_max - globals.scene_bb.y_min) / 2
    globals.scene_bb.hwidth = (globals.scene_bb.x_max - globals.scene_bb.x_min) / 2

    //////////////
    // a bit confusing bc we have another fn in label_utils that calculates if node is onscreen before the transition, 
    // while this one if after the transition. currently using only to decide when to draw actgrids.
    update_future_onscreen_status()



    console.timeEnd("compute layout")
}

// these will be the ops AFTER the update of their positions. similar to the one in label_utils
function update_future_onscreen_status() {
    let [h_width, h_height, cx, cz] = utils.get_main_window_position()
    let bh = 1; let bv = 1 // scaling to give buffer to count as 'on screen' to put labels in place before they scroll into view.
    let screen_left = cx-h_width*bh; let screen_right = cx+h_width*bh; let screen_top = cz+h_height*bv; let screen_bottom = cz-h_height*bv

    globals.ops_of_visible_nodes.forEach(op => {
      let will_be_onscreen = (op.x > screen_left) && (op.x < screen_right) && (op.y>screen_bottom) && (op.y<screen_top)
      op.will_be_onscreen = will_be_onscreen
      op.x_dist_from_cam_center = Math.abs(op.x - cx)
    })
}
