import { shaderMaterial } from '@react-three/drei';
import { extend } from '@react-three/fiber';
import {
  Color,
  ColorRepresentation,
  CustomBlending,
  DataTexture,
  OneFactor,
  Texture,
  UniformsLib,
  UniformsUtils,
  Vector2,
  Vector2Tuple,
} from 'three';

declare global {
  namespace JSX {
    interface IntrinsicElements {
      gaussianSplatPersonMaterial: JSX.IntrinsicElements['shaderMaterial'] & GaussianSplatPersonMaterialType;
    }
  }
}

export type GaussianSplatPersonMaterialType = {
  transition?: number;
  time?: number;
  scanBounds?: Vector2Tuple;
  noiseMap?: Texture;
  glowColor?: ColorRepresentation;
  glowIntensity?: number;
  glowLength?: number;
};

export const GaussianSplatPersonMaterial = shaderMaterial(
  {
    viewport: new Vector2(1920, 1080), // Dummy. will be overwritten
    focal: 1000.0,
    centerAndScaleTexture: new DataTexture(),
    covAndColorTexture: new DataTexture(),

    transition: 0.0,
    time: 0.0,
    scanBounds: new Vector2(-10.0, 40),
    noiseMap: new Texture(),
    glowColor: new Color(1, 0.5, 0),
    glowIntensity: 2.5,
    glowLength: 0.5,

    ...UniformsUtils.clone(UniformsLib.fog),
  },
  /*glsl*/ `
    precision highp sampler2D;
    precision highp usampler2D;

    varying vec4 vColor;
    varying vec3 vPosition;
    varying vec3 vCenter;
    varying float vScale;
    varying float vGlow;

    uniform vec2 viewport;
    uniform float focal;

    uniform float transition;
    uniform vec2 scanBounds;
    uniform sampler2D noiseMap;
    uniform float time;
    uniform float glowLength;

    attribute uint splatIndex;
    uniform sampler2D centerAndScaleTexture;
    uniform usampler2D covAndColorTexture;

    #include <fog_pars_vertex>

    vec2 unpackInt16(in uint value) {
      int v = int(value);
      int v0 = v >> 16;
      int v1 = (v & 0xFFFF);
      if((v & 0x8000) != 0)
        v1 |= 0xFFFF0000;
      return vec2(float(v1), float(v0));
    }

    float dot2(in vec2 v ) { return dot(v,v); }
    float sdTrapezoid( in vec2 p, in float r1, float r2, float he )
    {
        vec2 k1 = vec2(r2,he);
        vec2 k2 = vec2(r2-r1,2.0*he);
        p.x = abs(p.x);
        vec2 ca = vec2(p.x-min(p.x,(p.y<0.0)?r1:r2), abs(p.y)-he);
        vec2 cb = p - k1 + k2*clamp( dot(k1-p,k2)/dot2(k2), 0.0, 1.0 );
        float s = (cb.x<0.0 && ca.y<0.0) ? -1.0 : 1.0;
        return s*sqrt( min(dot2(ca),dot2(cb)) );
    }

    float map(float value, float min1, float max1, float min2, float max2) {
      return min2 + (value - min1) * (max2 - min2) / (max1 - min1);
    }

    float cubicSmoothstep( float x ){
      return x*x*(3.0-2.0*x);
    }

    void main () {
      ivec2 texSize = textureSize(centerAndScaleTexture, 0);
      ivec2 texPos = ivec2(splatIndex%uint(texSize.x), splatIndex/uint(texSize.x));
      vec4 centerAndScaleData = texelFetch(centerAndScaleTexture, texPos, 0);
      vec3 center = centerAndScaleData.xyz;

      // stuff
      vec4 worldPos = modelMatrix * vec4(center, 1.);
      vec4 zeroWorldPos = modelMatrix * vec4(0.0, 0.0, 0.0, 1.);

      float justPersonScale = smoothstep(0.6, .7, 1.0 - distance(center.xz, vec2(0.0)));
      justPersonScale *= step(0.0001, center.y);

      float distanceToZero = clamp(1.0 - distance(zeroWorldPos.x, 0.0), 0.0, 1.0);
      float mappedZ = map(center.z, scanBounds.x, scanBounds.y, 0.0, 1.0);

      float fadeTransition = smoothstep(
        0.9 - transition,
        1.0 - transition,
        mappedZ
      );
      fadeTransition *= smoothstep(0.5, 1.0, distanceToZero);

      vGlow = smoothstep(
        0.9 - transition,
        0.95 - transition,
        mappedZ
      ) - smoothstep(
        0.9 - transition + (glowLength * 0.1),
        0.95 - transition + (glowLength * 0.1),
        mappedZ
      );
      vGlow *= (1.0 - justPersonScale);

      vScale = max(
        justPersonScale,
        fadeTransition
      );
      vScale = clamp(vScale, 0.0, 1.0);

      // center.y += vGlow * (1.0 - justPersonScale) * 0.1;
      center.x += vGlow * 0.3 * (texture2D(noiseMap, vec2(time * 0.1) + center.xz * 0.1).x * 2.0 - 1.0);
      center.y += vGlow * 0.3 * (texture2D(noiseMap, vec2(time * 0.1) + center.zx * 0.2).x * 2.0 - 1.0);
      center.z += vGlow * 0.3 * (texture2D(noiseMap, vec2(time * 0.1) + center.xz * 0.1).x * 2.0 - 1.0);
      vCenter = center;

      float scale = centerAndScaleData.w;
      vec4 camspace = modelViewMatrix * vec4(center, 1.);

      vec4 pos2d = projectionMatrix * camspace;


      scale *= vScale;
      #ifdef USE_ALPHAHASH
        scale *= 2.0;
      #endif

      float bounds = 1.2 * pos2d.w;
      if (pos2d.z < -pos2d.w || pos2d.x < -bounds || pos2d.x > bounds
        || pos2d.y < -bounds || pos2d.y > bounds) {
        gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
        return;
      }

      uvec4 covAndColorData = texelFetch(covAndColorTexture, texPos, 0);
      vec2 cov3D_M11_M12 = unpackInt16(covAndColorData.x) * scale;
      vec2 cov3D_M13_M22 = unpackInt16(covAndColorData.y) * scale;
      vec2 cov3D_M23_M33 = unpackInt16(covAndColorData.z) * scale;
      mat3 Vrk = mat3(
        cov3D_M11_M12.x, cov3D_M11_M12.y, cov3D_M13_M22.x,
        cov3D_M11_M12.y, cov3D_M13_M22.y, cov3D_M23_M33.x,
        cov3D_M13_M22.x, cov3D_M23_M33.x, cov3D_M23_M33.y
      );

      mat3 J = mat3(
        focal / camspace.z, 0., -(focal * camspace.x) / (camspace.z * camspace.z),
        0., focal / camspace.z, -(focal * camspace.y) / (camspace.z * camspace.z),
        0., 0., 0.
      );

      mat3 W = transpose(mat3(modelViewMatrix));
      mat3 T = W * J;
      mat3 cov = transpose(T) * Vrk * T;

      vec2 screenCenter = vec2(pos2d) / pos2d.w;

      float diagonal1 = cov[0][0] + 0.3;
      float offDiagonal = cov[0][1];
      float diagonal2 = cov[1][1] + 0.3;

      float mid = 0.5 * (diagonal1 + diagonal2);
      float radius = length(vec2((diagonal1 - diagonal2) / 2.0, offDiagonal));
      float lambda1 = mid + radius;
      float lambda2 = max(mid - radius, 0.1);
      vec2 diagonalVector = normalize(vec2(offDiagonal, lambda1 - diagonal1));
      vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector;
      vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x);

      uint colorUint = covAndColorData.w;
      vColor = vec4(
        float(colorUint & uint(0xFF)) / 255.0,
        float((colorUint >> uint(8)) & uint(0xFF)) / 255.0,
        float((colorUint >> uint(16)) & uint(0xFF)) / 255.0,
        float(colorUint >> uint(24)) / 255.0
      );

      vec3 transformed = position;
      vPosition = transformed.xyz;


      gl_Position = vec4(
        screenCenter
          + transformed.x * v2 / viewport * 2.0
          + transformed.y * v1 / viewport * 2.0, pos2d.z / pos2d.w, 1.0);

      #ifdef USE_FOG
        vFogDepth = -(modelViewMatrix * vec4(center, 1.0)).z;
      #endif
    }
    `,
  /*glsl*/ `
    varying vec4 vColor;
    varying vec3 vPosition;
    varying vec3 vCenter;
    varying float vScale;
    varying float vGlow;

    uniform float time;
    uniform vec3 glowColor;
    uniform float glowIntensity;

    #include <alphahash_pars_fragment>
    #include <fog_pars_fragment>

    void main () {
      float A = -dot(vPosition.xy, vPosition.xy);

      if (A < -4.0) discard;

      #ifdef USE_ALPHAHASH
        float B = exp(A) * (.5 + vColor.a * .5);
        B *= vScale;
        if ( B < getAlphaHashThreshold( vPosition + fract(time * 0.02) ) ) discard;
      #else
        float B = exp(A) * vColor.a;
        B *= vScale;
      #endif

      gl_FragColor = vec4(vColor.rgb, B);
      gl_FragColor.rgb = mix(
        gl_FragColor.rgb,
        gl_FragColor.rgb * glowColor * glowIntensity,
        vGlow
      );

      #include <tonemapping_fragment>
      #include <colorspace_fragment>
      #include <fog_fragment>
      #include <premultiplied_alpha_fragment>
      #include <dithering_fragment>
    }
  `,
  (material) => {
    if (!material) return;
    material.extensions.derivatives = true;
    material.blending = CustomBlending;
    material.blendSrcAlpha = OneFactor;
    material.depthTest = true;
    material.depthWrite = false;
    material.transparent = true;
  }
);

extend({ GaussianSplatPersonMaterial });
