import React, { useEffect, useRef, useState } from 'react';
import { Card, CardContent, Typography, Box } from '@mui/material';
import pako from 'pako';
import * as utils from './utils';
import HistogramChart from './histogram'
import { Tabs, Tab } from '@mui/material';
import { FixedSizeGrid as Grid } from 'react-window';
import { X, Info } from 'lucide-react';
import FeatureDescriptionTextEditor from './FeatureDescriptionTextEditor';
import { globals } from './globals';

function color_lookup(v) {
  let {r, g, b} = utils.rgb_from_z(v/10);
  return `rgba(${r}, ${g}, ${b}, 1)`;
}

const img_dim = 80

// fns also used in tooltip
const drawActmap = (canvas, overlayData, index, tab_value, capture_config=null) => { 
  let tab_lookup = {
    0: "actmaps",
    1: "actmaps_mid",
    2: "actmaps_neg",
  }
  let overlayDataArray = overlayData[tab_lookup[tab_value]][index]
  let fn_type = overlayData['fn_type']
  // console.log("overlayData", overlayData)
    
  const ctx = canvas.getContext('2d');
  canvas.width = img_dim;

  if (fn_type === "conv2d") {
    canvas.height = img_dim;

    const cellWidth = canvas.width / overlayDataArray[0].length;
    const cellHeight = canvas.height / overlayDataArray.length;
    
    overlayDataArray.forEach((row, y) => {
      row.forEach((value, x) => {
        ctx.fillStyle = color_lookup(value);
        ctx.fillRect(x * cellWidth, y * cellHeight, cellWidth, cellHeight);
      });
    });
  } else if (fn_type === "linear") {
  
    if (typeof overlayDataArray === 'number') {
      overlayDataArray = [overlayDataArray];
    }

    if ((overlayDataArray[0] !== undefined) && overlayDataArray[0][0] !== undefined) {
      canvas.height = img_dim;
      const height = overlayDataArray.length;
      const width = overlayDataArray[0].length;
      const cellWidth = canvas.width / width;
      const cellHeight = canvas.height / height;
      
      overlayDataArray.forEach((row, y) => {
        row.forEach((cell, x) => {
          ctx.fillStyle = color_lookup(cell);
          ctx.fillRect(x * cellWidth, y * cellHeight, cellWidth, cellHeight);
        });
      });
    } else {
      const gridSize = Math.floor(Math.sqrt(overlayDataArray.length));
      const extraTokens = overlayDataArray.length - (gridSize * gridSize);
      const numRows = gridSize + (extraTokens > 0 ? 1 : 0);
      
      canvas.height = (img_dim * numRows) / gridSize;

      const cellWidth = canvas.width / gridSize;
      const cellHeight = canvas.height / numRows;

      let dataToDraw
      if (capture_config && capture_config.preceding_cls_tokens){ // move cls token to last position NOTE this is one of three places where we do this crap
        dataToDraw = overlayDataArray.slice(1) // NOTE TODO CLS token moving to last position
        dataToDraw.push(overlayDataArray[0])
      } else {
        dataToDraw = overlayDataArray
      }
      dataToDraw.forEach((value, i) => {
        const x = i % gridSize;
        const y = Math.floor(i / gridSize);
        ctx.fillStyle = color_lookup(value);
        ctx.fillRect(x * cellWidth, y * cellHeight, cellWidth, cellHeight);
      });
    }
  } 
};
function drawAttnMatrix(canvas, overlayData, index, tab_value) {
  
  const ctx = canvas.getContext('2d');
  canvas.width = img_dim;

  const q_pts = (tab_value===2) ? overlayData["q_pts_neg"][index] : overlayData["q_pts"][index];
  const k_pts = (tab_value===2) ? overlayData["k_pts_neg"][index] : overlayData["k_pts"][index];
  
  const gridSize = overlayData["grid_size"];
  const extraTokens = overlayData["extra_tokens"];
  const numRows = gridSize + (extraTokens > 0 ? 1 : 0);

  canvas.height = (img_dim * numRows) / gridSize;
  
  const cellWidth = canvas.width / gridSize;
  const cellHeight = canvas.height / numRows;
  const pointSize = Math.min(cellWidth, cellHeight) / 2;
  
  const adjustCoords = ([y, x]) => {
    const linearIndex = y * gridSize + x;
    return [
      Math.floor(linearIndex / gridSize),
      linearIndex % gridSize
    ];
  };

  ctx.strokeStyle = 'rgba(0, 0, 0, 0.3)';
  ctx.lineWidth = 1;
  
  q_pts.forEach((qPoint, i) => {
    const [qy, qx] = adjustCoords(qPoint);
    const [ky, kx] = adjustCoords(k_pts[i]);
    
    const startX = (kx + 0.5) * cellWidth;
    const startY = (ky + 0.5) * cellHeight;
    const endX = (qx + 0.5) * cellWidth;
    const endY = (qy + 0.5) * cellHeight;
    
    ctx.beginPath();
    ctx.moveTo(startX, startY);
    ctx.lineTo(endX, endY);
    ctx.stroke();
    
    const dx = endX - startX;
    const dy = endY - startY;
    const length = Math.sqrt(dx * dx + dy * dy);
    const unitX = dx / length;
    const unitY = dy / length;
    
    const chevronSpacing = 15;
    const numChevrons = Math.max(2, Math.floor(length / chevronSpacing));
    const actualSpacing = length / (numChevrons + 1);
    const chevronSize = 2.5;
    
    for (let j = 1; j <= numChevrons; j++) {
      const t = j / (numChevrons + 1);
      const x = startX + t * dx;
      const y = startY + t * dy;
      
      ctx.beginPath();
      ctx.moveTo(
        x - chevronSize * (unitX + unitY),
        y - chevronSize * (unitY - unitX)
      );
      ctx.lineTo(x, y);
      ctx.lineTo(
        x - chevronSize * (unitX - unitY),
        y - chevronSize * (unitY + unitX)
      );
      ctx.stroke();
    }
  });

  ctx.fillStyle = 'rgba(0, 255, 0, 0.7)';
  q_pts.forEach(point => {
    const [y, x] = adjustCoords(point);
    ctx.beginPath();
    ctx.arc(
      (x + 0.5) * cellWidth,
      (y + 0.5) * cellHeight,
      pointSize,
      0,
      2 * Math.PI
    );
    ctx.fill();
  });
  
  ctx.fillStyle = 'rgba(255, 0, 255, 0.7)';
  k_pts.forEach(point => {
    const [y, x] = adjustCoords(point);
    ctx.beginPath();
    ctx.arc(
      (x + 0.5) * cellWidth,
      (y + 0.5) * cellHeight,
      pointSize,
      0,
      2 * Math.PI
    );
    ctx.fill();
  });
}
export function drawOverlay(canvas, overlayData, index, tab_value, capture_config){
  const fn_type = overlayData["fn_type"];
  if (fn_type==="attnMatrix") {
    drawAttnMatrix(canvas, overlayData, index, tab_value);

  } else {
    drawActmap(canvas, overlayData, index, tab_value, capture_config);
  }
}

