import React, { useEffect, useRef, useState } from 'react';

import * as utils from './utils';
import { globals } from './globals';
import { FormControl, FormControlLabel, Radio, RadioGroup, Typography } from '@mui/material';

const AttentionExplorer = ({ 
  attention_matrix_tids,
  inputImage,
  tensorTraceId
}) => {
  const CELL_SIZE = 64;
  const CONTROL_SIZE = 224;
  const HEADER_HEIGHT = 30;
  const ROW_LABEL_WIDTH = 40;
  const MAX_HEIGHT = 600;

  const [mode, setMode] = useState('query');
  const [hoveredId, setHoveredId] = useState(null);
  const [allLayersData, setAllLayersData] = useState(get_all_layers_data({ index: 0 }));
  const [pos, setPos] = useState(null); // to cache so can keep constant across Key<->Query toggles


  function get_all_layers_data(pos, mode) {
    return attention_matrix_tids.map(layer_tid => {
      let attn_matrix = globals.tensor_trace[layer_tid][0];
      return attn_matrix.map(head => (
        mode === 'query'
          ? head[pos.index].map(v => color_from_z_lookup[v])
          : head.map(row => row[pos.index]).map(v => color_from_z_lookup[v])
      ));
    });
  }

  // on inputImage change, update data. Need to wait for new trace to load, tensorTraceId
  useEffect(() => {
    if (pos) {
      console.log("changing image", pos, mode)
      setAllLayersData(get_all_layers_data(pos, mode));
    }
  }, [tensorTraceId]);

  const onQueryPosChange = (pos) => {
    setPos(pos)
    setAllLayersData(get_all_layers_data(pos, mode));
  };

  const handleModeChange = (event) => {
    setMode(event.target.value);
    if (pos) {
      console.log("changing mode and resetting data", mode, pos)
      setAllLayersData(get_all_layers_data(pos, event.target.value)); // have changed from Query to Key, so need to get data again
    }
  };

  return (
    <div style={{ display: 'flex', gap: '10px' }}>
      {/* Left side - fixed control image */}
      <div style={{
        display: 'flex',
        flexDirection: 'column',
        gap: '10px',
        flex: '0 0 auto'
      }}>
        {/* Title for Left Side */}
        <Typography variant="h6" sx={{ textAlign: 'center', fontWeight: 'bold' }}>
          {mode === 'query' ? 'Query' : 'Key'}
        </Typography>

        <div style={{ position: "relative" }}>
          <img src={`/data/trace_imgs/${inputImage}.png`}
              alt={"some alt text"}
              style={{ 
                width: `${CONTROL_SIZE}px`,
                height: `${CONTROL_SIZE}px`,
                pointerEvents: 'none', 
                zIndex: 100 
              }} 
          />
          <CanvasGridOverlay
            imageWidth={CONTROL_SIZE}
            imageHeight={CONTROL_SIZE}
            gridData={allLayersData[0][0]}
            onQueryPosChange={onQueryPosChange}
            isControlOverlay={true}
          />
        </div>
        
        {/* Simple Radio Button Toggle */}
        <FormControl component="fieldset">
          <RadioGroup
            aria-label="Activation Mode"
            name="activationMode"
            value={mode}
            onChange={handleModeChange}
            row
          >
            <FormControlLabel value="query" control={<Radio />} label="Query" />
            <FormControlLabel value="key" control={<Radio />} label="Key" />
          </RadioGroup>
        </FormControl>
      </div>

      {/* Vertical Divider */}
      <div style={{
        width: '1px',
        backgroundColor: '#ccc',
        margin: '0 5px'
      }} />

      {/* Right side - scrollable layer/head grid */}
      <div style={{ 
        flex: '1',
        maxHeight: `${MAX_HEIGHT}px`,
        display: 'flex',
        flexDirection: 'column'
      }}>
        {/* Title for Right Side */}
        <Typography variant="h6" sx={{ textAlign: 'center', fontWeight: 'bold', position: 'sticky', top: 0, backgroundColor: 'white', zIndex: 2 }}>
          {mode === 'query' ? 'Key' : 'Query'}
        </Typography>

        {/* Fixed Header */}
        <div style={{
          display: 'flex',
          marginLeft: `${ROW_LABEL_WIDTH}px`,
          height: `${HEADER_HEIGHT}px`,
          position: 'sticky',
          top: 30, // Adjusted to keep title visible
          backgroundColor: 'white',
          zIndex: 2
        }}>
          {allLayersData[0].map((_, head_ix) => (
            <div key={`head-title-${head_ix}`} style={{
              width: `${CELL_SIZE}px`,
              textAlign: 'center',
              fontSize: '14px',
              fontWeight: '500',
              padding: '0 5px',
              display: 'flex',
              alignItems: 'flex-end',
              justifyContent: 'center',
              marginBottom: '5px'
            }}>
              Head {head_ix}
            </div>
          ))}
        </div>

        {/* Scrollable Grid Container */}
        <div style={{
          overflowY: 'auto',
          overflowX: 'hidden'
        }}>
          {/* Grid of layers and heads */}
          <div style={{
            display: 'flex',
            flexDirection: 'column',
            gap: '10px'
          }}>
            {allLayersData.map((layer_data, layer_ix) => (
              <div key={`layer-${attention_matrix_tids[layer_ix]}`} style={{
                display: 'flex',
                alignItems: 'center',
                gap: '10px'
              }}>
                {/* Row Header (Layer label) - Adjusted Rotation */}
                <div style={{
                  width: `${ROW_LABEL_WIDTH}px`,
                  height: `${CELL_SIZE}px`,
                  fontSize: '14px',
                  fontWeight: '500',
                  display: 'flex',
                  alignItems: 'center',
                  justifyContent: 'center',
                  textAlign: 'center',
                  position: 'sticky',
                  left: 0,
                  backgroundColor: 'white',
                  zIndex: 1
                }}>
                  Layer {attention_matrix_tids[layer_ix].slice(-5)}
                </div>

                {/* Layer's heads */}
                {layer_data.map((head_data, head_ix) => (
                  <div 
                    key={`layer-${attention_matrix_tids[layer_ix]}-head-${head_ix}`} 
                    style={{ position: "relative",
                              backgroundColor: (layer_ix+'-'+head_ix)===hoveredId?'grey':'white',
                              outline: (layer_ix+'-'+head_ix)===hoveredId?'solid grey 3px':'none'
                           }}
                    onMouseOver={(event) => {
                      let channel_ix = head_ix
                      let tensor_id = attention_matrix_tids[layer_ix]
                      let featurespace = globals.featurespace[tensor_id]
                      setHoveredId(layer_ix+'-'+head_ix)
          
                      let top_5 = featurespace["top_5s"][channel_ix]
                      let bottom_5 = featurespace["bottom_5s"][channel_ix]
                      
                      
                      globals.setTooltipPosition({ left: event.clientX, top: event.clientY }); 
                      
                      let dataset = globals.nn.trace_metadata.dataset
                      let dataset_name = dataset=="imagenet"?"imagenet_val":dataset // dumb
                      let image_paths_pos = top_5.map(ix => {
                        return `/data/${dataset_name}/image_${ix}.png`
                      })
                      let image_paths_neg = bottom_5.map(ix => {
                        return `/data/${dataset_name}/image_${ix}.png`
                      })
                      let overlaysPath = `/data/featurespace_overlays/${globals.nn.trace_metadata.name}/${tensor_id}_${channel_ix}.json.gz`
                      globals.setFeatureTooltipObject({image_paths_pos, image_paths_neg, channel_ix, tensor_id, overlaysPath})

                      
                      console.log(`entering layer ${layer_ix}, Head: ${head_ix}`)

                    }}
                    onMouseLeave={() => { // TODO this is closing tooltips when mouse hits this
                      console.log(`leaving layer ${layer_ix}, Head: ${head_ix}`)
                      globals.setFeatureTooltipObject(null)
                      setHoveredId(null)
                    }}
                  >
                    <img src={`/data/trace_imgs/${inputImage}.png`}
                        alt={"some alt text"}
                        style={{ 
                          width: `${CELL_SIZE}px`,
                          height: `${CELL_SIZE}px`,
                          pointerEvents: 'none', 
                          zIndex: 100,
                          opacity: .6,
                          filter: 'saturate(0)'
                        }} 
                    />
                    <CanvasGridOverlay
                      imageWidth={CELL_SIZE}
                      imageHeight={CELL_SIZE}
                      gridData={head_data}
                      onQueryPosChange={null}
                      isControlOverlay={false}
                    />
                  </div>
                ))}
              </div>
            ))}
          </div>
        </div>
      </div>
    </div>
  );
};