const FeatureSidebar = ({ 
                      featureTooltipObject, 
                      onCloseFeatureSidebar
                         }) => {
  const [overlayData, setOverlayData] = useState(null);

  const [tensorId, setTensorId] = useState(null);
  const [channelIx, setChannelIx] = useState(null);

  const [currentTabGrid, setCurrentTabGrid] = useState(null);


  const [tabIndex, setTabIndex] = React.useState(0);

  const gridRef = useRef(null); // used to grab ref and scroll to top on new feature

  const sidebar_width = 600

  let is_mean_centered = featureTooltipObject.capture_config.mean_center
  let isAttnHead = overlayData?.fn_type === "attnMatrix"
  let title = `${featureTooltipObject.node_id} — ${isAttnHead ? "attn head " : "channel "}${featureTooltipObject.channel_ix}`

  let info_tooltip_text = `Histogram shows raw activation values
  `
  if (!is_mean_centered) {
    info_tooltip_text += `
    Obs fire rate: The percentage of dataset examples that caused this channel to register at least one positive activation

    Overall fire rate: The percentage of all activations (across batch and spatial positions) that were positive.
    `
  }
  if (isAttnHead) {
    info_tooltip_text += `
    Dataset examples show max and min QK matches. E.g. when viewing max QK matches, this attn head will be moving information from the K locations into the Q locations.
    `
  } else {
    info_tooltip_text += `
    Dataset examples show max, min, and 95th percentile activations. A single obs may have both max and min activations.
    `
    if (!is_mean_centered) {
      info_tooltip_text += `
      Histogram and overlays are colored after scaling raw activation values by channel std. Coloration is not centered at the channel mean bc these values are or will be thresholded at zero.
      `
    } else {
      info_tooltip_text += `
      Histogram and overlays are colored after fully normalizing activation values (v-channel_mean)/channel_std. Coloration is with respect to the population, e.g. red means "more activated than other obs", though the raw value itself may be negative.
      `
    }
  }
  let n_obs = featureTooltipObject.capture_config.n_obs
  if (n_obs) n_obs = (n_obs/1000).toFixed(0)+"k";
  info_tooltip_text += `
  Analysis created using sample of ${n_obs} observations from ${globals.nn.trace_metadata.dataset} dataset.`




  const handleTabChange = (event, newTabValue) => {
    setTabIndex(newTabValue);
  };
  // console.log("remake") //TODO to remind us of how often being remade, which is too much currently

  // load overlays data. This is equivalent to on-load of new channel
  useEffect(() => {
    setTensorId(featureTooltipObject.tensor_id)
    setChannelIx(featureTooltipObject.channel_ix)
    setTabIndex(0) // always back to Max

    // Reset grid scroll position
    if (gridRef.current) {
      gridRef.current.scrollTo({ scrollTop: 0, scrollLeft: 0 });
    }

    fetch(featureTooltipObject.overlaysPath)
      .then(response => response.arrayBuffer())
      .then(arrayBuffer => {
        const uint8Array = new Uint8Array(arrayBuffer);
        const decompressed = pako.ungzip(uint8Array, { to: 'string' });
        const overlays = JSON.parse(decompressed);
        setOverlayData(overlays);
      });
  }, [featureTooltipObject.overlaysPath]);

  const percentiles = featureTooltipObject.percentiles // cdf

  let capture_config = featureTooltipObject.capture_config
  
  // kind of a hack, prob should get from data itself on the backend, but this is fast and close enough
  let fire_rate_estimate = (1 - utils.interp(0, featureTooltipObject.hist_bins, percentiles))*100 
  
  const getImagesAndOverlays = (imgPaths) => {
    console.log("get image overlays called from sidebar") //TODO this was to remind us how often this is being called

    const gridColumnCount = 6
  
    const Cell = ({ columnIndex, rowIndex, style }) => {
      const index = rowIndex * gridColumnCount + columnIndex;
      if (index >= imgPaths.length) return null;
  
      let src = imgPaths[index]
  
      return (
        <div style={{ ...style, padding: '4px' }}>

          <div style={{
            position: 'relative', 
            // height: '200px', 
            display: 'flex', 
            flexDirection:'column',
            // alignItems: 'top', 
            gap: '0px' 
          }}>
            <img src={src} alt={src} style={{ width: '80px', height: '80px' }} />
            <div style={{ position: 'relative', height: '80px', width: '80px' }}>
             <img
               src={src}
               alt={src}
               style={{
                 width: '80px',
                 height: '80px',
                 opacity: 0.25,
                 filter: 'saturate(0)',
                 position: 'absolute',
                 top: 0,
                 left: 0,
               }}
             />
              <canvas
                ref={(el) => {
                  if (el){ // why is this null after repeated re-renders? 
                    drawOverlay(el, overlayData, index, tabIndex, capture_config);
                  }

                }}
                style={{ 
                        pointerEvents: 'none' 
                      }}
              />
           </div>
          </div>

        </div>
      );
    };
    let grid = (
      <Grid
        ref={gridRef} // used to grab ref and scroll to top on new feature
        columnCount={gridColumnCount}
        rowCount={Math.ceil(imgPaths.length / gridColumnCount)}
        columnWidth={92}
        rowHeight={180}
        height={document.documentElement.clientHeight - 320} // clean these up. Top half will be constant, bottom tabs portion will take up the rest
        width={sidebar_width-20}
      >
        {Cell}
      </Grid>
    );
    return grid
  };

  // make grid
  useEffect(() => {
    if (!overlayData) return;
    let l = {0: featureTooltipObject.image_paths_pos,
              1: featureTooltipObject.image_paths_mid,
              2: featureTooltipObject.image_paths_neg}
    let grid = getImagesAndOverlays(l[tabIndex])
    setCurrentTabGrid(grid)
  }, [overlayData, tabIndex]);

  let has_mids = !isAttnHead && featureTooltipObject.image_paths_mid.length > 0
  let has_negs = featureTooltipObject.image_paths_neg.length > 0

  return (
    <div 
    className='featureSidebar' 
    style={{
      position: 'fixed',
      right: 0,
      top: 0,
      width: `${sidebar_width}px`,
      height: '100vh',
      display: 'flex',
      flexDirection: 'column',
      backgroundColor: 'white',
      boxShadow: '0 10px 15px -3px rgba(0, 0, 0, 0.1)'
    }}>

      {/* Top half of sidebar */}
      
      {/* Title */}
      <div style={{
        padding: '10px 14px',
        textAlign: 'left',
        // borderBottom: '1px solid #e5e7eb',
        position:'relative'
      }}>
        <button
          onClick={onCloseFeatureSidebar}
          style={{
            position: 'absolute',
            top: '0px',
            right: '2px',
            padding: '6px',
            borderRadius: '8px',
            cursor: 'pointer',
            backgroundColor: 'transparent',
            border: 'none'
          }}
          onMouseEnter={e => e.currentTarget.style.backgroundColor = '#f3f4f6'}
          onMouseLeave={e => e.currentTarget.style.backgroundColor = 'transparent'}
          title="Close feature sidebar"
        >
          <X size={20} />
        </button>
        <h1 style={{ fontSize: '18px', margin: 0 }}>
          {/* using node_id here bc it's shorter, though tensor_id would be correct */}
          {title}
        </h1>
      </div>

      <div style={{
        margin: '12px',
        border: '2px solid #e5e7eb',
        borderRadius: '4px'
      }}>
        <div style={{
          border: '1px solid #e5e7eb',
          borderRadius: '4px'
        }}>

          <div style={{ display: 'flex', position:'relative', height: '176px' }}>
            {/* Left side container */}
            <div style={{ flex: 1, border: '1px solid #e5e7eb' }}>
              {/* Description */}
              <div style={{ 
                padding: '12px',
                height: '110px',
              }}>
                <FeatureDescriptionTextEditor
                  tensorId={tensorId}
                  channelIx={channelIx}
                >
                </FeatureDescriptionTextEditor>
              </div>

              {/* Stats panel */}
              {!is_mean_centered && (<div style={{ borderTop: '1px solid #e5e7eb' }}>
                <div style={{ display: 'flex' }}>
                  {/* First stat */}
                  <div style={{ 
                    flex: 1,
                    textAlign: 'center',
                    padding: '4px',
                    borderRight: '1px solid #e5e7eb',
                  }}>
                    <div style={{ 
                      fontSize: '12px',
                      color: '#6b7280'
                    }}>Obs fire rate</div>
                    <div style={{ 
                      fontSize: '18px',
                      fontWeight: 'bold'
                    }}>{(featureTooltipObject.fired_rate * 100).toFixed(1)}%</div>
                  </div>
                  {/* Second stat */}
                  <div style={{ 
                    flex: 1,
                    textAlign: 'center',
                    padding: '4px',
                    borderRight: '1px solid #e5e7eb',
                  }}>
                    <div style={{ 
                      fontSize: '12px',
                      color: '#6b7280'
                    }}>Overall fire rate</div>
                    <div style={{ 
                      fontSize: '18px',
                      fontWeight: 'bold'
                    }}>{fire_rate_estimate.toFixed(2)}%</div>
                  </div>
                </div>
              </div>)}
            </div>

            {/* Histogram */}
            <div style={{
              width: '300px',
              // height: '100px',
              // padding: '12px',
              // borderLeft: '1px solid #e5e7eb',
              position: 'relative'
            }}>
              <p style={{
                position: 'absolute',
                top: '0px',
                right: '12px',
                width: '80px',
                fontSize: '12px'
              }}>
              </p>
              <HistogramChart 
                percentile_values={percentiles}
                hist_bins={featureTooltipObject.hist_bins}
                channel_mean={featureTooltipObject.channel_mean}
                channel_std={featureTooltipObject.channel_std}
                isAttnHead={isAttnHead}
                capture_config={capture_config}
              />
            </div>
            <div style={{
                position:'absolute',
                top:'0px',
                right:'2px',
              }}>
                {InfoTooltip(info_tooltip_text)}
            </div>
          </div>
        </div>
      </div>
      
      {/* Tabs Section */}
      {overlayData && (
        <div>
          <Box sx={{ borderBottom: 1, borderColor: 'divider', position:'relative' }}>
            <Tabs 
                value={tabIndex} 
                onChange={handleTabChange} 
                aria-label="Feature tabs"
                sx={{
                  '& .MuiTab-root.Mui-selected': {
                    color: 'black'
                  },
                  '& .MuiTabs-indicator': {
                    backgroundColor: 'black'
                  }
                }}>
              <Tab label="Max" />
              <Tab label="92% - 98%" sx={{ display: has_mids ? 'default' : 'none' }} />
              <Tab label="Min" sx={{ display: has_negs ? 'default' : 'none' }} />
            </Tabs>
            {/* Legend for attnMatrix */}
              {isAttnHead && (
                <div style={{
                  display: 'flex',
                  justifyContent: 'center',
                  gap: '16px',
                  marginBottom: '8px',
                  padding: '8px',
                  position: 'absolute',
                  top: '12px',
                  right: '14px'
                }}>
                  <div style={{ display: 'flex', alignItems: 'center', gap: '4px' }}>
                    <div style={{ 
                      width: '12px', 
                      height: '12px', 
                      backgroundColor: 'rgba(0, 255, 0, 0.7)', 
                      borderRadius: '50%' 
                    }}/>
                    <span style={{ fontSize: '12px' }}>Query (Q)</span>
                  </div>
                  <div style={{ display: 'flex', alignItems: 'center', gap: '4px' }}>
                    <div style={{ 
                      width: '12px', 
                      height: '12px', 
                      backgroundColor: 'rgba(255, 0, 255, 0.7)', 
                      borderRadius: '50%' 
                    }}/>
                    <span style={{ fontSize: '12px' }}>Key (K)</span>
                  </div>
                </div>
              )}
          </Box>

          {/* Tab Panels */}
          <Box sx={{ flex: 1, padding: 2}}>
            {currentTabGrid}
          </Box>
        </div>
      )}

    </div>
  );
};


function InfoTooltip(text) {
  const [tooltipPosition, setTooltipPosition] = useState({ horizontal: 'center', vertical: 'top' });
  const tooltipRef = useRef(null);
  
  const handleMouseEnter = (e) => {
    const icon = e.currentTarget;
    const tooltip = icon.parentElement.querySelector('.tooltip');
    const iconRect = icon.getBoundingClientRect();
    const tooltipRect = tooltip.getBoundingClientRect();
    
    // Check horizontal overflow
    const rightOverflow = iconRect.left + (tooltipRect.width / 2) > window.innerWidth;
    const leftOverflow = iconRect.left - (tooltipRect.width / 2) < 0;
    
    // Check vertical overflow
    const topOverflow = iconRect.top - tooltipRect.height - 8 < 0;
    
    let newPosition = {
      horizontal: 'center',
      vertical: 'top'
    };
    
    if (rightOverflow) {
      newPosition.horizontal = 'right';
    } else if (leftOverflow) {
      newPosition.horizontal = 'left';
    }
    
    if (topOverflow) {
      newPosition.vertical = 'bottom';
    }
    
    setTooltipPosition(newPosition);
    icon.style.opacity = '0.6';
    tooltip.style.visibility = 'visible';
  };

  const getTooltipStyles = () => {
    const baseStyles = {
      visibility: 'hidden',
      position: 'absolute',
      backgroundColor: 'rgba(0, 0, 0, 0.8)',
      color: 'white',
      padding: '4px 8px',
      borderRadius: '4px',
      fontSize: '12px',
      width: '180px',
      whiteSpace: 'normal',
      zIndex: 1000,
      textAlign: 'left',
      whiteSpace: 'pre-line'
    };

    // Horizontal positioning
    if (tooltipPosition.horizontal === 'center') {
      baseStyles.left = '50%';
      baseStyles.transform = 'translateX(-50%)';
    } else if (tooltipPosition.horizontal === 'right') {
      baseStyles.right = '0';
      baseStyles.transform = 'none';
    } else {
      baseStyles.left = '0';
      baseStyles.transform = 'none';
    }

    // Vertical positioning
    if (tooltipPosition.vertical === 'top') {
      baseStyles.bottom = '100%';
      baseStyles.marginBottom = '4px';
    } else {
      baseStyles.top = '100%';
      baseStyles.marginTop = '4px';
    }

    return baseStyles;
  };

  return (
    <span>
      <Info
        size={16}
        onMouseEnter={handleMouseEnter}
        onMouseLeave={(e) => {
          e.currentTarget.style.opacity = '1';
          e.currentTarget.parentElement.querySelector('.tooltip').style.visibility = 'hidden';
        }}
      />
      <div 
        ref={tooltipRef}
        className="tooltip"
        style={getTooltipStyles()}
      >
        {text}
      </div>
    </span>
  );
}


export default FeatureSidebar;