// Precompile color lookup for performance
let color_from_z_lookup = {};
for (let i = -120; i <= 120; i++) {
  let v = i / 10;
  let rgba = utils.rgb_from_z(v);
  rgba.opacity = utils.interp(Math.abs(v), [.8, 3], [0, .8]);
  color_from_z_lookup[i] = rgba;
}


const CanvasGridOverlay = ({ 
  imageWidth, 
  imageHeight, 
  gridData,
  onCellDeselect,
  onQueryPosChange,
  isControlOverlay
}) => {
  const canvasRef = useRef(null);
  const [cellSize, setCellSize] = useState({ width: 0, height: 0 });
  const [hoveredCell, setHoveredCell] = useState(null);
  const [selectedCell, setSelectedCell] = useState(null);
  
  // Calculate grid dimensions from data
  const gridSize = Math.floor(Math.sqrt(gridData.length));
  const extraTokens = gridData.length - (gridSize * gridSize);
  const numRows = gridSize + (extraTokens > 0 ? 1 : 0);
  const numCols = gridSize;

  // Adjust canvas height for extra tokens
  const adjustedHeight = extraTokens > 0 ? 
    (imageHeight * numRows) / gridSize : 
    imageHeight;
  
  const drawGrid = (ctx) => {
    const cellWidth = imageWidth / numCols;
    const cellHeight = adjustedHeight / numRows;
    
    ctx.clearRect(0, 0, imageWidth, adjustedHeight);
    
    gridData.forEach((cellData, index) => {

      const col = index % numCols;
      const row = Math.floor(index / numCols);
      const x = col * cellWidth;
      const y = row * cellHeight;
      
      if (cellData) { // was sometimes null though couldn't catch when
        if (isControlOverlay) {
          ctx.fillStyle = `rgba(0, 0, 0, 0)`;
        } else {
          ctx.fillStyle = `rgba(${cellData.r}, ${cellData.g}, ${cellData.b}, ${cellData.opacity})`;
        }
      }

      
      ctx.fillRect(x, y, cellWidth, cellHeight);
      
      const isSelected = selectedCell && selectedCell.col === col && selectedCell.row === row;
      const isHovered = !selectedCell && hoveredCell && 
                       hoveredCell.row === row && hoveredCell.col === col;
      
      if (isControlOverlay) {
        if (isSelected) {
          ctx.fillStyle = 'rgba(0, 255, 0, 1)';
          ctx.fillRect(x, y, cellWidth, cellHeight);
        } else if (isHovered) {
          ctx.fillStyle = 'rgba(0, 255, 0, 0.7)';
          ctx.fillRect(x, y, cellWidth, cellHeight);
        }
      }

      
      ctx.strokeStyle = 'rgba(0, 0, 0, 0.2)';
      ctx.strokeRect(x, y, cellWidth, cellHeight);
    });
  };
  
  useEffect(() => {
    const cellWidth = imageWidth / numCols;
    const cellHeight = adjustedHeight / numRows;
    setCellSize({ width: cellWidth, height: cellHeight });
    
    const canvas = canvasRef.current;
    const ctx = canvas.getContext('2d');
    drawGrid(ctx);
  }, [imageWidth, adjustedHeight, numRows, numCols, gridData, hoveredCell, selectedCell]);

  // useEffect(() => {
  //   if (selectedCell) {
  //     setSelectedCell(null);
  //     if (onCellDeselect) onCellDeselect();
  //   }
  // }, [gridData]);
  // otherwise we don't maintain clicked cell state
  
  const getCellFromCoords = (x, y) => {
    const col = Math.floor(x / cellSize.width);
    const row = Math.floor(y / cellSize.height);
    return { row, col };
  };
  
  const handleCanvasClick = (e) => {
    const canvas = canvasRef.current;
    const rect = canvas.getBoundingClientRect();
    const x = e.clientX - rect.left;
    const y = e.clientY - rect.top;
    const { row, col } = getCellFromCoords(x, y);
    
    const index = row * numCols + col;
    if (index < gridData.length) {
      if (selectedCell && selectedCell.col === col && selectedCell.row === row) {
        setSelectedCell(null);
        if (onCellDeselect) onCellDeselect();
      } else {
        if (onQueryPosChange) {
          onQueryPosChange({index})
        }
        setSelectedCell({ row, col });
        console.log("setting selected cell", selectedCell)
      }
      
      console.log(`Clicked cell: (col: ${col}, row: ${row}, index: ${index})`);
      console.log('Cell data:', gridData[index]);
    }
  };
  
  const handleMouseMove = (e) => {
    if (!isControlOverlay) return;
    if (!selectedCell) {
      const canvas = canvasRef.current;
      const rect = canvas.getBoundingClientRect();
      const x = e.clientX - rect.left;
      const y = e.clientY - rect.top;
      const { row, col } = getCellFromCoords(x, y);
      
      const index = row * numCols + col;
      if (index < gridData.length) {
        if (!hoveredCell || hoveredCell.row !== row || hoveredCell.col !== col) {
          setHoveredCell({ row, col });
          if (onQueryPosChange) {
            onQueryPosChange({index})
          }
        }
      }
    }
  };
  
  const handleMouseLeave = () => {
    // setHoveredCell(null);
  };
  
  return (
    <canvas
      ref={canvasRef}
      width={imageWidth}
      height={adjustedHeight}
      onClick={handleCanvasClick}
      onMouseMove={handleMouseMove}
      onMouseLeave={handleMouseLeave}
      className="absolute top-0 left-0"
      style={{ 
        pointerEvents: 'auto',
        cursor: 'pointer',
        position: 'absolute',
        top: 0,
        left: 0,
        height: `${adjustedHeight}px`
      }}
    />
  );
};

export default AttentionExplorer;
